**hybrid SupCon+CE**

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import RandAugment
from tqdm import tqdm
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, ConfusionMatrixDisplay, roc_curve, auc

from cbam_resnet import cbam_resnet50

# ---------- SEED & DEVICE ----------
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_all()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- PATHS & CLASSES ----------
data_root = "/kaggle/input/minida/mini_output1"
pretrained_path = "/kaggle/working/simsiam_cbam_pretrained_final.pth"
train_dir, val_dir, test_dir = [os.path.join(data_root, x) for x in ["train", "val", "test"]]
class_names = ['Alternaria', 'Healthy Leaf', 'straw_mite']
num_classes = len(class_names)

# ---------- MIXUP ----------
def mixup_data(x, y, alpha=0.3):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ---------- SUPERVISED CONTRASTIVE LOSS ----------
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.eps = 1e-8
    def forward(self, features, labels):
        device = features.device
        batch_size = features.size(0)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        anchor_dot_contrast = torch.div(torch.matmul(features, features.T), self.temperature)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(batch_size, device=device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + self.eps)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + self.eps)
        loss = -mean_log_prob_pos.mean()
        return loss

# ---------- DATALOADERS ----------
def get_loaders(batch_size=32):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_ds = ImageFolder(train_dir, train_transform)
    val_ds = ImageFolder(val_dir, val_transform)
    test_ds = ImageFolder(test_dir, val_transform)

    class_counts = np.bincount(train_ds.targets)
    weights = 1. / class_counts[train_ds.targets]
    sampler = WeightedRandomSampler(weights, len(train_ds), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader, test_loader, test_ds

# ---------- MODEL ----------
class FineTuneCBAM(nn.Module):
    def __init__(self, pretrained_path, num_classes=3):
        super().__init__()
        backbone = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        ckpt = torch.load(pretrained_path, map_location=device)
        if 'backbone' in ckpt:
            self.backbone.load_state_dict(ckpt['backbone'], strict=False)
        else:
            self.backbone.load_state_dict(ckpt, strict=False)
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        self.feature_layer = nn.Linear(2048, 128)  # For SupCon

    def forward(self, x, return_features=False):
        feats = self.backbone(x).flatten(1)
        features = F.normalize(self.feature_layer(feats), dim=1)
        logits = self.classifier(feats)
        if return_features:
            return logits, features
        return logits

# ---------- LR SCHEDULER ----------
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
def get_scheduler(optimizer, total_epochs, warmup_epochs=5):
    warmup = LinearLR(optimizer, start_factor=0.2, total_iters=warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=total_epochs-warmup_epochs)
    return SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])

# ---------- TRAINING & EVAL ----------
def train_epoch(model, loader, ce_loss_fn, supcon_loss_fn, optimizer,
                use_mixup=True, mixup_alpha=0.3, supcon_weight=0.5):
    model.train()
    total_loss, total_ce, total_supcon, correct = 0, 0, 0, 0
    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        if use_mixup:
            imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha)
            logits, features = model(imgs, return_features=True)
            loss_ce = mixup_criterion(ce_loss_fn, logits, y_a, y_b, lam)
            loss_supcon = supcon_loss_fn(features, labels)
            loss = (1 - supcon_weight) * loss_ce + supcon_weight * loss_supcon
            preds = logits.argmax(1)
            correct += (lam * preds.eq(y_a).sum().item() + (1 - lam) * preds.eq(y_b).sum().item())
        else:
            logits, features = model(imgs, return_features=True)
            loss_ce = ce_loss_fn(logits, labels)
            loss_supcon = supcon_loss_fn(features, labels)
            loss = (1 - supcon_weight) * loss_ce + supcon_weight * loss_supcon
            correct += logits.argmax(1).eq(labels).sum().item()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        total_ce += loss_ce.item() * imgs.size(0)
        total_supcon += loss_supcon.item() * imgs.size(0)
    n = len(loader.dataset)
    return total_loss / n, correct / n, total_ce / n, total_supcon / n

