In [None]:

# -------------------------
# Cell 1: Imports & Setup
# -------------------------
import os, gc, json, math, random
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [None]:

# NEW: Augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
import os
import random
import numpy as np
import torch

# -------------------------
# Config & Hyperparameters
# -------------------------
class CFG:
    # Seed
    SEED = 42

    # Paths
    BASE_DIR = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
    AUTH_DIR = f"{BASE_DIR}/train_images/authentic"
    FORG_DIR = f"{BASE_DIR}/train_images/forged"
    MASK_DIR = f"{BASE_DIR}/train_masks"
    TEST_DIR = f"{BASE_DIR}/test_images"
    SAMPLE_SUB = f"{BASE_DIR}/sample_submission.csv"

    # DINOv2 Path
    DINO_PATH = "/kaggle/input/dinov2/pytorch/base/1"

    # Model Paths (to save best models)
    BEST_SEG_MODEL = "best_seg_model.pth"
    BEST_CLS_MODEL = "best_cls_model.pth"

    # Params
    IMG_SIZE = 384  # <-- Increased for more detail
    BATCH_SEG = 4
    BATCH_CLS = 32

    # Training Epochs
    EPOCHS_SEG = 5  # <-- Increased
    EPOCHS_CLS = 5  # <-- Increased

    # Learning Rates & Optimizer
    LR_SEG = 1e-4
    LR_CLS = 1e-3
    WEIGHT_DECAY = 1e-4
    SCHEDULER_T_MAX = 5  # T_max for CosineAnnealingLR (should match epochs)

