# Hybrid Forgery Training + Evaluation
This notebook lets you tune every training hyperparameter, toggle ablation flags, and launch `run_training` with tqdm progress bars. After training, reuse the same configuration to load checkpoints, compute metrics (Dice/IoU/precision/recall/F1 + confusion matrix), and visualize 10 qualitative test samples with image / ground-truth / prediction / overlay columns.

In [None]:
%load_ext autoreload
%autoreload 2

from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np

try:
    import pandas as pd
except ImportError:
    pd = None

from train import TrainConfig, run_training
from model.hybrid_forgery_detector import HybridForgeryConfig
from evaluation.eval_utils import (
    collect_visual_samples,
    evaluate_split,
    load_model_from_checkpoint,
 )

In [None]:
default_train_cfg = TrainConfig()
default_model_cfg = HybridForgeryConfig()
print("Default TrainConfig:")
pprint(asdict(default_train_cfg))
print("Default HybridForgeryConfig:")
pprint(asdict(default_model_cfg))

In [None]:
train_config = TrainConfig(
    prepared_root="prepared/CASIA2",
    train_split="train",
    val_split="val",
    target_size=128,
    batch_size=8,
    num_epochs=10,
    learning_rate=1e-4,
    weight_decay=1e-2,
    grad_clip_norm=1.0,
    log_interval=10,
    checkpoint_dir="checkpoints",
    checkpoint_interval=1,
    save_best_only=True,
    use_amp=True,
    resume_from=None,
)

In [None]:
train_config.model_config = HybridForgeryConfig(
    use_efficientnet=True,
    use_swin=True,
    use_segformer=False,
    use_unet_decoder=True,
    use_skip_connections=True,
    pretrained_backbones=True,
    fused_channels=256,
)

In [None]:
print("Resolved device:", train_config.resolved_device())
print("TrainConfig overrides:")
pprint(asdict(train_config))
print("HybridForgeryConfig overrides:")
pprint(asdict(train_config.model_config))

### Optional Dry Run
Flip the flag in the next cell to execute a 1-epoch, few-batch sanity check (uses `max_train_batches` / `max_val_batches`) before kicking off the full training run.

In [None]:
ENABLE_DRY_RUN = False  # flip to True to quickly verify data -> model -> optimizer plumbing
if ENABLE_DRY_RUN:
    dry_run_config = deepcopy(train_config)
    dry_run_config.num_epochs = 1
    dry_run_config.batch_size = min(2, train_config.batch_size)
    dry_run_config.max_train_batches = 1
    dry_run_config.max_val_batches = 1
    dry_run_config.checkpoint_dir = str(Path(train_config.checkpoint_dir) / "dry_run")
    print("Dry run settings:", {
        "num_epochs": dry_run_config.num_epochs,
        "batch_size": dry_run_config.batch_size,
        "max_train_batches": dry_run_config.max_train_batches,
        "max_val_batches": dry_run_config.max_val_batches,
        "checkpoint_dir": dry_run_config.checkpoint_dir,
    })
    dry_run_history = run_training(dry_run_config)
else:
    print("Dry run skipped. Set ENABLE_DRY_RUN = True to execute the smoke test.")

In [None]:
history = run_training(train_config)
history

## Evaluation, Samples, and Ablations
Use the helpers below to load a checkpoint, compute aggregate metrics + confusion matrix on any split, and visualize 10 qualitative test samples aligned as image / ground-truth / prediction / overlay columns.

In [None]:
checkpoint_path = Path("checkpoints/best.pt")
evaluation_split = "test"
ablation_label = f"{checkpoint_path.stem}"
eval_device = train_config.resolved_device()
max_eval_batches = None
checkpoint_path

In [None]:
model, trained_config = load_model_from_checkpoint(checkpoint_path, device=eval_device)
evaluation_summary = evaluate_split(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    batch_size=trained_config.batch_size,
    device=eval_device,
    max_batches=max_eval_batches,
 )

print("Aggregate metrics:")
pprint(evaluation_summary.metrics)
print("\nConfusion matrix (rows=actual clean/tampered, cols=predicted clean/tampered):")
if pd is not None:
    display(pd.DataFrame(
        evaluation_summary.confusion_matrix,
        index=["Actual clean", "Actual tampered"],
        columns=["Pred clean", "Pred tampered"],
    ))
else:
    print(evaluation_summary.confusion_matrix)

if "ablation_results" not in globals():
    ablation_results = []

ablation_results.append({
    "label": ablation_label,
    **evaluation_summary.metrics,
})
if pd is not None:
    display(pd.DataFrame(ablation_results))
else:
    print(ablation_results)

In [None]:
num_preview_samples = 10
preview_samples = collect_visual_samples(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    num_samples=num_preview_samples,
    device=eval_device,
 )

columns = ["image", "ground_truth", "prediction", "overlay"]
rows = len(preview_samples)
if rows == 0:
    raise RuntimeError("No samples with ground-truth masks were found in the requested split.")
fig, axes = plt.subplots(rows, len(columns), figsize=(15, 3 * rows))
if rows == 1:
    axes = np.expand_dims(axes, axis=0)
for row_idx, sample in enumerate(preview_samples):
    for col_idx, key in enumerate(columns):
        axes[row_idx, col_idx].imshow(sample[key], cmap="gray" if key in {"ground_truth", "prediction"} else None)
        axes[row_idx, col_idx].set_title(f"{key.replace('_', ' ').title()} #{row_idx + 1}")
        axes[row_idx, col_idx].axis("off")
plt.tight_layout()