# Hybrid Forgery Testing & Evaluation
This notebook is dedicated to evaluating trained Hybrid Forgery Detection models. It covers loading checkpoints, computing metrics, visualizing qualitative samples, and single image inference.

In [None]:
import os
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
try:
    import pandas as pd
except ImportError:
    pd = None
try:
    from huggingface_hub import snapshot_download
except ImportError:
    snapshot_download = None
try:
    from google.colab import userdata
except ImportError:
    userdata = None
def running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False
USE_COLAB = running_in_colab()
if USE_COLAB:
    print("Colab environment detected.")
else:
    print("Running in local environment.")

In [None]:
DRIVE_BASE_DIR = "/content/drive/MyDrive/LIFD"
if USE_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except ModuleNotFoundError as exc:
        raise RuntimeError("google.colab is not available. Set USE_COLAB=False to bypass Drive mounting.") from exc
else:
    DRIVE_BASE_DIR = "."

In [None]:
BASE_PATH = Path(DRIVE_BASE_DIR)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
CPU_COUNT = os.cpu_count() or 4
GPU_AVAILABLE = torch.cuda.is_available()
if GPU_AVAILABLE:
    try:
        _gpu_props = torch.cuda.get_device_properties(0)
        GPU_TOTAL_MEM_GB = _gpu_props.total_memory / (1024 ** 3)
    except Exception:
        GPU_TOTAL_MEM_GB = None
else:
    GPU_TOTAL_MEM_GB = None
PERFORMANCE_MODE = "throughput" if GPU_AVAILABLE and CPU_COUNT >= 8 else ("balanced" if CPU_COUNT >= 6 else "fast")
if PERFORMANCE_MODE == "throughput":
    DATA_WORKERS = max(4, min(8, CPU_COUNT - 1))
    TRAIN_BATCH_SIZE = 32 if GPU_AVAILABLE else 8
    PREFETCH_FACTOR = 6 if DATA_WORKERS > 0 else None
elif PERFORMANCE_MODE == "balanced":
    DATA_WORKERS = max(3, min(6, CPU_COUNT - 1))
    TRAIN_BATCH_SIZE = 24 if GPU_AVAILABLE else 8
    PREFETCH_FACTOR = 4 if DATA_WORKERS > 0 else None
