# üß† Lightweight Mask R-CNN for Scientific Image Forgery Detection

# üßæ Summary

This notebook implements a complete scientific image forgery detection pipeline using PyTorch Mask R-CNN with a MobileNetV3 backbone for lightweight, efficient training and inference.

Data Analysis & Visualization

The dataset consists of authentic, forged, and masked scientific images.

Functions analyze folder structure, image size distributions, and visualize samples alongside corresponding masks.

Dataset Preparation

A custom ForgeryDataset class handles loading images and masks, aligning them, and generating bounding boxes for forged regions.

It supports both authentic and forged samples, automatically creating appropriate targets for Mask R-CNN training.

Model Definition

A lightweight Mask R-CNN model is built using MobileNetV3-Small as the backbone for efficiency.

The architecture is customized with appropriate anchor generators and ROI pooling layers.

Training Pipeline

The Trainer class manages the training and validation process using AdamW optimizer and StepLR scheduler.

The training loop computes loss per epoch and visualizes loss trends.

Data augmentations (flip, rotation, blur, brightness/contrast adjustment) are applied using Albumentations.

Visualization and Evaluation

Sample batches are visualized to confirm correct image-mask alignment.

Training and validation loss curves demonstrate learning progression.

In [None]:
# Scientific Image Forgery Detection Pipeline
# Complete implementation with Mask R-CNN

import os
import cv2
import json
import torch
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import albumentations as A
import matplotlib.pyplot as plt
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm
from collections import defaultdict
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import MaskRCNN
from sklearn.model_selection import train_test_split
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F_transforms

import warnings
warnings.filterwarnings('ignore')

