# Simple EDA

This section re-uses https://www.kaggle.com/code/jirkaborovec/forgery-detection-eda-visual-annotations

In [None]:
import os
import glob
import warnings

PATH_DATASET = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
authentic_images = glob.glob(os.path.join(PATH_DATASET, 'train_images', 'authentic', '*.png'))
forged_images = glob.glob(os.path.join(PATH_DATASET, 'train_images', 'forged', '*.png'))

warnings.simplefilter(action='ignore', category=FutureWarning)
print(f"Found {len(authentic_images)} authentic images.")
print(f"Found {len(forged_images)} forged images.")

In [None]:
# Assuming the masks are in a 'train_masks' directory within the data_dir
mask_dir = os.path.join(PATH_DATASET, 'train_masks')

# Find all .npy files in the train_masks directory and store in a dictionary
mask_files_dict = {}
for mask_path in glob.glob(os.path.join(mask_dir, '*.npy')):
    basename = os.path.basename(mask_path)
    filename_without_extension, _ = os.path.splitext(basename) # Remove extension
    mask_files_dict[filename_without_extension] = mask_path

print(f"Found {len(mask_files_dict)} mask files and stored in a dictionary with keys as filenames without extensions.")
mask_dict = [f"{k}: {v}" for k, v in list(mask_files_dict.items())]
print("\n".join(mask_dict[:5]))

In [None]:
def load_mask(mask_path: str):
    mask_raw = np.load(mask_path)
    # Sum across the first dimension and binarize: 1 if any channel has a value > 0, 0 otherwise.
    mask = np.zeros_like(mask_raw[0, :, :], dtype=np.uint8)
    for c in range(mask_raw.shape[0]):
        mask[mask_raw[c, :, :] > 0] = c + 1
    return mask

# Define a list of colors for the different mask levels (excluding background 0)
# You can customize this list with more colors if you expect more levels
mask_colors = ['red', 'blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan']

## Overlap authentic and forged cases

In [None]:
# Get just the filenames without the path
authentic_filenames = [os.path.basename(img_path) for img_path in authentic_images]
forged_filenames = [os.path.basename(img_path) for img_path in forged_images]

# Find the intersection of the two sets of filenames
overlapping_filenames = list(set(authentic_filenames).intersection(forged_filenames))

print(f"Found {len(overlapping_filenames)} overlapping filenames in authentic and forged folders.")

In [None]:
# Create dictionaries mapping filename without extension to full path for quicker lookup
authentic_image_dict = {os.path.splitext(os.path.basename(img_path))[0]: img_path for img_path in authentic_images}
forged_image_dict = {os.path.splitext(os.path.basename(img_path))[0]: img_path for img_path in forged_images}

# Find filenames that exist in both authentic and forged sets (using keys without extensions)
overlapping_filenames_without_extension = list(set(authentic_image_dict.keys()).intersection(forged_image_dict.keys()))

# Create pairs of (authentic_path, forged_path, mask_path) for overlapping filenames
matching_pairs_with_mask = []
for filename_without_extension in overlapping_filenames_without_extension:
    authentic_path = authentic_image_dict[filename_without_extension]
    forged_path = forged_image_dict[filename_without_extension]
    # Check if a mask exists for this forged image (mask_files_dict already uses keys without extension)
    if filename_without_extension in mask_files_dict:
        mask_path = mask_files_dict[filename_without_extension]
        matching_pairs_with_mask.append((authentic_path, forged_path, mask_path))
    else:
        print(f"Warning: No mask found for forged image with filename (without extension): {filename_without_extension}")

print(f"Found {len(matching_pairs_with_mask)} matching image-mask pairs with the same filename (without extension).")

In [None]:
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

# Determine the number of pairs to show (up to 12 rows)
num_pairs_to_show = min(12, len(matching_pairs_with_mask))

# Select a random subset of matching pairs
random_matching_pairs_with_mask = random.sample(matching_pairs_with_mask, num_pairs_to_show)

# Create the grid (num_pairs_to_show rows, 3 columns)

for i in range(num_pairs_to_show):
    auth_img_path, forged_img_path, mask_path = random_matching_pairs_with_mask[i]
    auth_img = mpimg.imread(auth_img_path)
    forged_img = mpimg.imread(forged_img_path)
    # Load the mask as multilabel
    mask = load_mask(mask_path)
    levels = np.unique(mask)[:-1] + 0.5

    fig, axes = plt.subplots(1, 2, figsize=(14, 7)) # Adjust figsize as needed
    # Display authentic image in the first column
    axes[0].imshow(auth_img)
    # Find and draw contours on the second column axes using the mask
    axes[0].contour(mask, levels=levels, colors=mask_colors, linewidths=1)
    axes[0].axis('off')
    axes[0].set_title(f"Authentic + Mask: {os.path.basename(auth_img_path)}", fontsize=8)

    # Display forged image with mask contour in the second column
    axes[1].imshow(forged_img) # Display the forged image
    # Find and draw contours on the second column axes using the mask
    axes[1].contour(mask, levels=levels, colors=mask_colors, linewidths=1)
    axes[1].axis('off')
    axes[1].set_title(f"Forged + Mask: {os.path.basename(forged_img_path)}", fontsize=8)

plt.tight_layout()
plt.show()

## Discover how many instances are per image

In [None]:
from tqdm.auto import tqdm

all_mask_instances = []
for filename_without_extension, mask_path in tqdm(mask_files_dict.items()):
    mask = np.load(mask_path)
    all_mask_instances.append(mask.shape[0])

print(f"Loaded shapes for {len(all_mask_instances)} masks.")
print("Mask shapes:", set(all_mask_instances))

In [None]:
import collections
import seaborn as sns

# Count the occurrences of each instance count
instance_counts = collections.Counter(all_mask_instances)
sorted_instance_counts = dict(sorted(instance_counts.items()))

# Create a bar plot of the instance counts
plt.figure(figsize=(8, 3))
sns.barplot(x=list(sorted_instance_counts.keys()), y=list(sorted_instance_counts.values()))
plt.title("Mask Instance Counts")
plt.xlabel("Number of Instances in Mask")
plt.ylabel("Occurances")
plt.grid(axis='y', alpha=0.75)
plt.show()

## Explore the ratios of object to image size

In [None]:
# Initialize a list to store all area ratios
all_area_ratios = []

# Iterate through mask files
for filename_without_extension, mask_path in tqdm(mask_files_dict.items()):
    # Load the raw mask data (not using the load_mask function as we need individual layers)
    mask_raw = np.load(mask_path)
    # Get image dimensions from the mask shape (assuming mask and image have same dimensions)
    num_instances, height, width = mask_raw.shape
    total_image_area = height * width
    # List to store ratios for the current mask
    mask_area_ratios = []

    # Iterate through mask instances (layers)
    for instance_layer in mask_raw:
        # Calculate segmented area for the instance
        segmented_area = np.sum(instance_layer > 0)
        # Calculate area ratio
        area_ratio = segmented_area / total_image_area
        # Store the ratio
        mask_area_ratios.append(area_ratio)

    # Extend the main list with ratios from the current mask
    all_area_ratios.extend(mask_area_ratios)

print(f"Calculated area ratios for {len(all_area_ratios)} instances across all masks.")

In [None]:
# Analyze and visualize the distribution of area ratios
plt.figure(figsize=(10, 3))
sns.histplot(all_area_ratios, bins=50, kde=True) # Using 50 bins to show the distribution shape
plt.title("Distribution of Segmented Area Ratios")
plt.xlabel("Area Ratio (Segmented Area / Total Image Area)")
plt.ylabel("Frequency")
plt.grid(axis='y', alpha=0.75)
plt.show()

