**CBAM-ResNet fine-tuning**

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import RandAugment
from tqdm import tqdm
import numpy as np
import random
import os
from sklearn.metrics import classification_report, confusion_matrix

from cbam_resnet import cbam_resnet50  # Ensure this file is in your working directory

# ---------- SEED & DEVICE ----------
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_all()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- PATHS & CLASSES ----------
data_root = "/kaggle/input/minida/mini_output1"
pretrained_path = "/kaggle/working/simsiam_cbam_pretrained_final.pth"
train_dir, val_dir, test_dir = [os.path.join(data_root, x) for x in ["train", "val", "test"]]
class_names = ['Alternaria', 'Healthy Leaf', 'straw_mite']
num_classes = len(class_names)

# ---------- DATALOADERS ----------
def get_loaders(batch_size=32):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_ds = ImageFolder(train_dir, train_transform)
    val_ds = ImageFolder(val_dir, val_transform)
    test_ds = ImageFolder(test_dir, val_transform)

    class_counts = np.bincount(train_ds.targets)
    weights = 1. / class_counts[train_ds.targets]
    sampler = WeightedRandomSampler(weights, len(train_ds), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader, test_loader, test_ds

# ---------- MODEL ----------
class FineTuneCBAM(nn.Module):
    def __init__(self, pretrained_path, num_classes=3):
        super().__init__()
        backbone = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        ckpt = torch.load(pretrained_path, map_location=device)
        if 'backbone' in ckpt:
            self.backbone.load_state_dict(ckpt['backbone'], strict=False)
        else:
            self.backbone.load_state_dict(ckpt, strict=False)
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        x = self.backbone(x).flatten(1)
        return self.classifier(x)

# ---------- LR SCHEDULER ----------
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
def get_scheduler(optimizer, total_epochs, warmup_epochs=5):
    warmup = LinearLR(optimizer, start_factor=0.2, total_iters=warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=total_epochs-warmup_epochs)
    return SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])

# ---------- MIXUP ----------
def mixup_data(x, y, alpha=0.3):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ---------- TRAINING & EVAL ----------
def train_epoch(model, loader, criterion, optimizer, use_mixup=True, mixup_alpha=0.3):
    model.train()
    total_loss, correct = 0, 0
    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        if use_mixup:
            imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha)
            outputs = model(imgs)
            loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
            preds = outputs.argmax(1)
            correct += (lam * preds.eq(y_a).sum().item() + (1 - lam) * preds.eq(y_b).sum().item())
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            correct += outputs.argmax(1).eq(labels).sum().item()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

