# ResNet Encoder-Decoder Model for Copy-Move Forgery Detection

This notebook implements and trains a ResNet-based encoder-decoder model for detecting copy-move manipulated regions in images.

## Model Architecture
- **Encoder**: Pre-trained ResNet50 (ImageNet weights) - provides good feature extraction
- **Decoder**: Custom decoder with transposed convolutions - randomly initialized, needs training
- **Output**: Binary segmentation mask (1 channel)

## Why Training is Needed
Even though the encoder is pre-trained on ImageNet, we still need to train the entire model because:
1. **Decoder is not pre-trained** - it's randomly initialized and needs to learn how to reconstruct masks
2. **Task-specific adaptation** - the encoder needs to adapt from general object recognition to copy-move detection
3. **End-to-end learning** - the encoder and decoder need to work together for this specific task

**Pre-trained = Good starting point, not a finished model!**

## Dataset
- Uses processed data from `data/processed/` (already split into train/val/test 70/15/15)
- Images are 512x512 RGB PNG files
- Masks are binary segmentation masks


## Setup and Imports


In [None]:
import sys
from pathlib import Path
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Add src directory to path
current_dir = Path.cwd()
if current_dir.name == 'notebooks':
    project_root = current_dir.parent.parent.parent
    method_root = current_dir.parent
else:
    project_root = Path.cwd()
    method_root = project_root / 'methods' / 'deep_learning'

sys.path.append(str(method_root / 'src'))

# Import custom modules
from data_loader import get_data_loaders
from model2_resnet import create_resnet_model
from trainer import ModelTrainer
from evaluator import ModelEvaluator

print(f"Project root: {project_root}")
print(f"Method root: {method_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Check for MPS (Apple Silicon GPU)
if hasattr(torch.backends, 'mps'):
    print(f"MPS (Apple Silicon) available: {torch.backends.mps.is_available()}")
else:
    print("MPS not available in this PyTorch version")


## Load Configuration


In [None]:
# Load configuration
config_path = method_root / 'configs' / 'dl_config.json'
with open(config_path, 'r') as f:
    config = json.load(f)

# Resolve data paths relative to project root (fix relative path issues)
# Config paths are like "../../../data/processed/..." - need to extract actual path
for key in config['data_paths']:
    path = config['data_paths'][key]
    if not Path(path).is_absolute():
        # Count ../ levels and extract actual path
        levels_up = path.count('../')
        if levels_up > 0:
            # Remove ../ parts and get the actual path (e.g., "data/processed/train/images")
            actual_path = '/'.join(path.split('/')[levels_up:])
            config['data_paths'][key] = str(project_root / actual_path)
        else:
            config['data_paths'][key] = str(project_root / path)

print("Configuration loaded:")
print(json.dumps(config, indent=2))


## Setup Device and Data Loaders


In [None]:
# Setup device - check for CUDA, MPS (Apple Silicon), or CPU
if torch.cuda.is_available() and config['model_settings']['device'] == 'cuda':
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS (Apple Silicon GPU)")
else:
    device = torch.device('cpu')
    print("Using CPU (no GPU acceleration available)")

print(f"Using device: {device}")

# Verify data paths exist
print("\nVerifying data paths:")
for split in ['train', 'val', 'test']:
    img_path = Path(config['data_paths'][f'{split}_images'])
    mask_path = Path(config['data_paths'][f'{split}_masks'])
    img_exists = img_path.exists()
    mask_exists = mask_path.exists()
    img_count = len(list(img_path.glob('*.png'))) if img_exists else 0
    mask_count = len(list(mask_path.glob('*.png'))) if mask_exists else 0
    print(f"  {split}: Images={img_count} ({'✓' if img_exists else '✗'}), Masks={mask_count} ({'✓' if mask_exists else '✗'})")

# Create data loaders
print("\nLoading datasets...")
train_loader, val_loader, test_loader = get_data_loaders(config)

print(f"\nData loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Check a sample batch
sample_images, sample_masks = next(iter(train_loader))
print(f"\nSample batch shape:")
print(f"  Images: {sample_images.shape}")
print(f"  Masks: {sample_masks.shape}")
print(f"  Image value range: [{sample_images.min():.3f}, {sample_images.max():.3f}]")
print(f"  Mask value range: [{sample_masks.min():.3f}, {sample_masks.max():.3f}]")


## Create ResNet Encoder-Decoder Model


In [None]:
# Create ResNet encoder-decoder model
model = create_resnet_model(config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {config['model2_resnet']['name']}")
print(f"Backbone: {config['model2_resnet']['backbone']}")
print(f"Decoder channels: {config['model2_resnet']['decoder_channels']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
model.eval()
with torch.no_grad():
    test_output = model(sample_images[:2].to(device))
    print(f"\nTest forward pass:")
    print(f"  Input shape: {sample_images[:2].shape}")
    print(f"  Output shape: {test_output.shape}")


## Train Model


In [None]:
# Create trainer
trainer = ModelTrainer(
    model=model,
    config=config,
    device=device,
    model_name='resnet_model2'
)

# Train model
num_epochs = config['model_settings']['num_epochs']
trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=num_epochs
)


## Evaluate Model on Test Set


In [None]:
# Load best model
checkpoint_path = method_root / 'outputs' / 'models' / 'resnet_model2_best.pth'
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']}")
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
    print(f"Best validation IoU: {checkpoint['best_val_iou']:.4f}")
else:
    print("Best model checkpoint not found, using current model state")

# Create evaluator
output_dir = method_root / config['output_paths']['predictions']
results_dir = method_root / config['results_paths']['metrics']

evaluator = ModelEvaluator(
    model=model,
    device=device,
    output_dir=output_dir,
    results_dir=results_dir
)

# Evaluate on test set
test_metrics = evaluator.evaluate(
    test_loader=test_loader,
    save_predictions=config['evaluation']['save_predictions'],
    num_visualizations=config['evaluation']['num_visualizations']
)

# Print metrics
evaluator.print_metrics(test_metrics)


## Training History Visualization


In [None]:
import matplotlib.pyplot as plt

# Load training history
history_path = method_root / 'outputs' / 'models' / 'resnet_model2_history.json'
if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    # Plot training curves
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Validation Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # IoU plot
    axes[1].plot(history['train_iou'], label='Train IoU')
    axes[1].plot(history['val_iou'], label='Validation IoU')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('IoU')
    axes[1].set_title('Training and Validation IoU')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig(method_root / 'outputs' / 'visualizations' / 'training_history_resnet.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Training completed in {len(history['train_loss'])} epochs")
else:
    print("Training history not found")


## Summary

ResNet encoder-decoder model training and evaluation completed. Results saved to:
- **Models**: `outputs/models/`
- **Predictions**: `outputs/predictions/`
- **Visualizations**: `outputs/visualizations/`
- **Metrics**: `results/metrics/`
