In [None]:
!pip install -q torch torchvision timm mlflow pyngrok pytorch-lightning==2.4.0
!pip install -q "huggingface-hub<1.0" "transformers<4.44" torchmetrics==1.4.0 scikit-learn numpy matplotlib seaborn

import os, shutil
if os.path.exists("/root/.cache/huggingface/hub"):
    shutil.rmtree("/root/.cache/huggingface/hub")

import os
import shutil
from pathlib import Path


from PIL import Image


import torch
from torch.utils.data import DataLoader


import pytorch_lightning as pl


from torchvision import datasets, transforms
import timm


import mlflow
import mlflow.pytorch

from pyngrok import ngrok

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

NGROK_AUTH_TOKEN = "36CLviJ3eedsGnzgPHm7MM9g4De_4tDSxH3CgGEQCytxcXKgX"


MLRUNS_DIR = "/content/drive/MyDrive/BongoDev/Project-1/mlruns"

DATA_ZIP = "/content/drive/MyDrive/BongoDev/Project-1/archive.zip"

EXTRACT_DIR = "/content/drive/MyDrive/BongoDev/Project-1/RiceLeafsDisease"

IMG_SIZE = 256

BATCH_SIZE = 16

MAX_EPOCHS = 2

In [None]:
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class PlantDataModule(pl.LightningDataModule):

    def __init__(self, data_dir, batch_size=32, img_size=256):

        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size

        self.train_tfms = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=30),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.8, 1.2),
                shear=10
            ),
            transforms.ColorJitter(
                brightness=0.3,
                contrast=0.3,
                saturation=0.3,
                hue=0.1
            ),

            transforms.ToTensor()
        ])

        self.val_tfms = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()
        ])


    def setup(self, stage=None):

        train_dir = os.path.join(self.data_dir, "train")
        val_dir = os.path.join(self.data_dir, "validation")
        test_dir = os.path.join(self.data_dir, "validation")


        if not os.path.exists(train_dir):
            raise FileNotFoundError(f"Expected train folder at {train_dir}")


        self.train_ds = datasets.ImageFolder(train_dir, transform=self.train_tfms)
        self.val_ds = datasets.ImageFolder(val_dir, transform=self.val_tfms) if os.path.exists(val_dir) else None
        self.test_ds = datasets.ImageFolder(test_dir, transform=self.val_tfms) if os.path.exists(test_dir) else None


    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2
        )


    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=2
        ) if self.val_ds else None


    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=2
        ) if self.test_ds else None

In [None]:
import torch.nn as nn
import pytorch_lightning as pl
import torch
import timm
from torchmetrics.classification import Accuracy, F1Score