def eval_epoch(model, loader, criterion):
    model.eval()
    total_loss, correct, all_labels, all_probs = 0, 0, [], []
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Eval", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * imgs.size(0)
            correct += outputs.argmax(1).eq(labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    acc = correct / len(loader.dataset)
    return total_loss / len(loader.dataset), acc, np.array(all_labels), np.array(all_probs)

# ---------- EARLY STOPPING ----------
class EarlyStopping:
    def __init__(self, patience=8, verbose=True):
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.best_state = None
        self.verbose = verbose
    def __call__(self, val_acc, model):
        if (self.best_acc is None) or (val_acc > self.best_acc):
            self.best_acc = val_acc
            self.best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            self.counter = 0
            if self.verbose: print("Validation accuracy improved, saving best state.")
            return False
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                return True
            return False

# ---------- MAIN ----------
def main(epochs=40, batch_size=32, patience=8, mixup_alpha=0.3):
    train_loader, val_loader, test_loader, test_ds = get_loaders(batch_size)
    model = FineTuneCBAM(pretrained_path, num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = get_scheduler(optimizer, epochs, warmup_epochs=5)
    early_stopper = EarlyStopping(patience=patience)

    # Freeze backbone for warmup
    for param in model.backbone.parameters():
        param.requires_grad = False
    for param in model.classifier.parameters():
        param.requires_grad = True

    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(epochs):
        if epoch == 5:
            for param in model.backbone.parameters():
                param.requires_grad = True
        t_loss, t_acc = train_epoch(model, train_loader, criterion, optimizer, use_mixup=True, mixup_alpha=mixup_alpha)
        v_loss, v_acc, _, _ = eval_epoch(model, val_loader, criterion)
        scheduler.step()
        train_losses.append(t_loss)
        train_accs.append(t_acc)
        val_losses.append(v_loss)
        val_accs.append(v_acc)
        print(f"Epoch {epoch+1}: Train Loss {t_loss:.4f}, Acc {t_acc:.4f} | Val Loss {v_loss:.4f}, Acc {v_acc:.4f}")
        if early_stopper(v_acc, model):
            print(f"Early stopping at epoch {epoch+1}")
            break

    # --- Always restore best state and eval before testing ---
    model.load_state_dict(early_stopper.best_state)
    model.eval()
    # --- Save model at its best state ---
    torch.save(model.state_dict(), "best_finetuned_cbam.pth")
    print("Model saved as best_finetuned_cbam.pth")

    # --- Test evaluation ---
    print("\nTest set results (CBAM):")
    test_loss, test_acc, test_labels, test_probs = eval_epoch(model, test_loader, criterion)
    test_preds = np.argmax(test_probs, axis=1)
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    print(classification_report(test_labels, test_preds, target_names=class_names))
    # (Add your plotting here as needed)

if __name__ == '__main__':
    main(epochs=40, batch_size=32, patience=8, mixup_alpha=0.3)


                                                      

Epoch 1: Train Loss 1.0797, Acc 0.4428 | Val Loss 1.1021, Acc 0.3131
Validation accuracy improved, saving best state.


                                                      

Epoch 2: Train Loss 1.0514, Acc 0.4315 | Val Loss 1.1091, Acc 0.3131
EarlyStopping counter: 1 / 8


                                                      

Epoch 3: Train Loss 1.0452, Acc 0.4233 | Val Loss 1.0132, Acc 0.4747
Validation accuracy improved, saving best state.


                                                      

Epoch 4: Train Loss 1.0478, Acc 0.4563 | Val Loss 0.9460, Acc 0.6869
Validation accuracy improved, saving best state.




Epoch 5: Train Loss 1.0646, Acc 0.4371 | Val Loss 0.9650, Acc 0.6162
EarlyStopping counter: 1 / 8


                                                      

Epoch 6: Train Loss 1.0260, Acc 0.4961 | Val Loss 0.8190, Acc 0.6869
EarlyStopping counter: 2 / 8


                                                      

Epoch 7: Train Loss 0.9878, Acc 0.5293 | Val Loss 0.7635, Acc 0.7273
Validation accuracy improved, saving best state.


                                                      

Epoch 8: Train Loss 0.9395, Acc 0.5544 | Val Loss 0.5978, Acc 0.8384
Validation accuracy improved, saving best state.


                                                      

Epoch 9: Train Loss 0.9318, Acc 0.5686 | Val Loss 0.5133, Acc 0.8788
Validation accuracy improved, saving best state.


                                                      

Epoch 10: Train Loss 0.8408, Acc 0.6399 | Val Loss 0.4725, Acc 0.8384
EarlyStopping counter: 1 / 8


                                                      

Epoch 11: Train Loss 0.8235, Acc 0.6594 | Val Loss 0.5732, Acc 0.7778
EarlyStopping counter: 2 / 8


                                                      

Epoch 12: Train Loss 0.8280, Acc 0.6499 | Val Loss 0.4551, Acc 0.8485
EarlyStopping counter: 3 / 8


                                                      

Epoch 13: Train Loss 0.8032, Acc 0.6788 | Val Loss 0.4415, Acc 0.8889
Validation accuracy improved, saving best state.


                                                      

Epoch 14: Train Loss 0.7189, Acc 0.7159 | Val Loss 0.4325, Acc 0.8889
EarlyStopping counter: 1 / 8


                                                      

Epoch 15: Train Loss 0.8371, Acc 0.6578 | Val Loss 0.4500, Acc 0.8990
Validation accuracy improved, saving best state.


                                                      

Epoch 16: Train Loss 0.7392, Acc 0.7146 | Val Loss 0.4594, Acc 0.8182
EarlyStopping counter: 1 / 8


                                                      

Epoch 17: Train Loss 0.8212, Acc 0.6590 | Val Loss 0.4256, Acc 0.8990
EarlyStopping counter: 2 / 8


                                                      

Epoch 18: Train Loss 0.7327, Acc 0.7387 | Val Loss 0.4640, Acc 0.8283
EarlyStopping counter: 3 / 8


                                                      

Epoch 19: Train Loss 0.7776, Acc 0.6827 | Val Loss 0.3873, Acc 0.9091
Validation accuracy improved, saving best state.


                                                      

Epoch 20: Train Loss 0.7947, Acc 0.7034 | Val Loss 0.3819, Acc 0.9293
Validation accuracy improved, saving best state.


                                                      

Epoch 21: Train Loss 0.6738, Acc 0.7706 | Val Loss 0.4699, Acc 0.8485
EarlyStopping counter: 1 / 8


                                                      

Epoch 22: Train Loss 0.7487, Acc 0.7139 | Val Loss 0.3828, Acc 0.8889
EarlyStopping counter: 2 / 8


                                                      

Epoch 23: Train Loss 0.7398, Acc 0.7155 | Val Loss 0.4218, Acc 0.9192
EarlyStopping counter: 3 / 8


                                                      

Epoch 24: Train Loss 0.7984, Acc 0.6800 | Val Loss 0.3936, Acc 0.9091
EarlyStopping counter: 4 / 8


                                                      

Epoch 25: Train Loss 0.7769, Acc 0.6762 | Val Loss 0.4109, Acc 0.9091
EarlyStopping counter: 5 / 8


                                                      

Epoch 26: Train Loss 0.7512, Acc 0.7150 | Val Loss 0.4405, Acc 0.8788
EarlyStopping counter: 6 / 8


                                                      

Epoch 27: Train Loss 0.7297, Acc 0.7033 | Val Loss 0.3844, Acc 0.8788
EarlyStopping counter: 7 / 8


                                                      

Epoch 28: Train Loss 0.6994, Acc 0.7555 | Val Loss 0.3401, Acc 0.9495
Validation accuracy improved, saving best state.


                                                      

Epoch 29: Train Loss 0.7122, Acc 0.7373 | Val Loss 0.3407, Acc 0.9394
EarlyStopping counter: 1 / 8


                                                      

Epoch 30: Train Loss 0.7144, Acc 0.7273 | Val Loss 0.3748, Acc 0.9091
EarlyStopping counter: 2 / 8


                                                      

Epoch 31: Train Loss 0.6796, Acc 0.7512 | Val Loss 0.3562, Acc 0.9192
EarlyStopping counter: 3 / 8


                                                      

Epoch 32: Train Loss 0.7132, Acc 0.7086 | Val Loss 0.3366, Acc 0.9293
EarlyStopping counter: 4 / 8


                                                      

Epoch 33: Train Loss 0.6956, Acc 0.7399 | Val Loss 0.3485, Acc 0.9192
EarlyStopping counter: 5 / 8


                                                      

Epoch 34: Train Loss 0.6700, Acc 0.7500 | Val Loss 0.3634, Acc 0.9091
EarlyStopping counter: 6 / 8


                                                      

Epoch 35: Train Loss 0.7068, Acc 0.7411 | Val Loss 0.3515, Acc 0.9192
EarlyStopping counter: 7 / 8


                                                      

Epoch 36: Train Loss 0.6281, Acc 0.7872 | Val Loss 0.3346, Acc 0.9293
EarlyStopping counter: 8 / 8
Early stopping at epoch 36
Model saved as best_finetuned_cbam.pth

Test set results (CBAM):


                                                   

Test Loss: 0.2851, Test Acc: 0.9697
              precision    recall  f1-score   support

  Alternaria       1.00      0.92      0.96        37
Healthy Leaf       0.91      1.00      0.95        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.97        99
   macro avg       0.97      0.97      0.97        99
weighted avg       0.97      0.97      0.97        99