else:
    half_cpus = max(1, CPU_COUNT // 2)
    DATA_WORKERS = max(1, min(4, half_cpus))
    TRAIN_BATCH_SIZE = 16 if GPU_AVAILABLE else 8
    PREFETCH_FACTOR = 2 if DATA_WORKERS > 0 else None

In [None]:
DATASET_ROOT = Path("/content/data") if USE_COLAB else Path("prepared") / "CASIA2"
DATASET_ROOT.mkdir(parents=True, exist_ok=True)
HF_REPO_ID = "juhenes/lifd"
HF_REVISION = os.environ.get("HF_DATA_REVISION", "main")
HF_TOKEN = userdata.get("HUGGINGFACE_TOKEN") if userdata is not None else os.environ.get("HUGGINGFACE_TOKEN")
MANIFEST_FILENAME = "manifest.parquet"
def _locate_manifest_root(base: Path):
    candidates = [base, base / "CASIA2"]
    for candidate in candidates:
        manifest_path = candidate / MANIFEST_FILENAME
        if manifest_path.exists():
            return candidate
    return None
def ensure_dataset_ready(force_sync: bool = False) -> Path:
    if not force_sync:
        local_root = _locate_manifest_root(DATASET_ROOT)
        if local_root is not None:
            return local_root
    if snapshot_download is None:
        raise ImportError("huggingface_hub is not installed.")
    snapshot_dir = Path(snapshot_download(
        repo_id=HF_REPO_ID,
        revision=HF_REVISION,
        token=HF_TOKEN,
        repo_type="dataset",
))
    resolved_snapshot = _locate_manifest_root(snapshot_dir)
    if resolved_snapshot is None:
        raise FileNotFoundError(
            f"Downloaded dataset at {snapshot_dir} does not contain {MANIFEST_FILENAME}."
        )
    return resolved_snapshot
DATASET_ROOT = ensure_dataset_ready()

In [None]:
from itertools import islice
def inspect_prepared_dataset(root: Path, split: str = "test", preview: int = 5):
    root = Path(root)
    print(f"Resolved DATASET_ROOT: {root}")
    manifest_path = root / "manifest.parquet"
    print(f"Manifest present: {manifest_path.exists()} ({manifest_path})")
    if not manifest_path.exists():
        return
    try:
        import pandas as pd
        manifest_df = pd.read_parquet(manifest_path)
        print(f"Manifest rows: {len(manifest_df):,}")
        label_counts = manifest_df.groupby(["split"]).size().to_dict()
        print("Rows per split:", label_counts)
    except Exception as exc:
        print(f"Manifest read failed: {exc}")
    split_dir = root / split
    print(f"Split directory exists: {split_dir.exists()} ({split_dir})")
    if not split_dir.exists():
        return
    shards = sorted(split_dir.glob("*.tar"))
    print(f"Found {len(shards)} shard files under {split_dir}")
    for shard in islice(shards, preview):
        print(f" - {shard} | exists={shard.exists()} | size={shard.stat().st_size if shard.exists() else 'missing'}")
    missing = [str(p) for p in shards if not p.exists()]
    if missing:
        print("Missing shards:", missing[:5], "...")
inspect_prepared_dataset(DATASET_ROOT)

In [None]:
from train import TrainConfig
from model.hybrid_forgery_detector import HybridForgeryConfig
from evaluation.eval_utils import (
    collect_visual_samples,
    evaluate_split,
    load_model_from_checkpoint,
 )
checkpoint_dir = BASE_PATH / "checkpoints"
checkpoint_path = checkpoint_dir / "best.pt"
eval_device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model, trained_config = load_model_from_checkpoint(checkpoint_path, device=eval_device)
auto_threshold = True
threshold_candidates = trained_config.eval_thresholds or [0.5]
threshold_metric = "f1"
evaluation_split = "test"
max_eval_batches = None
evaluation_summary = evaluate_split(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    batch_size=trained_config.batch_size,
    device=eval_device,
    max_batches=max_eval_batches,
    auto_threshold=auto_threshold,
    threshold_candidates=threshold_candidates,
    threshold_metric=threshold_metric,
 )
best_threshold = evaluation_summary.metrics.get("best_threshold", {}).get("value", threshold_candidates[0])
print("Aggregate metrics (threshold {:.2f}):".format(best_threshold))
from pprint import pprint
pprint({k: v for k, v in evaluation_summary.metrics.items() if k in {"loss", "dice", "iou", "precision", "recall", "f1"}})

In [None]:
print("\nConfusion matrix (rows=actual clean/tampered, cols=predicted clean/tampered):")
if pd is not None:
    display(pd.DataFrame(
        evaluation_summary.confusion_matrix,
        index=["Actual clean", "Actual tampered"],
        columns=["Pred clean", "Pred tampered"],
    ))
else:
    print(evaluation_summary.confusion_matrix)
threshold_table = evaluation_summary.metrics.get("thresholds", {})
if pd is not None and threshold_table:
    display(pd.DataFrame(threshold_table).T.sort_index())
print("Best threshold ({}): {:.2f}".format(threshold_metric, best_threshold))

In [None]:
num_preview_samples = 10
preview_threshold = best_threshold if "best_threshold" in locals() else 0.5
preview_samples = collect_visual_samples(
    model=model,
    train_config=trained_config,
    split=evaluation_split,
    num_samples=num_preview_samples,
    device=eval_device,
    threshold=preview_threshold,
 )
columns = ["image", "ground_truth", "prediction", "overlay"]
rows = len(preview_samples)
if rows == 0:
    raise RuntimeError("No samples with ground-truth masks were found in the requested split.")
fig, axes = plt.subplots(rows, len(columns), figsize=(15, 3 * rows))
if rows == 1:
    axes = np.expand_dims(axes, axis=0)
for row_idx, sample in enumerate(preview_samples):
    for col_idx, key in enumerate(columns):
        axes[row_idx, col_idx].imshow(sample[key], cmap="gray" if key in {"ground_truth", "prediction"} else None)
        axes[row_idx, col_idx].set_title(f"{key.replace('_', ' ').title()} #{row_idx + 1}")
        axes[row_idx, col_idx].axis("off")
plt.tight_layout()

### Single Image Inferencing

In [None]:
from io import BytesIO
from PIL import Image, ImageOps
try:
    from google.colab import files as colab_files
except ImportError:
    colab_files = None
def _predict_image_array(image_arr: np.ndarray, checkpoint_path: Path, threshold: float) -> np.ndarray:
    from model.hybrid_forgery_detector import HybridForgeryDetector
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    cfg = HybridForgeryConfig(**checkpoint["config"]["model_config"])
    device = eval_device
    model = HybridForgeryDetector(cfg).to(device)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()
    tensor = torch.from_numpy(image_arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
    zero_noise = torch.zeros(1, 3, tensor.shape[-2], tensor.shape[-1], device=device)
    noise_inputs = {"residual": zero_noise}
    with torch.no_grad():
        mask = model.predict_mask(tensor, threshold=threshold, noise_features=noise_inputs)
    return mask.squeeze().cpu().numpy()
def _load_uploaded_image() -> tuple[Image.Image, str]:
    if USE_COLAB and colab_files is not None:
        uploaded = colab_files.upload()
        if not uploaded:
            raise ValueError("No file uploaded. Please select an image.")
        name, data = next(iter(uploaded.items()))
        return Image.open(BytesIO(data)).convert("RGB"), name
    local_image_path = Path("./REPLACE_WITH_IMAGE.jpg")
    if "REPLACE_WITH_IMAGE" in str(local_image_path):
        raise ValueError("Set local_image_path to an actual image path when running outside Colab.")
    if not local_image_path.exists():
        raise FileNotFoundError(f"Local image not found: {local_image_path}")
    return Image.open(local_image_path).convert("RGB"), str(local_image_path)
uploaded_image, uploaded_label = _load_uploaded_image()
original_image = uploaded_image.copy()
try:
    resample_mode = Image.Resampling.BICUBIC
except AttributeError:
    resample_mode = Image.BICUBIC
target_hw = (trained_config.target_size, trained_config.target_size)
resized_image = ImageOps.fit(uploaded_image, target_hw, method=resample_mode)
normalized = np.array(resized_image, dtype=np.float32) / 255.0
checkpoint_for_single = checkpoint_path
threshold_for_single = best_threshold
pred_mask = _predict_image_array(normalized, checkpoint_for_single, threshold_for_single)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_image)
axes[0].set_title(f"Original ({uploaded_label})")
axes[0].axis("off")
axes[1].imshow(resized_image)
axes[1].set_title("Resized")
axes[1].axis("off")
axes[2].imshow(pred_mask, cmap="magma")
axes[2].set_title(f"Prediction (thr={threshold_for_single:.2f})")
axes[2].axis("off")
plt.tight_layout()
pred_mask.shape