In [None]:
import os
import random
from typing import Iterable, Tuple, Optional, Callable, List

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import models as tv_models
from torchvision.models.resnet import ResNet, Bottleneck
from torchvision.models import ResNet50_Weights
from PIL import Image, ImageFilter


# ============================================================
# 0. Global config
# ============================================================

PRETRAIN_ROOT = "/kaggle/input/minida/mini_output1/pretrain" 
IMG_SIZE      = 224
SEED          = 42
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================
# 1. Reproducibility
# ============================================================

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_all(SEED)


# ============================================================
# 2. CBAM modules and CBAM-ResNet50
# ============================================================

class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(x_cat)
        return self.sigmoid(out)


class CBAMBottleneck(Bottleneck):
    """
    ResNet Bottleneck with CBAM after the third BN and before the residual add.
    This means CBAM is applied inside *every* residual bottleneck block.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        planes_out = self.conv3.out_channels
        self.ca = ChannelAttention(planes_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


def cbam_resnet50(*, norm_layer: Optional[Callable] = None, **kwargs) -> ResNet:
    """
    Return a ResNet-50 that uses CBAMBottleneck blocks everywhere.
    """
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], norm_layer=norm_layer, **kwargs)
    # Zero-init last BN in each residual branch (improves training stability).
    for m in model.modules():
        if isinstance(m, Bottleneck):
            nn.init.constant_(m.bn3.weight, 0)
    return model


# ============================================================
# 3. SimSiam projection/prediction heads
# ============================================================

class MLPHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=2048, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=False),
        )

    def forward(self, x):
        return self.net(x)


class PredictionHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class SimSiamCBAM(nn.Module):
    """
    SimSiam with CBAM-ResNet50 backbone.

    use_imagenet_init:
      - False -> CBAM-ResNet50 from scratch
      - True  -> CBAM-ResNet50 initialized from vanilla ResNet-50 IN1K weights
    """
    def __init__(self, fix_backbone_bn: bool = True, use_imagenet_init: bool = False):
        super().__init__()

        if use_imagenet_init:
            base = tv_models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            base_sd = base.state_dict()

            cbam = cbam_resnet50(num_classes=1000)
            cbam_sd = cbam.state_dict()

            copied = 0
            for k in cbam_sd.keys():
                if k in base_sd and cbam_sd[k].shape == base_sd[k].shape:
                    cbam_sd[k] = base_sd[k]
                    copied += 1
            cbam.load_state_dict(cbam_sd)
            print(f"[SimSiam CBAM] IN1K init: copied {copied} parameters from vanilla ResNet-50.")
            resnet = cbam
        else:
            resnet = cbam_resnet50(num_classes=1000)
            print("[SimSiam CBAM] Scratch init: CBAM-ResNet50 without ImageNet weights.")

        # children: conv1, bn1, relu, maxpool, layer1..4, avgpool, fc
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # up to avgpool
        self.projector = MLPHead(2048)
        self.predictor = PredictionHead()
        self.fix_backbone_bn = fix_backbone_bn

        # Keep track of BN layers to freeze
        self._frozen_bn_modules: List[nn.BatchNorm2d] = []
        if self.fix_backbone_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.requires_grad_(False)
                    self._frozen_bn_modules.append(m)

    def _forward_backbone(self, x):
        x = self.backbone(x)           # (B, 2048, 1, 1)
        x = torch.flatten(x, 1)        # (B, 2048)
        return x

    def forward(self, x1, x2):
        z1 = self.projector(self._forward_backbone(x1))
        z2 = self.projector(self._forward_backbone(x2))
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        # stop-grad on targets (SimSiam)
        return p1, p2, z1.detach(), z2.detach()

    def train(self, mode: bool = True):
        super().train(mode)
        if self.fix_backbone_bn:
            for m in self._frozen_bn_modules:
                m.eval()
        return self


# ============================================================
# 4. Dataset & Transforms (two views)
# ============================================================

class UnlabeledDataset(Dataset):
    """
    Recursively load all images under PRETRAIN_ROOT. No labels.
    """
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".tif", ".tiff")
        filepaths = []
        for dp, _, fns in os.walk(root_dir):
            for fn in fns:
                if fn.lower().endswith(exts):
                    filepaths.append(os.path.join(dp, fn))
        self.filepaths = sorted(filepaths)
        if len(self.filepaths) == 0:
            raise RuntimeError(f"No images found under {root_dir}")

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx: int):
        img = Image.open(self.filepaths[idx]).convert("RGB")
        if self.transform is None:
            raise RuntimeError("Transform must be provided.")
        v1 = self.transform(img)
        v2 = self.transform(img)
        return v1, v2


class PILGaussianBlur(object):
    """
    Gaussian blur using PIL.ImageFilter to avoid numpy/torchvision np.bool issues.
    """
    def __init__(self, radius_min=0.1, radius_max=2.0):
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img: Image.Image):
        radius = random.uniform(self.radius_min, self.radius_max)
        return img.filter(ImageFilter.GaussianBlur(radius=radius))


def get_ssl_transform(img_size: int = 224):
    """
    SimSiam-style augmentations:
    - RandomResizedCrop
    - RandomHorizontalFlip
    - ColorJitter (p=0.8)
    - RandomGrayscale
    - Random Gaussian blur
    - ToTensor + Normalize
    """
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    color_jitter = transforms.ColorJitter(
        brightness=0.4,
        contrast=0.4,
        saturation=0.4,
        hue=0.1,
    )

    transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([color_jitter], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        PILGaussianBlur(radius_min=0.1, radius_max=2.0),
        transforms.ToTensor(),
        normalize,
    ])
    return transform


# ============================================================
# 5. Loss & optimizer helpers
# ============================================================

def negative_cosine_similarity(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


def exclude_from_wd(named_params: Iterable[Tuple[str, torch.nn.Parameter]], wd: float = 1e-4):
    """
    Create two param groups: with and without weight decay (no wd for BN/bias/1D).
    """
    wd_params, no_wd_params = [], []
    for n, p in named_params:
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower():
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    return [
        {"params": wd_params, "weight_decay": wd},
        {"params": no_wd_params, "weight_decay": 0.0},
    ]


# ============================================================
# 6. Pretrain function
# ============================================================

def pretrain_variant(
    root_path: str,
    checkpoint_dir: str,
    use_imagenet_init: bool,
    epochs: int = 200,
    batch_size: int = 64,
    num_workers: int = 2,
    accumulation_steps: int = 1,
    fix_backbone_bn: bool = True,
    seed: int = 42,
) -> str:
    os.makedirs(checkpoint_dir, exist_ok=True)
    seed_all(seed)

    device = DEVICE
    print(f"\n=================================================")
    print(f"SimSiam+CBAM pretrain | variant = {'IN1K-init' if use_imagenet_init else 'Scratch'}")
    print(f"Device: {device}")
    print(f"Root:   {root_path}")
    print(f"=================================================\n")

    transform = get_ssl_transform(IMG_SIZE)
    dataset = UnlabeledDataset(root_dir=root_path, transform=transform)
    print(f"Found {len(dataset)} unlabeled images.")

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )

    model = SimSiamCBAM(
        fix_backbone_bn=fix_backbone_bn,
        use_imagenet_init=use_imagenet_init
    ).to(device)

    # Safe torch.compile (optional)
    use_compile = False
    if hasattr(torch, "compile") and torch.cuda.is_available():
        try:
            major, _ = torch.cuda.get_device_capability()
            if major >= 7:
                use_compile = True
        except Exception as e:
            print(f"Skipping torch.compile (capability query failed: {e})")

    if use_compile:
        try:
            model = torch.compile(model)
        except Exception as e:
            print(f"torch.compile failed, using eager mode. Reason: {e}")

    # SimSiam LR scaling rule
    global_bs = batch_size
    base_lr = 0.05 * max(min(global_bs, 1024), 64) / 256.0
    print(f"Base LR: {base_lr:.5f}")

    param_groups = exclude_from_wd(model.named_parameters(), wd=1e-4)
    optimizer = torch.optim.SGD(param_groups, lr=base_lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        micro_step = 0

        print(f"[Variant: {'IN1K' if use_imagenet_init else 'Scratch'}] Epoch {epoch+1}/{epochs}")
        for step, (x1, x2) in enumerate(tqdm(dataloader, desc="SSL", leave=False), start=1):
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
                p1, p2, z1, z2 = model(x1, x2)
                loss_full = 0.5 * (
                    negative_cosine_similarity(p1, z2) +
                    negative_cosine_similarity(p2, z1)
                )
                loss = loss_full / max(accumulation_steps, 1)

            scaler.scale(loss).backward()
            micro_step += 1

            if micro_step == accumulation_steps:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                micro_step = 0

            total_loss += loss_full.item()

        if micro_step > 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        avg_loss = total_loss / len(dataloader)
        scheduler.step()
        print(f"  -> Avg SSL loss: {avg_loss:.4f}")

    # Save final backbone & heads for fine-tuning
    final_path = os.path.join(checkpoint_dir, "simsiam_cbam_pretrained_final.pth")
    target_model = model._orig_mod if hasattr(model, "_orig_mod") else model
    torch.save(
        {
            "backbone": target_model.backbone.state_dict(),
            "projector": target_model.projector.state_dict(),
            "predictor": target_model.predictor.state_dict(),
        },
        final_path,
    )
    print(f"\n[Done] Pretraining ({'IN1K' if use_imagenet_init else 'Scratch'}) saved to:\n  {final_path}\n")
    return final_path


# ============================================================
# 7. Main: run Scratch & IN1K pretrain
# ============================================================

if __name__ == "__main__":
    scratch_ckpt = pretrain_variant(
        root_path=PRETRAIN_ROOT,
        checkpoint_dir="/kaggle/working/simsiam_cbam_scratch",
        use_imagenet_init=False,
        epochs=150,         
        batch_size=64,
        num_workers=2,
        accumulation_steps=1,
        fix_backbone_bn=True,
        seed=42,
    )

    # 2) IN1K-initialized SimSiam+CBAM (for fairness comparison)
    in1k_ckpt = pretrain_variant(
        root_path=PRETRAIN_ROOT,
        checkpoint_dir="/kaggle/working/simsiam_cbam_in1k",
        use_imagenet_init=True,
        epochs=150,       
        batch_size=64,
        num_workers=2,
        accumulation_steps=1,
        fix_backbone_bn=True,
        seed=42,
    )

    print("Scratch checkpoint:", scratch_ckpt)
    print("IN1K-init checkpoint:", in1k_ckpt)


In [2]:
%%writefile cbam_resnet.py
import torch
import torch.nn as nn
from typing import Optional, Callable
from torchvision.models.resnet import ResNet, Bottleneck


class ChannelAttention(nn.Module):
    def __init__(self, in_planes: int, ratio: int = 16):
        super().__init__()
        hidden = max(in_planes // ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, hidden, kernel_size=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(hidden, in_planes, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(x_cat)
        return self.sigmoid(out)


class CBAMBottleneck(Bottleneck):
    """
    ResNet Bottleneck with CBAM after the third BN and before the residual add.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        planes_out = self.conv3.out_channels
        self.ca = ChannelAttention(planes_out)
        self.sa = SpatialAttention()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out); out = self.relu(out)
        out = self.conv3(out); out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


