In [None]:
# train_tstc.py
# TSTC (Triple-branch Swin Transformer with CBP and Deep Supervision) for disease + severity
# Implements full five-term loss: l1 (disease), l2 (final severity from feat23), l3 (severity from feat13),
# l4 (severity from feat2), l5 (severity from feat3), with a filtered AI-Challenger dataloader.

import os
import json
from typing import Tuple, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import swin_t, Swin_T_Weights
from PIL import Image

from sklearn.metrics import confusion_matrix
import numpy as np

# ------------------------------
# Compact Bilinear Pooling (CBP)
# ------------------------------

class CompactBilinearPooling(nn.Module):
    def __init__(self, input_dim1: int, input_dim2: int, output_dim: int):
        super().__init__()
        self.output_dim = output_dim
        self.register_buffer('h1', torch.randint(0, output_dim, (input_dim1,)))
        self.register_buffer('s1', 2 * torch.randint(0, 2, (input_dim1,)) - 1)
        self.register_buffer('h2', torch.randint(0, output_dim, (input_dim2,)))
        self.register_buffer('s2', 2 * torch.randint(0, 2, (input_dim2,)) - 1)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        fx1 = self._count_sketch(x1, self.h1, self.s1)
        fx2 = self._count_sketch(x2, self.h2, self.s2)
        cbp = torch.fft.ifft(torch.fft.fft(fx1) * torch.fft.fft(fx2)).real
        cbp = torch.sign(cbp) * torch.sqrt(torch.clamp(cbp.abs(), min=1e-8))
        cbp = F.normalize(cbp, p=2, dim=1)
        cbp = cbp / (self.output_dim ** 0.5)  # scale by 1/sqrt(d)
        return cbp

    def _count_sketch(self, x: torch.Tensor, h: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        h = h.to(x.device)
        s = s.to(x.device)
        h_batch = h.unsqueeze(0).repeat(x.size(0), 1)
        s_batch = s.unsqueeze(0)
        out = torch.zeros(x.size(0), self.output_dim, device=x.device)
        out.scatter_add_(1, h_batch, x * s_batch)
        return out

# --------------
# TSTC Backbone
# --------------

class TSTC(nn.Module):
    def __init__(
        self,
        num_disease_classes: int,
        num_severity_classes: int,
        cbp_output_dim: int = 7680
    ):
        super().__init__()

        # Swin-T backbone
        swin = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)

        # Shared stages (follow the provided notebook slicing of torchvision Swin blocks)
        self.stage1 = swin.features[:2]
        self.stage2 = swin.features[2:4]

        # Split stages for the three branches
        swin_stage3 = swin.features[4:6]
        swin_stage4 = swin.features[6:8]

        import copy
        self.branch1_stage3 = copy.deepcopy(swin_stage3)  # disease
        self.branch1_stage4 = copy.deepcopy(swin_stage4)

        self.branch2_stage3 = copy.deepcopy(swin_stage3)  # severity
        self.branch2_stage4 = copy.deepcopy(swin_stage4)

        self.branch3_stage3 = copy.deepcopy(swin_stage3)  # deep supervision
        self.branch3_stage4 = copy.deepcopy(swin_stage4)

        feat_dim = 768  # Swin-T final stage embedding dim

        # CBP fusions
        self.cbp23 = CompactBilinearPooling(feat_dim, feat_dim, cbp_output_dim)   # for final severity
        self.cbp13 = CompactBilinearPooling(feat_dim, feat_dim, cbp_output_dim)   # for deep supervision (feat1 x feat3)

        # Heads:
        # - disease final (feat1)
        self.fc_disease = nn.Linear(feat_dim, num_disease_classes)
        # - severity final (feat23)
        self.fc_severity_final = nn.Linear(cbp_output_dim, num_severity_classes)
        # - deep supervision severity heads
        self.fc_sev_feat2 = nn.Linear(feat_dim, num_severity_classes)   # l4
        self.fc_sev_feat3 = nn.Linear(feat_dim, num_severity_classes)   # l5
        self.fc_sev_feat13 = nn.Linear(cbp_output_dim, num_severity_classes)  # l3

    def _pool(self, feat_4d: torch.Tensor) -> torch.Tensor:
        # Input: [B, H, W, C] from Swin features → pool to [B, C]
        return F.adaptive_avg_pool2d(feat_4d.permute(0, 3, 1, 2), 1).flatten(1)

    def forward(self, x: torch.Tensor):
        shared = self.stage2(self.stage1(x))

        # Three branches
        f1_4d = self.branch1_stage4(self.branch1_stage3(shared))
        f2_4d = self.branch2_stage4(self.branch2_stage3(shared))
        f3_4d = self.branch3_stage4(self.branch3_stage3(shared))

        f1 = self._pool(f1_4d)   # disease-focused
        f2 = self._pool(f2_4d)   # severity-focused
        f3 = self._pool(f3_4d)   # deep-supervision-focused

        # Fusions
        f23 = self.cbp23(f2, f3)
        f13 = self.cbp13(f1, f3)

        # Outputs
        disease_out = self.fc_disease(f1)              # l1
        severity_final = self.fc_severity_final(f23)   # l2
        sev_feat13 = self.fc_sev_feat13(f13)           # l3
        sev_feat2 = self.fc_sev_feat2(f2)              # l4
        sev_feat3 = self.fc_sev_feat3(f3)              # l5

        if self.training:
            return disease_out, severity_final, sev_feat13, sev_feat2, sev_feat3
        else:
            # Inference typically uses the final heads
            return disease_out, severity_final