# Print some basic statistics about the area ratios
print("\nBasic statistics for area ratios:")
print(f"Mean: {np.mean(all_area_ratios):.4f}")
print(f"Median: {np.median(all_area_ratios):.4f}")
print(f"Standard Deviation: {np.std(all_area_ratios):.4f}")
print(f"Min: {np.min(all_area_ratios):.4f}")
print(f"Max: {np.max(all_area_ratios):.4f}")

# Feature-Based Self-Similarity Analysis for Copy-Move Detection

### Copy-Move Forgery Detection Solution

#### Problem
Detect copy-move forgeries in biomedical images where parts of an image are copied and pasted elsewhere in the same image to manipulate results.

#### Solution Overview

##### Architecture
**Self-Similarity Detection with Background Suppression**
```
Input Image (H×W×3)
    ↓
TIMM Backbone (ResNet/EfficientNet/etc.)
    ↓
Feature Extraction (C×H'×W')
    ↓
Self-Similarity Matrix (compare all spatial locations)
    ↓
Background Suppressor (filter uniform/repetitive regions)
    ↓
Copy-Move Detection (find distant similar regions)
    ↓
Refinement Network
    ↓
Output Heatmap (H×W) [0-1 probability]
```

##### Key Components

**1. TIMM Feature Extractor**
- Uses pretrained backbones (ResNet50, EfficientNet, ConvNeXt, etc.)
- Extracts normalized deep features
- Transfer learning from ImageNet

**2. Self-Similarity Matrix**
- Computes cosine similarity between all spatial locations
- High similarity = potential copied regions
- Formula: `similarity[i,j] = feature[i] · feature[j]`

**3. Background Suppressor** (Unsupervised)
- **Problem**: Uniform backgrounds (white/black) create false matches
- **Solution**: Compute feature complexity using:
  - Feature variance across channels
  - Local diversity (patch variations)
  - Gradient magnitude
- **Output**: Suppress matches in low-complexity regions
- **No labels needed**: Works automatically

**4. Copy-Move Detection**
- Filters out self-matches and nearby locations
- Only keeps high similarity at distant locations
- Spatial distance threshold prevents matching adjacent pixels
- Max similarity score = forgery probability

**5. Refinement Network**
- Small CNN to refine the heatmap
- Upsamples to original resolution
- Applies sigmoid for [0,1] probability

#### Training Strategy

##### Framework: PyTorch Lightning
- Automatic training loops
- Built-in callbacks (checkpointing, early stopping)
- Mixed precision training (16-bit)
- TorchMetrics for evaluation

##### Loss Function
**Combined Loss = α × Dice + (1-α) × BCE + β × Complexity**

- **Dice Loss**: Handles class imbalance (small forged regions)
- **BCE**: Binary cross-entropy for pixel-wise classification
- **Complexity Regularization**: Encourages high complexity in forged regions

##### Metrics (TorchMetrics)
- **IoU (Jaccard)**: Intersection over Union
- **F1 Score**: Harmonic mean of precision/recall
- **Precision**: Accuracy of forgery predictions
- **Recall**: Coverage of actual forgeries
- **Accuracy**: Overall correctness

##### Data Augmentation
- Horizontal/Vertical flips
- Rotation (±90°)
- Brightness/Contrast adjustment
- Noise and blur
- **Preserves spatial relationships** (important for copy-move detection)

#### Why This Works

**1. Self-Similarity is the Key Signal**
Copy-move creates identical patterns at different locations - this is unnatural and detectable through feature similarity.

**2. Background Suppression Prevents False Positives**
Scientific images often have uniform backgrounds that would otherwise match everywhere. By measuring feature complexity, we ignore these regions without needing labels.

**3. Deep Features are Robust**
- Invariant to small transformations (rotation, scaling, brightness)
- Capture semantic content, not just pixels
- Pretrained on ImageNet provides strong initialization

**4. Spatial Distance Filtering**
Only flag similarities at distant locations - prevents matching a pixel with itself or immediate neighbors.

**5. Heatmap Output > Binary Masks**
- Provides uncertainty/confidence
- Easier to threshold based on use case
- More interpretable for analysis

#### Technical Advantages

✅ **No foreground/background labels needed** - fully unsupervised suppression  
✅ **Flexible backbone** - easy to swap TIMM models  
✅ **End-to-end trainable** - loss directly optimizes detection  
✅ **Scalable** - PyTorch Lightning handles distributed training  
✅ **Production ready** - proper logging, checkpointing, metrics  

#### Limitations

❌ **Computational cost**: O(N²) similarity matrix for N pixels  
❌ **Small forgeries**: May miss very small copied regions  
❌ **Heavily compressed images**: JPEG artifacts can interfere  
❌ **Non-rigid transformations**: Assumes copy is similar in appearance  

#### Output

**Training**: Best model checkpoint saved based on validation IoU  
**Inference**: 
- `predictions/heatmaps/*.npy` - Probability maps [0-1]
- `predictions/masks/*.npy` - Binary masks (thresholded)
- `submission.csv` - Kaggle submission format

#### Performance Tips

1. **Increase image size** (512→768) for better detail
2. **Try different backbones** (efficientnet_b4, convnext_tiny)
3. **Adjust distance_threshold** (3→5) for larger images
4. **Ensemble multiple models** for robustness
5. **Lower threshold** (0.5→0.3) to catch more forgeries (higher recall)

In [None]:
# ! pip install -q pytorch_lightning torchmetrics timm albumentations

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
import torchmetrics
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryJaccardIndex, BinaryAccuracy
import timm

import cv2
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Tuple, Optional
from pathlib import Path

In [None]:
PATH_MODELS = "/kaggle/input/timm-resnets-weights/pytorch/default/2"
TRAIN_IMAGES_SUBDIR = 'train_images'
TRAIN_MASKS_SUBDIR = 'train_masks'
TEST_IMAGES_DIR = f'{PATH_DATASET}/test_images'
IMAGE_SIZE = 512
BATCH_SIZE = 32
NUM_EPOCHS = 45
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
BACKBONE = 'resnet50'
THRESHOLD = 0.5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_DIR = 'checkpoints'
OUTPUT_DIR = 'predictions'
TRAIN_VAL_SPLIT = 0.8
NUM_WORKERS = 4

## Data and DataModule

In [None]:
def load_mask(mask_path: str):
    """Load mask from .npy file with correct format"""
    mask_raw = np.load(mask_path)
    mask = np.zeros_like(mask_raw[0, :, :], dtype=np.uint8)
    for c in range(mask_raw.shape[0]):
        mask[mask_raw[c, :, :] > 0] = c + 1
    return mask


def get_train_transforms(image_size=512):
    """Training augmentations that preserve self-similarity"""
    return A.Compose([
        A.Resize(image_size, image_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=30, p=0.5),
        A.RandomBrightnessContrast(p=0.2), # Increased probability slightly
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.2),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=0, p=0.2),
        # Geometric transformations that preserve relative positions and shapes
        # A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.1),
        A.GridDistortion(p=0.1),
        # Removed noise and blur as they can break self-similarity
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])


def get_val_transforms(image_size=512):
    """Validation transforms"""
    return A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

