# LUC-CMFD: Copy-Move Forgery Detection for Kaggle Competition

This notebook trains a DINOv2-based forgery detection model and generates submissions.

**Competition:** Recod.ai/LUC - Scientific Image Forgery Detection  
**Objective:** Detect and segment copy-move forgeries in biomedical images

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -q torch torchvision tqdm scikit-image pillow numpy pandas pyyaml

In [None]:
# Import libraries
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import autocast, GradScaler

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Configuration
CONFIG = {
    # Paths (adjust for Kaggle environment)
    'data_root': '/kaggle/input/recodai-luc-scientific-image-forgery-detection',
    'output_dir': '/kaggle/working/weights',
    'submission_path': '/kaggle/working/submission.csv',
    
    # Training
    'seed': 42,
    'image_size': 256,
    'batch_size': 16,
    'val_split': 0.2,
    'epochs': 50,
    'lr': 1e-4,
    'patience': 10,
    
    # Model
    'backbone': 'dinov2_vits14',
    'freeze_backbone': True,
    'patch': 12,
    'stride': 4,
    'top_k': 5,
}

# Create output directory
Path(CONFIG['output_dir']).mkdir(exist_ok=True, parents=True)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Utility Functions

In [None]:
def set_seed(seed=42):
    """Set random seeds for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def rle_encode(mask: np.ndarray) -> str:
    """Encode binary mask to RLE string (1-indexed, row-major order)."""
    if not mask.any():
        return "authentic"
    
    # Flatten in row-major order
    pixels = mask.flatten()
    
    # Find transitions
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0]
    
    # Compute start positions and lengths
    starts = runs[::2]
    lengths = runs[1::2] - starts
    
    # Convert to 1-indexed
    starts_1idx = starts + 1
    
    # Interleave
    rle_pairs = np.empty(len(starts) * 2, dtype=np.int64)
    rle_pairs[::2] = starts_1idx
    rle_pairs[1::2] = lengths
    
    return str(rle_pairs.tolist())

set_seed(CONFIG['seed'])
print("✓ Utilities loaded")

## 4. Dataset

In [None]:
class CMFDDataset(Dataset):
    """Dataset for CMFD competition."""
    
    def __init__(self, root, split='train', image_size=256):
        self.root = Path(root)
        self.split = split
        self.image_size = image_size
        self.items = self._build_index()
    
    def _build_index(self):
        items = []
        
        if self.split == 'train':
            # Authentic images
            auth_dir = self.root / 'train_images' / 'authentic'
            if auth_dir.exists():
                for img_path in sorted(auth_dir.glob('*')):
                    if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                        items.append({
                            'image_path': img_path,
                            'mask_path': None,
                            'case_id': img_path.stem,
                            'is_forged': False
                        })
            
            # Forged images
            forg_dir = self.root / 'train_images' / 'forged'
            mask_dir = self.root / 'train_masks'
            if forg_dir.exists():
                for img_path in sorted(forg_dir.glob('*')):
                    if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                        case_id = img_path.stem
                        mask_path = mask_dir / f"{case_id}.npy"
                        items.append({
                            'image_path': img_path,
                            'mask_path': mask_path if mask_path.exists() else None,
                            'case_id': case_id,
                            'is_forged': True
                        })
        
        elif self.split == 'test':
            test_dir = self.root / 'test_images'
            for img_path in sorted(test_dir.glob('*')):
                if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                    items.append({
                        'image_path': img_path,
                        'mask_path': None,
                        'case_id': img_path.stem,
                        'is_forged': None
                    })
        
        return items
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        
        # Load image
        img = Image.open(item['image_path']).convert('RGB')
        orig_size = img.size  # (W, H)
        img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
        img_array = np.array(img).astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)  # (C, H, W)
        
        # Load mask
        if item['mask_path'] and item['mask_path'].exists():
            mask = np.load(item['mask_path'])
            # Handle multi-channel masks (merge with max)
            if mask.ndim == 3:
                if mask.shape[0] in [2, 3, 4]:  # (C, H, W)
                    mask = mask.max(axis=0)
                else:  # (H, W, C)
                    mask = mask.max(axis=-1)
            mask = (mask > 0).astype(np.uint8)
            mask_img = Image.fromarray(mask)
            mask_img = mask_img.resize((self.image_size, self.image_size), Image.NEAREST)
            mask = np.array(mask_img).astype(np.float32)
        else:
            mask = np.zeros((self.image_size, self.image_size), dtype=np.float32)
        
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)  # (1, H, W)
        
        return {
            'image': img_tensor,
            'mask': mask_tensor,
            'case_id': item['case_id'],
            'original_size': orig_size
        }

print("✓ Dataset class defined")

## 5. Model Architecture

In [None]:
# Simple correlation module (placeholder - can be enhanced)
def self_correlation_simple(feats, top_k=5):
    """Simplified self-correlation for features."""
    B, C, H, W = feats.shape
    
    # Normalize features
    feats_norm = F.normalize(feats, dim=1)
    
    # Reshape to (B, C, N) where N = H*W
    feats_flat = feats_norm.view(B, C, -1)
    
    # Compute correlation matrix (B, N, N)
    corr = torch.bmm(feats_flat.transpose(1, 2), feats_flat)
    
    # Get top-k correlations for each position
    topk_vals, _ = torch.topk(corr, k=min(top_k, corr.size(-1)), dim=-1)
    
    # Reshape back to spatial (B, k, H, W)
    corr_map = topk_vals.view(B, -1, top_k).transpose(1, 2)
    corr_map = corr_map.view(B, top_k, H, W)
    
    return corr_map


class DinoBackbone(nn.Module):
    """DINOv2 backbone for feature extraction."""
    
    def __init__(self, model_name='dinov2_vits14', freeze=True):
        super().__init__()
        self.freeze = freeze
        
        try:
            # Try loading DINOv2
            self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
            self.feat_dim = self.backbone.embed_dim
            print(f"✓ Loaded {model_name}")
        except Exception as e:
            print(f"Warning: Could not load DINOv2, using simple conv backbone")
            # Fallback to simple conv
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 64, 7, stride=2, padding=3),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 128, 3, stride=2, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, 3, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 384, 3, stride=1, padding=1),
                nn.BatchNorm2d(384),
                nn.ReLU(inplace=True),
            )
            self.feat_dim = 384
        
        if freeze:
            for param in self.backbone.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        if self.freeze:
            self.backbone.eval()
            with torch.no_grad():
                return self._extract_features(x)
        return self._extract_features(x)
    
    def _extract_features(self, x):
        if hasattr(self.backbone, 'get_intermediate_layers'):
            # DINOv2 ViT
            out = self.backbone.get_intermediate_layers(x, n=1)[0]
            B, N, C = out.shape
            H = W = int(N ** 0.5)
            feats = out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        else:
            # Simple conv
            feats = self.backbone(x)
        return feats


class CMFDNet(nn.Module):
    """Complete CMFD network."""
    
    def __init__(self, backbone='dinov2_vits14', freeze_backbone=True, top_k=5):
        super().__init__()
        self.top_k = top_k
        
        # Backbone
        self.backbone = DinoBackbone(backbone, freeze_backbone)
        
        # Correlation head
        self.corr_head = nn.Sequential(
            nn.Conv2d(top_k, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, 1)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1)
        )
    
    def forward(self, x):
        # Extract features
        feats = self.backbone(x)
        
        # Self-correlation
        corr_map = self_correlation_simple(feats, self.top_k)
        
        # Process correlation
        saliency = self.corr_head(corr_map)
        
        # Decode
        logits = self.decoder(saliency)
        
        # Upsample to input size
        logits = F.interpolate(logits, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        return {'logits': logits}

print("✓ Model architecture defined")

## 6. Training Functions

In [None]:
def dice_loss(pred, target, smooth=1.0):
    """Dice loss for segmentation."""
    pred = torch.sigmoid(pred)
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return 1.0 - dice

def combined_loss(pred, target):
    """Combined BCE + Dice loss."""
    bce = nn.BCEWithLogitsLoss()(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice

def compute_metrics(pred, target):
    """Compute F1, precision, recall."""
    pred_binary = (torch.sigmoid(pred) > 0.5).float()
    tp = (pred_binary * target).sum()
    fp = (pred_binary * (1 - target)).sum()
    fn = ((1 - pred_binary) * target).sum()
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    f1 = 2 * precision * recall / (precision + recall + 1e-7)
    return {'precision': precision.item(), 'recall': recall.item(), 'f1': f1.item()}

def train_epoch(model, dataloader, optimizer, scaler, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_f1 = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            output = model(images)
            loss = combined_loss(output['logits'], masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        with torch.no_grad():
            metrics = compute_metrics(output['logits'], masks)
        
        total_loss += loss.item()
        total_f1 += metrics['f1']
        
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'f1': f"{metrics['f1']:.4f}"})
    
    return {'loss': total_loss / len(dataloader), 'f1': total_f1 / len(dataloader)}

@torch.no_grad()
def validate(model, dataloader, device):
    """Validation loop."""
    model.eval()
    total_loss = 0
    total_f1 = 0
    
    for batch in tqdm(dataloader, desc="Validation"):
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            output = model(images)
            loss = combined_loss(output['logits'], masks)
        
        metrics = compute_metrics(output['logits'], masks)
        total_loss += loss.item()
        total_f1 += metrics['f1']
    
    return {'loss': total_loss / len(dataloader), 'f1': total_f1 / len(dataloader)}

print("✓ Training functions defined")

## 7. Load Data

In [None]:
# Create dataset
print("Loading dataset...")
full_dataset = CMFDDataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size']
)
print(f"Total samples: {len(full_dataset)}")

# Train/val split
val_size = int(len(full_dataset) * CONFIG['val_split'])
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(
    full_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(CONFIG['seed'])
)
print(f"Train: {train_size}, Val: {val_size}")

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("✓ Data loaders created")

## 8. Train Model

In [None]:
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = CMFDNet(
    backbone=CONFIG['backbone'],
    freeze_backbone=CONFIG['freeze_backbone'],
    top_k=CONFIG['top_k']
)
model = model.to(device)
if torch.cuda.is_available():
    model = model.to(memory_format=torch.channels_last)

# Optimizer and scheduler
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['lr']
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5
)

# AMP scaler
scaler = GradScaler('cuda' if torch.cuda.is_available() else 'cpu')

print("✓ Model setup complete")

In [None]:
# Training loop
best_f1 = 0
patience_counter = 0

print(f"Starting training for {CONFIG['epochs']} epochs...\n")

for epoch in range(CONFIG['epochs']):
    print(f"Epoch {epoch + 1}/{CONFIG['epochs']}")
    
    # Train
    train_stats = train_epoch(model, train_loader, optimizer, scaler, device)
    print(f"Train - Loss: {train_stats['loss']:.4f}, F1: {train_stats['f1']:.4f}")
    
    # Validate
    val_stats = validate(model, val_loader, device)
    print(f"Val - Loss: {val_stats['loss']:.4f}, F1: {val_stats['f1']:.4f}")
    
    # Update scheduler
    scheduler.step(val_stats['f1'])
    
    # Save best model
    if val_stats['f1'] > best_f1:
        best_f1 = val_stats['f1']
        patience_counter = 0
        save_path = Path(CONFIG['output_dir']) / 'best_model.pth'
        torch.save(model.state_dict(), save_path)
        print(f"✓ Saved best model (F1: {best_f1:.4f})")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"Early stopping after {epoch + 1} epochs")
        break
    
    print()

print(f"\n✓ Training complete! Best F1: {best_f1:.4f}")

## 9. Generate Predictions

In [None]:
# Load best model
best_model_path = Path(CONFIG['output_dir']) / 'best_model.pth'
model.load_state_dict(torch.load(best_model_path))
model.eval()
print("✓ Loaded best model")

# Create test dataset
test_dataset = CMFDDataset(
    root=CONFIG['data_root'],
    split='test',
    image_size=CONFIG['image_size']
)
print(f"Test samples: {len(test_dataset)}")

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

In [None]:
# Generate predictions
predictions = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting"):
        images = batch['image'].to(device)
        case_ids = batch['case_id']
        orig_sizes = batch['original_size']
        
        # Forward pass
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            output = model(images)
        
        # Get binary masks
        logits = output['logits']
        masks = (torch.sigmoid(logits) > 0.5).cpu().numpy()
        
        # Process each sample
        for i, case_id in enumerate(case_ids):
            mask = masks[i, 0]  # (H, W)
            
            # Resize to original size
            orig_w, orig_h = orig_sizes[i]
            mask_img = Image.fromarray((mask * 255).astype(np.uint8))
            mask_img = mask_img.resize((orig_w, orig_h), Image.NEAREST)
            mask_resized = (np.array(mask_img) > 127).astype(np.uint8)
            
            # Encode to RLE
            rle = rle_encode(mask_resized)
            
            predictions.append({
                'id': case_id,
                'mask_rle': rle
            })

print(f"✓ Generated {len(predictions)} predictions")

## 10. Create Submission

In [None]:
# Create submission DataFrame
submission_df = pd.DataFrame(predictions)
submission_df = submission_df.sort_values('id')

# Save submission
submission_df.to_csv(CONFIG['submission_path'], index=False)

print(f"✓ Submission saved to {CONFIG['submission_path']}")
print(f"\nSubmission preview:")
print(submission_df.head(10))
print(f"\nTotal submissions: {len(submission_df)}")
print(f"Authentic images: {(submission_df['mask_rle'] == 'authentic').sum()}")
print(f"Forged images: {(submission_df['mask_rle'] != 'authentic').sum()}")

## Done!

Download the `submission.csv` file and submit to Kaggle competition.