def cbam_resnet50(*, norm_layer: Optional[Callable] = None, **kwargs) -> ResNet:
    """
    Return a ResNet-50 that uses CBAMBottleneck blocks.
    kwargs forwarded to torchvision.models.resnet.ResNet (e.g., num_classes).
    """
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], norm_layer=norm_layer, **kwargs)
    # Zero-init last BN in each residual branch (improves training stability).
    for m in model.modules():
        if isinstance(m, Bottleneck):
            nn.init.constant_(m.bn3.weight, 0)
    return model

Writing cbam_resnet.py


In [6]:
import os
import random
from typing import Iterable, List, Tuple

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.transforms import RandAugment
from torchvision.datasets import ImageFolder

from sklearn.metrics import classification_report, confusion_matrix

from cbam_resnet import cbam_resnet50

# ============================================================
# 0. Paths & global config
# ============================================================

DATA_ROOT   = "/kaggle/input/minida/mini_output1"
SCRATCH_CKPT = "/kaggle/working/simsiam_cbam_scratch/simsiam_cbam_pretrained_final.pth"
IN1K_CKPT    = "/kaggle/working/simsiam_cbam_in1k/simsiam_cbam_pretrained_final.pth"

IMG_SIZE    = 256
NUM_EPOCHS  = 50
PATIENCE_CE = 10
PATIENCE_HY = 8
BATCH_SIZE  = 24
NUM_WORKERS = 2

BASE_LR_CE_BACKBONE = 3e-5
BASE_LR_CE_HEAD     = 1e-4
BASE_LR_HY          = 1e-4

MIXUP_ALPHA_CE = 0.3
MIXUP_ALPHA_HY = 0.3
SUPCON_WEIGHT  = 0.5    # lambda for hybrid loss

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================
# 1. Reproducibility & model stats
# ============================================================

def seed_all(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def compute_model_size_mb(model: nn.Module) -> float:
    total_bytes = 0
    for p in model.parameters():
        total_bytes += p.numel() * p.element_size()
    return total_bytes / (1024 ** 2)


seed_all(42)


# ============================================================
# 2. Common helpers (mixup, weight decay groups)
# ============================================================

def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.3, device=None):
    if alpha and alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    idx = torch.randperm(x.size(0), device=device)
    mixed_x = lam * x + (1.0 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam: float):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)


def exclude_from_wd(named_params: Iterable, wd: float = 1e-4):
    wd_params, no_wd_params = [], []
    for n, p in named_params:
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower():
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    return [
        {"params": wd_params,   "weight_decay": wd},
        {"params": no_wd_params,"weight_decay": 0.0},
    ]


# ============================================================
# 3. SupCon Loss (same as hybrid script)
# ============================================================

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.eps = 1e-8

    def forward(self, features, labels):
        device = features.device
        batch_size = features.size(0)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)

        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T), self.temperature
        )
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        exp_logits = torch.exp(logits) * (1 - torch.eye(batch_size, device=device))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + self.eps)

        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + self.eps)
        loss = -mean_log_prob_pos.mean()
        return loss


