In [2]:
# ================================================================
# FathomNet - ResNet50 Multilabel Classifier
# With Focal Loss, Tunable Threshold, Per-Class Metrics,
# Confusion Matrix visualization and sample predictions + t-SNE
# ================================================================

import os
import json
import random
import math
from collections import defaultdict, Counter
from tqdm import tqdm

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

from sklearn.metrics import (
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
    multilabel_confusion_matrix,
)
import seaborn as sns

# -----------------------
# CONFIG
# -----------------------
DATA_ROOT = "/kaggle/input/fathomnet-out-of-sample-detection"
IMG_DIR = "/kaggle/input/fathomnetimages/kaggle/working/images"
JSON_PATH = os.path.join(DATA_ROOT, "object_detection/train.json")
TRAIN_CSV = os.path.join(DATA_ROOT, "multilabel_classification/train.csv")
CATEGORY_KEY = os.path.join(DATA_ROOT, "category_key.csv")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
EPOCHS = 40
LR = 1e-4
IMG_SIZE = (224, 224)
THRESH = 0.5          # Tunable sigmoid probability threshold
RANDOM_SEED = 42
NUM_WORKERS = 2
MODEL_SAVE_PATH = "best_model_sampleiou.pt"

MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# -----------------------
# Focal Loss Implementation
# -----------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        loss = focal_weight * bce_loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

# -----------------------
# Load category mapping
# -----------------------
cat_df = pd.read_csv(CATEGORY_KEY)
cat_df["id"] = cat_df["id"].astype(int)
cat_ids = cat_df["id"].tolist()
cat_id_to_idx = {cid: i for i, cid in enumerate(cat_ids)}
cat_id_to_name = dict(zip(cat_df["id"], cat_df["name"]))
NUM_CLASSES = len(cat_ids)
print(f"Loaded {NUM_CLASSES} categories.")

# -----------------------
# JSON + CSV linking
# -----------------------
with open(JSON_PATH, "r") as f:
    coco = json.load(f)
images_by_id = {img["id"]: img for img in coco["images"]}

anns_by_stem = defaultdict(list)
for ann in coco["annotations"]:
    img_info = images_by_id.get(ann["image_id"])
    if img_info:
        stem = os.path.splitext(img_info["file_name"])[0]
        anns_by_stem[stem].append(ann)

df = pd.read_csv(TRAIN_CSV)
csv_image_to_cats = {}
for _, row in df.iterrows():
    stem = str(row["id"]).strip()
    cats = []
    raw = row["categories"]
    if isinstance(raw, str) and raw.startswith("["):
        parts = [x.strip() for x in raw.strip("[]").split(",") if x.strip()]
        try:
            cats = [int(float(x)) for x in parts]
        except:
            cats = []
    else:
        try:
            cats = [int(float(raw))]
        except:
            cats = []
    csv_image_to_cats[stem] = cats

# find stems present in CSV, JSON (anns), and directory
dir_stems = {os.path.splitext(f)[0] for f in os.listdir(IMG_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))}
common_stems = set(csv_image_to_cats.keys()) & set(anns_by_stem.keys()) & dir_stems
common_list = sorted(list(common_stems))
random.shuffle(common_list)
split_idx = int(0.8 * len(common_list))
train_stems = common_list[:split_idx]
val_stems   = common_list[split_idx:]

print(f"Total usable stems: {len(common_list)} | Train: {len(train_stems)} | Val: {len(val_stems)}")

# -----------------------
# Dataset
# -----------------------
class FathomNetDataset(Dataset):
    def __init__(self, stems, img_dir, csv_map, anns_map, transform=None, img_size=IMG_SIZE):
        self.stems = stems
        self.img_dir = img_dir
        self.csv_map = csv_map
        self.anns_map = anns_map
        self.transform = transform
        self.img_size = img_size

    def __len__(self):
        return len(self.stems)

    def __getitem__(self, idx):
        stem = self.stems[idx]
        path = None
        for ext in (".png", ".jpg", ".jpeg"):
            p = os.path.join(self.img_dir, stem + ext)
            if os.path.exists(p):
                path = p
                break
        if path is None:
            img = Image.new("RGB", self.img_size, (0, 0, 0))
        else:
            img = Image.open(path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        label = torch.zeros(NUM_CLASSES, dtype=torch.float32)
        for c in self.csv_map.get(stem, []):
            if c in cat_id_to_idx:
                label[cat_id_to_idx[c]] = 1.0

        return img, label, stem

# transforms
train_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.1, 0.1, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

val_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# create dataset objects (so we can reuse them later for visualization)
train_dataset = FathomNetDataset(train_stems, IMG_DIR, csv_image_to_cats, anns_by_stem, train_tfms)
val_dataset   = FathomNetDataset(val_stems,   IMG_DIR, csv_image_to_cats, anns_by_stem, val_tfms)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True
)