# Checking GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class ScientificForgeryPipeline:
    def __init__(self, base_path):
        self.base_path = base_path
        self.device = device
        self.model = None
        self.paths = {}
        
    def analyze_data_structure(self):
        """Analyze the dataset structure and statistics"""
        base_path = self.base_path
        
        # Checking train images
        train_authentic_path = os.path.join(base_path, 'train_images/authentic')
        train_forged_path = os.path.join(base_path, 'train_images/forged')
        train_masks_path = os.path.join(base_path, 'train_masks')
        test_images_path = os.path.join(base_path, 'test_images')
        
        print("üìä Dataset Analysis:")
        print(f"‚úÖ Authentic images: {len(os.listdir(train_authentic_path))}")
        print(f"‚ùå Forged images: {len(os.listdir(train_forged_path))}")
        print(f"üé≠ Masks: {len(os.listdir(train_masks_path))}")
        print(f"üß™ Test images: {len(os.listdir(test_images_path))}")
        
        # Analyze mask format
        mask_files = os.listdir(train_masks_path)[:3]
        print(f"\nüîç Mask examples: {mask_files}")
        
        if mask_files:
            sample_mask = np.load(os.path.join(train_masks_path, mask_files[0]))
            print(f"üìê Mask shape: {sample_mask.shape}, dtype: {sample_mask.dtype}")
        
        self.paths = {
            'train_authentic': train_authentic_path,
            'train_forged': train_forged_path,
            'train_masks': train_masks_path,
            'test_images': test_images_path
        }
        
        return self.paths
    
    def get_unique_sizes(self, directory):
        """Get unique image sizes in directory"""
        size_counts = defaultdict(int)
        for root, _, files in os.walk(directory):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', 'JPG')):
                    try:
                        with Image.open(os.path.join(root, file)) as img:
                            size = img.size
                            size_counts[size] += 1
                    except Exception as e:
                        print(f"Error {file}: {e}")

        return size_counts
    
    def visualize_size_distribution(self):
        """Visualize image size distribution"""
        folders = [
            self.paths['train_authentic'],
            self.paths['train_forged'],
            self.paths['test_images']
        ]
        
        all_sizes = []
        
        for folder in folders:
            print(f"\nüìÇ Folder: {folder}")
            sizes = self.get_unique_sizes(folder)

            if not sizes:
                print("No images or mistake in code")
                continue
            
            sorted_sizes = sorted(sizes.items(), key=lambda x: x[1], reverse=True)

            print("‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê")
            print("‚îÇ  Width (px)  ‚îÇ Height (px) ‚îÇ Quantity ‚îÇ")
            print("‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§")
            for (w, h), count in sorted_sizes[:5]:  # Show top 5
                print(f"‚îÇ {w:<13} ‚îÇ {h:<13} ‚îÇ {count:<7} ‚îÇ")
            print("‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò")
            
            # Collect for overall stats
            for (w, h), count in sizes.items():
                all_sizes.extend([(w, h)] * count)
        
        # Overall statistics
        if all_sizes:
            sizes_df = pd.DataFrame(all_sizes, columns=['width', 'height'])
            print("\nüìà Overall Image Size Statistics:")
            print(sizes_df.describe())

            # Visualization
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 2, 1)
            plt.hist(sizes_df['width'], bins=20, alpha=0.7, color='blue')
            plt.title('Width Distribution')
            plt.xlabel('Width (px)')
            plt.ylabel('Frequency')

            plt.subplot(1, 2, 2)
            plt.hist(sizes_df['height'], bins=20, alpha=0.7, color='red')
            plt.title('Height Distribution')
            plt.xlabel('Height (px)')
            plt.ylabel('Frequency')

            plt.tight_layout()
            plt.show()
    
    def visualize_samples(self, num_samples=3):
        """Visualize sample images and masks"""
        authentic_files = sorted(os.listdir(self.paths['train_authentic']))[:num_samples]
        forged_files = sorted(os.listdir(self.paths['train_forged']))[:num_samples]
        mask_files = sorted(os.listdir(self.paths['train_masks']))[:num_samples]
        
        fig, axes = plt.subplots(3, num_samples, figsize=(15, 10))
        
        # Authentic images
        for i, file in enumerate(authentic_files):
            img_path = os.path.join(self.paths['train_authentic'], file)
            img = Image.open(img_path)
            axes[0, i].imshow(img)
            axes[0, i].set_title(f'Authentic: {file}')
            axes[0, i].axis('off')
            
        # Forged images
        for i, file in enumerate(forged_files):
            img_path = os.path.join(self.paths['train_forged'], file)
            img = Image.open(img_path)
            axes[1, i].imshow(img)
            axes[1, i].set_title(f'Forged: {file}')
            axes[1, i].axis('off')
            
        # Masks
        for i, file in enumerate(mask_files):
            mask_path = os.path.join(self.paths['train_masks'], file)
            mask = np.load(mask_path)
            mask = np.squeeze(mask)
            axes[2, i].imshow(mask, cmap='gray')
            axes[2, i].set_title(f'Mask: {file}')
            axes[2, i].axis('off')
            
        plt.tight_layout()
        plt.show()