In [None]:
class CopyMoveDataset(Dataset):
    """Dataset for copy-move forgery detection"""

    def __init__(self, image_dir, mask_dir=None, transform=None, image_size=512):
        # Directory containing image folders (authentic and forged)
        self.image_dir = Path(image_dir)
        # Directory containing mask files (optional)
        self.mask_dir = Path(mask_dir) if mask_dir else None
        # Albumentations transform to apply
        self.transform = transform
        # Target image size for resizing
        self.image_size = image_size
        # List to store paths of all image files
        self.image_files = []

        authentic_dir = self.image_dir / 'authentic' # Path to authentic images
        forged_dir = self.image_dir / 'forged' # Path to forged images

        if authentic_dir.exists():
            self.image_files.extend(sorted(list(authentic_dir.glob('*.png'))))
        if forged_dir.exists():
            self.image_files.extend(sorted(list(forged_dir.glob('*.png'))))

        if self.transform is None:
            # Default transform if none is provided
            self.transform = A.Compose([
                A.Resize(image_size, image_size), # Resize image
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2() # Convert to PyTorch tensor
            ])

    def __len__(self):
        return len(self.image_files) # Return the total number of image files

    def __getitem__(self, idx):
        img_path = self.image_files[idx] # Get image path by index
        image = cv2.imread(str(img_path)) # Read image using OpenCV
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert image from BGR to RGB

        is_authentic = 'authentic' in str(img_path) # Check if the image is authentic

        if self.mask_dir and not is_authentic:
            mask_path = self.mask_dir / (img_path.stem + '.npy') # Construct mask path
            if mask_path.exists():
                mask = load_mask(str(mask_path)) # Load mask using the load_mask function
                mask = (mask > 0).astype(np.float32) # Binarize the mask
            else:
                # Create an empty mask if no mask file found
                mask = np.zeros(image.shape[:2], dtype=np.float32)
        else:
            # Create an empty mask for authentic images
            mask = np.zeros(image.shape[:2], dtype=np.float32)

        # Ensure mask has the same spatial dimensions as the image before transform
        if mask.shape[:2] != image.shape[:2]:
            mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

        if self.transform:
            transformed = self.transform(image=image, mask=mask) # Apply transformations
            image = transformed['image']
            mask = transformed['mask']

        if len(mask.shape) == 2:
            mask = mask.unsqueeze(0) # Add a channel dimension to the mask if it's missing

        return image, mask, str(img_path.name) # Return image, mask, and image filename

In [None]:
class TransformDataset(Dataset):
    """Wrapper to apply transforms to subset"""

    def __init__(self, subset, transform):
        self.subset = subset # Subset of the original dataset
        self.transform = transform # Albumentations transform to apply

    def __len__(self):
        return len(self.subset) # Return the length of the subset

    def __getitem__(self, idx):
        # Get the original index in the full dataset from the subset
        img_path = self.subset.dataset.image_files[self.subset.indices[idx]]
        image = cv2.imread(str(img_path)) # Read image using OpenCV
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert image from BGR to RGB

        is_authentic = 'authentic' in str(img_path) # Check if the image is authentic
        if not is_authentic and self.subset.dataset.mask_dir:
            # Construct mask path based on the original dataset's mask directory and image filename
            mask_path = self.subset.dataset.mask_dir / (img_path.stem + '.npy')
            if mask_path.exists():
                # Load mask using the load_mask function
                mask = load_mask(str(mask_path))
                # Binarize the mask
                mask = (mask > 0).astype(np.float32)
            else:
                # Create an empty mask if no mask file found
                mask = np.zeros(image.shape[:2], dtype=np.float32)
        else:
            # Create an empty mask for authentic images
            mask = np.zeros(image.shape[:2], dtype=np.float32)

        if self.transform:
            # Apply transformations
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        if len(mask.shape) == 2:
            # Add a channel dimension to the mask if it's missing
            mask = mask.unsqueeze(0)

        return image, mask, str(img_path.name) # Return image, mask, and image filename