# -----------------------
# Model
# -----------------------
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = FocalLoss(alpha=1.0, gamma=2.0)
optimizer = optim.AdamW(model.parameters(), lr=LR)

# -----------------------
# Utility helpers
# -----------------------
def unnormalize_tensor(tensor, mean=MEAN, std=STD):
    t = tensor.clone().cpu()
    for c in range(3):
        t[c] = t[c] * std[c] + mean[c]
    t = torch.clamp(t, 0.0, 1.0)
    return t

def tensor_to_numpy_img(tensor):
    # tensor assumed CHW normalized
    img = unnormalize_tensor(tensor)
    img_np = img.permute(1, 2, 0).numpy()
    return img_np

# -----------------------
# Evaluation helper
# -----------------------
def evaluate(model, loader, thresh=THRESH):
    model.eval()
    all_true, all_pred = [], []
    with torch.no_grad():
        for imgs, labels, _ in loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            logits = model(imgs)
            probs = torch.sigmoid(logits)
            preds = (probs > thresh).float()
            all_true.append(labels.cpu().numpy())
            all_pred.append(preds.cpu().numpy())
    if len(all_true) == 0:
        return 0,0,0,0,0,np.zeros((0,NUM_CLASSES)),np.zeros((0,NUM_CLASSES)), np.zeros(NUM_CLASSES), np.zeros(NUM_CLASSES), np.zeros(NUM_CLASSES)
    all_true = np.vstack(all_true)
    all_pred = np.vstack(all_pred)

    f1_micro = f1_score(all_true, all_pred, average="micro", zero_division=0)
    prec_micro = precision_score(all_true, all_pred, average="micro", zero_division=0)
    rec_micro = recall_score(all_true, all_pred, average="micro", zero_division=0)

    # per-class metrics
    per_class_f1 = f1_score(all_true, all_pred, average=None, zero_division=0)
    per_class_prec = precision_score(all_true, all_pred, average=None, zero_division=0)
    per_class_rec = recall_score(all_true, all_pred, average=None, zero_division=0)

    intersection = (all_true * all_pred).sum(axis=1)
    union = ((all_true + all_pred) > 0).sum(axis=1)
    sample_iou = np.mean(intersection / np.clip(union, a_min=1, a_max=None))
    exact_acc = accuracy_score(all_true, all_pred)

    return sample_iou, exact_acc, f1_micro, prec_micro, rec_micro, all_true, all_pred, per_class_f1, per_class_prec, per_class_rec

