# Copy-Move Forgery Detection - Kaggle Submission

This notebook runs inference on the evaluation set and generates `submission.csv`.

**Requirements:**
- Attach a Kaggle Dataset containing model weights
- No internet access (all dependencies pre-installed)
- GPU runtime recommended (< 4h target)

In [None]:
import sys
import os
from pathlib import Path
import time

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

# Add src to path
sys.path.insert(0, '../src')

from model import CMFDNet
from dataset import CMFDDataset, collate_fn
from infer import infer_image, create_submission
from utils import set_all_seeds, memory_stats

## Configuration

In [None]:
# Paths (adjust these for Kaggle environment)
WEIGHTS_PATH = "/kaggle/input/luc-cmfd-weights/best_model.pth"  # From attached dataset
IMAGE_DIR = "/kaggle/input/cmfd-competition/test_images"  # Test images
OUTPUT_PATH = "submission.csv"

# Configuration
CONFIG = {
    'backbone': 'dinov2_vits14',
    'freeze_backbone': True,
    'patch': 12,
    'stride': 4,
    'top_k': 5,
    'ransac_model': 'similarity',
    'inlier_thresh_px': 1.5,
    'tta': 'flip',
    'post': {
        'thr': 0.5,
        'min_area': 24,
        'morph': {'close': 3, 'open': 0}
    }
}

BATCH_SIZE = 4
SEED = 42

## Environment Check

In [None]:
# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Check AMP support
amp_available = torch.cuda.is_available() and torch.cuda.amp.is_autocast_available()
print(f"AMP Available: {amp_available}")

# Set seed
set_all_seeds(SEED)
print(f"\nSeed set to: {SEED}")

# Print config
print(f"\nConfiguration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Load Model

In [None]:
print("Creating model...")
model = CMFDNet(
    backbone=CONFIG['backbone'],
    freeze_backbone=CONFIG['freeze_backbone'],
    patch=CONFIG['patch'],
    stride=CONFIG['stride'],
    top_k=CONFIG['top_k']
)

# Load weights
if Path(WEIGHTS_PATH).exists():
    print(f"Loading weights from {WEIGHTS_PATH}...")
    state_dict = torch.load(WEIGHTS_PATH, map_location=device)
    model.load_state_dict(state_dict)
    print("Weights loaded successfully")
else:
    print(f"WARNING: Weights not found at {WEIGHTS_PATH}")
    print("Using randomly initialized weights (for testing only)")

model = model.to(device)
model = model.to(memory_format=torch.channels_last)  # Performance optimization
model.eval()

print(f"Model ready on {device}")

## Load Dataset

In [None]:
print(f"Loading dataset from {IMAGE_DIR}...")

dataset = CMFDDataset(
    image_dir=IMAGE_DIR,
    normalize=True
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True
)

print(f"Dataset loaded: {len(dataset)} images")
print(f"Batches: {len(dataloader)}")

## Run Inference

In [None]:
print("Starting inference...")
print(f"Target: Complete in < 4 hours")

results = []
total_time = 0
start_time = time.time()

with torch.no_grad():
    for batch in tqdm(dataloader, desc="Inference"):
        images = batch['image'].to(device, memory_format=torch.channels_last)
        case_ids = batch['case_id']
        original_sizes = batch['original_size']
        
        batch_start = time.perf_counter()
        
        # Process each image
        for img, case_id, orig_size in zip(images, case_ids, original_sizes):
            # Inference
            mask = infer_image(
                model, img, CONFIG, device,
                use_tta=(CONFIG['tta'] != 'none')
            )
            
            # Resize to original if needed
            if mask.shape != orig_size:
                import cv2
                mask = cv2.resize(
                    mask.astype(np.uint8),
                    (orig_size[1], orig_size[0]),
                    interpolation=cv2.INTER_NEAREST
                )
            
            results.append({
                'case_id': case_id,
                'mask': mask
            })
        
        batch_time = time.perf_counter() - batch_start
        total_time += batch_time

elapsed = time.time() - start_time
avg_time_per_image = (total_time / len(results)) * 1000  # ms

print(f"\nInference complete!")
print(f"Total time: {elapsed:.2f}s ({elapsed/3600:.2f}h)")
print(f"Processed: {len(results)} images")
print(f"Avg time per image: {avg_time_per_image:.2f}ms")

# Memory stats
if torch.cuda.is_available():
    mem = memory_stats(device)
    print(f"Peak GPU memory: {mem['allocated_gb']:.2f} GB")

## Create Submission

In [None]:
print(f"Creating submission file: {OUTPUT_PATH}")

create_submission(results, OUTPUT_PATH)

# Verify submission
df = pd.read_csv(OUTPUT_PATH)
print(f"\nSubmission created successfully!")
print(f"Total submissions: {len(df)}")
print(f"Columns: {list(df.columns)}")

# Stats
n_authentic = (df['annotation'] == 'authentic').sum()
n_forged = len(df) - n_authentic
print(f"\nAuthentic: {n_authentic} ({n_authentic/len(df)*100:.1f}%)")
print(f"Forged: {n_forged} ({n_forged/len(df)*100:.1f}%)")

# Preview
print(f"\nFirst 10 rows:")
display(df.head(10))

## Verify Submission Format

In [None]:
# Check required columns
assert 'case_id' in df.columns, "Missing 'case_id' column"
assert 'annotation' in df.columns, "Missing 'annotation' column"

# Check no missing values
assert df.isnull().sum().sum() == 0, "Submission contains null values"

# Check annotation format
for idx, row in df.iterrows():
    annot = row['annotation']
    if annot != 'authentic':
        assert annot.startswith('[') and annot.endswith(']'), \
            f"Invalid RLE format at row {idx}: {annot}"

print("✓ Submission format validation passed!")
print(f"\nSubmission file ready: {OUTPUT_PATH}")
print(f"Total runtime: {elapsed:.2f}s ({elapsed/3600:.2f}h)")