# -------------------------
# Reproducibility & Device
# -------------------------
def seed_everything(s=CFG.SEED):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    os.environ['PYTHONHASHSEED'] = str(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # Can be True for speed if inputs constant

seed_everything(CFG.SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
import numpy as np
import torch
import torch.nn as nn
import json

# -------------------------
# Utils (Metrics, RLE, Loss)
# -------------------------
def binarize(x, thr=0.5):
    return (x > thr).astype(np.uint8)

def iou_score(pred, gt, eps=1e-7):
    p = binarize(pred)
    g = binarize(gt)
    inter = (p & g).sum()
    union = (p | g).sum()
    return float(inter) / (float(union) + eps)

def dice_score(pred, gt, eps=1e-7):
    p = binarize(pred)
    g = binarize(gt)
    inter = (p & g).sum()
    return float(2 * inter) / (p.sum() + g.sum() + eps)

def pixel_acc(pred, gt, thr=0.5):
    p = binarize(pred, thr)
    g = binarize(gt, thr)
    return float((p == g).sum()) / float(np.prod(g.shape))

def rle_encode_numpy(mask):
    pixels = mask.T.flatten()
    dots = np.where(pixels == 1)[0]
    if len(dots) == 0:
        return "[]"
    run_lengths, prev = [], -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps(run_lengths)

# -------------------------
# Dice Loss
# -------------------------
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        intersection = (probs_flat * targets_flat).sum()
        dice = (2. * intersection + self.smooth) / (probs_flat.sum() + targets_flat.sum() + self.smooth)
        return 1. - dice

# -------------------------
# Combo Loss (BCE + Dice)
# -------------------------
class ComboLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(ComboLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()

    def forward(self, logits, targets):
        bce = self.bce_loss(logits, targets)
        dice = self.dice_loss(logits, targets)
        return (self.bce_weight * bce) + (self.dice_weight * dice)


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import cv2
import numpy as np

def get_transforms(img_size):
    return {
        "train": A.Compose([
            A.Resize(img_size, img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    }

def visualize_augmentation(image_path, img_size=384):
    # Read image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Apply augmentation
    transforms = get_transforms(img_size)["train"]
    augmented = transforms(image=image)
    aug_image = augmented["image"]

    # Convert tensor (C,H,W) â†’ numpy (H,W,C) and unnormalize
    aug_image = aug_image.numpy().transpose(1,2,0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    aug_image = std * aug_image + mean
    aug_image = np.clip(aug_image, 0, 1)

    # Show image
    plt.figure(figsize=(6,6))
    plt.imshow(aug_image)
    plt.axis('off')
    plt.show()

# Example usage:
# visualize_augmentation("/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/authentic/0001.png")


In [None]:
# -------------------------
# Cell 4: Augmentations (NEW)
# -------------------------

# Must be careful with augs. Flips and rotations are safe.
# Avoid crops, shears, or elastic transforms.
def get_transforms(img_size):
    return {
        "train": A.Compose([
            A.Resize(img_size, img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet stats
            ToTensorV2(),
        ]),
        "val": A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
        "test": A.Compose([ # For inference
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
    }

In [None]:
# -------------------------
# Cell 5: Datasets
# -------------------------

# --- MODIFIED: Segmentation Dataset with Augs ---
class ForgerySegDataset(Dataset):
    def __init__(self, auth_paths, forg_paths, mask_dir, transforms):
        self.samples = []
        self.mask_dir = mask_dir
        self.transforms = transforms

        for p in auth_paths:
            self.samples.append((p, None))

        for p in forg_paths:
            stem = Path(p).stem
            m = os.path.join(mask_dir, stem + ".npy")
            self.samples.append((p, m if os.path.exists(m) else None))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, _ = img.shape

        if (mask_path is None):
            mask = np.zeros((h, w), dtype=np.uint8)
        else:
            m = np.load(mask_path)
            if m.ndim == 3:
                m = np.max(m, axis=0)
            mask = m.astype(np.uint8)
            if mask.shape != (h, w):
                mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)

        # Apply transforms
        augmented = self.transforms(image=img, mask=mask)
        img_t = augmented['image']
        mask_t = augmented['mask'].unsqueeze(0).float() # [1,H,W]
        
        return img_t, mask_t

# --- MODIFIED: Classification Dataset with Augs ---
class ForgeryClsDataset(Dataset):
    def __init__(self, auth_paths, forg_paths, transforms):
        self.items = [(p,0) for p in auth_paths] + [(p,1) for p in forg_paths]
        self.transforms = transforms
        
    def __len__(self): return len(self.items)
    
    def __getitem__(self, idx):
        p, y = self.items[idx]
        img = cv2.imread(p)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Apply transforms (no mask)
        augmented = self.transforms(image=img)
        x = augmented['image']
        
        return x, torch.tensor(y, dtype=torch.long)


In [None]:

# -------------------------
# Cell 6: File Loading & Splitting
# -------------------------

auth_imgs = sorted([str(Path(CFG.AUTH_DIR)/f) for f in os.listdir(CFG.AUTH_DIR) if f.lower().endswith((".png",".jpg",".jpeg",".tif"))])
forg_imgs = sorted([str(Path(CFG.FORG_DIR)/f) for f in os.listdir(CFG.FORG_DIR) if f.lower().endswith((".png",".jpg",".jpeg",".tif"))])
print(f"Authentic images: {len(auth_imgs)}, Forged images: {len(forg_imgs)}")

# Split (must be same for both tasks)
train_auth, val_auth = train_test_split(auth_imgs, test_size=0.2, random_state=CFG.SEED)
train_forg, val_forg = train_test_split(forg_imgs, test_size=0.2, random_state=CFG.SEED)

# -------------------------
# Cell 7: DINOv2 Loader & Model Defs
# -------------------------

from transformers import AutoImageProcessor, AutoModel

try:
    processor = AutoImageProcessor.from_pretrained(CFG.DINO_PATH, local_files_only=True)
    dino_encoder = AutoModel.from_pretrained(CFG.DINO_PATH, local_files_only=True)
except Exception as e:
    raise RuntimeError(f"Could not load DINOv2 from {CFG.DINO_PATH}: {e}")

dino_encoder.eval().to(device)

# --- Deducing grid size (robustly) ---
# We use the processor's intended size
cfg = getattr(dino_encoder, "config", None)
patch = getattr(cfg, "patch_size", 14)
inp = getattr(processor, "size", {"shortest_edge": 224})

if "shortest_edge" in inp:
    proc_size = int(inp["shortest_edge"])
else:
    proc_size = inp.get("height", 224)

grid_h = grid_w = proc_size // patch
print(f"DINOv2 loaded. Processor size: {proc_size}, Patch size: {patch}, Grid: {grid_h}x{grid_w}")

# --- Model Definitions (Unchanged from your baseline) ---
class DinoTinyDecoder(nn.Module):
    def __init__(self, in_ch, out_ch=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_ch, 1)
        )
    def forward(self, f, out_size):
        x = nn.functional.interpolate(f, size=out_size, mode="bilinear", align_corners=False)
        return self.conv(x)

class DinoMultiTask(nn.Module):
    def __init__(self, encoder, seg_ch=768, freeze=True):
        super().__init__()
        self.encoder = encoder
        if freeze:
            for p in self.encoder.parameters():
                p.requires_grad = False
        
        self.seg_head = DinoTinyDecoder(in_ch=seg_ch, out_ch=1)
        self.cls_head = nn.Linear(seg_ch, 2)

    def forward_features(self, images):
        # Note: 'images' are expected to be normalized tensors from Dataset
        B,_,H,W = images.shape
        
        # Manually create pixel_values for processor
        # DINO processor expects [0, 255] images, but our loader gives normalized.
        # We must *un-normalize* to pass to processor if it expects images.
        # *Correction*: The HF processor.from_pretrained() when loaded this way
        # expects *normalized* tensors if we pass `pixel_values`.
        # Your previous code passed *numpy* arrays, which is different.
        # Let's stick to the tensor path, as it's cleaner.
        
        # The 'processor' object is for preprocessing *raw* images (PIL/numpy).
        # Our dataset *already* does this (Resize, Normalize, ToTensorV2).
        # So we can bypass the processor call and feed tensors directly.
        # BUT: The processor might do things differently (e.g., resize to 224, not 384).
        # Let's check the processor's expected size. It's `proc_size` (e.g., 224).
        # Our loader resizes to `IMG_SIZE` (e.g., 384).
        # We need to resize to the processor's *exact* size.

        # --- MODIFIED forward_features ---
        # Resize images to the processor's expected size (e.g., 224x224)
        x = nn.functional.interpolate(images, size=(proc_size, proc_size), mode="bilinear", align_corners=False)
        
        with torch.no_grad():
            feats = self.encoder(pixel_values=x).last_hidden_state  # [B, N, C]
        
        B, N, C = feats.shape
        expected_tokens = grid_h * grid_w
        
        if N == expected_tokens + 1:
            feats_spatial = feats[:, 1:, :]
            cls_token = feats[:, 0, :]
        elif N == expected_tokens:
            feats_spatial = feats
            cls_token = feats.mean(dim=1)
        else:
            # Fallback for unexpected token count
            s = int(round(math.sqrt(N - 1)))
            if s*s == (N-1): # CLS token present
                feats_spatial = feats[:, 1:, :]; cls_token = feats[:, 0, :]
                s_h, s_w = s, s
            else: # No CLS token
                s = int(round(math.sqrt(N)))
                feats_spatial = feats; cls_token = feats.mean(dim=1)
                s_h, s_w = s, s
            
            print(f"Warning: Token mismatch. Expected {expected_tokens}, got {N}. Inferred grid {s}x{s}")
            fmap = feats_spatial.permute(0,2,1).reshape(B, C, s_h, s_w)
            return fmap, cls_token
            
        fmap = feats_spatial.permute(0,2,1).reshape(B, C, grid_h, grid_w)
        return fmap, cls_token

    def forward_seg(self, images):
        fmap, _ = self.forward_features(images)
        B,_,H,W = images.shape
        # Upsample to the *input* image size (e.g., 384x384)
        logits = self.seg_head(fmap, out_size=(H,W))
        return logits

    def forward_cls(self, images):
        _, cls_token = self.forward_features(images)
        return self.cls_head(cls_token)


In [None]:

# -------------------------
# Cell 8: Dataloaders
# -------------------------

transforms = get_transforms(CFG.IMG_SIZE)

# Seg Loaders
train_seg_ds = ForgerySegDataset(train_auth, train_forg, CFG.MASK_DIR, transforms=transforms['train'])
val_seg_ds   = ForgerySegDataset(val_auth,   val_forg,   CFG.MASK_DIR, transforms=transforms['val'])
train_seg_loader = DataLoader(train_seg_ds, batch_size=CFG.BATCH_SEG, shuffle=True, num_workers=2, pin_memory=True)
val_seg_loader   = DataLoader(val_seg_ds,   batch_size=CFG.BATCH_SEG, shuffle=False, num_workers=2, pin_memory=True)

# Cls Loaders
train_cls_ds = ForgeryClsDataset(train_auth, train_forg, transforms=transforms['train'])
val_cls_ds   = ForgeryClsDataset(val_auth,   val_forg, transforms=transforms['val'])
train_cls_loader = DataLoader(train_cls_ds, batch_size=CFG.BATCH_CLS, shuffle=True,  num_workers=2, pin_memory=True)
val_cls_loader   = DataLoader(val_cls_ds,   batch_size=CFG.BATCH_CLS, shuffle=False, num_workers=2, pin_memory=True)

print(f"Seg loaders: {len(train_seg_loader)} train, {len(val_seg_loader)} val")
print(f"Cls loaders: {len(train_cls_loader)} train, {len(val_cls_loader)} val")


In [None]:

# -------------------------
# Cell 9: Train Segmentation Head
# -------------------------

print("\n--- Starting Segmentation Head Training ---")
model = DinoMultiTask(dino_encoder, seg_ch=768, freeze=True).to(device)
crit_seg = ComboLoss(bce_weight=0.5, dice_weight=0.5).to(device) # NEW
opt_seg  = optim.AdamW(model.seg_head.parameters(), lr=CFG.LR_SEG, weight_decay=CFG.WEIGHT_DECAY)
sched_seg = CosineAnnealingLR(opt_seg, T_max=CFG.SCHEDULER_T_MAX, eta_min=1e-6)

best_val_dice = 0.0

for epoch in range(1, CFG.EPOCHS_SEG + 1):
    model.train()
    tr_loss = 0.0
    for imgs, masks in tqdm(train_seg_loader, desc=f"[Seg] Epoch {epoch}/{CFG.EPOCHS_SEG}"):
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model.forward_seg(imgs)
        loss = crit_seg(logits, masks)
        opt_seg.zero_grad(); loss.backward(); opt_seg.step()
        tr_loss += loss.item() * imgs.size(0)
    tr_loss /= len(train_seg_loader.dataset)
    sched_seg.step() # NEW

    # Validation
    model.eval()
    val_loss, m_dice = 0.0, 0.0
    with torch.no_grad():
        for imgs, masks in val_seg_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model.forward_seg(imgs)
            loss = crit_seg(logits, masks)
            val_loss += loss.item() * imgs.size(0)
            
            # Calc dice score
            probs = torch.sigmoid(logits).cpu().numpy()
            gts   = masks.cpu().numpy()
            for p,g in zip(probs, gts):
                m_dice += dice_score(p[0], g[0]) # Use 0.5 default thr for val
                
    n_val = len(val_seg_loader.dataset)
    val_loss /= n_val
    m_dice /= n_val
    
    print(f"[Seg] Epoch {epoch} | train_loss={tr_loss:.4f} | val_loss={val_loss:.4f} | val_Dice={m_dice:.3f}")
    
    # NEW: Save best model
    if m_dice > best_val_dice:
        best_val_dice = m_dice
        torch.save(model.state_dict(), CFG.BEST_SEG_MODEL)
        print(f"  -> New best seg model saved with Dice: {best_val_dice:.4f}")

torch.cuda.empty_cache(); gc.collect()

In [None]:

# -------------------------
# Cell 10: Train Classification Head
# -------------------------

print("\n--- Starting Classification Head Training ---")
# Note: We continue training on the *same* model instance.
# The seg head is already trained, now we train the cls head.
# We could load the best seg model, but since the backbone is frozen,
# the heads are independent.

crit_cls = nn.CrossEntropyLoss()
opt_cls  = optim.AdamW(model.cls_head.parameters(), lr=CFG.LR_CLS, weight_decay=CFG.WEIGHT_DECAY)
sched_cls = CosineAnnealingLR(opt_cls, T_max=CFG.SCHEDULER_T_MAX, eta_min=1e-6)

best_val_acc = 0.0

for epoch in range(1, CFG.EPOCHS_CLS + 1):
    model.train()
    tr_loss = 0.0
    for imgs, labels in tqdm(train_cls_loader, desc=f"[Cls] Epoch {epoch}/{CFG.EPOCHS_CLS}"):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model.forward_cls(imgs)
        loss = crit_cls(logits, labels)
        opt_cls.zero_grad(); loss.backward(); opt_cls.step()
        tr_loss += loss.item() * imgs.size(0)
    tr_loss /= len(train_cls_loader.dataset)
    sched_cls.step()

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_cls_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model.forward_cls(imgs)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
    
    val_acc = 100.0 * correct / total
    print(f"[Cls] Epoch {epoch} | train_loss={tr_loss:.4f} | val_acc={val_acc:.2f}%")
    
    # NEW: Save best *overall* model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # We save the *entire* model state (best seg head + best cls head)
        # Note: This assumes the best seg head was from the *last* seg epoch
        # if we don't reload.
        # A safer way: load best seg, train cls, save cls.
        
        # Let's do the safer way:
        # Load best seg weights *before* training cls.
        # This part should be run *before* the cls loop:
        
        # --- Start of Cell 10 (Revised) ---
        print("\n--- Loading Best Seg Model & Training Classification Head ---")
        try:
            model.load_state_dict(torch.load(CFG.BEST_SEG_MODEL))
            print(f"Loaded best segmentation model from {CFG.BEST_SEG_MODEL}")
        except Exception as e:
            print(f"Warning: Could not load best seg model. Training cls head on last epoch. {e}")

        crit_cls = nn.CrossEntropyLoss()
        opt_cls  = optim.AdamW(model.cls_head.parameters(), lr=CFG.LR_CLS, weight_decay=CFG.WEIGHT_DECAY)
        sched_cls = CosineAnnealingLR(opt_cls, T_max=CFG.SCHEDULER_T_MAX, eta_min=1e-6)
        
        best_val_acc = 0.0
        
        for epoch in range(1, CFG.EPOCHS_CLS + 1):
            model.train() # Set seg_head to train mode (for BN) but grads are off
            model.cls_head.train()
            
            tr_loss = 0.0
            for imgs, labels in tqdm(train_cls_loader, desc=f"[Cls] Epoch {epoch}/{CFG.EPOCHS_CLS}"):
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model.forward_cls(imgs)
                loss = crit_cls(logits, labels)
                opt_cls.zero_grad(); loss.backward(); opt_cls.step()
                tr_loss += loss.item() * imgs.size(0)
            tr_loss /= len(train_cls_loader.dataset)
            sched_cls.step()
        
            # Validation
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for imgs, labels in val_cls_loader:
                    imgs, labels = imgs.to(device), labels.to(device)
                    logits = model.forward_cls(imgs)
                    preds = logits.argmax(dim=1)
                    correct += (preds == labels).sum().item()
                    total   += labels.size(0)
            
            val_acc = 100.0 * correct / total
            print(f"[Cls] Epoch {epoch} | train_loss={tr_loss:.4f} | val_acc={val_acc:.2f}%")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), CFG.BEST_CLS_MODEL)
                print(f"  -> New best cls model saved with Acc: {best_val_acc:.2f}%")

        torch.cuda.empty_cache(); gc.collect()
        # --- End of Cell 10 (Revised) ---



In [None]:


# -------------------------
# Cell 11: Find Optimal Threshold (NEW)
# -------------------------
print(f"\n--- Finding Optimal Segmentation Threshold ---")
# Load the best model (which has best seg + best cls heads)
try:
    model.load_state_dict(torch.load(CFG.BEST_CLS_MODEL))
    print(f"Loaded best multi-task model from {CFG.BEST_CLS_MODEL}")
except Exception as e:
    print(f"Warning: Could not load {CFG.BEST_CLS_MODEL}. Using last model state. {e}")
model.eval()

# Get all validation preds and gts
all_preds = []
all_gts = []
with torch.no_grad():
    for imgs, masks in tqdm(val_seg_loader, desc="Finding Threshold"):
        imgs = imgs.to(device)
        logits = model.forward_seg(imgs)
        probs = torch.sigmoid(logits).cpu().numpy()
        gts   = masks.cpu().numpy()
        
        for p, g in zip(probs, gts):
            all_preds.append(p[0]) # p is [1,H,W], g is [1,H,W]
            all_gts.append(g[0])

# Test thresholds
thresholds = np.arange(0.2, 0.8, 0.05)
best_dice = 0.0
OPTIMAL_THR = 0.5

for thr in thresholds:
    current_dice = 0.0
    for p, g in zip(all_preds, all_gts):
        current_dice += dice_score(p, g, eps=1e-7)
    
    mean_dice = current_dice / len(all_preds)
    print(f"Threshold {thr:.2f} -> Mean Dice: {mean_dice:.4f}")
    
    if mean_dice > best_dice:
        best_dice = mean_dice
        OPTIMAL_THR = thr

print(f"\n==> Found Optimal Threshold: {OPTIMAL_THR:.2f} (Dice: {best_dice:.4f})")


In [None]:

# -------------------------
# Cell 12: Inference & Submission (NEW LOGIC)
# -------------------------
print("\n--- Starting Test Inference ---")

# Ensure model is in eval mode
model.eval()
test_transforms = get_transforms(CFG.IMG_SIZE)['test']

# --- NEW: Prediction functions ---
def predict_class(img_pil):
    img = np.array(img_pil.convert("RGB"))
    img_t = test_transforms(image=img)['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model.forward_cls(img_t)
        pred = logits.argmax(dim=1).item()
    return pred

def predict_mask_prob(img_pil):
    img = np.array(img_pil.convert("RGB"))
    img_t = test_transforms(image=img)['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model.forward_seg(img_t)
        prob = torch.sigmoid(logits)[0,0].cpu().numpy()
    return prob

# --- NEW: Two-Stage Inference Loop ---
rows = []
if os.path.exists(CFG.TEST_DIR):
    test_files = sorted(os.listdir(CFG.TEST_DIR))
    print(f"Found {len(test_files)} test images.")
    
    for fname in tqdm(test_files, desc="Inference"):
        case_id = Path(fname).stem
        path = str(Path(CFG.TEST_DIR)/fname)
        
        try:
            pil = Image.open(path).convert("RGB")
            ow, oh = pil.size

            # --- STAGE 1: CLASSIFY ---
            pred_class = predict_class(pil)

            if pred_class == 0:
                # Predicted 'authentic'
                annot = "authentic"
            else:
                # --- STAGE 2: SEGMENT ---
                # Predicted 'forged', now get the mask
                prob = predict_mask_prob(pil) # [IMG_SIZE, IMG_SIZE]
                
                # Resize to original size
                mask = cv2.resize(prob, (ow, oh), interpolation=cv2.INTER_NEAREST)
                binm = (mask > OPTIMAL_THR).astype(np.uint8) # Use optimal thr

                if binm.sum() == 0:
                    # Fallback: Classifier said forged, but segmentor found nothing.
                    annot = "authentic"
                else:
                    annot = rle_encode_numpy(binm)

            rows.append({"case_id": case_id, "annotation": annot})
            
        except Exception as e:
            print(f"Error processing {fname}: {e}. Defaulting to 'authentic'.")
            rows.append({"case_id": case_id, "annotation": "authentic"})

else:
    print("Test directory not found. Using sample submission.")


In [None]:

# -------------------------
# Cell 13: Submission File Creation
# -------------------------

sub = pd.DataFrame(rows, columns=["case_id","annotation"])

if sub.empty:
    print("No test images processed. Creating submission from sample.")
    sub = pd.read_csv(CFG.SAMPLE_SUB)
    sub["annotation"] = "authentic"
else:
    # Aligne with sample_submission
    if os.path.exists(CFG.SAMPLE_SUB):
        ss = pd.read_csv(CFG.SAMPLE_SUB)
        ss["case_id"] = ss["case_id"].astype(str)
        sub["case_id"] = sub["case_id"].astype(str)
        
        # Merge to ensure all test_ids are present
        sub = ss[["case_id"]].merge(sub, on="case_id", how="left")
        
        # Fill any missing (e.g., from errors) with 'authentic'
        sub["annotation"] = sub["annotation"].fillna("authentic")
    else:
        print("Sample submission not found. Saving as is.")

OUT_PATH = "submission.csv"
sub.to_csv(OUT_PATH, index=False)
print(f"\nWrote submission to: {OUT_PATH}")
print(sub.head())

print("\n--- Grandmaster Notebook Finished ---")