class PlantDiseaseModel(pl.LightningModule):
    def __init__(self,num_classes=6,lr=1e-4,weight_decay=1e-5,freeze_backbones=True):
        super().__init__()
        self.save_hyperparameters()

        self.swin = timm.create_model("swinv2_tiny_window8_256.ms_in1k",pretrained=True,num_classes=0)

        self.convnext = timm.create_model("convnext_tiny",pretrained=True,num_classes=0)


        self.train_losses = []
        self.train_accs   = []
        self.val_losses   = []
        self.val_accs     = []


        self.train_loss_epoch = []
        self.train_acc_epoch  = []
        self.val_loss_epoch   = []
        self.val_acc_epoch    = []


        # Freeze backbones (initial)
        if freeze_backbones:
            self.freeze_all_backbones()

        # Learnable attention weights (model importance)
        self.attn_weights = nn.Parameter(torch.ones(2))


        # Fusion Head
        fusion_dim = self.swin.num_features + self.convnext.num_features

        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, num_classes)
        )


        # Loss & Metrics
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc   = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_f1    = F1Score(task="multiclass", num_classes=num_classes)


    # Freeze / Unfreeze utilities
    def freeze_all_backbones(self):
        for m in [self.swin, self.convnext]:
            for p in m.parameters():
                p.requires_grad = False

    def unfreeze_last_blocks(self):
        # SwinV2: last stage
        for p in self.swin.layers[-1].parameters():
            p.requires_grad = True

        # ConvNeXt: last block
        for p in self.convnext.stages[-1].parameters():
            p.requires_grad = True

    # Lightning hook → AUTO CALLED
    def on_train_epoch_start(self):
        if self.current_epoch == 10:
            print("Unfreezing last backbone blocks")
            self.unfreeze_last_blocks()


    # Global Pool
    def global_pool(self, x):
        return x.mean(dim=[2, 3]) if x.dim() == 4 else x

    # Forward
    def forward(self, x):
        swin_feat = self.global_pool(self.swin(x))
        conv_feat = self.global_pool(self.convnext(x))

        # Attention-based weighting
        w = torch.softmax(self.attn_weights, dim=0)
        swin_feat = swin_feat * w[0]
        conv_feat = conv_feat * w[1]

        fused = torch.cat([swin_feat, conv_feat], dim=1)
        return self.fusion(fused)


    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Accuracy
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()

        # Logging (per step / batch)
        self.log("train_loss", loss, prog_bar=True, sync_dist=True)
        self.log("train_acc",  acc,  prog_bar=True, sync_dist=True)

        # Collect for plotting
        self.train_losses.append(loss.item())
        self.train_accs.append(acc.item())

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()

        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_acc",  acc,  prog_bar=True, sync_dist=True)

        self.val_losses.append(loss.item())
        self.val_accs.append(acc.item())

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc",  acc,  prog_bar=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        logits = self(x)
        return logits.softmax(dim=1)


    def on_train_epoch_end(self):
        if len(self.train_losses) > 0:
            avg_loss = sum(self.train_losses) / len(self.train_losses)
            avg_acc  = sum(self.train_accs)  / len(self.train_accs)

            self.train_loss_epoch.append(avg_loss)
            self.train_acc_epoch.append(avg_acc)

            # Optional: print or log epoch average
            self.log("train_loss_epoch", avg_loss, prog_bar=False)
            self.log("train_acc_epoch",  avg_acc,  prog_bar=False)



    def on_validation_epoch_end(self):
        if len(self.val_losses) > 0:
            avg_loss = sum(self.val_losses) / len(self.val_losses)
            avg_acc  = sum(self.val_accs)  / len(self.val_accs)

            self.val_loss_epoch.append(avg_loss)
            self.val_acc_epoch.append(avg_acc)

            self.log("val_loss_epoch", avg_loss, prog_bar=False)
            self.log("val_acc_epoch",  avg_acc,  prog_bar=False)


    def configure_optimizers(self):
      optimizer = torch.optim.Adam(
          self.parameters(),
          lr=self.hparams.lr
      )

      scheduler = {
          "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
              optimizer,
              mode="min",
              factor=0.5,
              patience=3
          ),
          "monitor": "val_loss",
          "interval": "epoch",
          "frequency": 1
      }

      return {
          "optimizer": optimizer,
          "lr_scheduler": scheduler
      }

In [None]:
print("Preparing data module and model (this will error if your dataset path is wrong)...")
dm = PlantDataModule(EXTRACT_DIR, batch_size=BATCH_SIZE, img_size=IMG_SIZE)

try:
    dm.setup()
except Exception as e:
    raise RuntimeError(f"Data setup failed: {e}")

num_classes = len(dm.train_ds.classes)
print("Classes:", dm.train_ds.classes)

model = PlantDiseaseModel(num_classes=num_classes)

In [None]:
checkpoint_path = os.path.join(
    os.getcwd(), "checkpoints", "/content/drive/MyDrive/BongoDev/Project-1/rice1_model.pth"
)

In [None]:
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    print("Model loaded successfully from checkpoints.")

In [None]:
trainer = pl.Trainer(
    max_epochs=25,
)

trainer.fit(model, dm)

In [None]:
trainer.test(model, datamodule=dm)

In [None]:

MODEL_PATH = "/content/drive/MyDrive/BongoDev/Project-1/rice1_model.pth"
torch.save(model.state_dict(), MODEL_PATH)
print("Model saved at:", MODEL_PATH)

num_classes = len(dm.train_ds.classes)
num_classes

In [None]:
import matplotlib.pyplot as plt


plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(model.train_loss_epoch,marker='o', label='Train loss (per batch)', alpha=0.6)
plt.plot(model.val_loss_epoch,   label='Val loss', marker='o')
plt.title("Loss Curves")
plt.xlabel("Batch / Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(model.train_acc_epoch,marker='o', label='Train acc (per batch)', alpha=0.6)
plt.plot(model.val_acc_epoch,   label='Val acc', marker='o')
plt.title("Accuracy Curves")
plt.xlabel("Batch / Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
import torch
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
import matplotlib.pyplot as plt

model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

all_preds = []
all_labels = []

test_dl = dm.test_dataloader()
if test_dl is None:
    print("No test set found → using validation set")
    test_dl = dm.val_dataloader()

with torch.no_grad():
    for batch in test_dl:
        x, y = batch
        x = x.to(model.device)
        logits = model(x)
        preds = logits.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())


cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=dm.train_ds.classes)

plt.figure(figsize=(10, 8))
disp.plot(cmap='Blues', xticks_rotation=90)
plt.title("Confusion Matrix (Test set)")
plt.tight_layout()
plt.show()


plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=dm.train_ds.classes,
            yticklabels=dm.train_ds.classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