# ============================================================
# 4. Dataloaders (CE-style & Hybrid-style, matching your old code)
# ============================================================

def get_loaders_ce(
    data_root: str,
    batch_size: int = 32,
    num_workers: int = 2,
    img_size: int = 256,
):
    pin = torch.cuda.is_available()

    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.RandomErasing(
            p=0.2, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'
        ),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    resize_size = int(round(img_size * 1.14))
    eval_tf = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    train_ds = ImageFolder(os.path.join(data_root, "train"), transform=train_tf)
    val_ds   = ImageFolder(os.path.join(data_root, "val"),   transform=eval_tf)
    test_ds  = ImageFolder(os.path.join(data_root, "test"),  transform=eval_tf)

    class_names = train_ds.classes

    targets_np = np.array(train_ds.targets, dtype=np.int64)
    classes = np.unique(targets_np)
    class_counts = np.array([(targets_np == c).sum() for c in classes], dtype=np.float64)
    class_counts[class_counts == 0] = 1.0
    weights_per_class = 1.0 / class_counts
    sample_weights = weights_per_class[targets_np]
    sampler = WeightedRandomSampler(
        torch.as_tensor(sample_weights, dtype=torch.double),
        num_samples=len(sample_weights),
        replacement=True,
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=pin, drop_last=True
    )
    val_loader   = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    test_loader  = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )

    return train_loader, val_loader, test_loader, test_ds, class_names


def get_loaders_hybrid(
    data_root: str,
    batch_size: int = 32,
    num_workers: int = 2,
    img_size: int = 224,
):
    # This follows your hybrid script more closely (224 input)
    pin = torch.cuda.is_available()

    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    val_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    train_ds = ImageFolder(os.path.join(data_root, "train"), transform=train_tf)
    val_ds   = ImageFolder(os.path.join(data_root, "val"),   transform=val_tf)
    test_ds  = ImageFolder(os.path.join(data_root, "test"),  transform=val_tf)

    class_names = train_ds.classes

    class_counts = np.bincount(train_ds.targets)
    weights = 1. / class_counts[train_ds.targets]
    sampler = WeightedRandomSampler(weights, len(train_ds), replacement=True)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=pin
    )
    val_loader   = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )
    test_loader  = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin
    )

    return train_loader, val_loader, test_loader, test_ds, class_names


# ============================================================
# 5. Models (CE-only & Hybrid) matching your old scripts
# ============================================================

class FineTuneCBAM_CE(nn.Module):
    """
    CE-only: backbone + 512 FC + dropout + final classifier.
    Matches your finetune_cbam.py
    """
    def __init__(self, pretrained_path: str, num_classes: int):
        super().__init__()
        backbone = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # up to avgpool
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )
        ckpt = torch.load(pretrained_path, map_location="cpu")
        sd = ckpt.get("backbone", ckpt)
        self.backbone.load_state_dict(sd, strict=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x).flatten(1)
        return self.classifier(x)


class FineTuneCBAM_Hybrid(nn.Module):
    """
    Hybrid SupCon+CE: backbone + feature_layer(2048->128) + classifier(2048->512->3).
    Matches your hybrid script.
    """
    def __init__(self, pretrained_path: str, num_classes: int):
        super().__init__()
        backbone = cbam_resnet50(num_classes=1000)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        ckpt = torch.load(pretrained_path, map_location="cpu")
        sd = ckpt.get("backbone", ckpt)
        self.backbone.load_state_dict(sd, strict=False)

        self.feature_layer = nn.Linear(2048, 128)  # for SupCon
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )

    def forward(self, x, return_features=False):
        feats = self.backbone(x).flatten(1)
        features = F.normalize(self.feature_layer(feats), dim=1)
        logits = self.classifier(feats)
        if return_features:
            return logits, features
        return logits


# ============================================================
# 6. Schedulers (same as your code)
# ============================================================

from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

def get_scheduler(optimizer, total_epochs: int, warmup_epochs: int = 5):
    warmup = LinearLR(optimizer, start_factor=0.2, total_iters=warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=max(total_epochs - warmup_epochs, 1))
    return SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])


# ============================================================
# 7. Training loops (CE-only)
# ============================================================

def train_epoch_ce(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: torch.amp.GradScaler,
    use_mixup: bool = True,
    mixup_alpha: float = 0.3,
):
    model.train()
    total_loss, correct = 0.0, 0
    for imgs, labels in tqdm(loader, desc="Train (CE)", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            if use_mixup and mixup_alpha > 0:
                imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha, device=device)
                logits = model(imgs)
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
                preds = logits.argmax(1)
                correct += int(
                    lam * preds.eq(y_a).sum().item()
                    + (1.0 - lam) * preds.eq(y_b).sum().item()
                )
            else:
                logits = model(imgs)
                loss = criterion(logits, labels)
                preds = logits.argmax(1)
                correct += int(preds.eq(labels).sum().item())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item()) * imgs.size(0)

    n = len(loader.dataset)
    return total_loss / n, correct / n


