In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ====================================================
# Recod.ai/LUC - Scientific Image Forgery Detection
# Multi-GPU (2√ó T4) Training Notebook
# ====================================================

import os, gc, random
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
import cv2

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


# ====================================================
# 1. Paths and Globals
# ====================================================
BASE_PATH = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
TRAIN_IMG_PATH = os.path.join(BASE_PATH, "train_images")
TRAIN_MASK_PATH = os.path.join(BASE_PATH, "train_masks")
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "OFF"  # set "DETAIL" for debug

# ====================================================
# 2. Helper functions
# ====================================================
def is_valid_image(img_path):
    try:
        img = plt.imread(img_path)
        return img is not None and img.size > 0 and min(img.shape[:2]) > 8
    except Exception:
        return False

def is_valid_mask(mask_path):
    try:
        mask = np.load(mask_path, allow_pickle=True)
        return mask is not None and mask.size > 0 and len(mask.shape) == 2
    except Exception:
        return False

# ====================================================
# 3. Dataset
# ====================================================
class ForgeryDataset(Dataset):
    def __init__(self, image_paths, mask_paths, size=512, augment=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.size = size
        self.augment = augment

        self.tf_train = A.Compose([
            A.Resize(self.size, self.size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.05), rotate=(-10, 10), p=0.4)
        ], is_check_shapes=False)
        self.tf_val = A.Compose([A.Resize(self.size, self.size)], is_check_shapes=False)

    def __len__(self): return len(self.image_paths)

    def safe_img(self, p):
        try: img = plt.imread(p)
        except: img = np.zeros((256,256,3), np.uint8)
        if img.ndim == 2: img = np.stack([img]*3, -1)
        if img.shape[2] == 4: img = img[:,:,:3]
        if img.dtype != np.uint8:
            img = (img*255).astype(np.uint8) if img.max()<=1 else img.astype(np.uint8)
        return img

    def safe_mask(self, p, shape):
        if not p or not os.path.exists(p): return np.zeros(shape[:2], np.uint8)
        try:
            m = np.load(p); 
            if m.ndim!=2: return np.zeros(shape[:2], np.uint8)
            m = (m>0).astype(np.uint8)
            if m.shape!=shape[:2]:
                m = cv2.resize(m, (shape[1], shape[0]), interpolation=cv2.INTER_NEAREST)
            return m
        except: return np.zeros(shape[:2], np.uint8)

    def __getitem__(self, i):
        imgp = self.image_paths[i]
        name = os.path.basename(imgp).replace(".png","")
        img = self.safe_img(imgp)
        mask = self.safe_mask(mask_paths.get(name), img.shape)
        tfm = self.tf_train if self.augment else self.tf_val
        aug = tfm(image=img, mask=mask)
        img, mask = aug["image"], aug["mask"]
        img = torch.tensor(img.transpose(2,0,1), dtype=torch.float32)/255.
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
        return img, mask