# -----------------------
# t-SNE helper (correct feature extraction + multi-label handling)
# -----------------------
def extract_backbone_features(model, loader, device):
    """
    Extracts backbone features (pre-FC) for all samples in loader.
    Returns:
      feats: (N, D) numpy array
      labels: (N, NUM_CLASSES) numpy array (multi-hot)
      stems: list of stems
    """
    model.eval()
    # build backbone extractor for ResNet-like model (everything up to the final avgpool)
    # Using children approach is robust for torchvision ResNet
    backbone = nn.Sequential(*list(model.children())[:-1]).to(device)
    all_feats = []
    all_labels = []
    all_stems = []
    with torch.no_grad():
        for imgs, labels, stems in tqdm(loader, desc="Extracting backbone features"):
            imgs = imgs.to(device)
            feats = backbone(imgs)    # shape: (B, 2048, 1, 1)
            feats = feats.view(feats.size(0), -1)  # (B, 2048)
            all_feats.append(feats.cpu())
            all_labels.append(labels.cpu())
            all_stems.extend(stems)
    feats = torch.cat(all_feats, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    return feats, labels, all_stems

def compute_sample_primary_label_per_sample(labels, global_class_freq):
    """
    labels: (N, C) multi-hot
    global_class_freq: dict class->freq
    For each sample, pick the primary label:
      - if only one label present -> that
      - if multiple -> choose the one with highest global freq (so that t-SNE groups by frequent classes)
      - if none -> -1
    Returns: array shape (N,) of primary label indices (or -1)
    """
    primaries = []
    for row in labels:
        inds = np.where(row > 0.5)[0]
        if len(inds) == 0:
            primaries.append(-1)
        elif len(inds) == 1:
            primaries.append(inds[0])
        else:
            # choose the one with highest global frequency
            best = max(inds, key=lambda x: global_class_freq.get(x, 0))
            primaries.append(int(best))
    return np.array(primaries, dtype=int)

def tsne_for_top_classes(model, loader, class_names, device, top_k=10, title_suffix="After Training", max_points=2000, perplexity=30):
    """
    Compute t-SNE for top_k frequent classes (global frequency over whole dataset).
    Will plot only samples whose primary label (computed heuristically) is in top_classes.
    """
    # first compute global class frequencies from the loader labels
    # simpler: accumulate labels from loader (they are multi-hot)
    all_labels = []
    for imgs, labels, _ in loader:
        all_labels.append(labels.numpy())
    all_labels = np.vstack(all_labels) if len(all_labels) > 0 else np.zeros((0, NUM_CLASSES))
    class_counts = dict(enumerate(all_labels.sum(axis=0).astype(int)))
    top_classes = [c for c,_ in sorted(class_counts.items(), key=lambda x: -x[1])][:top_k]
    if len(top_classes) == 0:
        print("No classes found for t-SNE (empty dataset).")
        return

    print(f"Top-{top_k} classes by frequency: {[ (c, class_counts[c], class_names[c]) for c in top_classes ]}")

    # extract features
    feats, labels, stems = extract_backbone_features(model, loader, device)
    # compute primary label per sample
    primaries = compute_sample_primary_label_per_sample(labels, class_counts)

    # select samples whose primary label is in top_classes
    mask = np.isin(primaries, top_classes)
    feats_sel = feats[mask]
    primaries_sel = primaries[mask]

    if feats_sel.shape[0] == 0:
        print("No selected samples for top classes (t-SNE).")
        return

    # subsample if too many
    if feats_sel.shape[0] > max_points:
        rng = np.random.RandomState(RANDOM_SEED)
        idxs = rng.choice(np.arange(feats_sel.shape[0]), size=max_points, replace=False)
        feats_sel = feats_sel[idxs]
        primaries_sel = primaries_sel[idxs]

    # run t-SNE
    tsne = TSNE(n_components=2, random_state=RANDOM_SEED, perplexity=min(perplexity, max(5, feats_sel.shape[0]//3)))
    reduced = tsne.fit_transform(feats_sel)

    plt.figure(figsize=(10, 8))
    palette = sns.color_palette("tab10", n_colors=len(top_classes))
    label_to_color = {c: palette[i % len(palette)] for i, c in enumerate(top_classes)}
    colors = [label_to_color[p] for p in primaries_sel]
    sns.scatterplot(x=reduced[:,0], y=reduced[:,1], hue=[class_names[p] for p in primaries_sel],
                    palette=[label_to_color[c] for c in top_classes], s=25, alpha=0.8)
    plt.title(f"t-SNE ({title_suffix}) for Top-{top_k} Frequent Classes")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()

# -----------------------
# Sample Predictions Visualization
# -----------------------
def visualize_sample_predictions_stacked(model, dataset, class_names, device, n_samples=20, thresh=THRESH):
    """
    Shows n_samples images stacked vertically (1 per row) with GT and predicted labels.
    """
    model.eval()
    n = min(n_samples, len(dataset))
    indices = random.sample(range(len(dataset)), n)
    plt.figure(figsize=(8, n * 3))  # width x height: each image 3" tall
    with torch.no_grad():
        for i, idx in enumerate(indices):
            img_t, labels, stem = dataset[idx]  # img is normalized tensor
            img_input = img_t.unsqueeze(0).to(device)
            outputs = torch.sigmoid(model(img_input)).cpu().numpy().flatten()
            preds = (outputs > thresh).astype(int)

            gt_labels = [class_names[j] for j in range(len(labels)) if labels[j] == 1]
            pred_labels = [class_names[j] for j in range(len(preds)) if preds[j] == 1]

            img_np = tensor_to_numpy_img(img_t)  # unnormalized numpy HWC in [0,1]
            ax = plt.subplot(n, 1, i + 1)
            ax.imshow(img_np)
            ax.axis("off")
            title = f"{stem}  |  GT: {', '.join(gt_labels) if gt_labels else 'None'}"
            if pred_labels:
                title += f"  |  Pred (> {thresh:.2f}): {', '.join(pred_labels)}"
            else:
                title += f"  |  Pred (> {thresh:.2f}): None"
            ax.set_title(title, fontsize=9)
    plt.tight_layout()
    plt.show()

# -----------------------
# Training loop (save best by sample IoU)
# -----------------------
print("Starting training...\n")
best_iou = -1.0
history = []
for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    for imgs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)

    train_loss = running_loss / max(1, len(train_loader.dataset))
    sample_iou, exact_acc, f1, prec, rec, _, _, _, _, _ = evaluate(model, val_loader, thresh=THRESH)
    history.append((train_loss, sample_iou, exact_acc, f1, prec, rec))

    print(
        f"Epoch {epoch:02d} | Loss: {train_loss:.4f} | Accuracy: {sample_iou:.4f} | "
        f"ExactAcc: {exact_acc:.4f} | F1: {f1:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f}"
    )

    # Save best by sample IoU
    if sample_iou > best_iou:
        best_iou = sample_iou
        best_state = model.state_dict().copy()
        torch.save(best_state, MODEL_SAVE_PATH)
        print(f"✅ Saved new best model (Accuracy={best_iou:.4f}) → {MODEL_SAVE_PATH}")

# -----------------------
# Final evaluation + confusion matrix + per-class metrics
# -----------------------
print("\nEvaluating Confusion Matrices...")
if os.path.exists(MODEL_SAVE_PATH):
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    print(f"Loaded best model from {MODEL_SAVE_PATH} (Accuracy={best_iou:.4f})")
else:
    print("No saved model found — using current model state for final eval.")

sample_iou, exact_acc, f1, prec, rec, y_true, y_pred, per_f1, per_prec, per_rec = evaluate(model, val_loader, thresh=THRESH)
print(f"Final val: Accuracy={sample_iou:.4f} | ExactAcc={exact_acc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")

# per-class metrics table
print("\nTop-10 Classes by F1:\n")
cls_metrics = pd.DataFrame({
    "ClassID": cat_ids,
    "ClassName": [cat_id_to_name[c] for c in cat_ids],
    "Precision": per_prec,
    "Recall": per_rec,
    "F1": per_f1
}).sort_values("F1", ascending=False)
print(cls_metrics.head(10))

# confusion matrices
cm = multilabel_confusion_matrix(y_true, y_pred)
n_plot = min(9, NUM_CLASSES)
cols = 3
rows = math.ceil(n_plot / cols)
fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
axes = axes.flatten()
for i in range(n_plot):
    ax = axes[i]
    sns.heatmap(cm[i], annot=True, fmt="d", cmap="Blues", ax=ax, cbar=True)
    ax.set_title(cat_id_to_name.get(cat_ids[i], f"Class {cat_ids[i]}"))
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
for j in range(n_plot, len(axes)):
    axes[j].axis("off")
plt.tight_layout()
plt.show()

# -----------------------
# Sample Predictions (20 stacked)
# -----------------------
print("\nShowing 20 sample predictions (stacked vertically)...")
visualize_sample_predictions_stacked(model, val_dataset, [cat_id_to_name[c] for c in cat_ids], DEVICE, n_samples=20, thresh=THRESH)

# -----------------------
# t-SNE: before and after training
# -----------------------
# Build a fresh untrained model for "before training" visualization (same architecture)
print("\nComputing t-SNE BEFORE training (untrained weights)...")
untrained_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
untrained_model.fc = nn.Linear(untrained_model.fc.in_features, NUM_CLASSES)
untrained_model = untrained_model.to(DEVICE)
tsne_for_top_classes(untrained_model, val_loader, [cat_id_to_name[c] for c in cat_ids], DEVICE, top_k=10, title_suffix="Before Training")

print("\nComputing t-SNE AFTER training (best model)...")
tsne_for_top_classes(model, val_loader, [cat_id_to_name[c] for c in cat_ids], DEVICE, top_k=10, title_suffix="After Training")

# -----------------------
# Epoch summary
# -----------------------
print("\nEpoch summary:")
for e, (tloss, siou, acc, f1_e, prec_e, rec_e) in enumerate(history, start=1):
    print(f"Epoch {e:02d} | Loss: {tloss:.4f} | Accuracy: {siou:.4f} | ExactAcc: {acc:.4f} | "
          f"F1: {f1_e:.4f} | Prec: {prec_e:.4f} | Rec: {rec_e:.4f}")

print(f"\nBest model saved at: {MODEL_SAVE_PATH} (Accuracy={best_iou:.4f})" if best_iou >= 0 else "\nNo best model saved.")
print("\nDone.")


Loaded 290 categories.
Total usable stems: 5950 | Train: 4760 | Val: 1190
Starting training...



                                                             

Epoch 01 | Loss: 0.0066 | Accuracy: 0.5119 | ExactAcc: 0.3992 | F1: 0.5503 | Prec: 0.6826 | Rec: 0.4609
✅ Saved new best model (Accuracy=0.5119) → best_model_sampleiou.pt


                                                             

KeyboardInterrupt: 