# Scientific Image Forgery Detection - Copy-Move Forgery Detection

This notebook implements EDA and model training for detecting copy-move forgeries in biomedical images.


In [None]:
# Install required packages
# pip install torch torchvision torchaudio
# pip install numpy pandas matplotlib seaborn opencv-python tqdm jupyter ipykernel
# pip install pillow scikit-image

### Required Packages:
- `torch` - PyTorch deep learning framework
- `torchvision` - Computer vision utilities for PyTorch
- `numpy` - Numerical computing
- `pandas` - Data manipulation and analysis
- `matplotlib` - Plotting and visualization
- `seaborn` - Statistical data visualization
- `opencv-python` - Computer vision library
- `tqdm` - Progress bars
- `jupyter` - Jupyter notebook environment
- `ipykernel` - IPython kernel for Jupyter
- `Pillow` - Image processing
- `scikit-image` - Image processing algorithms

**Note:** For GPU support, install PyTorch with CUDA. Visit [pytorch.org](https://pytorch.org/get-started/locally/) for the correct installation command for your system.


In [None]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Deep learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")


## 1. Exploratory Data Analysis (EDA)


In [None]:
# Define data paths for Kaggle
# Input data is in /kaggle/input/ - adjust the dataset name as needed
# Example: if your dataset is named 'copy-move-forgery-detection', use:
# input_dir = Path('/kaggle/input/copy-move-forgery-detection')
# Or if data is directly in input, use:
input_dir = Path('/kaggle/input/recodai-luc-scientific-image-forgery-detection')
base_dir = input_dir  # Change this to your specific dataset path if needed

train_authentic_dir = base_dir / 'train_images' / 'authentic'
train_forged_dir = base_dir / 'train_images' / 'forged'
train_masks_dir = base_dir / 'train_masks'
test_images_dir = base_dir / 'test_images'

# Working directory for outputs (models, submissions, etc.)
working_dir = Path('/kaggle/working')
model_path = working_dir / 'best_model.pth'
submission_path = working_dir / 'submission.csv'

# Count files
def count_files(directory, extension='.png'):
    if not directory.exists():
        return 0
    return len(list(directory.glob(f'*{extension}')))

authentic_count = count_files(train_authentic_dir)
forged_count = count_files(train_forged_dir)
mask_count = count_files(train_masks_dir, '.npy')
test_count = count_files(test_images_dir)

print(f"\nFile counts:")
print(f"Authentic images: {authentic_count}")
print(f"Forged images: {forged_count}")
print(f"Mask files: {mask_count}")
print(f"Test images: {test_count}")


In [None]:
# Load sample images and analyze
def load_sample_images(authentic_dir, forged_dir, masks_dir, n_samples=5):
    """Load sample images for analysis"""
    authentic_files = list(authentic_dir.glob('*.png'))[:n_samples]
    forged_files = list(forged_dir.glob('*.png'))[:n_samples]
    
    authentic_images = []
    forged_images = []
    masks = []
    
    for img_path in authentic_files:
        img = cv2.imread(str(img_path))
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            authentic_images.append((img_path.name, img))
    
    for img_path in forged_files:
        img = cv2.imread(str(img_path))
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            forged_images.append((img_path.name, img))
            
            # Try to load corresponding mask
            mask_path = masks_dir / img_path.name.replace('.png', '.npy')
            if mask_path.exists():
                mask = np.load(mask_path)
                masks.append((img_path.name, mask))
    
    return authentic_images, forged_images, masks

if train_authentic_dir.exists() and train_forged_dir.exists():
    authentic_samples, forged_samples, mask_samples = load_sample_images(
        train_authentic_dir, train_forged_dir, train_masks_dir, n_samples=3
    )
    print(f"Loaded {len(authentic_samples)} authentic samples")
    print(f"Loaded {len(forged_samples)} forged samples")
    print(f"Loaded {len(mask_samples)} mask samples")


In [None]:
# Analyze image statistics
def analyze_image_statistics(image_dir, label=''):
    """Analyze statistics of images in a directory"""
    image_files = list(image_dir.glob('*.png'))
    if len(image_files) == 0:
        return None
    
    heights = []
    widths = []
    channels = []
    
    for img_path in tqdm(image_files[:100], desc=f"Analyzing {label}"):  # Sample first 100
        img = cv2.imread(str(img_path))
        if img is not None:
            h, w, c = img.shape
            heights.append(h)
            widths.append(w)
            channels.append(c)
    
    if len(heights) == 0:
        return None
    
    stats = {
        'label': label,
        'count': len(image_files),
        'height_mean': np.mean(heights),
        'height_std': np.std(heights),
        'width_mean': np.mean(widths),
        'width_std': np.std(widths),
        'height_min': np.min(heights),
        'height_max': np.max(heights),
        'width_min': np.min(widths),
        'width_max': np.max(widths),
    }
    return stats

# Analyze authentic and forged images
if train_authentic_dir.exists():
    authentic_stats = analyze_image_statistics(train_authentic_dir, 'Authentic')
    if authentic_stats:
        print("\nAuthentic Images Statistics:")
        for key, value in authentic_stats.items():
            print(f"  {key}: {value}")

if train_forged_dir.exists():
    forged_stats = analyze_image_statistics(train_forged_dir, 'Forged')
    if forged_stats:
        print("\nForged Images Statistics:")
        for key, value in forged_stats.items():
            print(f"  {key}: {value}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def _to_rgb(img):
    """Ensure image is displayable by plt.imshow (H,W) or (H,W,3/4)."""
    arr = np.asarray(img)
    # If CHW, move to HWC
    if arr.ndim == 3 and arr.shape[0] in (1,3,4) and arr.shape[-1] not in (3,4):
        arr = np.moveaxis(arr, 0, -1)  # CHW -> HWC
    return arr

def _to_2d_mask(mask):
    """Make mask 2D (H,W) and normalized to [0,1]."""
    m = np.asarray(mask)

    # If CHW format with multiple channels -> take first channel or combine
    if m.ndim == 3 and m.shape[0] in (1, 2, 3, 4) and m.shape[-1] not in (1, 2, 3, 4):
        # (C, H, W) format where C is small
        if m.shape[0] == 1:
            m = np.squeeze(m, axis=0)  # (1,H,W) -> (H,W)
        elif m.shape[0] == 2:
            # For 2-channel masks, take the maximum (union) or first channel
            # Using maximum to combine both channels
            m = np.max(m, axis=0)  # (2,H,W) -> (H,W)
        else:
            # For 3+ channels, take first channel
            m = m[0]  # (C,H,W) -> (H,W)

    # If HWC with a single channel -> squeeze last
    elif m.ndim == 3 and m.shape[-1] == 1:
        m = np.squeeze(m, axis=-1)  # (H,W,1) -> (H,W)

    # If still 3D (e.g., RGB mask), convert to grayscale
    elif m.ndim == 3 and m.shape[-1] in (3,4):
        m = m[..., :3].mean(axis=-1)

    # Now expect 2D
    if m.ndim != 2:
        raise ValueError(f"Mask must be 2D after processing, got shape {m.shape}")

    # Normalize to [0,1]
    m = m.astype(np.float32)
    m_min, m_max = np.min(m), np.max(m)
    if m_max > m_min:
        m = (m - m_min) / (m_max - m_min)
    else:
        m = np.zeros_like(m, dtype=np.float32)

    return m

def visualize_samples(authentic_samples, forged_samples, mask_samples):
    """Visualize sample images and masks."""
    # Build quick lookup for forged images by name
    forged_dict = {name: img for name, img in forged_samples}

    n = 3
    n_auth = min(n, len(authentic_samples))
    n_forg = min(n, len(forged_samples))
    n_mask = min(n, len(mask_samples))

    fig, axes = plt.subplots(3, 3, figsize=(15, 15))

    # Row 0: authentic
    for idx in range(n_auth):
        name, img = authentic_samples[idx]
        axes[0, idx].imshow(_to_rgb(img))
        axes[0, idx].set_title(f'Authentic: {name}', fontsize=10)
        axes[0, idx].axis('off')
    for idx in range(n_auth, 3):
        axes[0, idx].axis('off')

    # Row 1: forged
    for idx in range(n_forg):
        name, img = forged_samples[idx]
        axes[1, idx].imshow(_to_rgb(img))
        axes[1, idx].set_title(f'Forged: {name}', fontsize=10)
        axes[1, idx].axis('off')
    for idx in range(n_forg, 3):
        axes[1, idx].axis('off')

    # Row 2: forged + mask overlay
    for idx in range(n_mask):
        name, mask = mask_samples[idx]
        forged_img = forged_dict.get(name, None)
        if forged_img is None:
            axes[2, idx].set_title(f'No forged image for: {name}', fontsize=10)
            axes[2, idx].axis('off')
            continue

        img_rgb = _to_rgb(forged_img)
        m2d = _to_2d_mask(mask)

        # If mask and image sizes differ, center-crop or pad is an option;
        # here we do a quick resize with numpy if shapes mismatch (nearest-neighbor).
        if m2d.shape[:2] != img_rgb.shape[:2]:
            # Simple nearest-neighbor resize without external deps
            ih, iw = img_rgb.shape[:2]
            mh, mw = m2d.shape
            ys = (np.linspace(0, mh-1, ih)).astype(int)
            xs = (np.linspace(0, mw-1, iw)).astype(int)
            m2d = m2d[ys][:, xs]

        axes[2, idx].imshow(img_rgb)
        axes[2, idx].imshow(m2d, alpha=0.5, cmap='Reds')  # mask is now (H,W) in [0,1]
        axes[2, idx].set_title(f'Forged + Mask: {name}', fontsize=10)
        axes[2, idx].axis('off')

    for idx in range(n_mask, 3):
        axes[2, idx].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
print("auth sample shape:", np.asarray(authentic_samples[0][1]).shape)
print("forged sample shape:", np.asarray(forged_samples[0][1]).shape)
print("mask sample shape:", np.asarray(mask_samples[0][1]).shape)

# Visualize sample images and masks
if 'authentic_samples' in locals() and 'forged_samples' in locals() and 'mask_samples' in locals():
    visualize_samples(authentic_samples, forged_samples, mask_samples)

## 2. Data Loading and Preprocessing


In [None]:
# Data augmentation functions
def augment_image(image, mask=None):
    """Apply random augmentations to image and mask"""
    # Random horizontal flip
    if np.random.random() > 0.5:
        image = cv2.flip(image, 1)
        if mask is not None:
            mask = cv2.flip(mask, 1)
    
    # Random vertical flip
    if np.random.random() > 0.5:
        image = cv2.flip(image, 0)
        if mask is not None:
            mask = cv2.flip(mask, 0)
    
    # Random rotation
    angle = np.random.uniform(-15, 15)
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    if mask is not None:
        mask = cv2.warpAffine(mask, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
    
    # Random brightness adjustment
    if np.random.random() > 0.5:
        brightness = np.random.uniform(0.8, 1.2)
        image = np.clip(image * brightness, 0, 255).astype(np.uint8)
    
    return image, mask

def preprocess_image(image, target_size=(512, 512)):
    """Preprocess image for model input"""
    # Resize
    image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
    # Normalize to [0, 1]
    image = image.astype(np.float32) / 255.0
    # Convert to tensor format (H, W, C) -> (C, H, W)
    image = np.transpose(image, (2, 0, 1))
    return image

def preprocess_mask(mask, target_size=(512, 512)):
    """Preprocess mask for model input"""
    # Resize
    mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
    # Ensure binary mask
    mask = (mask > 0.5).astype(np.float32)
    return mask


In [None]:
# Dataset class (using functional approach without class definition)
def create_dataset(image_dir, mask_dir=None, is_train=True, target_size=(512, 512)):
    """Create dataset from directory"""
    image_files = list(image_dir.glob('*.png'))
    dataset = []
    
    for img_path in image_files:
        item = {
            'image_path': str(img_path),
            'image_id': img_path.stem
        }
        
        if mask_dir is not None:
            mask_path = mask_dir / f"{img_path.stem}.npy"
            if mask_path.exists():
                item['mask_path'] = str(mask_path)
            else:
                item['mask_path'] = None
        else:
            item['mask_path'] = None
        
        dataset.append(item)
    
    return dataset

def load_image_and_mask(item, is_train=True, target_size=(512, 512)):
    """Load and preprocess image and mask"""
    # Load image
    image = cv2.imread(item['image_path'])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Load mask if available
    mask = None
    if item['mask_path'] is not None and os.path.exists(item['mask_path']):
        mask = np.load(item['mask_path'])
        # Handle different mask formats
        if len(mask.shape) == 3:
            # Check if it's (C, H, W) or (H, W, C) format
            # Priority: squeeze dimension of size 1 first
            if mask.shape[0] == 1:
                # (1, H, W) format - squeeze first dimension
                mask = mask.squeeze(0)
            elif mask.shape[-1] == 1:
                # (H, W, 1) format - squeeze last dimension
                mask = mask.squeeze(-1)
            elif mask.shape[0] < mask.shape[-1]:
                # Likely (C, H, W) format where C < H - take first channel
                mask = mask[0]
            elif mask.shape[-1] < mask.shape[0]:
                # Likely (H, W, C) format where C < H - take first channel from last dimension
                mask = mask[:, :, 0]
            else:
                # Ambiguous case - take first channel from last dimension
                mask = mask[:, :, 0]
        mask = mask.astype(np.float32)
    
    # Apply augmentation during training
    if is_train and mask is not None:
        image, mask = augment_image(image, mask)
    
    # Preprocess
    image = preprocess_image(image, target_size)
    if mask is not None:
        mask = preprocess_mask(mask, target_size)
    else:
        # Create empty mask for authentic images
        mask = np.zeros(target_size, dtype=np.float32)
    
    return image, mask

# Create datasets
if train_authentic_dir.exists() and train_forged_dir.exists():
    authentic_dataset = create_dataset(train_authentic_dir, None, is_train=True)
    forged_dataset = create_dataset(train_forged_dir, train_masks_dir, is_train=True)
    
    print(f"Authentic dataset size: {len(authentic_dataset)}")
    print(f"Forged dataset size: {len(forged_dataset)}")
    
    # Combine datasets
    full_dataset = authentic_dataset + forged_dataset
    print(f"Total dataset size: {len(full_dataset)}")


## 3. Model Architecture


In [None]:
# U-Net architecture components (functional approach)
def double_conv(in_channels, out_channels):
    """Double convolution block"""
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def create_unet_model(in_channels=3, n_classes=1):
    """Create U-Net model for segmentation"""
    # Encoder (downsampling path)
    enc1 = double_conv(in_channels, 64)
    enc2 = double_conv(64, 128)
    enc3 = double_conv(128, 256)
    enc4 = double_conv(256, 512)
    
    # Bottleneck
    bottleneck = double_conv(512, 1024)
    
    # Decoder (upsampling path)
    up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
    dec4 = double_conv(1024, 512)
    
    up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
    dec3 = double_conv(512, 256)
    
    up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
    dec2 = double_conv(256, 128)
    
    up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
    dec1 = double_conv(128, 64)
    
    # Final layer
    final_conv = nn.Conv2d(64, n_classes, 1)
    
    return {
        'enc1': enc1, 'enc2': enc2, 'enc3': enc3, 'enc4': enc4,
        'bottleneck': bottleneck,
        'up4': up4, 'dec4': dec4,
        'up3': up3, 'dec3': dec3,
        'up2': up2, 'dec2': dec2,
        'up1': up1, 'dec1': dec1,
        'final_conv': final_conv
    }

def forward_unet(model_dict, x):
    """Forward pass through U-Net"""
    # Encoder
    enc1 = model_dict['enc1'](x)
    x1 = nn.MaxPool2d(2)(enc1)
    
    enc2 = model_dict['enc2'](x1)
    x2 = nn.MaxPool2d(2)(enc2)
    
    enc3 = model_dict['enc3'](x2)
    x3 = nn.MaxPool2d(2)(enc3)
    
    enc4 = model_dict['enc4'](x3)
    x4 = nn.MaxPool2d(2)(enc4)
    
    # Bottleneck
    bottleneck = model_dict['bottleneck'](x4)
    
    # Decoder
    up4 = model_dict['up4'](bottleneck)
    up4 = torch.cat([up4, enc4], dim=1)
    dec4 = model_dict['dec4'](up4)
    
    up3 = model_dict['up3'](dec4)
    up3 = torch.cat([up3, enc3], dim=1)
    dec3 = model_dict['dec3'](up3)
    
    up2 = model_dict['up2'](dec3)
    up2 = torch.cat([up2, enc2], dim=1)
    dec2 = model_dict['dec2'](up2)
    
    up1 = model_dict['up1'](dec2)
    up1 = torch.cat([up1, enc1], dim=1)
    dec1 = model_dict['dec1'](up1)
    
    # Final output
    output = model_dict['final_conv'](dec1)
    return torch.sigmoid(output)

# Create model wrapper class for PyTorch compatibility
class UNet(nn.Module):
    def __init__(self, in_channels=3, n_classes=1):
        super(UNet, self).__init__()
        model_dict = create_unet_model(in_channels, n_classes)
        for key, value in model_dict.items():
            setattr(self, key, value)
        self.model_dict = model_dict
    
    def forward(self, x):
        return forward_unet(self.model_dict, x)

# Initialize model
model = UNet(in_channels=3, n_classes=1)
model = model.to(device)
print(f"Model created and moved to {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


## 4. Dataset and Loss Functions


In [None]:
# Custom Dataset class for PyTorch DataLoader
class ForgeryDataset(Dataset):
    def __init__(self, dataset, is_train=True, target_size=(512, 512)):
        self.dataset = dataset
        self.is_train = is_train
        self.target_size = target_size
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image, mask = load_image_and_mask(item, self.is_train, self.target_size)
        
        # Convert to tensors
        image_tensor = torch.from_numpy(image).float()
        mask_tensor = torch.from_numpy(mask).float().unsqueeze(0)  # Add channel dimension
        
        return image_tensor, mask_tensor

# Loss function - Dice Loss + BCE Loss
def dice_loss(pred, target, smooth=1e-6):
    """Dice loss for segmentation"""
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return 1 - dice

def combined_loss(pred, target, bce_weight=0.5):
    """Combined BCE and Dice loss"""
    bce = nn.functional.binary_cross_entropy(pred, target)
    dice = dice_loss(pred, target)
    return bce_weight * bce + (1 - bce_weight) * dice

# Metrics
def calculate_iou(pred, target, threshold=0.5):
    """Calculate IoU (Intersection over Union)"""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum() - intersection
    
    if union == 0:
        return 1.0
    
    iou = intersection / union
    return iou.item()

def calculate_dice_score(pred, target, threshold=0.5):
    """Calculate Dice score"""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    intersection = (pred_binary * target_binary).sum()
    dice = (2. * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-6)
    return dice.item()


In [None]:
# Split dataset into train and validation
def split_dataset(dataset, val_ratio=0.2):
    """Split dataset into train and validation sets"""
    np.random.shuffle(dataset)
    split_idx = int(len(dataset) * (1 - val_ratio))
    train_dataset = dataset[:split_idx]
    val_dataset = dataset[split_idx:]
    return train_dataset, val_dataset

if 'full_dataset' in locals() and len(full_dataset) > 0:
    train_data, val_data = split_dataset(full_dataset.copy(), val_ratio=0.2)
    
    # Create data loaders
    train_dataset = ForgeryDataset(train_data, is_train=True, target_size=(512, 512))
    val_dataset = ForgeryDataset(val_data, is_train=False, target_size=(512, 512))
    
    batch_size = 8 if device.type == 'cuda' else 4
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"Train samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")
    print(f"Batch size: {batch_size}")


## 5. Training Loop


In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    total_iou = 0.0
    total_dice = 0.0
    num_batches = 0
    
    for images, masks in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        with torch.no_grad():
            iou = calculate_iou(outputs, masks)
            dice = calculate_dice_score(outputs, masks)
        
        total_loss += loss.item()
        total_iou += iou
        total_dice += dice
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_iou = total_iou / num_batches
    avg_dice = total_dice / num_batches
    
    return avg_loss, avg_iou, avg_dice

# Validation function
def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0.0
    total_iou = 0.0
    total_dice = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            iou = calculate_iou(outputs, masks)
            dice = calculate_dice_score(outputs, masks)
            
            total_loss += loss.item()
            total_iou += iou
            total_dice += dice
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_iou = total_iou / num_batches
    avg_dice = total_dice / num_batches
    
    return avg_loss, avg_iou, avg_dice


In [None]:
# Main training loop
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=1e-4):
    """Main training function"""
    # Setup optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    criterion = combined_loss
    
    # Training history
    history = {
        'train_loss': [],
        'train_iou': [],
        'train_dice': [],
        'val_loss': [],
        'val_iou': [],
        'val_dice': []
    }
    
    best_val_loss = float('inf')
    current_lr = learning_rate
    
    print(f"Starting training for {num_epochs} epochs...")
    print(f"Device: {device}")
    print(f"Initial learning rate: {learning_rate}")
    print("-" * 50)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_iou, train_dice = train_epoch(model, train_loader, optimizer, criterion, device)
        
        # Validate
        val_loss, val_iou, val_dice = validate_epoch(model, val_loader, criterion, device)
        
        # Update learning rate
        old_lr = current_lr
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        if current_lr < old_lr:
            print(f"Learning rate reduced from {old_lr:.2e} to {current_lr:.2e}")
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_iou'].append(train_iou)
        history['train_dice'].append(train_dice)
        history['val_loss'].append(val_loss)
        history['val_iou'].append(val_iou)
        history['val_dice'].append(val_dice)
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Train Dice: {train_dice:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}, Val Dice: {val_dice:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), str(model_path))
            print(f"Saved best model (Val Loss: {val_loss:.4f})")
    
    return history

# Start training
if 'train_loader' in locals() and 'val_loader' in locals():
    print("\n" + "="*50)
    print("Starting Model Training")
    print("="*50)
    history = train_model(model, train_loader, val_loader, num_epochs=20, learning_rate=1e-4)
    print("\nTraining completed!")
else:
    print("Data loaders not available. Please run the previous cells first.")


## 6. Visualization and Evaluation


In [None]:
# Plot training history
def plot_training_history(history):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # IoU
    axes[1].plot(history['train_iou'], label='Train IoU')
    axes[1].plot(history['val_iou'], label='Val IoU')
    axes[1].set_title('IoU')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('IoU')
    axes[1].legend()
    axes[1].grid(True)
    
    # Dice
    axes[2].plot(history['train_dice'], label='Train Dice')
    axes[2].plot(history['val_dice'], label='Val Dice')
    axes[2].set_title('Dice Score')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Dice')
    axes[2].legend()
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot training history after training
if 'history' in locals():
    plot_training_history(history)
else:
    print("Training history not available. Please train the model first.")


In [None]:
# Visualize predictions
def visualize_predictions(model, val_loader, device, num_samples=5):
    """Visualize model predictions on validation set"""
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    with torch.no_grad():
        for idx, (images, masks) in enumerate(val_loader):
            if idx >= num_samples:
                break
            
            images = images.to(device)
            masks = masks.to(device)
            
            # Get predictions
            outputs = model(images)
            pred_masks = (outputs > 0.5).float()
            
            # Convert to numpy for visualization
            img = images[0].cpu().numpy().transpose(1, 2, 0)
            true_mask = masks[0, 0].cpu().numpy()
            pred_mask = pred_masks[0, 0].cpu().numpy()
            prob_mask = outputs[0, 0].cpu().numpy()
            
            # Plot
            axes[idx, 0].imshow(img)
            axes[idx, 0].set_title('Original Image')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(true_mask, cmap='gray')
            axes[idx, 1].set_title('True Mask')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(prob_mask, cmap='hot')
            axes[idx, 2].set_title('Predicted Probability')
            axes[idx, 2].axis('off')
            
            axes[idx, 3].imshow(pred_mask, cmap='gray')
            axes[idx, 3].set_title('Predicted Mask')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize predictions on validation set
if 'val_loader' in locals() and 'model' in locals():
    # Load best model if available
    if os.path.exists(str(model_path)):
        model.load_state_dict(torch.load(str(model_path)))
        print("Loaded best model for visualization")
    visualize_predictions(model, val_loader, device, num_samples=5)
else:
    print("Model or validation loader not available.")


## 7. Model Evaluation and Inference


In [None]:
# Run-Length Encoding (RLE) for mask encoding
def rle_encode(mask):
    """Encode binary mask to RLE string"""
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(rle_str, shape):
    """Decode RLE string to binary mask"""
    s = rle_str.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# Inference function for test images
def predict_test_image(model, image_path, device, target_size=(512, 512), threshold=0.5):
    """Predict forgery mask for a test image"""
    model.eval()
    
    # Load and preprocess image
    image = cv2.imread(str(image_path))
    if image is None:
        raise ValueError(f"Could not load image: {image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_size = image.shape[:2]
    
    # Preprocess
    processed_image = preprocess_image(image, target_size)
    image_tensor = torch.from_numpy(processed_image).float().unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        mask = output[0, 0].cpu().numpy()
    
    # Resize mask back to original size
    mask_resized = cv2.resize(mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_LINEAR)
    binary_mask = (mask_resized > threshold).astype(np.uint8)
    
    return mask_resized, binary_mask

def process_test_set(model, test_dir, device, output_file='submission.csv', threshold=0.5):
    """Process all test images and create submission file with RLE encoded masks"""
    test_files = list(test_dir.glob('*.png'))
    results = []
    
    if len(test_files) == 0:
        print(f"No test images found in {test_dir}")
        return pd.DataFrame()
    
    print(f"Processing {len(test_files)} test images...")
    
    for test_file in tqdm(test_files, desc="Processing test images"):
        case_id = test_file.stem
        try:
            mask_prob, binary_mask = predict_test_image(model, test_file, device, threshold=threshold)
            
            # Check if image is authentic (no significant forgery detected)
            forgery_ratio = binary_mask.sum() / (binary_mask.shape[0] * binary_mask.shape[1])
            
            if forgery_ratio < 0.01:  # Less than 1% of image is forged
                annotation = 'authentic'
                rle = ''  # No mask for authentic images
            else:
                annotation = 'forged'
                # Encode mask using RLE
                rle = rle_encode(binary_mask)
            
            results.append({
                'case_id': case_id,
                'annotation': annotation,
                'rle': rle
            })
        except Exception as e:
            print(f"Error processing {test_file}: {e}")
            results.append({
                'case_id': case_id,
                'annotation': 'authentic',  # Default to authentic on error
                'rle': ''
            })
    
    # Create submission DataFrame
    submission_df = pd.DataFrame(results)
    submission_df.to_csv(output_file, index=False)
    print(f"\nSubmission file saved to {output_file}")
    print(f"Summary:")
    print(f"  Authentic: {len(submission_df[submission_df['annotation'] == 'authentic'])}")
    print(f"  Forged: {len(submission_df[submission_df['annotation'] == 'forged'])}")
    
    return submission_df

# Process test set if model is trained
if 'model' in locals() and test_images_dir.exists():
    # Load best model if available
    if os.path.exists(str(model_path)):
        model.load_state_dict(torch.load(str(model_path), map_location=device))
        print("Loaded best model for inference")
        submission = process_test_set(model, test_images_dir, device, output_file=str(submission_path))
        if not submission.empty:
            print("\nSubmission preview:")
            print(submission.head(10))
    else:
        print("Best model not found. Please train the model first.")
else:
    print("Model or test directory not available.")


In [None]:
# Model evaluation utility
def evaluate_model(model, val_loader, device, model_path='best_model.pth'):
    """Evaluate model on validation set"""
    if model_path and os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Loaded model from {model_path}")
    elif model_path:
        print(f"Model file {model_path} not found. Using current model state.")
    
    model.eval()
    criterion = combined_loss
    
    val_loss, val_iou, val_dice = validate_epoch(model, val_loader, criterion, device)
    
    print("\n" + "="*50)
    print("Model Evaluation Results")
    print("="*50)
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation IoU:  {val_iou:.4f}")
    print(f"Validation Dice: {val_dice:.4f}")
    print("="*50)
    
    return val_loss, val_iou, val_dice

# Evaluate model if available
if 'val_loader' in locals() and 'model' in locals():
    if os.path.exists(str(model_path)):
        evaluate_model(model, val_loader, device, model_path=str(model_path))
    else:
        print("Best model not found. Evaluating current model state...")
        evaluate_model(model, val_loader, device, model_path=None)
else:
    print("Model or validation loader not available for evaluation.")
