In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#!/usr/bin/env python3
"""
Scientific Image Forgery Detection - Complete Working Solution
Detects and segments copy-move forgeries in biomedical images
"""

import numpy as np
import pandas as pd
import os
import json
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import random
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import cv2
from PIL import Image
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration with optimized parameters"""
    
    # Paths
    BASE_PATH = Path('/kaggle/input/recodai-luc-scientific-image-forgery-detection')
    TRAIN_IMAGES_DIR = BASE_PATH / 'train_images'
    TRAIN_MASKS_DIR = BASE_PATH / 'train_masks'
    TEST_IMAGES_DIR = BASE_PATH / 'test_images'
    SAMPLE_SUB_PATH = BASE_PATH / 'sample_submission.csv'
    
    # Model parameters
    IMAGE_SIZE = 384
    BATCH_SIZE = 8 if torch.cuda.is_available() else 2
    VAL_BATCH_SIZE = 12 if torch.cuda.is_available() else 2
    NUM_WORKERS = 0
    
    # Training parameters
    EPOCHS = 12
    LEARNING_RATE = 2e-3
    WEIGHT_DECAY = 1e-5
    VALIDATION_SPLIT = 0.15
    
    # Loss weights
    DICE_WEIGHT = 0.6
    BCE_WEIGHT = 0.4
    
    # Detection thresholds
    CLASSIFICATION_THRESHOLD = 0.35
    SEGMENTATION_THRESHOLD = 0.45
    MIN_AREA = 100
    
    # Test Time Augmentation
    TTA_ENABLED = True

# ============================================================================
# DATA DISCOVERY
# ============================================================================

def discover_data():
    """Discover all training and test data"""
    
    config = Config()
    print("\n" + "="*70)
    print("DATA DISCOVERY")
    print("="*70)
    
    # Discover training images
    authentic_images = []
    forged_images = []
    
    authentic_dir = config.TRAIN_IMAGES_DIR / 'authentic'
    forged_dir = config.TRAIN_IMAGES_DIR / 'forged'
    
    # Check authentic images
    if authentic_dir.exists():
        authentic_images = sorted(list(authentic_dir.glob('*.[jpJP][npNP][gG]*')))
        print(f"\nAuthentic images found: {len(authentic_images)}")
    
    # Check forged images
    if forged_dir.exists():
        forged_images = sorted(list(forged_dir.glob('*.[jpJP][npNP][gG]*')))
        print(f"Forged images found: {len(forged_images)}")
    
    # Discover mask files
    mask_mapping = {}
    
    if config.TRAIN_MASKS_DIR.exists():
        mask_files = sorted(list(config.TRAIN_MASKS_DIR.glob('*.npy')))
        print(f"Mask files (.npy) found: {len(mask_files)}")
        
        # Create mask mapping
        for mask_file in mask_files:
            mask_stem = mask_file.stem
            
            # Match with forged images
            for forged_img in forged_images:
                if forged_img.stem == mask_stem:
                    if forged_img.stem not in mask_mapping:
                        mask_mapping[forged_img.stem] = []
                    mask_mapping[forged_img.stem].append(mask_file)
                    break
    
    print(f"Images with masks: {len(mask_mapping)}")
    
    # Discover test images
    test_images = sorted(list(config.TEST_IMAGES_DIR.glob('*.[jpJP][npNP][gG]*')))
    print(f"Test images found: {len(test_images)}")
    
    print("\n" + "="*70)
    print(f"Total training images: {len(authentic_images) + len(forged_images)}")
    print(f"  - Authentic: {len(authentic_images)}")
    print(f"  - Forged: {len(forged_images)}")
    print("="*70)
    
    return authentic_images, forged_images, mask_mapping, test_images

# ============================================================================
# DATASET - FIXED VERSION
# ============================================================================

class ForgeryDataset(Dataset):
    """Dataset for copy-move forgery detection with FIXED authentic/forged handling"""
    
    def __init__(self, authentic_paths: List[Path], forged_paths: List[Path], 
                 mask_mapping: Dict, image_size: int = 384, augment: bool = False):
        # Store authentic and forged separately
        self.authentic_paths = authentic_paths
        self.forged_paths = forged_paths
        self.all_paths = authentic_paths + forged_paths
        self.mask_mapping = mask_mapping
        self.image_size = image_size
        self.augment = augment
        
        print(f"\nDataset initialized:")
        print(f"  Total images: {len(self.all_paths)}")
        print(f"  Authentic images: {len(self.authentic_paths)}")
        print(f"  Forged images: {len(self.forged_paths)}")
    
    def __len__(self):
        return len(self.all_paths)
    
    def augment_image(self, image, mask):
        """Apply random augmentations"""
        # Random horizontal flip
        if random.random() > 0.5:
            image = cv2.flip(image, 1)
            mask = cv2.flip(mask, 1)
        
        # Random vertical flip
        if random.random() > 0.5:
            image = cv2.flip(image, 0)
            mask = cv2.flip(mask, 0)
        
        # Random rotation (90, 180, 270)
        if random.random() > 0.5:
            k = random.choice([1, 2, 3])
            image = np.rot90(image, k)
            mask = np.rot90(mask, k)
        
        # Random brightness/contrast
        if random.random() > 0.5:
            alpha = random.uniform(0.8, 1.2)  # Contrast
            beta = random.uniform(-20, 20)     # Brightness
            image = np.clip(alpha * image + beta, 0, 255).astype(np.uint8)
        
        return image, mask
    
    def __getitem__(self, idx):
        img_path = self.all_paths[idx]
        
        # Load image
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = img.shape[:2]
        
        # Determine if this is an authentic or forged image
        is_forged = img_path in self.forged_paths
        has_mask = img_path.stem in self.mask_mapping
        
        # Load mask
        if is_forged and has_mask:
            mask_paths = self.mask_mapping[img_path.stem]
            mask = np.zeros((orig_h, orig_w), dtype=np.uint8)
            
            for mask_path in mask_paths:
                # Load .npy mask
                single_mask = np.load(str(mask_path))
                
                # Ensure mask is 2D
                if single_mask.ndim > 2:
                    single_mask = single_mask[:, :, 0] if single_mask.shape[2] == 1 else single_mask.max(axis=2)
                
                # Resize if needed
                if single_mask.shape[:2] != (orig_h, orig_w):
                    single_mask = cv2.resize(single_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
                
                # Binary threshold
                single_mask = (single_mask > 0).astype(np.uint8)
                mask = np.maximum(mask, single_mask)
            
            label = 1  # Forged
        else:
            # Authentic image - no mask
            mask = np.zeros((orig_h, orig_w), dtype=np.uint8)
            label = 0  # Authentic
        
        # Apply augmentation
        if self.augment:
            img, mask = self.augment_image(img, mask)
        
        # Resize
        img = cv2.resize(img, (self.image_size, self.image_size))
        mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
        
        # Normalize image
        img = img.astype(np.float32) / 255.0
        
        # Convert to tensors
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).float()
        label_tensor = torch.tensor([label], dtype=torch.float32)
        
        return img_tensor, mask_tensor, label_tensor

# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================

class ConvBlock(nn.Module):
    """Convolutional block with BatchNorm and ReLU"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class ForgeryDetectionModel(nn.Module):
    """U-Net style model for segmentation with classification head"""
    
    def __init__(self, in_channels=3, num_classes=1):
        super().__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = ConvBlock(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        
        self.enc4 = ConvBlock(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(256, 512)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec4 = ConvBlock(512, 256)
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec3 = ConvBlock(256, 128)
        
        self.upconv2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = ConvBlock(128, 64)
        
        self.upconv1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = ConvBlock(64, 32)
        
        # Segmentation head
        self.seg_head = nn.Conv2d(32, num_classes, 1)
        
        # Classification head
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # Decoder with skip connections
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)
        
        # Segmentation output
        seg_out = self.seg_head(dec1)
        
        # Classification output
        pooled = self.global_pool(bottleneck)
        pooled = pooled.view(pooled.size(0), -1)
        cls_out = self.classifier(pooled)
        
        return seg_out, cls_out

# ============================================================================
# LOSS FUNCTIONS
# ============================================================================

class DiceLoss(nn.Module):
    """Dice loss for segmentation"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice

class CombinedLoss(nn.Module):
    """Combined loss for segmentation and classification"""
    def __init__(self, dice_weight=0.6, bce_weight=0.4):
        super().__init__()
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
    
    def forward(self, seg_pred, cls_pred, seg_target, cls_target):
        # Segmentation loss (only for forged images)
        seg_loss = self.dice_loss(seg_pred, seg_target)
        seg_bce = self.bce_loss(seg_pred, seg_target)
        seg_combined = self.dice_weight * seg_loss + self.bce_weight * seg_bce
        
        # Classification loss
        cls_loss = self.bce_loss(cls_pred, cls_target)
        
        # Combine losses
        total_loss = seg_combined + cls_loss
        
        return total_loss, seg_loss, cls_loss

# ============================================================================
# TRAINING
# ============================================================================

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_seg_loss = 0
    total_cls_loss = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, masks, labels in pbar:
        images = images.to(device)
        masks = masks.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        seg_out, cls_out = model(images)
        loss, seg_loss, cls_loss = criterion(seg_out, cls_out, masks, labels)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        total_seg_loss += seg_loss.item()
        total_cls_loss += cls_loss.item()
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'seg': f'{seg_loss.item():.4f}',
            'cls': f'{cls_loss.item():.4f}',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, masks, labels in tqdm(dataloader, desc="Validation"):
            images = images.to(device)
            masks = masks.to(device)
            labels = labels.to(device)
            
            seg_out, cls_out = model(images)
            loss, _, _ = criterion(seg_out, cls_out, masks, labels)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

# ============================================================================
# INFERENCE
# ============================================================================

def predict_with_tta(model, image, config):
    """Prediction with Test Time Augmentation"""
    model.eval()
    
    if not config.TTA_ENABLED:
        with torch.no_grad():
            seg_out, cls_out = model(image.unsqueeze(0))
        return seg_out[0], cls_out[0]
    
    predictions_seg = []
    predictions_cls = []
    
    # Original
    with torch.no_grad():
        seg, cls = model(image.unsqueeze(0))
        predictions_seg.append(seg[0])
        predictions_cls.append(cls[0])
    
    # Horizontal flip
    img_flip = torch.flip(image, [2])
    with torch.no_grad():
        seg, cls = model(img_flip.unsqueeze(0))
        seg = torch.flip(seg[0], [2])
        predictions_seg.append(seg)
        predictions_cls.append(cls[0])
    
    # Vertical flip
    img_flip = torch.flip(image, [1])
    with torch.no_grad():
        seg, cls = model(img_flip.unsqueeze(0))
        seg = torch.flip(seg[0], [1])
        predictions_seg.append(seg)
        predictions_cls.append(cls[0])
    
    # Average predictions
    seg_final = torch.stack(predictions_seg).mean(0)
    cls_final = torch.stack(predictions_cls).mean(0)
    
    return seg_final, cls_final

def rle_encode(mask):
    """Run-length encode a binary mask"""
    dots = np.where(mask.T.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

# ============================================================================
# MAIN PIPELINE
# ============================================================================

def main():
    config = Config()
    
    print("\n" + "="*70)
    print("SCIENTIFIC IMAGE FORGERY DETECTION")
    print("="*70)
    
    # Discover data
    authentic_images, forged_images, mask_mapping, test_images = discover_data()
    
    # Create dataset
    all_authentic = authentic_images
    all_forged = forged_images
    
    # Split into train and validation
    num_val_authentic = int(len(all_authentic) * config.VALIDATION_SPLIT)
    num_val_forged = int(len(all_forged) * config.VALIDATION_SPLIT)
    
    # Shuffle
    random.shuffle(all_authentic)
    random.shuffle(all_forged)
    
    train_authentic = all_authentic[num_val_authentic:]
    val_authentic = all_authentic[:num_val_authentic]
    
    train_forged = all_forged[num_val_forged:]
    val_forged = all_forged[:num_val_forged]
    
    print(f"\nDataset split:")
    print(f"  Training: {len(train_authentic)} authentic + {len(train_forged)} forged = {len(train_authentic) + len(train_forged)}")
    print(f"  Validation: {len(val_authentic)} authentic + {len(val_forged)} forged = {len(val_authentic) + len(val_forged)}")
    
    # Create datasets
    train_dataset = ForgeryDataset(
        train_authentic, train_forged, mask_mapping,
        image_size=config.IMAGE_SIZE, augment=True
    )
    
    val_dataset = ForgeryDataset(
        val_authentic, val_forged, mask_mapping,
        image_size=config.IMAGE_SIZE, augment=False
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config.VAL_BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True
    )
    
    # Initialize model
    print("\nInitializing model...")
    model = ForgeryDetectionModel().to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}")
    
    # Loss and optimizer
    criterion = CombinedLoss(
        dice_weight=config.DICE_WEIGHT,
        bce_weight=config.BCE_WEIGHT
    )
    
    optimizer = AdamW(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=config.LEARNING_RATE,
        epochs=config.EPOCHS,
        steps_per_epoch=len(train_loader),
        pct_start=0.3
    )
    
    # Training loop
    print("\nStarting training...")
    best_val_loss = float('inf')
    
    for epoch in range(config.EPOCHS):
        print(f"\nEpoch {epoch + 1}/{config.EPOCHS}")
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
        val_loss = validate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"  -> Best model saved (Val Loss: {val_loss:.4f})")
    
    # Load best model for inference
    print("\nLoading best model for inference...")
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()
    
    # Generate predictions
    print("\nGenerating predictions...")
    
    sample_sub = pd.read_csv(config.SAMPLE_SUB_PATH)
    results = []
    
    for idx, row in tqdm(sample_sub.iterrows(), total=len(sample_sub), desc="Predicting"):
        case_id = str(row['case_id'])
        
        # Find test image
        test_img_path = None
        for img_path in test_images:
            if img_path.stem == case_id:
                test_img_path = img_path
                break
        
        if test_img_path and test_img_path.exists():
            # Load image
            img = cv2.imread(str(test_img_path))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            orig_h, orig_w = img.shape[:2]
            
            # Preprocess
            img_resized = cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE))
            img_normalized = img_resized.astype(np.float32) / 255.0
            img_tensor = torch.from_numpy(img_normalized).permute(2, 0, 1).float().to(device)
            
            # Predict with TTA
            seg_output, cls_output = predict_with_tta(model, img_tensor, config)
            
            cls_prob = torch.sigmoid(cls_output).cpu().item()
            
            if cls_prob < config.CLASSIFICATION_THRESHOLD:
                results.append({'case_id': int(case_id), 'annotation': 'authentic'})
            else:
                seg_prob = torch.sigmoid(seg_output).cpu().numpy()[0]
                seg_mask = (seg_prob > config.SEGMENTATION_THRESHOLD).astype(np.uint8)
                seg_mask = cv2.resize(seg_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
                
                # Morphological operations
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
                seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
                seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel)
                
                if seg_mask.sum() > config.MIN_AREA:
                    # RLE encoding
                    run_lengths = rle_encode(seg_mask)
                    
                    if len(run_lengths) > 0:
                        results.append({
                            'case_id': int(case_id),
                            'annotation': json.dumps([int(x) for x in run_lengths])
                        })
                    else:
                        results.append({'case_id': int(case_id), 'annotation': 'authentic'})
                else:
                    results.append({'case_id': int(case_id), 'annotation': 'authentic'})
        else:
            results.append({'case_id': int(case_id), 'annotation': 'authentic'})
    
    # Create submission
    submission_df = pd.DataFrame(results)
    submission_df.to_csv('submission.csv', index=False)
    
    print("\n" + "="*70)
    print("SUBMISSION COMPLETE")
    print("="*70)
    print(f"Total predictions: {len(submission_df)}")
    print(f"Authentic: {(submission_df['annotation'] == 'authentic').sum()}")
    print(f"Forgeries: {(submission_df['annotation'] != 'authentic').sum()}")
    
    return submission_df

if __name__ == "__main__":
    submission = main()