# ====================================================
# 4. U-Net Model
# ====================================================
class ConvBlock(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(c1, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(True),
            nn.Conv2d(c2, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(True)
        )
    def forward(self, x): return self.seq(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.enc4 = ConvBlock(base*4, base*8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base*8, base*16)
        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = ConvBlock(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.out_conv = nn.Conv2d(base, out_ch, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b  = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self.up4(b), e4], 1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return self.out_conv(d1)

# ====================================================
# 5. Loss, optimizer, AMP, etc.
# ====================================================
def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    inter = (pred*target).sum((1,2,3))
    union = pred.sum((1,2,3)) + target.sum((1,2,3))
    return 1 - ((2*inter+eps)/(union+eps)).mean()

def make_loss(device, pos_weight=8.0):
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
    def loss_fn(p, t):
        return 0.7*bce(p,t) + 0.3*dice_loss(p,t)
    return loss_fn

# ====================================================
# 6. DDP Setup
# ====================================================
def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

# ====================================================
# 7. Training function (per GPU process)
# ====================================================
def train_ddp(rank, world_size):
    setup(rank, world_size)
    device = torch.device(f"cuda:{rank}")
    print(f"Rank {rank} ready.")

    # Data
    forged = glob(os.path.join(TRAIN_IMG_PATH, "forged", "*.png"))
    auth   = glob(os.path.join(TRAIN_IMG_PATH, "authentic", "*.png"))
    mask_files = glob(os.path.join(TRAIN_MASK_PATH, "*.npy"))
    global mask_paths
    mask_paths = {os.path.basename(p).replace(".npy",""):p for p in mask_files if is_valid_mask(p)}

    forged_valid = [p for p in forged if is_valid_image(p)]
    random.shuffle(forged_valid)
    forged_train, forged_val = train_test_split(forged_valid, test_size=0.15, random_state=42)

    # Datasets & Samplers
    size_stages = [512, 768]
    BATCH = 8
    EPOCHS = 6
    EARLY_STOP = 3

    model = UNet(base=32).to(device)
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler()
    loss_fn = make_loss(device, pos_weight=8.0)

    best_val = 1e9

    for size in size_stages:
        if rank == 0:
            print(f"\n=== Stage {size} ===")
        train_ds = ForgeryDataset(forged_train, mask_paths, size=size, augment=True)
        val_ds   = ForgeryDataset(forged_val, mask_paths, size=size, augment=False)

        train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True)
        val_sampler   = torch.utils.data.distributed.DistributedSampler(val_ds,   num_replicas=world_size, rank=rank, shuffle=False)
        train_dl = DataLoader(train_ds, batch_size=BATCH, sampler=train_sampler, num_workers=2, pin_memory=True)
        val_dl   = DataLoader(val_ds,   batch_size=BATCH, sampler=val_sampler,   num_workers=2, pin_memory=True)

        early = 0
        for epoch in range(EPOCHS):
            train_sampler.set_epoch(epoch)
            model.train()
            total_train = 0
            for imgs, masks in tqdm(train_dl, disable=(rank!=0)):
                imgs, masks = imgs.to(device, non_blocking=True), masks.to(device, non_blocking=True)
                optimizer.zero_grad()
                with autocast():
                    preds = model(imgs)
                    loss = loss_fn(preds, masks)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_train += loss.item()
            avg_train = total_train / len(train_dl)

            # Validation
            model.eval()
            total_val = 0
            with torch.no_grad():
                for imgs, masks in val_dl:
                    imgs, masks = imgs.to(device, non_blocking=True), masks.to(device, non_blocking=True)
                    with autocast():
                        preds = model(imgs)
                        vloss = loss_fn(preds, masks)
                    total_val += vloss.item()
            avg_val = total_val / len(val_dl)

            if rank == 0:
                print(f"[{size}] Epoch {epoch+1}/{EPOCHS} - Train={avg_train:.4f}, Val={avg_val:.4f}")
            if avg_val < best_val:
                best_val = avg_val
                if rank == 0:
                    torch.save(model.module.state_dict(), "model_final_ddp.pth")
                    print("‚úÖ Saved checkpoint.")
                early = 0
            else:
                early += 1
            if early >= EARLY_STOP:
                if rank == 0: print("‚èπÔ∏è Early stop triggered.")
                break
        torch.cuda.empty_cache()
        gc.collect()

    cleanup()
    if rank == 0:
        print("\nüéØ Training complete. model_final_ddp.pth saved.")

# ====================================================
# 8. Launch DataParallel (simpler for notebooks)
# ====================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device(s): {torch.cuda.device_count()} GPU(s) detected")

model = UNet(base=32).to(DEVICE)
if torch.cuda.device_count() > 1:
    print("‚úÖ Using nn.DataParallel across GPUs")
    model = nn.DataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = make_loss(DEVICE, pos_weight=8.0)
scaler = torch.amp.GradScaler("cuda")

# Your same training loop from train_ddp, but now in single-process:
stages = [512, 768]
EPOCHS = 6
BATCH_SIZE = 16
EARLY_STOP = 3

# ====================================================
# Prepare training/validation splits (forged-first)
# ====================================================

# Locate all image and mask files again (if not already loaded)
forged_images = glob(os.path.join(TRAIN_IMG_PATH, "forged", "*.png"))
auth_images   = glob(os.path.join(TRAIN_IMG_PATH, "authentic", "*.png"))
mask_files    = glob(os.path.join(TRAIN_MASK_PATH, "*.npy"))

# Filter valid mask files
mask_paths = {
    os.path.basename(p).replace(".npy", ""): p
    for p in mask_files
    if os.path.exists(p)
}

# Verify images are valid
forged_valid = [p for p in forged_images if is_valid_image(p)]
print(f"‚úÖ Found {len(forged_valid)} valid forged images")

# Split into train/val (forged-only first phase)
from sklearn.model_selection import train_test_split
forged_train, forged_val = train_test_split(forged_valid, test_size=0.15, random_state=42)

print(f"Train forged images: {len(forged_train)} | Val forged images: {len(forged_val)}")


# Reuse your forged_train, forged_val definitions
for size in stages:
    print(f"\n=== Stage {size} ===")
    train_ds = ForgeryDataset(forged_train, mask_paths, size=size, augment=True)
    val_ds   = ForgeryDataset(forged_val,   mask_paths, size=size, augment=False)
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

    best_val = np.inf
    early = 0

    for epoch in range(EPOCHS):
        model.train()
        total_train = 0
        for imgs, masks in tqdm(train_dl):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad()
            with torch.amp.autocast("cuda"):
                preds = model(imgs)
                loss = loss_fn(preds, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_train += loss.item()
        train_loss = total_train / len(train_dl)

        # Validation
        model.eval()
        total_val = 0
        with torch.no_grad():
            for imgs, masks in val_dl:
                imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
                with torch.amp.autocast("cuda"):
                    preds = model(imgs)
                    val_loss = loss_fn(preds, masks)
                total_val += val_loss.item()
        val_loss = total_val / len(val_dl)
        print(f"[{size}] Epoch {epoch+1}/{EPOCHS} | Train={train_loss:.4f} | Val={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "model_final.pth")
            print("‚úÖ Model improved and saved.")
            early = 0
        else:
            early += 1
        if early >= EARLY_STOP:
            print("‚èπÔ∏è Early stopping triggered.")
            break

print("\nüéØ Training complete. model_final.pth saved.")
