# Hybrid Forgery Training + Evaluation
This notebook lets you tune every training hyperparameter, toggle ablation flags, and launch `run_training` with tqdm progress bars. After training, reuse the same configuration to load checkpoints, compute metrics (Dice/IoU/precision/recall/F1 + confusion matrix), and visualize 10 qualitative test samples with image / ground-truth / prediction / overlay columns.

In [None]:
import os
import subprocess
from pathlib import Path

REPO_URL = "https://github.com/DeogenesMaranan/LIFD"
COLAB_REPO_DIR = Path("/content/LIFD")


def running_in_colab() -> bool:
    try:
        import google.colab  # type: ignore  # noqa: F401
        return True
    except ImportError:
        return False


USE_COLAB = running_in_colab()

if USE_COLAB:
    if not COLAB_REPO_DIR.exists():
        print(f"Cloning {REPO_URL} -> {COLAB_REPO_DIR}")
        subprocess.run(["git", "clone", REPO_URL, str(COLAB_REPO_DIR)], check=True)
    else:
        print(f"Repository already exists at {COLAB_REPO_DIR}")
    os.chdir(COLAB_REPO_DIR)
    print(f"Working directory set to {Path.cwd()}")
else:
    print("Colab environment not detected; using current local working directory.")

In [None]:
DRIVE_BASE_DIR = "/content/drive/MyDrive/LIFD"
if USE_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except ModuleNotFoundError as exc:
        raise RuntimeError("google.colab is not available. Set USE_COLAB=False to bypass Drive mounting.") from exc
else:
    DRIVE_BASE_DIR = "."

In [None]:
%load_ext autoreload
%autoreload 2

from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pprint
import os
import shutil
from zipfile import ZipFile

import matplotlib.pyplot as plt
import numpy as np
import torch

try:
    import pandas as pd
except ImportError:
    pd = None

try:
    import requests
except ImportError:
    requests = None

try:
    from huggingface_hub import hf_hub_download
except ImportError:
    hf_hub_download = None

try:
    from google.colab import userdata
except ImportError:
    userdata = None

torch.set_float32_matmul_precision("high")

CPU_COUNT = os.cpu_count() or 4
DATA_WORKERS = max(2, min(8, CPU_COUNT - 1))
GPU_AVAILABLE = torch.cuda.is_available()
TRAIN_BATCH_SIZE = 32 if GPU_AVAILABLE else 8
GRAD_ACCUM_STEPS = 1 if GPU_AVAILABLE else 2
PREFETCH_FACTOR = 4 if DATA_WORKERS > 0 else None

BASE_PATH = Path(DRIVE_BASE_DIR)
if USE_COLAB:
    DATASET_ROOT = Path("/content/data/CASIA2")
else:
    DATASET_ROOT = Path("prepared") / "CASIA2"

HF_REPO_ID = "juhenes/image-forgery-detection"
HF_FILE_NAME = "CASIA2.zip"
HF_FILE_URL = None
HF_TOKEN = userdata.get("HUGGINGFACE_TOKEN") if userdata is not None else os.environ.get("HUGGINGFACE_TOKEN")


def _download_hf_zip(dest_path: Path) -> Path:
    if HF_FILE_URL:
        if requests is None:
            raise RuntimeError("requests is required to download from HF_FILE_URL.")
        headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else None
        response = requests.get(HF_FILE_URL, stream=True, headers=headers)
        response.raise_for_status()
        with open(dest_path, "wb") as out_file:
            for chunk in response.iter_content(chunk_size=1 << 20):
                if chunk:
                    out_file.write(chunk)
        response.close()
        return dest_path
    if hf_hub_download is None:
        raise RuntimeError("huggingface_hub is required to download CASIA2.zip when HF_FILE_URL is not provided.")
    return Path(
        hf_hub_download(
            repo_id=HF_REPO_ID,
            filename=HF_FILE_NAME,
            token=HF_TOKEN,
            repo_type="dataset",
            force_download=False,
        )
    )


def _extract_zip_flat(zip_path: Path, destination: Path) -> None:
    with ZipFile(zip_path) as archive:
        for member in archive.infolist():
            member_path = Path(member.filename)
            if not member_path.parts:
                continue
            if member_path.parts[0] == "__MACOSX":
                continue
            relative_parts = member_path.parts[1:] if len(member_path.parts) > 1 else member_path.parts
            if not relative_parts:
                continue
            target_path = destination.joinpath(*relative_parts)
            if member.is_dir():
                target_path.mkdir(parents=True, exist_ok=True)
                continue
            target_path.parent.mkdir(parents=True, exist_ok=True)
            with archive.open(member) as src, open(target_path, "wb") as dst:
                shutil.copyfileobj(src, dst)


def ensure_dataset_ready() -> None:
    if not USE_COLAB:
        return
    DATASET_ROOT.parent.mkdir(parents=True, exist_ok=True)
    if DATASET_ROOT.exists() and any(DATASET_ROOT.iterdir()):
        print(f"Dataset already available at {DATASET_ROOT}")
        return
    zip_file_path = DATASET_ROOT.parent / HF_FILE_NAME
    downloaded_path = _download_hf_zip(zip_file_path)
    if downloaded_path != zip_file_path:
        shutil.copy(downloaded_path, zip_file_path)
        downloaded_path = zip_file_path
    DATASET_ROOT.mkdir(parents=True, exist_ok=True)
    _extract_zip_flat(downloaded_path, DATASET_ROOT)
    print(f"Dataset extracted to {DATASET_ROOT}")