def _compute_label_counts(train_ds):
    # Count disease/severity occurrences using dataset’s mapping and samples
    num_d = train_ds.num_disease_classes
    num_s = train_ds.num_severity_classes
    d_counts = torch.zeros(num_d, dtype=torch.long)
    s_counts = torch.zeros(num_s, dtype=torch.long)
    for _, numeric_label in train_ds.samples:
        d = train_ds.label2disease[numeric_label]
        s = train_ds.label2severity[numeric_label]
        d_counts[d] += 1
        s_counts[s] += 1
    return d_counts, s_counts

def _class_weights_from_counts(counts: torch.Tensor, scheme: str = "effective", beta: float = 0.999, max_w: float = 10.0):
    counts = counts.float().clamp(min=1.0)
    if scheme == "effective":
        # Class-Balanced Loss (Cui et al.): w_c = (1 - β) / (1 - β^{n_c})
        w = (1.0 - beta) / (1.0 - torch.pow(beta, counts))
    else:
        # Inverse frequency
        w = 1.0 / counts
    # Normalize to mean ≈ 1 and clip extremes
    w = w * (w.numel() / w.sum().clamp(min=1e-8))
    w = w.clamp(max=max_w)
    return w
# ------------
# Five-term CE
# ------------

class TSTCLoss(nn.Module):
    def __init__(
        self,
        disease_weight: torch.Tensor | None = None,
        severity_weight: torch.Tensor | None = None,
        aux_lambda: float = 0.5,
        print_every: int = 0,
    ):
        super().__init__()
        # Register as buffers so they move with .to(device)
        if disease_weight is not None:
            self.register_buffer("w_disease", disease_weight.float())
        else:
            self.w_disease = None
        if severity_weight is not None:
            self.register_buffer("w_severity", severity_weight.float())
        else:
            self.w_severity = None

        self.aux_lambda = float(aux_lambda)
        self.print_every = int(print_every)
        self._step = 0

    def forward(self, outputs, targets):
        disease_out, severity_final, sev_feat13, sev_feat2, sev_feat3 = outputs
        disease_labels, severity_labels = targets

        # Ensure weights live on the same device as logits
        wd = self.w_disease.to(disease_out.device) if self.w_disease is not None else None
        ws = self.w_severity.to(severity_final.device) if self.w_severity is not None else None

        l1 = F.cross_entropy(disease_out, disease_labels, weight=wd)         # disease
        l2 = F.cross_entropy(severity_final, severity_labels, weight=ws)     # final severity (feat23)
        l3 = F.cross_entropy(sev_feat13, severity_labels, weight=ws)         # aux severity (feat13)
        l4 = F.cross_entropy(sev_feat2, severity_labels, weight=ws)          # aux severity (feat2)
        l5 = F.cross_entropy(sev_feat3, severity_labels, weight=ws)          # aux severity (feat3)

        total = l1 + l2 + self.aux_lambda * (l3 + l4 + l5)

        # Optional periodic logging
        if self.print_every > 0 and (self._step % self.print_every == 0):
            print(
                f"Losses => Disease: {l1.item():.4f}, "
                f"Severity_final: {l2.item():.4f}, "
                f"Sev_feat13: {l3.item():.4f}, "
                f"Sev_feat2: {l4.item():.4f}, "
                f"Sev_feat3: {l5.item():.4f}, "
                f"Total: {total.item():.4f}"
            )
        self._step += 1
        return total
# -------------------------------
# AI-Challenger filtered dataset
# -------------------------------