class ForgeryDataset(Dataset):
    def __init__(self, authentic_path, forged_path, masks_path, transform=None, is_train=True):
        self.transform = transform
        self.is_train = is_train
        
        # Collect all data samples
        self.samples = []
        
        # Authentic images
        for file in os.listdir(authentic_path):
            img_path = os.path.join(authentic_path, file)
            base_name = file.split('.')[0]
            mask_path = os.path.join(masks_path, f"{base_name}.npy")
            
            self.samples.append({
                'image_path': img_path,
                'mask_path': mask_path,
                'is_forged': False,
                'image_id': base_name
            })
        
        # Forged images
        for file in os.listdir(forged_path):
            img_path = os.path.join(forged_path, file)
            base_name = file.split('.')[0]
            mask_path = os.path.join(masks_path, f"{base_name}.npy")
            
            self.samples.append({
                'image_path': img_path,
                'mask_path': mask_path,
                'is_forged': True,
                'image_id': base_name
            })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image = Image.open(sample['image_path']).convert('RGB')
        image = np.array(image)
        
        # Load and process mask
        if os.path.exists(sample['mask_path']):
            mask = np.load(sample['mask_path'])
            
            # Handle multi-channel masks
            if mask.ndim == 3:
                if mask.shape[0] <= 10:  # channels first (C, H, W)
                    mask = np.any(mask, axis=0)
                elif mask.shape[-1] <= 10:  # channels last (H, W, C)
                    mask = np.any(mask, axis=-1)
                else:
                    raise ValueError(f"Ambiguous 3D mask shape: {mask.shape}")
            
            mask = (mask > 0).astype(np.uint8)
        else:
            mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    
        # Shape validation
        assert image.shape[:2] == mask.shape, f"Shape mismatch: img {image.shape}, mask {mask.shape}"
        
        # Apply transformations
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        else:
            image = F_transforms.to_tensor(image)
            mask = torch.tensor(mask, dtype=torch.uint8)
        
        # Prepare targets for Mask R-CNN
        if sample['is_forged'] and mask.sum() > 0:
            boxes, labels, masks = self.mask_to_boxes(mask)
            
            target = {
                'boxes': boxes,
                'labels': labels,
                'masks': masks,
                'image_id': torch.tensor([idx]),
                'area': (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
                'iscrowd': torch.zeros((len(boxes),), dtype=torch.int64)
            }
        else:
            # For authentic images or images without masks
            target = {
                'boxes': torch.zeros((0, 4), dtype=torch.float32),
                'labels': torch.zeros(0, dtype=torch.int64),
                'masks': torch.zeros((0, image.shape[1], image.shape[2]), dtype=torch.uint8),
                'image_id': torch.tensor([idx]),
                'area': torch.zeros(0, dtype=torch.float32),
                'iscrowd': torch.zeros((0,), dtype=torch.int64)
            }
        
        return image, target
    
    def mask_to_boxes(self, mask):
        """Convert segmentation mask to bounding boxes for Mask R-CNN"""
        if isinstance(mask, torch.Tensor):
            mask_np = mask.numpy()
        else:
            mask_np = mask
        
        # Find contours in the mask
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        boxes = []
        masks = []
        
        for contour in contours:
            if len(contour) > 0:
                x, y, w, h = cv2.boundingRect(contour)
                # Filter out very small regions
                if w > 5 and h > 5:
                    boxes.append([x, y, x + w, y + h])
                    # Create binary mask for this contour
                    contour_mask = np.zeros_like(mask_np)
                    cv2.fillPoly(contour_mask, [contour], 1)
                    masks.append(contour_mask)
        
        if boxes:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.ones((len(boxes),), dtype=torch.int64)
            masks = torch.tensor(np.array(masks), dtype=torch.uint8)
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
            masks = torch.zeros((0, mask_np.shape[0], mask_np.shape[1]), dtype=torch.uint8)
        
        return boxes, labels, masks

class MaskRCNNModel:
    def __init__(self, num_classes=2, device=device):
        self.device = device
        self.model = None
        self.num_classes = num_classes
        
    def create_light_mask_rcnn(self):
        """Create lightweight Mask R-CNN with MobileNetV3 backbone"""
        backbone = torchvision.models.mobilenet_v3_small(pretrained=True).features
        backbone.out_channels = 576
        
        # Enhanced backbone
        backbone = nn.Sequential(
            backbone,
            nn.Conv2d(576, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        backbone.out_channels = 256
        
        # Anchor generator
        anchor_generator = AnchorGenerator(
            sizes=((16, 32, 64, 128),),
            aspect_ratios=((0.5, 1.0, 2.0),)
        )
        
        # ROI pools
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=7,
            sampling_ratio=2
        )
        
        mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=14,
            sampling_ratio=2
        )
        
        model = MaskRCNN(
            backbone,
            num_classes=self.num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            mask_roi_pool=mask_roi_pooler,
            min_size=256,
            max_size=256,
            rpn_pre_nms_top_n_train=1000,
            rpn_pre_nms_top_n_test=1000,
            rpn_post_nms_top_n_train=200,
            rpn_post_nms_top_n_test=200,
            box_detections_per_img=100
        )
        
        return model
    
    def initialize_model(self):
        """Initialize the model"""
        self.model = self.create_light_mask_rcnn()
        self.model.to(self.device)
        
        print(f"‚úÖ Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters")
        return self.model

class Trainer:
    def __init__(self, model, train_loader, val_loader, device, learning_rate=0.001):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.0001)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1)
        
        self.train_losses = []
        self.val_losses = []
        
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, targets) in enumerate(tqdm(self.train_loader, desc="üöÄ Training")):
            images = [img.to(self.device) for img in images]
            targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            
            # Forward pass
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # Backward pass
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            
            total_loss += losses.item()
        
        avg_loss = total_loss / len(self.train_loader)
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate_epoch(self):
        """Validate for one epoch"""
        self.model.train()  # For validation, we use train mode because of Mask R-CNN specifics
        total_loss = 0
        
        with torch.no_grad():
            for batch_idx, (images, targets) in enumerate(tqdm(self.val_loader, desc="üß™ Validating")):
                images = [img.to(self.device) for img in images]
                targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
                
                loss_dict = self.model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                total_loss += losses.item()
        
        avg_loss = total_loss / len(self.val_loader)
        self.val_losses.append(avg_loss)
        return avg_loss
    
    def train(self, epochs=10):
        """Full training loop"""
        print(f"üéØ Starting training for {epochs} epochs...")
        
        for epoch in range(epochs):
            print(f"\nüìç Epoch {epoch+1}/{epochs}")
            
            # Train
            train_loss = self.train_epoch()
            
            # Validate
            val_loss = self.validate_epoch()
            
            # Step scheduler
            self.scheduler.step()
            
            print(f"üìä Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
        self.plot_training_history()
    
    def plot_training_history(self):
        """Plot training and validation loss"""
        plt.figure(figsize=(10, 6))
        plt.plot(self.train_losses, label='Training Loss', marker='o')
        plt.plot(self.val_losses, label='Validation Loss', marker='s')
        plt.title('Training History')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.show()

def create_transforms():
    """Create data transformations"""
    train_transform = A.Compose([
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.GaussianBlur(blur_limit=3, p=0.3),
        A.RandomBrightnessContrast(p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    val_transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    return train_transform, val_transform

def visualize_batch_samples(dataloader, model=None, device=device):
    """Visualize batch samples with predictions if model provided"""
    images, targets = next(iter(dataloader))
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    for i in range(min(4, len(images))):
        # Original image
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        axes[0, i].imshow(img)
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Ground truth mask
        mask = torch.zeros_like(images[i][0])
        for target_mask in targets[i]['masks']:
            mask = torch.max(mask, target_mask.cpu())
        
        axes[1, i].imshow(mask, cmap='hot', alpha=0.7)
        axes[1, i].set_title(f'Ground Truth Mask {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Main execution
if __name__ == "__main__":
    # Initialize pipeline
    base_path = '/kaggle/input/recodai-luc-scientific-image-forgery-detection'
    pipeline = ScientificForgeryPipeline(base_path)
    
    # Analyze data
    paths = pipeline.analyze_data_structure()
    pipeline.visualize_size_distribution()
    pipeline.visualize_samples(num_samples=3)
    
    # Create datasets and dataloaders
    train_transform, val_transform = create_transforms()
    
    full_dataset = ForgeryDataset(
        paths['train_authentic'], 
        paths['train_forged'], 
        paths['train_masks'],
        transform=train_transform
    )
    
    # Split into train/val
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
    # Apply val transform to validation set
    val_dataset.dataset.transform = val_transform
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
    
    print(f"üì¶ Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}")
    
    # Visualize samples
    visualize_batch_samples(train_loader)
    
    # Initialize model
    model_handler = MaskRCNNModel(device=device)
    model = model_handler.initialize_model()
    
    # Train model
    trainer = Trainer(model, train_loader, val_loader, device, learning_rate=0.001)
    trainer.train(epochs=10)
    
    print("üéâ Training completed!")

# üß© Conclusion

This notebook successfully establishes an end-to-end framework for scientific image forgery detection.
It integrates data exploration, preprocessing, augmentation, model construction, and training in a cohesive pipeline.
The lightweight MobileNetV3 + Mask R-CNN architecture balances accuracy and computational efficiency, making it suitable for large-scale or resource-constrained environments.

Future improvements could include:

Fine-tuning hyperparameters and augmentations.

Incorporating more advanced backbones (e.g., Swin Transformer).

Adding metrics like IoU or F1-score for mask evaluation.

Overall, this project provides a strong foundation for automated forgery localization and authenticity verification in scientific imagery.