def eval_epoch(model, loader, criterion):
    model.eval()
    total_loss, correct, all_labels, all_probs = 0, 0, [], []
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Eval", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * imgs.size(0)
            correct += outputs.argmax(1).eq(labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    acc = correct / len(loader.dataset)
    return total_loss / len(loader.dataset), acc, np.array(all_labels), np.array(all_probs)

# ---------- EARLY STOPPING ----------
class EarlyStopping:
    def __init__(self, patience=8, verbose=True):
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.best_state = None
        self.verbose = verbose
    def __call__(self, val_acc, model):
        if (self.best_acc is None) or (val_acc > self.best_acc):
            self.best_acc = val_acc
            self.best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            self.counter = 0
            if self.verbose: print("Validation accuracy improved, saving best state.")
            return False
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                return True
            return False

# ---------- PLOTTING ----------
def plot_confusion_matrix(y_true, y_pred, class_names, save_path=None):
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig, ax = plt.subplots(figsize=(6,6))
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    plt.title('Normalized Confusion Matrix (CBAM)')
    if save_path: plt.savefig(save_path)
    plt.close()

def plot_roc_per_class(y_true, y_score, n_classes, class_names, save_path=None):
    plt.figure(figsize=(8,6))
    for i in range(n_classes):
        if np.sum(y_true==i) == 0: continue
        try:
            fpr, tpr, _ = roc_curve((y_true==i).astype(int), y_score[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
        except Exception as e:
            print(f"ROC error for class {class_names[i]}: {e}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Per-class ROC Curves (CBAM)')
    plt.legend()
    if save_path: plt.savefig(save_path)
    plt.close()

def plot_reliability(y_true, y_prob, n_bins=10):
    from sklearn.calibration import calibration_curve
    plt.figure(figsize=(5,5))
    for i, name in enumerate(class_names):
        try:
            prob_true, prob_pred = calibration_curve((y_true==i).astype(int), y_prob[:,i], n_bins=n_bins, strategy='uniform')
            plt.plot(prob_pred, prob_true, marker='o', label=f"{name}")
        except Exception as e:
            print(f"Reliability curve failed for {name}: {e}")
    plt.plot([0,1],[0,1],'--', color='gray')
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Fraction of Positives")
    plt.title("Reliability Diagram (CBAM)")
    plt.legend()
    plt.tight_layout()
    plt.savefig("reliability_diagram_cbam.png")
    plt.close()

def plot_loss_acc_curves(train_losses, val_losses, train_accs, val_accs, train_ce_losses, train_supcon_losses):
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Total')
    plt.plot(val_losses, label='Val Loss')
    plt.plot(train_ce_losses, label='Train CE')
    plt.plot(train_supcon_losses, label='Train SupCon')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curves (CBAM Hybrid)")
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title("Accuracy Curves (CBAM Hybrid)")
    plt.legend()
    plt.tight_layout()
    plt.savefig("loss_acc_curves_cbam_hybrid.png")
    plt.close()

# ---------- TTA ----------
def tta_eval(model, test_ds, batch_size, class_names):
    tta_transforms = [
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(1.0),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomVerticalFlip(1.0),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomRotation(15),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(280), transforms.CenterCrop(224), transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.GaussianBlur(3),
                            transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    ]
    tta_probs = []
    model.eval()
    for t in tta_transforms:
        test_ds.transform = t
        loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
        preds = []
        with torch.no_grad():
            for imgs, _ in loader:
                imgs = imgs.to(device)
                outputs = model(imgs)
                preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        tta_probs.append(np.concatenate(preds))
    tta_probs = np.array(tta_probs)
    mean_probs = np.mean(tta_probs, axis=0)
    final_preds = np.argmax(mean_probs, axis=1)
    return mean_probs, final_preds

# ---------- MAIN ----------
def main(epochs=40, batch_size=32, patience=8, mixup_alpha=0.3, supcon_weight=0.5):
    train_loader, val_loader, test_loader, test_ds = get_loaders(batch_size)
    model = FineTuneCBAM(pretrained_path, num_classes=num_classes).to(device)
    ce_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
    supcon_loss_fn = SupConLoss(temperature=0.07)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = get_scheduler(optimizer, epochs, warmup_epochs=5)
    early_stopper = EarlyStopping(patience=patience)

    # (OPTIONAL) Freeze backbone for warmup
    for param in model.backbone.parameters():
        param.requires_grad = False
    for param in model.classifier.parameters():
        param.requires_grad = True
    for param in model.feature_layer.parameters():
        param.requires_grad = True

    train_losses, train_accs, val_losses, val_accs = [], [], [], []
    train_ce_losses, train_supcon_losses = [], []

    for epoch in range(epochs):
        if epoch == 5:
            for param in model.backbone.parameters():
                param.requires_grad = True
        t_loss, t_acc, t_ce, t_sup = train_epoch(
            model, train_loader, ce_loss_fn, supcon_loss_fn, optimizer,
            use_mixup=True, mixup_alpha=mixup_alpha, supcon_weight=supcon_weight)
        v_loss, v_acc, _, _ = eval_epoch(model, val_loader, ce_loss_fn)
        scheduler.step()
        train_losses.append(t_loss)
        train_accs.append(t_acc)
        train_ce_losses.append(t_ce)
        train_supcon_losses.append(t_sup)
        val_losses.append(v_loss)
        val_accs.append(v_acc)
        print(f"Epoch {epoch+1}: Train Loss {t_loss:.4f}, Acc {t_acc:.4f} | Val Loss {v_loss:.4f}, Acc {v_acc:.4f}")
        if early_stopper(v_acc, model):
            print(f"Early stopping at epoch {epoch+1}")
            break

    plot_loss_acc_curves(train_losses, val_losses, train_accs, val_accs, train_ce_losses, train_supcon_losses)
    model.load_state_dict(early_stopper.best_state)

    print("\nTest set results (CBAM Hybrid):")
    test_loss, test_acc, test_labels, test_probs = eval_epoch(model, test_loader, ce_loss_fn)
    test_preds = np.argmax(test_probs, axis=1)
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    print(classification_report(test_labels, test_preds, target_names=class_names))
    plot_confusion_matrix(test_labels, test_preds, class_names, save_path="cbam_norm_confmat_hybrid.png")

    try:
        test_labels_onehot = np.eye(num_classes)[test_labels]
        roc_macro = roc_auc_score(test_labels_onehot, test_probs, average='macro', multi_class='ovr')
        print(f"Test ROC-AUC (macro): {roc_macro:.4f}")
        plot_roc_per_class(test_labels, test_probs, num_classes, class_names, save_path="cbam_perclass_roc_hybrid.png")
    except Exception as e:
        print(f"ROC-AUC calculation failed: {e}")

    plot_reliability(test_labels, test_probs)

    # -------- TTA --------
    print("\nTest-Time Augmentation (TTA) Evaluation (CBAM Hybrid):")
    mean_probs, final_preds = tta_eval(model, test_ds, batch_size, class_names)
    print(classification_report(test_ds.targets, final_preds, target_names=class_names))
    plot_confusion_matrix(test_ds.targets, final_preds, class_names, save_path="cbam_tta_confmat_hybrid.png")

if __name__ == '__main__':
    main(epochs=40, batch_size=32, patience=8, mixup_alpha=0.3, supcon_weight=0.5)


                                                      

Epoch 1: Train Loss 2.2487, Acc 0.3898 | Val Loss 1.1088, Acc 0.3131
Validation accuracy improved, saving best state.


                                                      

Epoch 2: Train Loss 2.2073, Acc 0.4732 | Val Loss 1.1009, Acc 0.3131
EarlyStopping counter: 1 / 8


                                                      

Epoch 3: Train Loss 2.1854, Acc 0.4149 | Val Loss 1.0241, Acc 0.4646
Validation accuracy improved, saving best state.


                                                      

Epoch 4: Train Loss 2.1509, Acc 0.3909 | Val Loss 0.9630, Acc 0.5556
Validation accuracy improved, saving best state.




Epoch 5: Train Loss 2.1198, Acc 0.4880 | Val Loss 0.9547, Acc 0.5051
EarlyStopping counter: 1 / 8


                                                      

Epoch 6: Train Loss 2.2027, Acc 0.4493 | Val Loss 0.9060, Acc 0.5859
Validation accuracy improved, saving best state.


                                                      

Epoch 7: Train Loss 2.1879, Acc 0.4928 | Val Loss 0.8403, Acc 0.6465
Validation accuracy improved, saving best state.


                                                      

Epoch 8: Train Loss 2.1510, Acc 0.5059 | Val Loss 0.7801, Acc 0.7677
Validation accuracy improved, saving best state.


                                                      

Epoch 9: Train Loss 2.1323, Acc 0.5324 | Val Loss 0.7574, Acc 0.7273
EarlyStopping counter: 1 / 8


                                                      

Epoch 10: Train Loss 2.1112, Acc 0.5755 | Val Loss 0.7450, Acc 0.6970
EarlyStopping counter: 2 / 8


                                                      

Epoch 11: Train Loss 2.1601, Acc 0.5789 | Val Loss 0.6835, Acc 0.7273
EarlyStopping counter: 3 / 8


                                                      

Epoch 12: Train Loss 2.1311, Acc 0.5866 | Val Loss 0.5978, Acc 0.8384
Validation accuracy improved, saving best state.


                                                      

Epoch 13: Train Loss 2.1179, Acc 0.6101 | Val Loss 0.5575, Acc 0.8081
EarlyStopping counter: 1 / 8


                                                      

Epoch 14: Train Loss 2.0786, Acc 0.6883 | Val Loss 0.4622, Acc 0.8889
Validation accuracy improved, saving best state.


                                                      

Epoch 15: Train Loss 2.0899, Acc 0.6630 | Val Loss 0.4747, Acc 0.8283
EarlyStopping counter: 1 / 8


                                                      

Epoch 16: Train Loss 2.0270, Acc 0.6823 | Val Loss 0.4400, Acc 0.8889
EarlyStopping counter: 2 / 8


                                                      

Epoch 17: Train Loss 2.0749, Acc 0.6422 | Val Loss 0.4603, Acc 0.8283
EarlyStopping counter: 3 / 8


                                                      

Epoch 18: Train Loss 2.0704, Acc 0.6573 | Val Loss 0.4909, Acc 0.8586
EarlyStopping counter: 4 / 8


                                                      

Epoch 19: Train Loss 2.0773, Acc 0.6651 | Val Loss 0.4470, Acc 0.8788
EarlyStopping counter: 5 / 8


                                                      

Epoch 20: Train Loss 2.0559, Acc 0.6503 | Val Loss 0.4231, Acc 0.9192
Validation accuracy improved, saving best state.


                                                      

Epoch 21: Train Loss 2.0119, Acc 0.7370 | Val Loss 0.4326, Acc 0.8990
EarlyStopping counter: 1 / 8


                                                      

Epoch 22: Train Loss 2.0867, Acc 0.6377 | Val Loss 0.4136, Acc 0.8889
EarlyStopping counter: 2 / 8


                                                      

Epoch 23: Train Loss 2.0443, Acc 0.6928 | Val Loss 0.4408, Acc 0.9293
Validation accuracy improved, saving best state.


                                                      

Epoch 24: Train Loss 2.0399, Acc 0.6767 | Val Loss 0.4642, Acc 0.8485
EarlyStopping counter: 1 / 8


                                                      

Epoch 25: Train Loss 1.9983, Acc 0.6738 | Val Loss 0.4130, Acc 0.9192
EarlyStopping counter: 2 / 8


                                                      

Epoch 26: Train Loss 2.0241, Acc 0.6965 | Val Loss 0.4057, Acc 0.9192
EarlyStopping counter: 3 / 8


                                                      

Epoch 27: Train Loss 1.9852, Acc 0.6989 | Val Loss 0.4227, Acc 0.9091
EarlyStopping counter: 4 / 8


                                                      

Epoch 28: Train Loss 1.9733, Acc 0.7095 | Val Loss 0.3969, Acc 0.9091
EarlyStopping counter: 5 / 8


                                                      

Epoch 29: Train Loss 2.0262, Acc 0.6961 | Val Loss 0.3916, Acc 0.9192
EarlyStopping counter: 6 / 8


                                                      

Epoch 30: Train Loss 1.9875, Acc 0.7056 | Val Loss 0.3999, Acc 0.8889
EarlyStopping counter: 7 / 8


                                                      

Epoch 31: Train Loss 1.9295, Acc 0.7617 | Val Loss 0.3832, Acc 0.9394
Validation accuracy improved, saving best state.


                                                      

Epoch 32: Train Loss 1.9527, Acc 0.7137 | Val Loss 0.3812, Acc 0.9293
EarlyStopping counter: 1 / 8


                                                      

Epoch 33: Train Loss 1.9037, Acc 0.7407 | Val Loss 0.3880, Acc 0.9293
EarlyStopping counter: 2 / 8


                                                      

Epoch 34: Train Loss 1.9424, Acc 0.7320 | Val Loss 0.3910, Acc 0.9192
EarlyStopping counter: 3 / 8


                                                      

Epoch 35: Train Loss 1.9574, Acc 0.7187 | Val Loss 0.3765, Acc 0.9394
EarlyStopping counter: 4 / 8


                                                      

Epoch 36: Train Loss 1.8872, Acc 0.7782 | Val Loss 0.3792, Acc 0.9394
EarlyStopping counter: 5 / 8


                                                      

Epoch 37: Train Loss 1.8645, Acc 0.7565 | Val Loss 0.3805, Acc 0.9192
EarlyStopping counter: 6 / 8


                                                      

Epoch 38: Train Loss 1.9698, Acc 0.7004 | Val Loss 0.3763, Acc 0.9293
EarlyStopping counter: 7 / 8


                                                      

Epoch 39: Train Loss 1.9289, Acc 0.7224 | Val Loss 0.3786, Acc 0.9293
EarlyStopping counter: 8 / 8
Early stopping at epoch 39

Test set results (CBAM Hybrid):


                                                   

Test Loss: 0.3125, Test Acc: 0.9596
              precision    recall  f1-score   support

  Alternaria       1.00      0.89      0.94        37
Healthy Leaf       0.89      1.00      0.94        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.96        99
   macro avg       0.96      0.96      0.96        99
weighted avg       0.96      0.96      0.96        99

Test ROC-AUC (macro): 0.9988

Test-Time Augmentation (TTA) Evaluation (CBAM Hybrid):
              precision    recall  f1-score   support

  Alternaria       0.97      0.89      0.93        37
Healthy Leaf       0.88      0.97      0.92        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.95        99
   macro avg       0.95      0.95      0.95        99
weighted avg       0.95      0.95      0.95        99