In [None]:
class CopyMoveDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule"""

    def __init__(
        self,
        data_dir,
        train_images_subdir='train_images',
        train_masks_subdir='train_masks',
        image_size=512,
        batch_size=4,
        num_workers=4,
        train_val_split=0.8
    ):
        super().__init__()
        self.data_dir = Path(data_dir) # Root directory of the dataset
        self.train_images_dir = self.data_dir / train_images_subdir
        self.train_masks_dir = self.data_dir / train_masks_subdir
        self.image_size = image_size # Target image size
        self.batch_size = batch_size # Batch size for dataloaders
        self.num_workers = num_workers # Number of workers for dataloaders
        # Ratio for splitting training and validation data
        self.train_val_split = train_val_split

    def setup(self, stage=None):
        # Create the full dataset
        full_dataset = CopyMoveDataset(
            self.train_images_dir,
            self.train_masks_dir,
            image_size=self.image_size
        )

        # Determine the sizes of training and validation sets
        train_size = int(self.train_val_split * len(full_dataset))
        val_size = len(full_dataset) - train_size
        # Get indices for training and validation sets
        train_indices = list(range(train_size))
        val_indices = list(range(train_size, len(full_dataset)))

        # Create Subset datasets for training and validation
        self.train_dataset = TransformDataset(
            Subset(full_dataset, train_indices),
            # Apply training transformations
            get_train_transforms(self.image_size)
        )
        self.val_dataset = TransformDataset(
            Subset(full_dataset, val_indices),
            # Apply validation transformations
            get_val_transforms(self.image_size)
        )
        print(f"Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}")

    def train_dataloader(self):
        # Create DataLoader for the training set
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True, # Shuffle training data
            num_workers=self.num_workers,
            # Pin memory for faster data transfer to GPU
            pin_memory=True,
            # Keep workers alive between epochs
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        # Create DataLoader for the validation set
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False, # Do not shuffle validation data
            num_workers=self.num_workers,
            # Pin memory for faster data transfer to GPU
            pin_memory=True,
            # Keep workers alive between epochs
            persistent_workers=True if self.num_workers > 0 else False
        )

In [None]:
# Create data module
data_module = CopyMoveDataModule(
    data_dir=PATH_DATASET,
    train_images_subdir=TRAIN_IMAGES_SUBDIR,
    train_masks_subdir=TRAIN_MASKS_SUBDIR,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_val_split=TRAIN_VAL_SPLIT
)

data_module.setup()

# You can now access the dataloaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Example of getting a batch (optional)
train_images, train_masks, train_filenames = next(iter(train_loader))
print(f"Shape of training images batch: {train_images.shape}")
print(f"Shape of training masks batch: {train_masks.shape}")
print(f"Shape of training filenames batch: {len(train_filenames)}")

In [None]:
train_dataloader = data_module.train_dataloader()
images, masks, filenames = next(iter(train_dataloader))

# Determine the number of samples to show (up to the batch size or a smaller number)
num_samples_to_show = min(4, images.shape[0])

# Display the samples # 2 columns: Image and Mask
fig, axes = plt.subplots(num_samples_to_show, 2, figsize=(10, num_samples_to_show * 5))

for i in range(num_samples_to_show):
    # Denormalize the image for display
    img = images[i].permute(1, 2, 0).cpu().numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1) # Clip values to be within [0, 1]

    mask = masks[i].squeeze().cpu().numpy() # Remove channel dimension for display

    # Display the image
    axes[i, 0].imshow(img)
    axes[i, 0].set_title(f"Image: {filenames[i]}", fontsize=8)
    axes[i, 0].axis('off')

    # Display the mask
    axes[i, 1].imshow(mask, cmap='gray') # Use grayscale cmap for binary mask
    axes[i, 1].set_title(f"Mask: {filenames[i]}", fontsize=8)
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

## Model architecture

The model used for copy-move forgery detection is a custom architecture named `CopyMoveHeatmapDetector`. This model is designed to take an image as input and output a heatmap indicating potential forged regions. The architecture consists of the following main components:

1.  **TIMMFeatureExtractor**: This module uses a pre-trained convolutional neural network (specified by the `backbone` parameter, e.g., ResNet50) to extract rich features from the input image. It extracts features from a specific layer of the backbone, which are then normalized.

2.  **BackgroundSuppressor**: This module aims to reduce the influence of background regions on the self-similarity calculation. It can either use a learnable network (`informativeness_net`) or compute feature complexity based on variance, local diversity, and gradient magnitude. The output is a complexity map used to weigh the self-similarity matrix.

3.  **Self-Similarity Computation**: The extracted features are used to compute a self-similarity matrix. This matrix represents the similarity between every pair of feature vectors in the image.

4.  **Spatial Distance Filtering**: A spatial distance mask is applied to the self-similarity matrix. This filter removes matches between feature vectors that are too close to each other in the image, focusing on potential copy-move regions that are spatially separated.

5.  **Forgery Detection**: The filtered similarity matrix is then processed to obtain initial forgery scores, which are reshaped into a heatmap.

6.  **Refinement Module**: A small convolutional network is used to refine the initial heatmap, further enhancing the detection of forged areas.

7.  **Upsampling**: The refined heatmap and the complexity map are upsampled to the desired output size.

Here is a simplified diagram of the architecture:

```
Input Image
      |
      V
TIMMFeatureExtractor
      |
      +-----------------------+
      V                       V
  Features           (for learnable suppression)
      |                       |
      +-----> BackgroundSuppressor
      |                       |
      V                       V
Self-Similarity <---+  Complexity Map
      |               |
      V               |
Spatial Distance Filter
      |
      V
Filtered Similarity
      |
      V
Forgery Detection
      |
      V
Initial Heatmap
      |
      V
Refinement Module
      |
      V
Final Heatmap
```

In [None]:
# # Define the model name you want to download
# model_name_to_download = 'resnet50' # You can change this to other timm models

# # Define the directory and filename for saving the weights
# save_dir = 'pretrained_models'
# model_filename = f'{model_name_to_download}_weights.pth'
# save_path = os.path.join(save_dir, model_filename)

# # Create the directory if it doesn't exist
# os.makedirs(save_dir, exist_ok=True)

# # Download the pretrained model
# print(f"Downloading {model_name_to_download} with pretrained weights...")
# model = timm.create_model(model_name_to_download, pretrained=True)
# print("Download successful.")

# # Save the state dictionary
# print(f"Saving model weights to {save_path}...")
# torch.save(model.state_dict(), save_path)
# print("Saving complete.")

In [None]:
class TIMMFeatureExtractor(nn.Module):
    """Feature extractor using TIMM models with optional local weights loading"""

    def __init__(
        self,
        model_name: str = 'resnet50', # Name of the TIMM model to use
        # Whether to use pretrained weights from timm (ignored if weights_path is provided)
        pretrained: bool = True,
        features_only: bool = True, # Whether to return only features
        out_indices: Tuple[int] = (3,), # Indices of the feature maps to return
        weights_path: Optional[str] = None # Path to local pretrained weights file
    ):
        super().__init__()

        # Determine whether to use timm's pretrained weights based on weights_path
        use_timm_pretrained = pretrained and (weights_path is None)

        self.backbone = timm.create_model(
            model_name,
            # Use timm pretrained only if no weights_path
            pretrained=use_timm_pretrained,
            features_only=features_only,
            out_indices=out_indices
        )

        if weights_path is not None:
            # Load weights from the local file if provided
            print(f"Loading weights from {weights_path}")
            self.backbone.load_state_dict(torch.load(weights_path))
            print("Successfully loaded weights from local file.")

        # Get feature dimensions
        # Create a dummy input to get feature dimensions
        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            # Pass dummy input through the backbone to get feature dimensions
            # Need to handle potential list/tuple output from features_only=True
            dummy_features = self.backbone(dummy_input)
            if isinstance(dummy_features, (list, tuple)):
                 # Assuming the last feature map is the one we'll use
                self.feature_dim = dummy_features[-1].shape[1]
            else:
                self.feature_dim = dummy_features.shape[1]


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x) # Extract features using the backbone

        if isinstance(features, (list, tuple)):
            features = features[-1] # Use the last feature map if multiple are returned

        features = F.normalize(features, p=2, dim=1) # Normalize features

        return features # Return extracted features

In [None]:
class BackgroundSuppressor(nn.Module):
    """Unsupervised background suppression based on feature complexity"""

    def __init__(self, feature_dim: int, learnable: bool = True):
        super().__init__()
        self.learnable = learnable

        # Optional, used only when learnable=True
        self.informativeness_net = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim // 4, 1), # 1x1 convolution to reduce channels
            nn.BatchNorm2d(feature_dim // 4), # Batch normalization
            nn.ReLU(inplace=True), # ReLU activation
            nn.Conv2d(feature_dim // 4, feature_dim // 8, 1), # Another 1x1 convolution
            nn.BatchNorm2d(feature_dim // 8), # Batch normalization
            nn.ReLU(inplace=True), # ReLU activation
            nn.Conv2d(feature_dim // 8, 1, 1), # 1x1 convolution to output a single channel
            nn.Sigmoid() # Sigmoid activation to get values between 0 and 1
        )

    def compute_feature_complexity(self, features: torch.Tensor) -> torch.Tensor:
        B, C, H, W = features.shape

        # Method 1: Feature variance
        variance = torch.var(features, dim=1, keepdim=True) # Calculate variance across channels

        # Method 2: Local feature diversity
        kernel_size = 3
        padding = kernel_size // 2
        unfold = nn.Unfold(kernel_size=kernel_size, padding=padding) # Unfold features into patches
        patches = unfold(features)
        patches = patches.view(B, C, kernel_size * kernel_size, H * W) # Reshape patches
        local_diversity = torch.var(patches, dim=2) # Calculate variance within patches
        local_diversity = local_diversity.view(B, C, H, W) # Reshape back to spatial dimensions
        local_diversity = torch.mean(local_diversity, dim=1, keepdim=True) # Average across channels

        # Method 3: Gradient magnitude
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                               dtype=features.dtype, device=features.device)
        sobel_y = sobel_x.t()
        # Convolve with Sobel filter for x-gradients
        grad_x = F.conv2d(features, sobel_x.view(1, 1, 3, 3).repeat(C, 1, 1, 1),
                         padding=1, groups=C)
        # Convolve with Sobel filter for y-gradients
        grad_y = F.conv2d(features, sobel_y.view(1, 1, 3, 3).repeat(C, 1, 1, 1),
                         padding=1, groups=C)
        gradient_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2) # Calculate gradient magnitude
        gradient_mag = torch.mean(gradient_mag, dim=1, keepdim=True) # Average across channels

        # Combine all metrics
        # Average the complexity metrics
        complexity = (variance + local_diversity + gradient_mag) / 3.0
        # Normalize to [0, 1]
        complexity = (complexity - complexity.min()) / (complexity.max() - complexity.min() + 1e-8)
        return complexity

    def forward(
        self,
        features: torch.Tensor,
        similarity_matrix: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        B, C, H, W = features.shape
        N = H * W

        if self.learnable:
            complexity = self.informativeness_net(features) # Use learnable network
        else:
            complexity = self.compute_feature_complexity(features) # Compute complexity

        complexity_flat = complexity.view(B, 1, N) # Flatten complexity map

        # Transpose for matrix multiplication
        complexity_product = torch.bmm(
            complexity_flat.transpose(1, 2),
            complexity_flat
        )

        # Apply suppression to similarity matrix
        suppressed_similarity = similarity_matrix * complexity_product
        # Return suppressed similarity and complexity map
        return suppressed_similarity, complexity

In [None]:
class CopyMoveHeatmapDetector(nn.Module):
    """Complete copy-move forgery detection network with heatmap output"""

    def __init__(
        self,
        backbone: str = 'resnet50', # Name of the backbone model (e.g., resnet50)
        pretrained: bool = True, # Whether to use a pretrained backbone
        output_size: Tuple[int, int] = (512, 512), # Desired output size of the heatmap
        # Minimum spatial distance between potential copy-move regions
        distance_threshold: int = 3,
        # Whether to use a learnable background suppressor
        learnable_suppression: bool = True,
        # Path to local pretrained weights file for the backbone
        weights_path: Optional[str] = None
    ):
        super().__init__()

        self.output_size = output_size
        self.distance_threshold = distance_threshold

        self.feature_extractor = TIMMFeatureExtractor(
            model_name=backbone,
            pretrained=pretrained,
            features_only=True,
            # Extract features from the last stage of the backbone
            out_indices=(3,),
            # Pass the weights_path to the feature extractor
            weights_path=weights_path
        )

        self.bg_suppressor = BackgroundSuppressor(
            feature_dim=self.feature_extractor.feature_dim,
            learnable=learnable_suppression
        )

        # Refinement module to process the initial heatmap
        self.refine = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), # Initial convolution
            nn.BatchNorm2d(32), # Batch normalization
            nn.ReLU(inplace=True), # ReLU activation
            nn.Conv2d(32, 16, 3, padding=1), # Second convolution
            nn.BatchNorm2d(16), # Batch normalization
            nn.ReLU(inplace=True), # ReLU activation
            nn.Conv2d(16, 1, 1), # Final 1x1 convolution to output a single channel
            # Removed Sigmoid here as BCEWithLogitsLoss is used
        )

    def compute_self_similarity(
        self,
        features: torch.Tensor
    ) -> Tuple[torch.Tensor, Tuple[int, int]]:
        B, C, H, W = features.shape
        N = H * W

        features_flat = features.view(B, C, N) # Flatten spatial dimensions

        # Compute self-similarity matrix
        similarity = torch.bmm(
            features_flat.transpose(1, 2), # Transpose to get shape (B, N, C)
            features_flat # Shape (B, C, N)
        )
        return similarity, (H, W) # Return similarity matrix and spatial shape

    def create_spatial_distance_mask(
        self,
        H: int,
        W: int,
        device: torch.device,
        threshold: int
    ) -> torch.Tensor:
        # Create coordinate grid
        y_coords = torch.arange(H, device=device).view(-1, 1).repeat(1, W)
        x_coords = torch.arange(W, device=device).view(1, -1).repeat(H, 1)
        coords = torch.stack([y_coords, x_coords], dim=-1).view(-1, 2).float() # Flatten coordinates

        # Compute pairwise spatial distances
        spatial_dist = torch.cdist(coords, coords, p=2)
        # Create mask where distance is greater than threshold
        mask = (spatial_dist > threshold).float()
        return mask

    def detect_copy_move(
        self,
        similarity_matrix: torch.Tensor,
        spatial_shape: Tuple[int, int]
    ) -> torch.Tensor:
        B, N, _ = similarity_matrix.shape
        H, W = spatial_shape

        # Create and expand spatial mask to the batch size
        spatial_mask = self.create_spatial_distance_mask(
            H, W,
            similarity_matrix.device,
            self.distance_threshold
        )
        spatial_mask = spatial_mask.unsqueeze(0).expand(B, -1, -1)

        # Apply spatial distance filter to similarity matrix
        filtered_similarity = similarity_matrix * spatial_mask

        # Get the maximum similarity score for each pixel (potential forgery score)
        forgery_scores, _ = torch.max(filtered_similarity, dim=2)

        # Reshape forgery scores back to spatial dimensions (heatmap)
        forgery_heatmap = forgery_scores.view(B, H, W).unsqueeze(1)
        return forgery_heatmap

    def forward(
        self,
        x: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        original_size = x.shape[2:] # Store original image size
        # Extract features
        features = self.feature_extractor(x)
        # Compute self-similarity
        similarity, spatial_shape = self.compute_self_similarity(features)
        # Apply background suppression
        suppressed_similarity, complexity = self.bg_suppressor(features, similarity)
        # Detect copy-move regions
        forgery_heatmap = self.detect_copy_move(suppressed_similarity, spatial_shape)

        # Upsample initial heatmap for refinement
        heatmap = F.interpolate(
            forgery_heatmap,
            # Upsample to half original size
            size=(original_size[0] // 2, original_size[1] // 2),
            mode='bilinear',
            align_corners=False
        )

        heatmap = self.refine(heatmap) # Refine the heatmap

        # Final upsampling to the target output size
        heatmap = F.interpolate(
            heatmap,
            size=self.output_size,
            mode='bilinear',
            align_corners=False
        )

        # Upsample complexity map to the target output size
        complexity_upsampled = F.interpolate(
            complexity,
            size=self.output_size,
            mode='bilinear',
            align_corners=False
        )
        # Return final heatmap and upsampled complexity map
        return heatmap, complexity_upsampled

In [None]:
class CopyMoveLightningModule(pl.LightningModule):
    """PyTorch Lightning module for copy-move detection"""

    def __init__(
        self,
        backbone='resnet50', # Name of the backbone model
        pretrained=False,
        image_size=512, # Target image size
        learning_rate=1e-4, # Learning rate for the optimizer
        weight_decay=1e-4, # Weight decay for the optimizer
        alpha=0.5, # Weight for Dice loss
        beta=0.1, # Weight for complexity regularization
        weights_path: Optional[str] = None # Path to local pretrained weights file
    ):
        super().__init__()
        self.save_hyperparameters() # Saves all __init__ parameters as hyperparameters

        self.model = CopyMoveHeatmapDetector(
            backbone=backbone,
            pretrained=pretrained,
            output_size=(image_size, image_size),
            distance_threshold=3,
            learnable_suppression=True,
            # Pass the weights_path to the detector
            weights_path=weights_path,
        )

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.alpha = alpha
        self.beta = beta

        # Define metrics using torchmetrics
        metrics = MetricCollection({
            'precision': BinaryPrecision(),
            'recall': BinaryRecall(),
            'f1': BinaryF1Score(),
            'iou': BinaryJaccardIndex(),
            'accuracy': BinaryAccuracy()
        })

        self.train_metrics = metrics.clone(prefix='train/') # Metrics for training
        self.val_metrics = metrics.clone(prefix='val/') # Metrics for validation

    def forward(self, x):
        return self.model(x) # Pass input through the detection model

    def dice_loss(self, pred, target, smooth=1.0):
        # Apply sigmoid to predictions for Dice loss since the model output is now logits
        pred_sigmoid = torch.sigmoid(pred)
        pred_flat = pred_sigmoid.contiguous().view(-1)
        target_flat = target.contiguous().view(-1)
        intersection = (pred_flat * target_flat).sum()
        dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
        # Return the Dice loss
        return 1 - dice

    def compute_loss(self, pred_heatmap, gt_mask, complexity_map):
        # Use BCEWithLogitsLoss for stability with autocasting
        bce_loss = F.binary_cross_entropy_with_logits(pred_heatmap, gt_mask.float())
        dice_loss = self.dice_loss(pred_heatmap, gt_mask.float())

        main_loss = self.alpha * dice_loss + (1 - self.alpha) * bce_loss # Combined main loss

        if complexity_map is not None:
            # Apply sigmoid to prediction heatmap for complexity regularization
            pred_heatmap_sigmoid = torch.sigmoid(pred_heatmap)
            # Complexity regularization loss
            complexity_reg = F.mse_loss(complexity_map * gt_mask, gt_mask.float())
            # Add complexity regularization to main loss
            main_loss = main_loss + self.beta * complexity_reg

        # Return total loss, Dice loss, and BCE loss
        return main_loss, dice_loss, bce_loss

    def training_step(self, batch, batch_idx):
        images, masks, _ = batch # Unpack the batch
        pred_heatmap, complexity_map = self(images) # Forward pass

        loss, dice, bce = self.compute_loss(pred_heatmap, masks, complexity_map) # Compute loss

        # Apply sigmoid for metric calculation and logging
        pred_binary = (torch.sigmoid(pred_heatmap) > 0.5).long() # Binarize predictions
        target_binary = masks.long() # Convert target mask to long
        metrics = self.train_metrics(pred_binary, target_binary) # Compute training metrics

        self.log('train/loss', loss, prog_bar=True, on_step=True) # Log training loss
        self.log('train/dice_loss', dice, on_step=True) # Log training Dice loss
        self.log('train/bce_loss', bce, on_step=True) # Log training BCE loss
        self.log_dict(metrics, prog_bar=True, on_step=True) # Log training metrics
        return loss # Return the loss

    def validation_step(self, batch, batch_idx):
        images, masks, _ = batch # Unpack the batch
        pred_heatmap, complexity_map = self(images) # Forward pass

        loss, dice, bce = self.compute_loss(pred_heatmap, masks, complexity_map) # Compute loss

        # Apply sigmoid for metric calculation and logging
        pred_binary = (torch.sigmoid(pred_heatmap) > 0.5).long() # Binarize predictions
        target_binary = masks.long() # Convert target mask to long
        metrics = self.val_metrics(pred_binary, target_binary) # Compute validation metrics

        self.log('val/loss', loss, prog_bar=True, on_step=True) # Log validation loss
        self.log('val/dice_loss', dice, on_step=True) # Log validation Dice loss
        self.log('val/bce_loss', bce, on_step=True) # Log validation BCE loss
        self.log_dict(metrics, prog_bar=True, on_step=True) # Log validation metrics
        return loss # Return the loss

    def on_train_epoch_end(self):
        self.train_metrics.reset() # Reset training metrics at the end of each epoch

    def on_validation_epoch_end(self):
        self.val_metrics.reset() # Reset validation metrics at the end of each epoch

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        ) # AdamW optimizer

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        ) # Cosine Annealing learning rate scheduler

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        } # Return optimizer and learning rate scheduler configuration

In [None]:
# Create model
model = CopyMoveLightningModule(
    backbone=BACKBONE,
    image_size=IMAGE_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    weights_path=f"{PATH_MODELS}/{BACKBONE}_weights.pth",
)

## Training

In [None]:
# Setup callbacks for the PyTorch Lightning trainer
checkpoint_callback = ModelCheckpoint(
    # Filename format for saved checkpoints
    filename='best-{epoch:03d}-{val/iou:.4f}',
    # Metric to monitor for saving the best model
    monitor='val/iou',
    # Save checkpoint when the monitored metric is maximized
    mode='max',
    # Save the top 3 models based on the monitored metric
    save_top_k=3,
    # Save the last checkpoint
    save_last=True
)

# Metric to monitor for early stopping
early_stop_callback = EarlyStopping(
    # Number of epochs with no improvement after which training will be stopped
    monitor='val/iou',
    patience=10,
    # Stop training when the monitored metric is maximized
    mode='max',
    # Print message when early stopping is triggered
    verbose=True
)

# Log the learning rate at the end of each epoch
lr_monitor = LearningRateMonitor(
    logging_interval='epoch',
)
# Logger to save training logs in CSV format
logger = CSVLogger(
    save_dir='logs',
    name='copy_move',
)

In [None]:
print(f"Backbone: {BACKBONE} | Epochs: {NUM_EPOCHS} | Batch: {BATCH_SIZE}")

# Create trainer
trainer = pl.Trainer(
    # Set the maximum number of training epochs
    max_epochs=NUM_EPOCHS,
    # Automatically select the accelerator (GPU, CPU, etc.)
    accelerator='auto',
    # Add defined callbacks
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    # Use the configured logger for logging training progress
    logger=logger,
    # Log training metrics every 10 steps
    log_every_n_steps=10,
    # Use mixed precision training on GPU
    precision='16-mixed' if DEVICE == 'cuda' else 32,
    # Clip gradients to prevent exploding gradients
    gradient_clip_val=1.0,
    # Accumulate gradients over 6 batches before updating model weights
    accumulate_grad_batches=6,
)

# Train
trainer.fit(model, data_module) # Start the training process
print(f"✓ Training complete. Best model: {checkpoint_callback.best_model_path}")

In [None]:
import seaborn as sns
import pandas as pd
sns.set()

# Read the metrics.csv using the trainer's logger directory
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

# Remove the step column and set epoch as index
# metrics.set_index("step", inplace=True)
display(metrics.dropna(axis=1, how="all").head())

# Melt the DataFrame to long-form for plotting
metrics_melted = metrics.reset_index().melt(id_vars='epoch', var_name='metric', value_name='value')

# Define metric groups
metric_groups = {
    'Loss': [c for c in metrics.columns if "loss" in c],
    'Metrics': [c for c in metrics.columns if any(m in c for m in ["precision", "recall", "f1", "iou", "accuracy"])],
}


# Plot metrics for each group in a separate chart
for title, metric_list in metric_groups.items():
    # Filter melted DataFrame for the current group
    group_metrics = metrics_melted[metrics_melted['metric'].isin(metric_list)]

    plt.figure(figsize=(10, 5))
    sns.lineplot(data=group_metrics, x='epoch', y='value', hue='metric')
    plt.title(f'{title} over Epochs', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel(title, fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Move legend outside and to the right
    # plt.yscale('log')  # Set y-axis to logarithmic scale
plt.show()

## Inference

In [None]:
# Load best model
best_checkpoint = checkpoint_callback.best_model_path
model = CopyMoveHeatmapDetector(
    backbone=BACKBONE,
    pretrained=False,
    output_size=(IMAGE_SIZE, IMAGE_SIZE),
    distance_threshold=3,
    learnable_suppression=True
)

checkpoint = torch.load(best_checkpoint, map_location=DEVICE)
state_dict = {k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model = model.to(DEVICE)
model.eval()

In [None]:
def post_process_mask(mask, kernel_size=5, min_area_percentage=0.01):
    """
    Applies morphological closing and filters small objects from a binary mask.

    Args:
        mask (np.ndarray): Binary mask (uint8).
        kernel_size (int): Size of the kernel for morphological operations. If None, no morphological closing is applied.
        min_area_percentage (float): Minimum area of connected components as a percentage of total image area. If None, no area filtering is applied.

    Returns:
        np.ndarray: Processed binary mask.
    """
    processed_mask = mask.copy()

    # Apply morphological closing if kernel_size is provided
    if kernel_size is not None and kernel_size > 0:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel)

    # Filter small objects if min_area_percentage is provided
    if min_area_percentage is not None and min_area_percentage > 0:
        # Find connected components
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(processed_mask, 8, cv2.CV_32S)

        # Calculate the minimum area in pixels based on the percentage
        total_image_area = mask.shape[0] * mask.shape[1]
        min_area_pixels = total_image_area * min_area_percentage

        # Create a new mask with only large components
        filtered_mask = np.zeros_like(mask, dtype=np.uint8)
        for i in range(1, num_labels):  # Start from 1 to exclude the background
            if stats[i, cv2.CC_STAT_AREA] >= min_area_pixels:
                filtered_mask[labels == i] = 255 # Keep components larger than min_area
        processed_mask = filtered_mask

    # Ensure the output is binary (0 or 1)
    return (processed_mask > 0).astype(np.uint8)

In [None]:
# Create output directories
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(exist_ok=True, parents=True)
masks_dir = output_dir / 'masks'
heatmaps_dir = output_dir / 'heatmaps'
masks_dir.mkdir(exist_ok=True)
heatmaps_dir.mkdir(exist_ok=True)

# Prepare transform
transform = get_val_transforms(IMAGE_SIZE)

### Validation Set

In [None]:
# Prepare validation DataLoader
val_loader = data_module.val_dataloader()

# Create output directories for validation results if they don't exist
val_output_dir = Path(OUTPUT_DIR) / 'validation'
val_masks_dir = val_output_dir / 'masks'
val_heatmaps_dir = val_output_dir / 'heatmaps'
val_output_dir.mkdir(exist_ok=True, parents=True)
val_masks_dir.mkdir(exist_ok=True)
val_heatmaps_dir.mkdir(exist_ok=True)

# Prepare transform (using validation transforms)
transform = get_val_transforms(IMAGE_SIZE)

# List to store results for threshold optimization
threshold_results = []

# Define a range of thresholds to test
thresholds_to_test = np.linspace(0, 1, 51) # Test 51 thresholds from 0 to 1

# Initialize metrics for threshold optimization
threshold_metrics = MetricCollection({
    'precision': BinaryPrecision(),
    'recall': BinaryRecall(),
    'f1': BinaryF1Score(),
    'iou': BinaryJaccardIndex(),
    'accuracy': BinaryAccuracy()
}).to(DEVICE)

print("Optimizing threshold on Validation Set...")

for threshold in tqdm(thresholds_to_test, desc="Testing Thresholds"):
    all_preds = []
    all_targets = []

    for images, masks, filenames in val_loader:
        with torch.no_grad():
            images = images.to(DEVICE)
            pred_heatmap_tensor, _ = model(images)

        # Apply the current threshold to the heatmap
        binary_mask_tensor = (torch.sigmoid(pred_heatmap_tensor) > threshold).long()

        all_preds.append(binary_mask_tensor.cpu())
        all_targets.append(masks.long().cpu())

    # Concatenate results from all batches
    all_preds = torch.cat(all_preds, dim=0).view(-1)
    all_targets = torch.cat(all_targets, dim=0).view(-1)

    # Compute metrics for the current threshold
    current_metrics = threshold_metrics(all_preds.to(DEVICE), all_targets.to(DEVICE))

    threshold_results.append({
        'threshold': threshold,
        'precision': current_metrics['precision'].item(),
        'recall': current_metrics['recall'].item(),
        'f1': current_metrics['f1'].item(),
        'iou': current_metrics['iou'].item(),
        'accuracy': current_metrics['accuracy'].item()
    })
    threshold_metrics.reset() # Reset metrics for the next threshold

# Convert results to DataFrame and find the best threshold
threshold_df = pd.DataFrame(threshold_results)
# Find the threshold that maximizes IoU (or another metric like F1)
best_threshold_row = threshold_df.loc[threshold_df['iou'].idxmax()]
best_threshold = best_threshold_row['threshold']

print(f"\nBest Threshold found based on IoU: {best_threshold:.4f}")
print("Metrics at best threshold:")
print(best_threshold_row)

In [None]:
# Plot metrics vs. threshold using pandas plotting
plt.figure(figsize=(12, 6))

# Select the metric columns and plot them directly from the DataFrame
threshold_df.plot(x='threshold', y=['precision', 'recall', 'f1', 'iou', 'accuracy'], ax=plt.gca())

plt.title('Validation Metrics vs. Threshold', fontsize=14, fontweight='bold')
plt.xlabel('Threshold', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Now, run inference again using the best threshold to save masks and heatmaps
results = []
display_samples = []
num_display_samples = 12 # Number of samples to display

for images, masks, filenames in tqdm(val_loader, desc="Predicting on Validation Set (Best Threshold)"):
    with torch.no_grad():
        images = images.to(DEVICE)
        pred_heatmap_tensor, _ = model(images)

    for i in range(images.shape[0]):
        img = images[i].cpu().numpy().transpose(1, 2, 0)
        # Denormalize the image for display
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1) # Clip values to be within [0, 1]

        gt_mask = masks[i].squeeze().cpu().numpy()
        heatmap = pred_heatmap_tensor[i, 0].cpu().numpy()

        # Resize heatmap to original image size for visualization and post-processing
        original_size = img.shape[:2]
        heatmap_resized = cv2.resize(heatmap, (original_size[1], original_size[0]))
        # Apply the best threshold to get the binary mask (0/255 for post_processing)
        binary_mask = (heatmap_resized > best_threshold).astype(np.uint8) * 255
        # Apply post-processing
        processed_mask = post_process_mask(binary_mask, kernel_size=None, min_area_percentage=None)

        # Store sample for display
        if len(display_samples) < num_display_samples:
            display_samples.append({
                'image': img,
                'gt_mask': gt_mask,
                'heatmap': heatmap_resized,
                'predicted_mask': processed_mask,
                'filename': filenames[i]
            })

        # Save heatmap and mask
        filename_stem = Path(filenames[i]).stem
        np.save(val_heatmaps_dir / f"{filename_stem}.npy", heatmap_resized)
        np.save(val_masks_dir / f"{filename_stem}.npy", processed_mask)

        results.append({
            'image_name': filenames[i],
            'has_forgery': processed_mask.sum() > 0,
            'forgery_percentage': (processed_mask.sum() / (processed_mask.size * 255)) * 100
        })

results_df = pd.DataFrame(results)
results_df.to_csv(val_output_dir / 'validation_predictions_summary.csv', index=False)

# You can now use the best_threshold for inference on the test set.
# Update the THRESHOLD constant or pass best_threshold to the test inference loop.
THRESHOLD = best_threshold
print(f"Updated global THRESHOLD to: {THRESHOLD:.4f}")
print(f"\n✓ Forged: {results_df['has_forgery'].sum()}, Authentic: {(~results_df['has_forgery']).sum()}")

In [None]:
print(f"✓ Validation Inference complete. Forged: {results_df['has_forgery'].sum()}, Authentic: {(~results_df['has_forgery']).sum()}")

for sample in display_samples:
    fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # 3 columns: GT Mask, Image with Contours, Predicted Mask

    # Display Ground Truth Mask
    axes[0].imshow(sample['gt_mask'], cmap='gray')
    axes[0].set_title(f"Ground Truth Mask\n{sample['filename']}", fontsize=10)
    axes[0].axis('off')

    # Display Original Image with GT and Predicted Mask Contours
    axes[1].imshow(sample['image'])
    # Draw GT contours using contour
    axes[1].contour(sample['gt_mask'], levels=[0.5], colors=['green'], linewidths=1) # Green for GT
    # Draw Predicted contours using contour
    axes[1].contour(sample['predicted_mask'], levels=[128], colors=['red'], linewidths=1) # Red for Predicted (assuming 0/255 mask)
    axes[1].set_title(f"Image with Contours (GT: Green, Pred: Red)\n{sample['filename']}", fontsize=10)
    axes[1].axis('off')

    # Display Predicted Binary Mask
    axes[2].imshow(sample['predicted_mask'], cmap='gray')
    axes[2].set_title("Predicted Binary Mask", fontsize=10)
    axes[2].axis('off')

    plt.tight_layout()
plt.show()

### Test Set

In [None]:
# Predict on test set
test_img_dir = Path(TEST_IMAGES_DIR)
print(f"Checking test image directory: {test_img_dir}")
test_image_files = sorted(list(test_img_dir.glob('*.png')))
print(f"Found {len(test_image_files)} test image files.")
results = []

print(f"Running Inference on {len(list(Path(TEST_IMAGES_DIR).glob('*.png')))} images...")

for img_path in tqdm(test_image_files, desc="Predicting"):
    image = cv2.imread(str(img_path))
    original_size = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    transformed = transform(image=image_rgb)
    image_tensor = transformed['image'].unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        heatmap_tensor, _ = model(image_tensor)

    heatmap = heatmap_tensor[0, 0].cpu().numpy()
    heatmap = cv2.resize(heatmap, (original_size[1], original_size[0]))
    binary_mask = (heatmap > THRESHOLD).astype(np.uint8) * 255 # Convert to 0/255 for post_processing

    # Apply post-processing
    # Using 0.01% of the image area as the minimum area threshold
    processed_mask = post_process_mask(binary_mask, min_area_percentage=0.0005)

    np.save(heatmaps_dir / f"{img_path.stem}.npy", heatmap)
    np.save(masks_dir / f"{img_path.stem}.npy", processed_mask) # Save processed mask

    results.append({
        'image_name': img_path.name,
        'has_forgery': processed_mask.sum() > 0,
        'forgery_percentage': (processed_mask.sum() / (processed_mask.size * 255)) * 100 # Adjust calculation for 0/255 mask
    })

results_df = pd.DataFrame(results)
results_df.to_csv(output_dir / 'predictions_summary.csv', index=False)

print(f"✓ Inference complete. Forged: {results_df['has_forgery'].sum()}, Authentic: {(~results_df['has_forgery']).sum()}")

In [None]:
# Assuming inference has been run and outputs are saved in OUTPUT_DIR
output_dir = Path(OUTPUT_DIR)
masks_dir = output_dir / 'masks'
heatmaps_dir = output_dir / 'heatmaps'
test_img_dir = Path(TEST_IMAGES_DIR)

# Get a list of test image files and corresponding saved mask/heatmap files
test_image_files = sorted(list(test_img_dir.glob('*.png')))
saved_mask_files = sorted(list(masks_dir.glob('*.npy')))
saved_heatmap_files = sorted(list(heatmaps_dir.glob('*.npy')))

# Create a list of tuples (image_path, heatmap_path, mask_path) for files that exist
display_files = []
for img_path in test_image_files:
    stem = img_path.stem
    heatmap_path = heatmaps_dir / f"{stem}.npy"
    mask_path = masks_dir / f"{stem}.npy"
    if heatmap_path.exists() and mask_path.exists():
        display_files.append((img_path, heatmap_path, mask_path))

In [None]:
# Determine the number of samples to show
num_samples_to_show = min(5, len(display_files)) # Show up to 5 samples

# Select random samples
random_display_files = random.sample(display_files, num_samples_to_show)

# Display predictions for selected samples
for img_path, heatmap_path, mask_path in random_display_files:
    image = cv2.imread(str(img_path))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Load saved heatmap and mask
    heatmap_resized = np.load(heatmap_path)
    binary_mask = np.load(mask_path)

    # Create a figure with 3 subplots: Original Image, Heatmap, Binary Mask
    fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Adjust figsize as needed

    # Display Original Image
    axes[0].imshow(image_rgb)
    axes[0].set_title(f"Original Image\n{img_path.name}", fontsize=10)
    axes[0].axis('off')

    # Display Heatmap
    heatmap_display = axes[1].imshow(heatmap_resized, cmap='viridis', vmin=0, vmax=1)
    axes[1].set_title("Predicted Heatmap", fontsize=10)
    axes[1].axis('off')
    fig.colorbar(heatmap_display, ax=axes[1]) # Add colorbar for heatmap

    # Display Binary Mask
    axes[2].imshow(binary_mask, cmap='gray') # Use grayscale cmap for binary mask
    axes[2].set_title("Predicted Binary Mask", fontsize=10)
    axes[2].axis('off')

    plt.tight_layout()
plt.show()

## Submission

You must create a row in the submission file for each image. For images that do not contain a copy-move forgery predict the string "authentic". For all others submit run length encoded masks as serialized using the rle_encode function in the metric. The file should contain a header and have the following format:

```
case_id,annotation
1,authentic
2,"[123, 4]"
```

In [None]:
def rle_encode(mask, fg_val=1):
    """Convert binary mask to RLE using the competition metric format"""
    dots = np.where(mask.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    
    return run_lengths

In [None]:
# Assuming inference has been run and outputs are saved in OUTPUT_DIR
output_dir = Path(OUTPUT_DIR)
masks_dir = output_dir / 'masks'
test_img_dir = Path(TEST_IMAGES_DIR) # Need test_img_dir to get all image names

# Get a list of all test image filenames (without extension)
test_image_stems = sorted([img_path.stem for img_path in test_img_dir.glob('*.png')])

submissions = []

for img_stem in tqdm(test_image_stems, desc="Creating submission"):
    mask_path = masks_dir / f"{img_stem}.npy"

    if mask_path.exists():
        mask = np.load(mask_path)
        rle_annotation = rle_encode(mask)
        print(rle_annotation)
        if rle_annotation: # If any pixels are marked as forgery
            # Ensure the annotation is correctly formatted with quotes if it's an RLE string
            submissions.append({'case_id': img_stem, 'annotation': f'"{repr(rle_annotation)}"'})
        else: # For authentic images, the annotation should be 'authentic' without quotes
            submissions.append({'case_id': img_stem, 'annotation': 'authentic'})
    else:
        # If no mask file exists (e.g., for an authentic image if they were in test set)
        # Based on previous EDA, test images are all forged according to train_masks count.
        # However, for robustness, we can assume no mask means authentic if needed.
        # For this competition, test images should have corresponding masks.
        print(f"Warning: No mask found for test image {img_stem}. Assuming authentic.")
        submissions.append({'case_id': img_stem, 'annotation': 'authentic'})

submission_df = pd.DataFrame(submissions)

In [None]:
from pprint import pprint

def write_submission_csv(data, filename):
    """Writes submission data to a CSV file using writelines by generating lines upfront."""
    lines = ["case_id,annotation\n"]  # Header line
    for row in data:
        annotation = row['annotation']
        lines.append(f"{row['case_id']},{annotation}\n")
    # pprint(lines)
    with open(filename, 'w') as f:
        f.writelines(lines)


# Assuming 'submissions' list is already created from the previous inference step
# (or you can re-run the inference part to generate it)
write_submission_csv(submissions, 'submission.csv')

In [None]:
!head submission.csv