In [None]:
import numpy as np
import torch
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

model.eval()
model_device = next(model.parameters()).device
model.to(model_device)

all_labels = []
all_logits = []


current_dataloader = dm.test_dataloader()
if current_dataloader is None:
    print("No test set found, using validation set for ROC curve calculation.")
    current_dataloader = dm.val_dataloader()
    if current_dataloader is None:
        raise ValueError("Neither validation nor test dataloader is available. Cannot compute ROC curve.")
else:
    print("Using test set for ROC curve calculation.")

with torch.no_grad():
    for inputs, targets in current_dataloader:
        inputs = inputs.to(model_device)
        logits = model(inputs)
        all_logits.append(logits.cpu())
        all_labels.append(targets.cpu())


all_logits = torch.cat(all_logits).numpy()
all_labels = torch.cat(all_labels).numpy()


num_classes = len(dm.train_ds.classes)
class_names = dm.train_ds.classes


y_true = label_binarize(all_labels, classes=list(range(num_classes)))

y_score = torch.softmax(torch.tensor(all_logits), dim=1).numpy()

print(f"Collected {len(all_labels)} samples for ROC analysis.")
print(f"Shape of y_true: {y_true.shape}")
print(f"Shape of y_score: {y_score.shape}")


In [None]:
plt.figure(figsize=(8, 6))

for i in range(num_classes):
    fpr, tpr, _ = roc_curve(y_true[:, i], y_score[:, i])
    roc_auc = auc(fpr, tpr)

    plt.plot(
        fpr,
        tpr,
        label=f"Class {i} (AUC = {roc_auc:.3f})"
    )

plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Multiclass ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

In [None]:
!pip install grad-cam

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support
import torch

model.eval()
all_preds = []
all_labels = []
all_probs = []

test_dl = dm.test_dataloader()

with torch.no_grad():
    for batch in test_dl:
        x, y = batch
        x = x.to(model.device)
        logits = model(x)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = logits.argmax(dim=1).cpu().numpy()

        all_preds.extend(preds)
        all_labels.extend(y.cpu().numpy())
        all_probs.extend(probs)


all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

class_names = dm.train_ds.classes

print("\n" + "="*60)
print("Classification Report (Precision, Recall, F1-Score)")
print("="*60)
report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
print(report)

precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_preds, average=None)

plt.figure(figsize=(14, 6))
x = np.arange(len(class_names))
width = 0.25

plt.bar(x - width, precision, width, label='Precision', color='skyblue')
plt.bar(x, recall, width, label='Recall', color='lightgreen')
plt.bar(x + width, f1, width, label='F1-Score', color='salmon')

plt.xlabel('Classes')
plt.ylabel('Score')
plt.title('Per-Class Precision, Recall & F1-Score')
plt.xticks(x, class_names, rotation=45, ha='right')
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)

plt.figure(figsize=(12, 10))
disp.plot(cmap='Blues', xticks_rotation=45)
plt.title("Confusion Matrix (Test Set)")
plt.tight_layout()
plt.show()

precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')