# Keep only labels with severity information plus healthy
SELECTED_LABEL_IDS = [
    14,15, 39,40, 48,49, 52,53, 50,51,
    34,35,44,45, 36,37,46,47, 10,11, 28,29,
    7,8,42,43, 54,55, 22,23, 20,21, 12,13, 4,5,
    56,57, 26,27, 2,3, 18,19,
    # healthy labels
    0,6,9,17,24,30,33,38,41
]

class AIChallengerSubset(Dataset):
    def __init__(self, json_path: str, img_root: str, transform=None):
        self.img_root = img_root
        self.transform = transform

        with open(json_path, 'r') as f:
            all_data = json.load(f)

        # Filter samples
        self.samples = [
            (item["image_id"], item["disease_class"])
            for item in all_data
            if item["disease_class"] in SELECTED_LABEL_IDS
        ]

        # Build maps: numeric label -> disease group index, severity index
        self.label2disease, self.label2severity = self._build_label_maps()

        # Derive class counts
        self.num_disease_classes = len(set(self.label2disease.values()))
        self.num_severity_classes = len(set(self.label2severity.values()))

    def _build_label_maps(self) -> Tuple[Dict[int, int], Dict[int, int]]:
        # 19 disease groups indexed from 1; healthy -> 0
        disease_groups = [
            [14,15], [39,40], [48,49], [52,53], [50,51],
            [34,35,44,45], [36,37,46,47], [10,11], [28,29],
            [7,8,42,43], [54,55], [22,23], [20,21],
            [12,13], [4,5], [56,57], [26,27], [2,3], [18,19]
        ]
        label2disease, label2severity = {}, {}

        # General (1) vs Serious (2) by even/odd parity in the provided mapping
        for i, grp in enumerate(disease_groups, 1):
            for label in grp:
                label2disease[label] = i
                label2severity[label] = 2 if (label % 2 == 1) else 1

        # Healthy labels → disease 0, severity 0
        for label in [0, 6, 9, 17, 24, 30, 33, 38, 41]:
            label2disease[label] = 0
            label2severity[label] = 0

        return label2disease, label2severity

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        img_name, numeric_label = self.samples[idx]
        img_path = os.path.join(self.img_root, img_name)

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        disease_label = self.label2disease[numeric_label]
        severity_label = self.label2severity[numeric_label]
        return image, disease_label, severity_label

# ------------------
# Training utilities
# ------------------

def get_loaders(
    train_json: str,
    train_images: str,
    val_json: str,
    val_images: str,
    batch_size: int = 32,
    num_workers: int = 4
):
    tf_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])
    tf_eval = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    train_ds = AIChallengerSubset(train_json, train_images, transform=tf_train)
    val_ds = AIChallengerSubset(val_json, val_images, transform=tf_eval)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, train_ds, val_ds