@torch.no_grad()
def eval_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion,
    device: torch.device,
    desc: str = "Eval",
):
    model.eval()
    total_loss, correct = 0.0, 0
    all_labels: List[int] = []
    all_preds:  List[int] = []

    for imgs, labels in tqdm(loader, desc=desc, leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        total_loss += float(loss.item()) * imgs.size(0)
        preds = logits.argmax(1)
        correct += int(preds.eq(labels).sum().item())
        all_labels.extend(labels.cpu().numpy().tolist())
        all_preds.extend(preds.cpu().numpy().tolist())

    n = len(loader.dataset)
    return total_loss / n, correct / n, np.array(all_labels), np.array(all_preds)


# ============================================================
# 8. Training loops (Hybrid CE+SupCon)
# ============================================================

def train_epoch_hybrid(
    model: nn.Module,
    loader: DataLoader,
    ce_loss_fn,
    supcon_loss_fn,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    mixup_alpha: float = 0.3,
    supcon_weight: float = 0.5,
):
    model.train()
    total_loss, total_ce, total_sup, correct = 0, 0, 0, 0

    for imgs, labels in tqdm(loader, desc="Train (Hybrid)", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        imgs, y_a, y_b, lam = mixup_data(imgs, labels, alpha=mixup_alpha, device=device)
        logits, features = model(imgs, return_features=True)
        loss_ce = mixup_criterion(ce_loss_fn, logits, y_a, y_b, lam)
        loss_sup = supcon_loss_fn(features, labels)
        loss = (1 - supcon_weight) * loss_ce + supcon_weight * loss_sup

        preds = logits.argmax(1)
        correct += int(
            lam * preds.eq(y_a).sum().item()
            + (1.0 - lam) * preds.eq(y_b).sum().item()
        )

        loss.backward()
        optimizer.step()

        bs = imgs.size(0)
        total_loss += loss.item() * bs
        total_ce   += loss_ce.item() * bs
        total_sup  += loss_sup.item() * bs

    n = len(loader.dataset)
    return (
        total_loss / n,
        correct / n,
        total_ce / n,
        total_sup / n,
    )


# ============================================================
# 9. Early stopping
# ============================================================

class EarlyStopping:
    def __init__(self, patience: int = 7, verbose: bool = True):
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.best_state = None
        self.best_epoch = -1
        self.verbose = verbose

    def __call__(self, val_acc: float, model: nn.Module, epoch: int):
        if self.best_acc is None or val_acc > self.best_acc:
            self.best_acc = val_acc
            self.counter = 0
            self.best_epoch = epoch
            self.best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            if self.verbose:
                print("Validation accuracy improved, saving best state.")
            return False
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            return self.counter >= self.patience


# ============================================================
# 10. TTA (CE-style and Hybrid-style, simplified)
# ============================================================

@torch.no_grad()
def tta_ce(model: nn.Module, test_ds: ImageFolder, batch_size: int, device: torch.device, img_size: int):
    resize_base = int(round(img_size * 1.14))
    resize_up   = int(round(img_size * 1.25))

    tta_transforms = [
        transforms.Compose([
            transforms.Resize(resize_base),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(resize_base),
            transforms.CenterCrop(img_size),
            transforms.RandomHorizontalFlip(1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(resize_up),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(resize_base),
            transforms.CenterCrop(img_size),
            transforms.GaussianBlur(3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
    ]

    pin = torch.cuda.is_available()
    model.eval()
    probs_all = []

    for t in tta_transforms:
        test_ds.transform = t
        loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=pin)
        probs = []
        for imgs, _ in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            probs.append(F.softmax(logits, dim=1).cpu().numpy())
        probs_all.append(np.concatenate(probs, axis=0))

    probs_all = np.stack(probs_all, axis=0)
    mean_probs = probs_all.mean(axis=0)
    preds = mean_probs.argmax(axis=1)
    return preds


@torch.no_grad()
def tta_hybrid(model: nn.Module, test_ds: ImageFolder, batch_size: int, device: torch.device):
    tta_transforms = [
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomVerticalFlip(1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(280),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.GaussianBlur(3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ]),
    ]
    model.eval()
    probs_all = []
    for t in tta_transforms:
        test_ds.transform = t
        loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
        preds_list = []
        for imgs, _ in loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            preds_list.append(F.softmax(outputs, dim=1).cpu().numpy())
        probs_all.append(np.concatenate(preds_list, axis=0))
    probs_all = np.array(probs_all)
    mean_probs = probs_all.mean(axis=0)
    final_preds = np.argmax(mean_probs, axis=1)
    return final_preds


# ============================================================
# 11. Variant runners
# ============================================================

def run_ce_variant(variant_name: str, ckpt_path: str):
    print("\n" + "#" * 60)
    print(f"[CE] Variant: {variant_name}")
    print(f"Checkpoint: {ckpt_path}")
    print("#" * 60)

    seed_all(42)
    train_loader, val_loader, test_loader, test_ds, class_names = get_loaders_ce(
        DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, img_size=IMG_SIZE
    )
    num_classes = len(class_names)

    model = FineTuneCBAM_CE(pretrained_path=ckpt_path, num_classes=num_classes).to(DEVICE)

    # Ensure BN uses batch stats during finetune
    for m in model.backbone.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.train()
            m.requires_grad_(True)

    # param groups: smaller LR for backbone
    backbone_named = [(n, p) for n, p in model.named_parameters() if "classifier" not in n]
    head_named     = [(n, p) for n, p in model.named_parameters() if "classifier" in n]
    groups_backbone = exclude_from_wd(backbone_named, wd=1e-4)
    groups_head     = exclude_from_wd(head_named, wd=1e-4)
    for g in groups_backbone:
        g["lr"] = BASE_LR_CE_BACKBONE
    for g in groups_head:
        g["lr"] = BASE_LR_CE_HEAD

    optimizer = torch.optim.AdamW(groups_backbone + groups_head, betas=(0.9, 0.999))
    scheduler = get_scheduler(optimizer, total_epochs=NUM_EPOCHS, warmup_epochs=5)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.02)
    scaler = torch.amp.GradScaler('cuda', enabled=(DEVICE.type == "cuda"))

    # head-only warmup
    for p in model.backbone.parameters():
        p.requires_grad = False

    early = EarlyStopping(patience=PATIENCE_CE, verbose=True)

    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(NUM_EPOCHS):
        if epoch == 5:
            for p in model.backbone.parameters():
                p.requires_grad = True
            print(f"Unfroze backbone at epoch {epoch}")

        # mixup decay (as in your script)
        decay_start = int(NUM_EPOCHS * 2 / 3)
        if epoch < decay_start:
            cur_alpha = MIXUP_ALPHA_CE
        else:
            remaining = max(NUM_EPOCHS - epoch - 1, 0)
            span = max(NUM_EPOCHS - decay_start, 1)
            cur_alpha = MIXUP_ALPHA_CE * (remaining / span)

        print(f"\n[CE] Epoch {epoch+1}/{NUM_EPOCHS} | mixup_alpha={cur_alpha:.4f}")
        t_loss, t_acc = train_epoch_ce(
            model, train_loader, criterion, optimizer, DEVICE,
            scaler, use_mixup=(cur_alpha > 1e-6), mixup_alpha=cur_alpha
        )
        v_loss, v_acc, v_labels, v_preds = eval_epoch(
            model, val_loader, criterion, DEVICE, desc="Val (CE)"
        )
        scheduler.step()

        print(f"Train Loss: {t_loss:.4f}, Acc: {t_acc:.4f} | "
              f"Val Loss: {v_loss:.4f}, Acc: {v_acc:.4f}")

        train_losses.append(t_loss); train_accs.append(t_acc)
        val_losses.append(v_loss);   val_accs.append(v_acc)

        if early(v_acc, model, epoch):
            print("Early stopping triggered (CE).")
            break

    # load best
    model.load_state_dict(early.best_state)
    model.to(DEVICE).eval()
    best_val = early.best_acc

    # final test
    te_loss, te_acc, te_labels, te_preds = eval_epoch(
        model, test_loader, criterion, DEVICE, desc="Test (CE)"
    )
    print("\n[CE] Test report:")
    print(classification_report(te_labels, te_preds, target_names=class_names))

    # TTA
    tta_preds = tta_ce(model, test_ds, BATCH_SIZE, DEVICE, IMG_SIZE)
    tta_acc = (tta_preds == np.array(test_ds.targets)).mean()

    print(f"\n[{variant_name}] CE Best Val Acc: {best_val*100:.2f}%")
    print(f"[{variant_name}] CE Test Acc:     {te_acc*100:.2f}%")
    print(f"[{variant_name}] CE TTA  Acc:     {tta_acc*100:.2f}%")

    n_params = count_parameters(model)
    size_mb  = compute_model_size_mb(model)

    return {
        "variant": variant_name,
        "type": "CE",
        "best_val": best_val,
        "test_acc": te_acc,
        "tta_acc": tta_acc,
        "n_params": n_params,
        "size_mb": size_mb,
    }


def run_hybrid_variant(variant_name: str, ckpt_path: str):
    print("\n" + "#" * 60)
    print(f"[HYBRID] Variant: {variant_name}")
    print(f"Checkpoint: {ckpt_path}")
    print("#" * 60)

    seed_all(42)
    train_loader, val_loader, test_loader, test_ds, class_names = get_loaders_hybrid(
        DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
    )
    num_classes = len(class_names)

    model = FineTuneCBAM_Hybrid(pretrained_path=ckpt_path, num_classes=num_classes).to(DEVICE)

    ce_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
    supcon_loss_fn = SupConLoss(temperature=0.07)
    optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR_HY, weight_decay=1e-4)
    scheduler = get_scheduler(optimizer, total_epochs=NUM_EPOCHS, warmup_epochs=5)
    early = EarlyStopping(patience=PATIENCE_HY, verbose=True)

    # freeze backbone for first 5 epochs (as in your hybrid script)
    for p in model.backbone.parameters():
        p.requires_grad = False
    for p in model.classifier.parameters():
        p.requires_grad = True
    for p in model.feature_layer.parameters():
        p.requires_grad = True

    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(NUM_EPOCHS):
        if epoch == 5:
            for p in model.backbone.parameters():
                p.requires_grad = True
            print("Unfroze backbone at epoch 5 (Hybrid).")

        print(f"\n[HYBRID] Epoch {epoch+1}/{NUM_EPOCHS}")
        t_loss, t_acc, t_ce, t_sup = train_epoch_hybrid(
            model, train_loader, ce_loss_fn, supcon_loss_fn,
            optimizer, DEVICE, mixup_alpha=MIXUP_ALPHA_HY,
            supcon_weight=SUPCON_WEIGHT
        )
        v_loss, v_acc, v_labels, v_preds = eval_epoch(
            model, val_loader, ce_loss_fn, DEVICE, desc="Val (Hybrid)"
        )
        scheduler.step()

        print(f"Train Loss: {t_loss:.4f}, Acc: {t_acc:.4f} | "
              f"Val Loss: {v_loss:.4f}, Acc: {v_acc:.4f}")

        train_losses.append(t_loss); train_accs.append(t_acc)
        val_losses.append(v_loss);   val_accs.append(v_acc)

        if early(v_acc, model, epoch):
            print("Early stopping triggered (Hybrid).")
            break

    model.load_state_dict(early.best_state)
    model.to(DEVICE).eval()
    best_val = early.best_acc

    te_loss, te_acc, te_labels, te_preds = eval_epoch(
        model, test_loader, ce_loss_fn, DEVICE, desc="Test (Hybrid)"
    )
    print("\n[HYBRID] Test report:")
    print(classification_report(te_labels, te_preds, target_names=class_names))

    tta_preds = tta_hybrid(model, test_ds, BATCH_SIZE, DEVICE)
    tta_acc = (tta_preds == np.array(test_ds.targets)).mean()

    print(f"\n[{variant_name}] Hybrid Best Val Acc: {best_val*100:.2f}%")
    print(f"[{variant_name}] Hybrid Test Acc:     {te_acc*100:.2f}%")
    print(f"[{variant_name}] Hybrid TTA  Acc:     {tta_acc*100:.2f}%")

    n_params = count_parameters(model)
    size_mb  = compute_model_size_mb(model)

    return {
        "variant": variant_name,
        "type": "Hybrid",
        "best_val": best_val,
        "test_acc": te_acc,
        "tta_acc": tta_acc,
        "n_params": n_params,
        "size_mb": size_mb,
    }


# ============================================================
# 12. Run all four variants & summary
# ============================================================

if __name__ == "__main__":
    results = []

    # 1) Scratch + CE-only
    results.append(run_ce_variant(
        variant_name="scratch_ce",
        ckpt_path=SCRATCH_CKPT,
    ))

    # 2) Scratch + Hybrid
    results.append(run_hybrid_variant(
        variant_name="scratch_hybrid",
        ckpt_path=SCRATCH_CKPT,
    ))

    # 3) IN1K-init + CE-only
    results.append(run_ce_variant(
        variant_name="in1k_ce",
        ckpt_path=IN1K_CKPT,
    ))

    # 4) IN1K-init + Hybrid
    results.append(run_hybrid_variant(
        variant_name="in1k_hybrid",
        ckpt_path=IN1K_CKPT,
    ))

    print("\n" + "=" * 60)
    print("Summary over CBAM SimSiam fine-tune variants")
    print("=" * 60)
    print(f"{'Variant':15} | {'Type':7} | {'Params':12} | {'Size(MB)':9} | "
          f"{'Best Val':8} | {'Test':8} | {'TTA':8}")
    print("-" * 60)
    for r in results:
        print(
            f"{r['variant']:15} | "
            f"{r['type']:7} | "
            f"{r['n_params']:12,d} | "
            f"{r['size_mb']:9.2f} | "
            f"{r['best_val']*100:8.2f}% | "
            f"{r['test_acc']*100:8.2f}% | "
            f"{r['tta_acc']*100:8.2f}%"
        )



############################################################
[CE] Variant: scratch_ce
Checkpoint: /kaggle/working/simsiam_cbam_scratch/simsiam_cbam_pretrained_final.pth
############################################################

[CE] Epoch 1/50 | mixup_alpha=0.3000


                                                           

Train Loss: 1.0529, Acc: 0.3340 | Val Loss: 1.0782, Acc: 0.6061
Validation accuracy improved, saving best state.

[CE] Epoch 2/50 | mixup_alpha=0.3000


                                                           

Train Loss: 1.0378, Acc: 0.3467 | Val Loss: 1.0071, Acc: 0.4343
EarlyStopping counter: 1 / 10

[CE] Epoch 3/50 | mixup_alpha=0.3000


                                                           

Train Loss: 1.0229, Acc: 0.3784 | Val Loss: 0.9730, Acc: 0.6162
Validation accuracy improved, saving best state.

[CE] Epoch 4/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9935, Acc: 0.4313 | Val Loss: 0.9024, Acc: 0.5657
EarlyStopping counter: 1 / 10

[CE] Epoch 5/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9712, Acc: 0.4461 | Val Loss: 0.8622, Acc: 0.7172
Validation accuracy improved, saving best state.
Unfroze backbone at epoch 5

[CE] Epoch 6/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9357, Acc: 0.5264 | Val Loss: 0.7599, Acc: 0.7576
Validation accuracy improved, saving best state.

[CE] Epoch 7/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9364, Acc: 0.4672 | Val Loss: 0.7097, Acc: 0.7778
Validation accuracy improved, saving best state.

[CE] Epoch 8/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8999, Acc: 0.4968 | Val Loss: 0.6762, Acc: 0.7071
EarlyStopping counter: 1 / 10

[CE] Epoch 9/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8934, Acc: 0.5011 | Val Loss: 0.6301, Acc: 0.7576
EarlyStopping counter: 2 / 10

[CE] Epoch 10/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8753, Acc: 0.5264 | Val Loss: 0.5882, Acc: 0.8182
Validation accuracy improved, saving best state.

[CE] Epoch 11/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7769, Acc: 0.6237 | Val Loss: 0.4903, Acc: 0.8485
Validation accuracy improved, saving best state.

[CE] Epoch 12/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8235, Acc: 0.5835 | Val Loss: 0.4893, Acc: 0.8283
EarlyStopping counter: 1 / 10

[CE] Epoch 13/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8043, Acc: 0.5920 | Val Loss: 0.4713, Acc: 0.8687
Validation accuracy improved, saving best state.

[CE] Epoch 14/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8199, Acc: 0.5899 | Val Loss: 0.4746, Acc: 0.8889
Validation accuracy improved, saving best state.

[CE] Epoch 15/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7902, Acc: 0.5941 | Val Loss: 0.4962, Acc: 0.8182
EarlyStopping counter: 1 / 10

[CE] Epoch 16/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7454, Acc: 0.6279 | Val Loss: 0.4056, Acc: 0.8586
EarlyStopping counter: 2 / 10

[CE] Epoch 17/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7591, Acc: 0.6004 | Val Loss: 0.4362, Acc: 0.8586
EarlyStopping counter: 3 / 10

[CE] Epoch 18/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7840, Acc: 0.6089 | Val Loss: 0.4046, Acc: 0.8889
EarlyStopping counter: 4 / 10

[CE] Epoch 19/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7656, Acc: 0.5983 | Val Loss: 0.4019, Acc: 0.8788
EarlyStopping counter: 5 / 10

[CE] Epoch 20/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8340, Acc: 0.5645 | Val Loss: 0.4323, Acc: 0.8485
EarlyStopping counter: 6 / 10

[CE] Epoch 21/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7475, Acc: 0.6089 | Val Loss: 0.4158, Acc: 0.9091
Validation accuracy improved, saving best state.

[CE] Epoch 22/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7814, Acc: 0.6258 | Val Loss: 0.3974, Acc: 0.8788
EarlyStopping counter: 1 / 10

[CE] Epoch 23/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7855, Acc: 0.6131 | Val Loss: 0.4382, Acc: 0.8182
EarlyStopping counter: 2 / 10

[CE] Epoch 24/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7780, Acc: 0.6004 | Val Loss: 0.3885, Acc: 0.8687
EarlyStopping counter: 3 / 10

[CE] Epoch 25/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.6749, Acc: 0.6850 | Val Loss: 0.3745, Acc: 0.8889
EarlyStopping counter: 4 / 10

[CE] Epoch 26/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7454, Acc: 0.6237 | Val Loss: 0.3934, Acc: 0.8485
EarlyStopping counter: 5 / 10

[CE] Epoch 27/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7450, Acc: 0.6321 | Val Loss: 0.3777, Acc: 0.8384
EarlyStopping counter: 6 / 10

[CE] Epoch 28/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7658, Acc: 0.6004 | Val Loss: 0.3783, Acc: 0.8384
EarlyStopping counter: 7 / 10

[CE] Epoch 29/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.6941, Acc: 0.6702 | Val Loss: 0.3753, Acc: 0.8384
EarlyStopping counter: 8 / 10

[CE] Epoch 30/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7441, Acc: 0.6364 | Val Loss: 0.3543, Acc: 0.9091
EarlyStopping counter: 9 / 10

[CE] Epoch 31/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7383, Acc: 0.6321 | Val Loss: 0.3782, Acc: 0.8485
EarlyStopping counter: 10 / 10
Early stopping triggered (CE).


                                                        


[CE] Test report:
              precision    recall  f1-score   support

  Alternaria       1.00      0.81      0.90        37
Healthy Leaf       0.82      1.00      0.90        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.93        99
   macro avg       0.94      0.94      0.93        99
weighted avg       0.94      0.93      0.93        99






[scratch_ce] CE Best Val Acc: 90.91%
[scratch_ce] CE Test Acc:     92.93%
[scratch_ce] CE TTA  Acc:     92.93%

############################################################
[HYBRID] Variant: scratch_hybrid
Checkpoint: /kaggle/working/simsiam_cbam_scratch/simsiam_cbam_pretrained_final.pth
############################################################

[HYBRID] Epoch 1/50


                                                               

Train Loss: 2.0654, Acc: 0.3890 | Val Loss: 1.0735, Acc: 0.5455
Validation accuracy improved, saving best state.

[HYBRID] Epoch 2/50


                                                               

Train Loss: 2.0239, Acc: 0.5137 | Val Loss: 0.9903, Acc: 0.5253
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 3/50


                                                               

Train Loss: 2.0084, Acc: 0.4355 | Val Loss: 0.9114, Acc: 0.7172
Validation accuracy improved, saving best state.

[HYBRID] Epoch 4/50


                                                               

Train Loss: 1.9669, Acc: 0.5074 | Val Loss: 0.9130, Acc: 0.5253
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 5/50




Train Loss: 1.9461, Acc: 0.5011 | Val Loss: 0.8198, Acc: 0.7172
EarlyStopping counter: 2 / 8
Unfroze backbone at epoch 5 (Hybrid).

[HYBRID] Epoch 6/50


                                                               

Train Loss: 1.9627, Acc: 0.5180 | Val Loss: 0.7536, Acc: 0.7475
Validation accuracy improved, saving best state.

[HYBRID] Epoch 7/50


                                                               

Train Loss: 1.9644, Acc: 0.5032 | Val Loss: 0.7065, Acc: 0.7980
Validation accuracy improved, saving best state.

[HYBRID] Epoch 8/50


                                                               

Train Loss: 1.9801, Acc: 0.5032 | Val Loss: 0.7052, Acc: 0.7475
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 9/50


                                                               

Train Loss: 1.9317, Acc: 0.5856 | Val Loss: 0.5972, Acc: 0.7879
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 10/50


                                                               

Train Loss: 1.9087, Acc: 0.6364 | Val Loss: 0.5213, Acc: 0.8990
Validation accuracy improved, saving best state.

[HYBRID] Epoch 11/50


                                                               

Train Loss: 1.8880, Acc: 0.6512 | Val Loss: 0.4768, Acc: 0.8687
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 12/50


                                                               

Train Loss: 1.8719, Acc: 0.6195 | Val Loss: 0.4695, Acc: 0.8586
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 13/50


                                                               

Train Loss: 1.9247, Acc: 0.5772 | Val Loss: 0.4606, Acc: 0.9091
Validation accuracy improved, saving best state.

[HYBRID] Epoch 14/50


                                                               

Train Loss: 1.9003, Acc: 0.6068 | Val Loss: 0.4617, Acc: 0.9192
Validation accuracy improved, saving best state.

[HYBRID] Epoch 15/50


                                                               

Train Loss: 1.9174, Acc: 0.5920 | Val Loss: 0.4644, Acc: 0.8687
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 16/50


                                                               

Train Loss: 1.8661, Acc: 0.6279 | Val Loss: 0.5092, Acc: 0.8182
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 17/50


                                                               

Train Loss: 1.8688, Acc: 0.6448 | Val Loss: 0.4159, Acc: 0.8788
EarlyStopping counter: 3 / 8

[HYBRID] Epoch 18/50


                                                               

Train Loss: 1.8845, Acc: 0.6342 | Val Loss: 0.4601, Acc: 0.8384
EarlyStopping counter: 4 / 8

[HYBRID] Epoch 19/50


                                                               

Train Loss: 1.8961, Acc: 0.6321 | Val Loss: 0.4336, Acc: 0.8485
EarlyStopping counter: 5 / 8

[HYBRID] Epoch 20/50


                                                               

Train Loss: 1.8873, Acc: 0.6173 | Val Loss: 0.4218, Acc: 0.9091
EarlyStopping counter: 6 / 8

[HYBRID] Epoch 21/50


                                                               

Train Loss: 1.8328, Acc: 0.6660 | Val Loss: 0.4652, Acc: 0.8283
EarlyStopping counter: 7 / 8

[HYBRID] Epoch 22/50


                                                               

Train Loss: 1.8714, Acc: 0.6575 | Val Loss: 0.4034, Acc: 0.8889
EarlyStopping counter: 8 / 8
Early stopping triggered (Hybrid).


                                                            


[HYBRID] Test report:
              precision    recall  f1-score   support

  Alternaria       0.94      0.92      0.93        37
Healthy Leaf       0.91      0.97      0.94        31
  straw_mite       1.00      0.97      0.98        31

    accuracy                           0.95        99
   macro avg       0.95      0.95      0.95        99
weighted avg       0.95      0.95      0.95        99






[scratch_hybrid] Hybrid Best Val Acc: 91.92%
[scratch_hybrid] Hybrid Test Acc:     94.95%
[scratch_hybrid] Hybrid TTA  Acc:     95.96%

############################################################
[CE] Variant: in1k_ce
Checkpoint: /kaggle/working/simsiam_cbam_in1k/simsiam_cbam_pretrained_final.pth
############################################################

[CE] Epoch 1/50 | mixup_alpha=0.3000


                                                           

Train Loss: 1.0386, Acc: 0.4249 | Val Loss: 0.9950, Acc: 0.4848
Validation accuracy improved, saving best state.

[CE] Epoch 2/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9968, Acc: 0.4820 | Val Loss: 0.8984, Acc: 0.7475
Validation accuracy improved, saving best state.

[CE] Epoch 3/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.9696, Acc: 0.4461 | Val Loss: 0.8014, Acc: 0.8081
Validation accuracy improved, saving best state.

[CE] Epoch 4/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.8937, Acc: 0.5751 | Val Loss: 0.7025, Acc: 0.8283
Validation accuracy improved, saving best state.

[CE] Epoch 5/50 | mixup_alpha=0.3000




Train Loss: 0.8222, Acc: 0.6406 | Val Loss: 0.6047, Acc: 0.8485
Validation accuracy improved, saving best state.
Unfroze backbone at epoch 5

[CE] Epoch 6/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.7597, Acc: 0.6490 | Val Loss: 0.3628, Acc: 0.9091
Validation accuracy improved, saving best state.

[CE] Epoch 7/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.6916, Acc: 0.6575 | Val Loss: 0.2476, Acc: 0.9596
Validation accuracy improved, saving best state.

[CE] Epoch 8/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.6312, Acc: 0.6998 | Val Loss: 0.2129, Acc: 0.9596
EarlyStopping counter: 1 / 10

[CE] Epoch 9/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4901, Acc: 0.7674 | Val Loss: 0.1518, Acc: 0.9697
Validation accuracy improved, saving best state.

[CE] Epoch 10/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.5288, Acc: 0.7611 | Val Loss: 0.1331, Acc: 0.9798
Validation accuracy improved, saving best state.

[CE] Epoch 11/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3556, Acc: 0.8330 | Val Loss: 0.1319, Acc: 0.9596
EarlyStopping counter: 1 / 10

[CE] Epoch 12/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4957, Acc: 0.7780 | Val Loss: 0.1386, Acc: 0.9798
EarlyStopping counter: 2 / 10

[CE] Epoch 13/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4727, Acc: 0.7801 | Val Loss: 0.1436, Acc: 0.9697
EarlyStopping counter: 3 / 10

[CE] Epoch 14/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.5226, Acc: 0.7611 | Val Loss: 0.1462, Acc: 0.9798
EarlyStopping counter: 4 / 10

[CE] Epoch 15/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4984, Acc: 0.7780 | Val Loss: 0.1271, Acc: 0.9899
Validation accuracy improved, saving best state.

[CE] Epoch 16/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4336, Acc: 0.8140 | Val Loss: 0.1462, Acc: 0.9596
EarlyStopping counter: 1 / 10

[CE] Epoch 17/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3928, Acc: 0.8414 | Val Loss: 0.1123, Acc: 0.9899
EarlyStopping counter: 2 / 10

[CE] Epoch 18/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4746, Acc: 0.7970 | Val Loss: 0.1301, Acc: 0.9798
EarlyStopping counter: 3 / 10

[CE] Epoch 19/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4434, Acc: 0.8013 | Val Loss: 0.1154, Acc: 0.9899
EarlyStopping counter: 4 / 10

[CE] Epoch 20/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.5205, Acc: 0.7653 | Val Loss: 0.1363, Acc: 0.9899
EarlyStopping counter: 5 / 10

[CE] Epoch 21/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4174, Acc: 0.8097 | Val Loss: 0.1219, Acc: 0.9899
EarlyStopping counter: 6 / 10

[CE] Epoch 22/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3923, Acc: 0.8224 | Val Loss: 0.1830, Acc: 0.9596
EarlyStopping counter: 7 / 10

[CE] Epoch 23/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4356, Acc: 0.8161 | Val Loss: 0.1173, Acc: 1.0000
Validation accuracy improved, saving best state.

[CE] Epoch 24/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4400, Acc: 0.8182 | Val Loss: 0.1358, Acc: 0.9899
EarlyStopping counter: 1 / 10

[CE] Epoch 25/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3038, Acc: 0.8732 | Val Loss: 0.1242, Acc: 0.9899
EarlyStopping counter: 2 / 10

[CE] Epoch 26/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4600, Acc: 0.7992 | Val Loss: 0.1235, Acc: 0.9899
EarlyStopping counter: 3 / 10

[CE] Epoch 27/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4117, Acc: 0.8288 | Val Loss: 0.1135, Acc: 0.9899
EarlyStopping counter: 4 / 10

[CE] Epoch 28/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3819, Acc: 0.8203 | Val Loss: 0.1095, Acc: 1.0000
EarlyStopping counter: 5 / 10

[CE] Epoch 29/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3294, Acc: 0.8478 | Val Loss: 0.1031, Acc: 1.0000
EarlyStopping counter: 6 / 10

[CE] Epoch 30/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4411, Acc: 0.8034 | Val Loss: 0.1087, Acc: 1.0000
EarlyStopping counter: 7 / 10

[CE] Epoch 31/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.3901, Acc: 0.8161 | Val Loss: 0.1161, Acc: 1.0000
EarlyStopping counter: 8 / 10

[CE] Epoch 32/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4402, Acc: 0.8055 | Val Loss: 0.1153, Acc: 1.0000
EarlyStopping counter: 9 / 10

[CE] Epoch 33/50 | mixup_alpha=0.3000


                                                           

Train Loss: 0.4144, Acc: 0.8203 | Val Loss: 0.1172, Acc: 0.9899
EarlyStopping counter: 10 / 10
Early stopping triggered (CE).


                                                        


[CE] Test report:
              precision    recall  f1-score   support

  Alternaria       1.00      0.97      0.99        37
Healthy Leaf       0.97      1.00      0.98        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.99        99
   macro avg       0.99      0.99      0.99        99
weighted avg       0.99      0.99      0.99        99






[in1k_ce] CE Best Val Acc: 100.00%
[in1k_ce] CE Test Acc:     98.99%
[in1k_ce] CE TTA  Acc:     98.99%

############################################################
[HYBRID] Variant: in1k_hybrid
Checkpoint: /kaggle/working/simsiam_cbam_in1k/simsiam_cbam_pretrained_final.pth
############################################################

[HYBRID] Epoch 1/50


                                                               

Train Loss: 1.9716, Acc: 0.4186 | Val Loss: 1.0337, Acc: 0.6465
Validation accuracy improved, saving best state.

[HYBRID] Epoch 2/50


                                                               

Train Loss: 1.9244, Acc: 0.5814 | Val Loss: 0.9286, Acc: 0.6465
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 3/50


                                                               

Train Loss: 1.8030, Acc: 0.5708 | Val Loss: 0.8046, Acc: 0.8384
Validation accuracy improved, saving best state.

[HYBRID] Epoch 4/50


                                                               

Train Loss: 1.7585, Acc: 0.6469 | Val Loss: 0.7420, Acc: 0.7374
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 5/50




Train Loss: 1.6655, Acc: 0.6258 | Val Loss: 0.6089, Acc: 0.8687
Validation accuracy improved, saving best state.
Unfroze backbone at epoch 5 (Hybrid).

[HYBRID] Epoch 6/50


                                                               

Train Loss: 1.7672, Acc: 0.6575 | Val Loss: 0.4761, Acc: 0.9192
Validation accuracy improved, saving best state.

[HYBRID] Epoch 7/50


                                                               

Train Loss: 1.6916, Acc: 0.6998 | Val Loss: 0.3569, Acc: 0.9495
Validation accuracy improved, saving best state.

[HYBRID] Epoch 8/50


                                                               

Train Loss: 1.7109, Acc: 0.7273 | Val Loss: 0.4039, Acc: 0.8586
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 9/50


                                                               

Train Loss: 1.6399, Acc: 0.7653 | Val Loss: 0.3025, Acc: 0.9495
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 10/50


                                                               

Train Loss: 1.6661, Acc: 0.7970 | Val Loss: 0.2571, Acc: 0.9495
EarlyStopping counter: 3 / 8

[HYBRID] Epoch 11/50


                                                               

Train Loss: 1.6518, Acc: 0.8013 | Val Loss: 0.2392, Acc: 0.9798
Validation accuracy improved, saving best state.

[HYBRID] Epoch 12/50


                                                               

Train Loss: 1.5429, Acc: 0.8140 | Val Loss: 0.2760, Acc: 0.9697
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 13/50


                                                               

Train Loss: 1.6715, Acc: 0.7674 | Val Loss: 0.2457, Acc: 0.9899
Validation accuracy improved, saving best state.

[HYBRID] Epoch 14/50


                                                               

Train Loss: 1.6420, Acc: 0.8013 | Val Loss: 0.2273, Acc: 0.9798
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 15/50


                                                               

Train Loss: 1.6122, Acc: 0.8055 | Val Loss: 0.2200, Acc: 0.9798
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 16/50


                                                               

Train Loss: 1.5808, Acc: 0.8414 | Val Loss: 0.2265, Acc: 0.9899
EarlyStopping counter: 3 / 8

[HYBRID] Epoch 17/50


                                                               

Train Loss: 1.5999, Acc: 0.8245 | Val Loss: 0.2432, Acc: 0.9899
EarlyStopping counter: 4 / 8

[HYBRID] Epoch 18/50


                                                               

Train Loss: 1.6432, Acc: 0.8140 | Val Loss: 0.2467, Acc: 0.9798
EarlyStopping counter: 5 / 8

[HYBRID] Epoch 19/50


                                                               

Train Loss: 1.5920, Acc: 0.8161 | Val Loss: 0.2331, Acc: 0.9899
EarlyStopping counter: 6 / 8

[HYBRID] Epoch 20/50


                                                               

Train Loss: 1.6033, Acc: 0.8161 | Val Loss: 0.2049, Acc: 1.0000
Validation accuracy improved, saving best state.

[HYBRID] Epoch 21/50


                                                               

Train Loss: 1.4985, Acc: 0.8605 | Val Loss: 0.2057, Acc: 0.9899
EarlyStopping counter: 1 / 8

[HYBRID] Epoch 22/50


                                                               

Train Loss: 1.5945, Acc: 0.8330 | Val Loss: 0.2123, Acc: 0.9899
EarlyStopping counter: 2 / 8

[HYBRID] Epoch 23/50


                                                               

Train Loss: 1.5679, Acc: 0.8140 | Val Loss: 0.2520, Acc: 0.9798
EarlyStopping counter: 3 / 8

[HYBRID] Epoch 24/50


                                                               

Train Loss: 1.4846, Acc: 0.8689 | Val Loss: 0.2061, Acc: 0.9899
EarlyStopping counter: 4 / 8

[HYBRID] Epoch 25/50


                                                               

Train Loss: 1.4883, Acc: 0.8414 | Val Loss: 0.2509, Acc: 0.9697
EarlyStopping counter: 5 / 8

[HYBRID] Epoch 26/50


                                                               

Train Loss: 1.5511, Acc: 0.7886 | Val Loss: 0.2243, Acc: 0.9899
EarlyStopping counter: 6 / 8

[HYBRID] Epoch 27/50


                                                               

Train Loss: 1.4620, Acc: 0.8626 | Val Loss: 0.2069, Acc: 1.0000
EarlyStopping counter: 7 / 8

[HYBRID] Epoch 28/50


                                                               

Train Loss: 1.4030, Acc: 0.8541 | Val Loss: 0.2072, Acc: 1.0000
EarlyStopping counter: 8 / 8
Early stopping triggered (Hybrid).


                                                            


[HYBRID] Test report:
              precision    recall  f1-score   support

  Alternaria       1.00      0.95      0.97        37
Healthy Leaf       0.94      1.00      0.97        31
  straw_mite       1.00      1.00      1.00        31

    accuracy                           0.98        99
   macro avg       0.98      0.98      0.98        99
weighted avg       0.98      0.98      0.98        99






[in1k_hybrid] Hybrid Best Val Acc: 100.00%
[in1k_hybrid] Hybrid Test Acc:     97.98%
[in1k_hybrid] Hybrid TTA  Acc:     100.00%

Summary over CBAM SimSiam fine-tune variants
Variant         | Type    | Params       | Size(MB)  | Best Val | Test     | TTA     
------------------------------------------------------------
scratch_ce      | CE      |   27,075,171 |    103.28 |    90.91% |    92.93% |    92.93%
scratch_hybrid  | Hybrid  |   27,337,443 |    104.28 |    91.92% |    94.95% |    95.96%
in1k_ce         | CE      |   27,075,171 |    103.28 |   100.00% |    98.99% |    98.99%
in1k_hybrid     | Hybrid  |   27,337,443 |    104.28 |   100.00% |    97.98% |   100.00%