print("\n" + "="*60)
print("Summary Averages")
print("="*60)
print(f"Macro Average    → Precision: {precision_macro:.4f} | Recall: {recall_macro:.4f} | F1: {f1_macro:.4f}")
print(f"Weighted Average → Precision: {precision_weighted:.4f} | Recall: {recall_weighted:.4f} | F1: {f1_weighted:.4f}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from pytorch_grad_cam import GradCAM, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
def swin_reshape_transform(tensor):
    if isinstance(tensor, list):
        tensor = tensor[0] # Assuming it's a list with a single tensor from out_indices=[3]

    if tensor.ndim == 4: # Swin FeatureListNet output (B, C, H_spatial, W_spatial)
        # This is already in (B, C, H, W) format, which ScoreCAM's upsampling expects.
        return tensor

    elif tensor.ndim == 3: # ViT tokens output (B, N, C)
        B, N, C = tensor.shape
        # Infer spatial dimensions (H, W) from N
        # N could be 1 + H*W (with CLS token) or H*W (without CLS token)

        H = W = 0
        # Try to infer H, W assuming N is H*W or 1+H*W
        if int(N**0.5) * int(N**0.5) == N: # N is a perfect square, assume no CLS token
            H = W = int(N**0.5)
            # Reshape from (B, N, C) to (B, C, H, W)
            return tensor.permute(0, 2, 1).reshape(B, C, H, W)
        elif N > 1 and int((N-1)**0.5) * int((N-1)**0.5) == (N-1): # N-1 is a perfect square, assume CLS token
            H = W = int((N - 1)**0.5)
            # Remove CLS token, then reshape from (B, N-1, C) to (B, C, H, W)
            return tensor[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W)
        else:
            raise ValueError(f"Cannot infer 2D spatial dimensions from 3D tensor with N={N}. Shape: {tensor.shape}")
    else:
        raise ValueError(f"Unexpected tensor dimensions for reshape_transform: {tensor.ndim}. Shape: {tensor.shape}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

from pytorch_grad_cam import GradCAMPlusPlus # Added this import

# ─── Swin CAM (Score-CAM) ───
swin_cam = ScoreCAM(
    model=model,
    target_layers=[model.swin.norm], # Changed target to the LayerNorm before the head
    reshape_transform=swin_reshape_transform
    # Removed 'batch_size' argument
)

# ─── convnext CAM (Grad-CAM++) ───
# Using GradCAMPlusPlus for potentially sharper heatmaps
eff_cam = GradCAMPlusPlus(
    model=model,
    target_layers=[model.convnext.stages[-1].blocks[-1]]
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam import GradCAMPlusPlus # Import GradCAMPlusPlus

# --- Define a list of image paths to process ---
image_paths = [
    "/content/drive/MyDrive/BongoDev/Project-1/image/bacterial_val.JPG",
    "/content/drive/MyDrive/BongoDev/Project-1/image/brown_val.jpg",
    "/content/drive/MyDrive/BongoDev/Project-1/image/healthy_val.jpg",
    "/content/drive/MyDrive/BongoDev/Project-1/image/leaf_blast_val.jpg",
    "/content/drive/MyDrive/BongoDev/Project-1/image/leaf_scald_val.jpg",
    "/content/drive/MyDrive/BongoDev/Project-1/image/narrow_brown_val.jpg"
]
# --- Labels matching the image paths ---
image_labels = [
    "Bacterial Leaf Blight",
    "Brown Spot",
    "Healthy",
    "Leaf Blast",
    "Leaf Scald",
    "Narrow Brown Spot"
]

def norm(cam):
    return (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)

# --- Iterate through each image and generate/plot CAMs ---
for idx, img_path in enumerate(image_paths):
    try:
        img_pil = Image.open(img_path).convert("RGB")
        input_tensor = dm.val_tfms(img_pil).unsqueeze(0).to(device)

        # prediction
        with torch.no_grad():
            logits = model(input_tensor)
            pred_idx = logits.argmax(dim=1).item()
            confidence = torch.softmax(logits, dim=1)[0, pred_idx].item()
        pred_class_name = dm.train_ds.classes[pred_idx] # Get class name here

        targets = [ClassifierOutputTarget(pred_idx)]

        # CAMs
        cam_swin = swin_cam(input_tensor, targets)[0]
        cam_eff  = eff_cam(input_tensor, targets)[0]

        cam_swin_norm = norm(cam_swin)
        cam_eff_norm  = norm(cam_eff)

        img_np = np.array(img_pil.resize((dm.img_size, dm.img_size))) / 255.0

        vis_swin = show_cam_on_image(img_np, cam_swin_norm, use_rgb=True)
        vis_eff  = show_cam_on_image(img_np, cam_eff_norm, use_rgb=True)

        # --- Plot for the current image ---
        fig, axes = plt.subplots(1, 3, figsize=(18, 5), dpi=140) # Changed to 3 subplots

        axes[0].imshow(img_np)
        axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
        axes[0].axis("off")

        # Function to add arrow and title to CAM plots (modified to remove arrow)
        def add_cam_details(ax, vis_cam, cam_norm, title_text, pred_class_name, confidence):
            ax.imshow(vis_cam)
            ax.set_title(f"{title_text}\nPred: {pred_class_name.replace('_', ' ').title()}\nConf: {confidence:.2%}",
                         fontsize=12, fontweight='bold')
            ax.axis("off")

        add_cam_details(axes[1], vis_swin, cam_swin_norm, "Swin Transformer\n(Score-CAM)", pred_class_name, confidence)
        add_cam_details(axes[2], vis_eff, cam_eff_norm, "ConvNext\n(Grad-CAM++)", pred_class_name, confidence)

        fig.suptitle(
            f"Hybrid CAM Analysis for Rice Leaf Disease\nImage Shows: {image_labels[idx]} (Predicted: {pred_class_name.replace('_', ' ').title()})",
            fontsize=16,
            fontweight="bold",
            y=1.02 # Adjust vertical position of main title
        )

        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust rect to prevent suptitle overlap
        plt.show()
    except FileNotFoundError:
        print(f"Error: Image not found at path: {img_path}. Please check the path and try again.")
    except Exception as e:
        print(f"An unexpected error occurred while processing {img_path}: {e}")
        import traceback
        traceback.print_exc()