@torch.no_grad()
def evaluate(loader: DataLoader, model: nn.Module, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total, correct_d, correct_s = 0, 0, 0
    for images, d_labels, s_labels in loader:
        images = images.to(device)
        d_labels = d_labels.to(device)
        s_labels = s_labels.to(device)
        d_logits, s_logits = model(images)
        d_pred = d_logits.argmax(1)
        s_pred = s_logits.argmax(1)
        total += d_labels.size(0)
        correct_d += (d_pred == d_labels).sum().item()
        correct_s += (s_pred == s_labels).sum().item()
    return 100.0 * correct_d / total, 100.0 * correct_s / total

from tqdm.auto import tqdm  # add this import at top of your file

import os
from tqdm.auto import tqdm

def train(
    train_json: str,
    train_images: str,
    val_json: str,
    val_images: str,
    epochs: int = 20,
    batch_size: int = 32,
    lr: float = 0.01,
    momentum: float = 0.9,
    weight_decay: float = 1e-4,
    cbp_output_dim: int = 7680,
    out_ckpt: str = "tstc_best.pt",
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    train_loader, val_loader, train_ds, _ = get_loaders(
        train_json, train_images, val_json, val_images, batch_size=batch_size
    )

    num_disease = train_ds.num_disease_classes
    num_severity = train_ds.num_severity_classes
    print(f"num_disease = {num_disease} || num_severity = {num_severity}")

    model = TSTC(num_disease, num_severity, cbp_output_dim=cbp_output_dim).to(device)
    criterion = TSTCLoss()

    optimizer = optim.SGD(
        model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )

    best_val_score = -1.0

    # Prepare checkpoint directory and track current-best file
    ckpt_dir = os.path.dirname(out_ckpt) if os.path.dirname(out_ckpt) else "."
    os.makedirs(ckpt_dir, exist_ok=True)
    current_best_path = None

    epoch_bar = tqdm(range(1, epochs + 1), desc="Epochs", position=0)
    for epoch in epoch_bar:
        model.train()
        running = 0.0

        batch_bar = tqdm(
            train_loader,
            desc=f"Train {epoch}/{epochs}",
            leave=False,
            position=1,
        )

        for images, d_labels, s_labels in batch_bar:
            images = images.to(device)
            d_labels = d_labels.to(device)
            s_labels = s_labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)  # returns 5 heads in train mode
            loss = criterion(outputs, (d_labels, s_labels))
            loss.backward()
            optimizer.step()

            running += loss.item()
            avg_loss = running / max(1, batch_bar.n)
            batch_bar.set_postfix(loss=f"{avg_loss:.4f}")

        train_d_acc, train_s_acc = evaluate(train_loader, model, device)
        val_d_acc, val_s_acc = evaluate(val_loader, model, device)

        epoch_bar.set_postfix(
            loss=f"{running/len(train_loader):.4f}",
            trainD=f"{train_d_acc:.2f}%",
            trainS=f"{train_s_acc:.2f}%",
            valD=f"{val_d_acc:.2f}%",
            valS=f"{val_s_acc:.2f}%"
        )

        print(
            f"Epoch {epoch:03d} | loss={running/len(train_loader):.4f} "
            f"| train D={train_d_acc:.2f}% S={train_s_acc:.2f}% "
            f"| val D={val_d_acc:.2f}% S={val_s_acc:.2f}%"
        )

        combined = 0.5 * (val_d_acc + val_s_acc)
        if combined > best_val_score:
            best_val_score = combined

            # Save to a temp path first to avoid partial writes
            tmp_path = os.path.join(ckpt_dir, f".tmp_epoch{epoch}.pt")
            payload = {
                "model": model.state_dict(),
                "num_disease": num_disease,
                "num_severity": num_severity,
                "epoch": epoch,
                "val_d_acc": val_d_acc,
                "val_s_acc": val_s_acc,
                "combined": combined,
            }
            torch.save(payload, tmp_path)

            # Remove previous best, if any
            if current_best_path and os.path.exists(current_best_path):
                try:
                    os.remove(current_best_path)
                except OSError:
                    pass

            # Atomically move tmp to final out_ckpt
            try:
                os.replace(tmp_path, out_ckpt)
            except Exception:
                # Fallback: copy then remove
                import shutil
                shutil.copyfile(tmp_path, out_ckpt)
                os.remove(tmp_path)

            current_best_path = out_ckpt
            tqdm.write(f"  -> Saved best checkpoint to {out_ckpt} (epoch {epoch}, combined={combined:.3f})")


# --- Validate dataset paths before proceeding ---
import os, sys
def verify_path(json_path, img_dir, name="train"):
    print(f"\n[Checking {name} dataset paths...]")
    if not os.path.isfile(json_path):
        print(f"❌ ERROR: The JSON file for {name} dataset was not found:\n   → {json_path}")
        print("💡 Fix: Ensure TRAIN_JSON / VAL_JSON points to a valid file.\n")
        sys.exit(1)
    if not os.path.isdir(img_dir):
        print(f"❌ ERROR: The image directory for {name} dataset was not found:\n   → {img_dir}")
        print("💡 Fix: Ensure TRAIN_IMG_DIR / VAL_IMG_DIR points to an existing folder containing images.\n")
        sys.exit(1)
    print(f"✅ {name.capitalize()} dataset paths verified successfully.")

if __name__ == "__main__":
    print("starting")
    # TODO: set these paths to the AI-Challenger 2018 dataset locations
    # Train/val split can be constructed by random split or using provided splits if available
    TRAIN_JSON = "/kaggle/input/ai-challenger-dataset/ai_challenger_pdr2018/train/train_label.json"
    TRAIN_IMG_DIR = "/kaggle/input/ai-challenger-dataset/ai_challenger_pdr2018/train/images"
    VAL_JSON = "/kaggle/input/ai-challenger-dataset/ai_challenger_pdr2018/test/test_label.json" 
    VAL_IMG_DIR = "/kaggle/input/ai-challenger-dataset/ai_challenger_pdr2018/test/images"

    # Minimal guard to prevent accidental run without paths
    verify_path(TRAIN_JSON, TRAIN_IMG_DIR, "train")
    verify_path(VAL_JSON, VAL_IMG_DIR, "validation")

    train(
        train_json=TRAIN_JSON,
        train_images=TRAIN_IMG_DIR,
        val_json=VAL_JSON,
        val_images=VAL_IMG_DIR,
        epochs=100,
        batch_size=32,
        lr=0.001,
        momentum=0.9,
        weight_decay=1e-4,
        cbp_output_dim=7680,
        out_ckpt="tstc_best.pt",
    )