ensure_dataset_ready()

from train import TrainConfig, run_training
from model.hybrid_forgery_detector import HybridForgeryConfig
from evaluation.eval_utils import (
    collect_visual_samples,
    evaluate_split,
    load_model_from_checkpoint,
 )

In [None]:
default_train_cfg = TrainConfig()
default_model_cfg = HybridForgeryConfig()
print("Default TrainConfig:")
pprint(asdict(default_train_cfg))
print("Default HybridForgeryConfig:")
pprint(asdict(default_model_cfg))

In [None]:
train_config = TrainConfig(
    prepared_root=str(DATASET_ROOT),
    train_split="train",
    val_split="val",
    target_size=128,
    batch_size=TRAIN_BATCH_SIZE,
    num_epochs=10,
    learning_rate=1e-4,
    weight_decay=1e-2,
    num_workers=DATA_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=DATA_WORKERS > 0,
    pin_memory=True,
    grad_accumulation_steps=GRAD_ACCUM_STEPS,
    grad_clip_norm=1.0,
    log_interval=10,
    checkpoint_dir=str(BASE_PATH / "checkpoints"),
    checkpoint_interval=1,
    save_best_only=True,
    use_amp=True,
    resume_from=None,
)

In [None]:
train_config.model_config = HybridForgeryConfig(
    use_efficientnet=True,
    use_swin=True,
    use_segformer=False,
    use_unet_decoder=True,
    use_skip_connections=True,
    pretrained_backbones=True,
    fused_channels=256,
)

In [None]:
print("Resolved device:", train_config.resolved_device())
print("TrainConfig overrides:")
pprint(asdict(train_config))
print("HybridForgeryConfig overrides:")
pprint(asdict(train_config.model_config))

### Optional Dry Run
Flip the flag in the next cell to execute a 1-epoch, few-batch sanity check (uses `max_train_batches` / `max_val_batches`) before kicking off the full training run.

In [None]:
ENABLE_DRY_RUN = False
if ENABLE_DRY_RUN:
    dry_run_config = deepcopy(train_config)
    dry_run_config.num_epochs = 1
    dry_run_config.batch_size = min(2, train_config.batch_size)
    dry_run_config.max_train_batches = 1
    dry_run_config.max_val_batches = 1
    dry_run_config.checkpoint_dir = str(Path(train_config.checkpoint_dir) / "dry_run")
    print("Dry run settings:", {
        "num_epochs": dry_run_config.num_epochs,
        "batch_size": dry_run_config.batch_size,
        "max_train_batches": dry_run_config.max_train_batches,
        "max_val_batches": dry_run_config.max_val_batches,
        "checkpoint_dir": dry_run_config.checkpoint_dir,
    })
    dry_run_history = run_training(dry_run_config)
else:
    print("Dry run skipped. Set ENABLE_DRY_RUN = True to execute the smoke test.")

In [None]:
history = run_training(train_config)
history

## Evaluation, Samples, and Ablations
Use the helpers below to load a checkpoint, compute aggregate metrics + confusion matrix on any split, and visualize 10 qualitative test samples aligned as image / ground-truth / prediction / overlay columns.

In [None]:
checkpoint_path = Path(train_config.checkpoint_dir) / "best.pt"
evaluation_split = "test"
ablation_label = f"{checkpoint_path.stem}"
eval_device = train_config.resolved_device()
max_eval_batches = None
checkpoint_path

In [None]:
model, trained_config = load_model_from_checkpoint(checkpoint_path, device=eval_device)
evaluation_summary = evaluate_split(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    batch_size=trained_config.batch_size,
    device=eval_device,
    max_batches=max_eval_batches,
 )

print("Aggregate metrics:")
pprint(evaluation_summary.metrics)
print("\nConfusion matrix (rows=actual clean/tampered, cols=predicted clean/tampered):")
if pd is not None:
    display(pd.DataFrame(
        evaluation_summary.confusion_matrix,
        index=["Actual clean", "Actual tampered"],
        columns=["Pred clean", "Pred tampered"],
    ))
else:
    print(evaluation_summary.confusion_matrix)

if "ablation_results" not in globals():
    ablation_results = []

ablation_results.append({
    "label": ablation_label,
    **evaluation_summary.metrics,
})
if pd is not None:
    display(pd.DataFrame(ablation_results))
else:
    print(ablation_results)

In [None]:
num_preview_samples = 10
preview_samples = collect_visual_samples(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    num_samples=num_preview_samples,
    device=eval_device,
 )

columns = ["image", "ground_truth", "prediction", "overlay"]
rows = len(preview_samples)
if rows == 0:
    raise RuntimeError("No samples with ground-truth masks were found in the requested split.")
fig, axes = plt.subplots(rows, len(columns), figsize=(15, 3 * rows))
if rows == 1:
    axes = np.expand_dims(axes, axis=0)
for row_idx, sample in enumerate(preview_samples):
    for col_idx, key in enumerate(columns):
        axes[row_idx, col_idx].imshow(sample[key], cmap="gray" if key in {"ground_truth", "prediction"} else None)
        axes[row_idx, col_idx].set_title(f"{key.replace('_', ' ').title()} #{row_idx + 1}")
        axes[row_idx, col_idx].axis("off")
plt.tight_layout()