<a href="https://colab.research.google.com/github/RitikKumar3/D2R_Detect_to_Restore/blob/main/D2R_full_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **D2R_Full_Pipline**

### Setup, mount, and paths

In [None]:
# setup, mount, and paths

import os
from pathlib import Path

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

PROJECT_PATH = "/content/drive/MyDrive/D2R Model"
assert os.path.isdir(PROJECT_PATH), f"PROJECT_PATH not found: {PROJECT_PATH}"

ROOT_DATASET_DIR = os.path.join(PROJECT_PATH, "D2R Dataset")
FRAMES_ROOT = os.path.join(PROJECT_PATH, "frames")
MASKS_ROOT  = os.path.join(PROJECT_PATH, "masks")

MANIFEST_PATH   = os.path.join(PROJECT_PATH, "manifest.jsonl")
TRAIN_MANIFEST  = os.path.join(PROJECT_PATH, "train_manifest.jsonl")
VAL_MANIFEST    = os.path.join(PROJECT_PATH, "val_manifest.jsonl")
TEST_MANIFEST   = os.path.join(PROJECT_PATH, "test_manifest.jsonl")

# outputs
LOCALIZATION_ROOT = os.path.join(PROJECT_PATH, "localization")
RESTORED_ROOT     = os.path.join(PROJECT_PATH, "restored")
os.makedirs(LOCALIZATION_ROOT, exist_ok=True)
os.makedirs(RESTORED_ROOT, exist_ok=True)

print("PROJECT_PATH:", PROJECT_PATH)
print("MANIFEST_PATH:", MANIFEST_PATH)

import sys
if PROJECT_PATH not in sys.path:
    sys.path.append(PROJECT_PATH)


Mounted at /content/drive
PROJECT_PATH: /content/drive/MyDrive/D2R Model
MANIFEST_PATH: /content/drive/MyDrive/D2R Model/manifest.jsonl


### Core imports and device

In [None]:
# imports and device

import json
import numpy as np
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision import models as tvmodels

from d2r_dataloader import dataset_from_manifest, collate_video_batch, normalize_annotation

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


In [None]:
# Cell 2A: Stable Diffusion inpainting setup (for Object Deletion restoration)

!pip install -q diffusers transformers accelerate safetensors

from diffusers import StableDiffusionInpaintPipeline
from PIL import Image

SD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Flag: turn SD inpainting ON/OFF for object deletion
USE_SD_INPAINT = True   # set to False to force OpenCV Telea only

sd_pipe = None
HAS_SD = False

if USE_SD_INPAINT:
    try:
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            torch_dtype=torch.float16 if SD_DEVICE == "cuda" else torch.float32,
        )
        sd_pipe = sd_pipe.to(SD_DEVICE)
        HAS_SD = True
        print("[SD] Stable Diffusion inpainting pipeline loaded on:", SD_DEVICE)
    except Exception as e:
        HAS_SD = False
        print("[SD][WARN] Failed to load SD inpainting pipeline; falling back to OpenCV. Error:", e)
else:
    print("[SD] USE_SD_INPAINT=False → using OpenCV Telea inpainting only.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/748 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

text_encoder/pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

safety_checker/pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

unet/diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!
An error occurred while trying to fetch /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


[SD] Stable Diffusion inpainting pipeline loaded on: cuda


In [None]:
# Cell 2B: helper to inpaint one frame using Stable Diffusion

def sd_inpaint_frame(
    frame_rgb: np.ndarray,
    mask_np: np.ndarray,
    prompt: str = "fill the missing region with realistic background",
    num_inference_steps: int = 25,
    guidance_scale: float = 7.5,
) -> np.ndarray:
    """
    frame_rgb : HxWx3 uint8 (RGB)
    mask_np   : HxW uint8 (0 = keep, >0 = hole)
    Returns   : HxWx3 uint8 (RGB) inpainted result
    """
    if sd_pipe is None or not HAS_SD:
        raise RuntimeError("Stable Diffusion pipeline is not available. Check USE_SD_INPAINT / HAS_SD.")

    h, w = frame_rgb.shape[:2]

    image = Image.fromarray(frame_rgb)
    mask_image = Image.fromarray((mask_np > 0).astype(np.uint8) * 255)

    # SD inpainting typically works at 512x512
    image_512 = image.resize((512, 512), resample=Image.LANCZOS)
    mask_512  = mask_image.resize((512, 512), resample=Image.NEAREST)

    with torch.autocast(device_type=SD_DEVICE, enabled=(SD_DEVICE == "cuda")):
        result = sd_pipe(
            prompt=prompt,
            image=image_512,
            mask_image=mask_512,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )

    out_img = result.images[0]  # PIL image
    out_img = out_img.resize((w, h), resample=Image.LANCZOS)
    return np.array(out_img, dtype=np.uint8)


### Generic helpers (files, manifest, masks, annotations)

In [None]:
# helpers for IO, manifest, annotation, frames, masks

from pathlib import Path

def list_files_sorted(folder, exts=None):
    if folder is None:
        return []
    p = Path(folder)
    if not p.exists() or not p.is_dir():
        return []
    if exts is None:
        files = [f for f in p.iterdir() if f.is_file()]
    else:
        exts_l = {e.lower() for e in exts}
        files = [f for f in p.iterdir() if f.is_file() and f.suffix.lower() in exts_l]
    return sorted(files, key=lambda x: x.name)

def mask_to_bbox(mask_np, min_area=10):
    """
    Convert binary mask (H,W) -> bbox (x1,y1,x2,y2) or None if small/empty.
    """
    ys, xs = np.where(mask_np > 0)
    if len(xs) == 0 or len(ys) == 0:
        return None
    x1, x2 = xs.min(), xs.max()
    y1, y2 = ys.min(), ys.max()
    if (x2 - x1 + 1) * (y2 - y1 + 1) < min_area:
        return None
    return int(x1), int(y1), int(x2), int(y2)

def load_manifest_entries(manifest_path):
    entries = []
    with open(manifest_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            entries.append(obj)
    return entries

def index_manifest_by_vid(manifest_path):
    by_vid = {}
    with open(manifest_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            by_vid[obj["video_id"]] = obj
    return by_vid

def load_and_normalize_annotation(entry):
    """
    Load raw JSON annotation and normalize using normalize_annotation from d2r_dataloader.
    """
    ann_path = entry.get("annotation_path")
    if not ann_path or not os.path.exists(ann_path):
        return {"normalized_affected_frames": []}
    try:
        with open(ann_path, "r", encoding="utf-8") as f:
            ann_raw = json.load(f)
    except Exception as e:
        print(f"[WARN] Failed to load annotation for {entry['video_id']}: {e}")
        return {"normalized_affected_frames": []}

    n_frames = int(entry.get("n_frames") or 0)
    if n_frames <= 0:
        n_frames = None
    ann_norm = normalize_annotation(ann_raw, n_frames_in_manifest=n_frames)
    return ann_norm

def get_tampered_indices(ann_norm):
    idxs = ann_norm.get("normalized_affected_frames", [])
    idxs = sorted({int(i) for i in idxs})
    return idxs

def load_frames_for_entry(entry):
    frames_dir = entry.get("frames_dir")
    frame_files = list_files_sorted(frames_dir, exts={".jpg", ".jpeg", ".png"})
    if not frame_files:
        raise FileNotFoundError(f"No frames found in {frames_dir} for video {entry['video_id']}")
    frames = []
    for p in frame_files:
        img = cv2.imread(str(p))
        if img is None:
            raise RuntimeError(f"Failed to read frame {p}")
        frames.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return frames, frame_files

def load_masks_for_entry(entry):
    """
    Only Object_deletion has masks_dir; others will return all-zero masks.
    """
    mask_dir = entry.get("mask_dir")
    frames_dir = entry.get("frames_dir")

    frame_files = list_files_sorted(frames_dir, exts={".jpg", ".jpeg", ".png"})
    if not frame_files:
        raise FileNotFoundError(f"No frames to infer mask shape for video {entry['video_id']}")

    # If no mask_dir → all zeros
    if not mask_dir or not os.path.isdir(mask_dir):
        h, w, _ = cv2.imread(str(frame_files[0])).shape
        return [np.zeros((h, w), dtype=np.uint8) for _ in frame_files], None

    mask_files = list_files_sorted(mask_dir, exts={".png"})
    masks = []

    if len(mask_files) == len(frame_files):
        for mp in mask_files:
            m = cv2.imread(str(mp), cv2.IMREAD_GRAYSCALE)
            if m is None:
                h, w, _ = cv2.imread(str(frame_files[0])).shape
                m = np.zeros((h, w), dtype=np.uint8)
            else:
                m = (m > 127).astype(np.uint8)
            masks.append(m)
    else:
        h, w, _ = cv2.imread(str(frame_files[0])).shape
        for i in range(len(frame_files)):
            if i < len(mask_files):
                m = cv2.imread(str(mask_files[i]), cv2.IMREAD_GRAYSCALE)
                if m is None:
                    m = np.zeros((h, w), dtype=np.uint8)
                else:
                    m = (m > 127).astype(np.uint8)
            else:
                m = np.zeros((h, w), dtype=np.uint8)
            masks.append(m)

    return masks, mask_files


### Per-category restoration functions

In [None]:
# restoration functions using RAW annotations (no normalize_annotation)

def restore_object_deletion(entry, ann):
    """
    Object deletion restoration using the raw annotation structure:
    {
      "video_name": "vid1",
      "operation": "object_deletion",
      "affected_frames": [
        {
          "frame_number": 8,
          "mask_file": "mask_00008.png",
          "bounding_boxes": [
            { "x": 407, "y": 234, "width": 228, "height": 206 }
          ],
          "tampered_pixels": 32854
        },
        ...
      ]
    }

    We build a per-frame binary mask from bounding_boxes and inpaint:

      - If USE_SD_INPAINT & HAS_SD → Stable Diffusion inpainting
      - Else → OpenCV Telea inpainting

    This does not truly "recover" the deleted object, but fills the tampered
    region more semantically.
    """
    video_id = entry["video_id"]
    category = entry["category"]
    assert "Object_deletion" in category, f"Expected Object_deletion, got {category}"

    frames, frame_files = load_frames_for_entry(entry)
    n_frames = len(frames)
    h, w = frames[0].shape[:2]

    # ann is expected to be RAW (with 'affected_frames') when called from GUI.
    # If not present, we treat as no tampered frames.
    affected_frames = ann.get("affected_frames", [])
    if not affected_frames:
        print(f"[ObjectDeletion][WARN] No affected_frames in annotation for {video_id}.")

    # Build zero masks for all frames
    masks = [np.zeros((h, w), dtype=np.uint8) for _ in range(n_frames)]

    used_indices = []

    for af in affected_frames:
        fnum = af.get("frame_number", None)
        if fnum is None:
            continue

        # Annotations are usually 1-based; try idx = fnum, if out-of-range, use fnum-1
        idx = int(fnum)
        if idx >= n_frames:
            idx = int(fnum) - 1

        if idx < 0 or idx >= n_frames:
            print(f"[ObjectDeletion][WARN] frame_number {fnum} -> idx {idx} out of range for {video_id}")
            continue

        used_indices.append(idx)

        bboxes = af.get("bounding_boxes", [])
        if not bboxes:
            continue

        for bb in bboxes:
            x  = int(round(bb.get("x", 0)))
            y  = int(round(bb.get("y", 0)))
            bw = int(round(bb.get("width", 0)))
            bh = int(round(bb.get("height", 0)))

            x1 = max(0, x)
            y1 = max(0, y)
            x2 = min(w - 1, x + bw)
            y2 = min(h - 1, y + bh)

            if x2 <= x1 or y2 <= y1:
                continue

            # Fill mask rectangle with 255
            cv2.rectangle(masks[idx], (x1, y1), (x2, y2), 255, thickness=-1)

    used_indices = sorted(list(set(used_indices)))
    print(f"[ObjectDeletion] Video {video_id}: non-zero masks for frames {used_indices}")

    # --- Inpaint each frame where mask > 0 ---

    out_dir = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored_frames")
    os.makedirs(out_dir, exist_ok=True)

    restored_frames = []
    for i, (frame, mask) in enumerate(zip(frames, masks)):
        if mask.max() > 0:
            # Tampered frame → inpaint
            if USE_SD_INPAINT and HAS_SD:
                try:
                    restored_rgb = sd_inpaint_frame(frame, mask)
                except Exception as e:
                    print(f"[ObjectDeletion][WARN] SD inpainting failed on frame {i}: {e}")
                    # Fallback to OpenCV Telea
                    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    mask_255  = (mask > 0).astype(np.uint8) * 255
                    inpainted = cv2.inpaint(frame_bgr, mask_255, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
                    restored_rgb = cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)
            else:
                # Only OpenCV Telea
                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                mask_255  = (mask > 0).astype(np.uint8) * 255
                inpainted = cv2.inpaint(frame_bgr, mask_255, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
                restored_rgb = cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)
        else:
            # Clean frame → copy
            restored_rgb = frame

        restored_frames.append(restored_rgb)

        out_path = os.path.join(out_dir, f"frame_{i:05d}.png")
        cv2.imwrite(out_path, cv2.cvtColor(restored_rgb, cv2.COLOR_RGB2BGR))

    # --- Write restored video (AVI only) ---

    fps = entry.get("fps", 25.0) or 25.0
    out_video_path = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored.avi")
    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
    for fr in restored_frames:
        writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
    writer.release()

    # --- Localization JSON (store bboxes per frame) ---

    loc_dir = os.path.join(LOCALIZATION_ROOT, category)
    os.makedirs(loc_dir, exist_ok=True)
    loc_path = os.path.join(loc_dir, f"{video_id}_localization.json")

    loc_data = {
        "video_id": video_id,
        "category": category,
        "affected_frame_indices": used_indices,
        "annotation_affected_frames": affected_frames
    }
    with open(loc_path, "w", encoding="utf-8") as f:
        json.dump(loc_data, f, indent=2)

    print(f"[ObjectDeletion] Restored video: {out_video_path}")
    print(f"[ObjectDeletion] Localization JSON: {loc_path}")
    return out_video_path, loc_path


# --- the rest of the categories (unchanged) ---

def restore_frame_insertion(entry, ann):
    video_id = entry["video_id"]
    category = entry["category"]
    assert "frame_insertion" in category, f"Expected frame_insertion, got {category}"

    frames, frame_files = load_frames_for_entry(entry)
    total_frames = len(frames)

    # Inserted frames are at positions in "output_insert_positions"
    output_insert_positions = ann.get("output_insert_positions", [])
    tampered_idxs = set(int(i) for i in output_insert_positions if 0 <= int(i) < total_frames)

    restored_frames = []
    kept_indices = []

    for i in range(total_frames):
        if i in tampered_idxs:
            # drop inserted frames
            continue
        restored_frames.append(frames[i])
        kept_indices.append(i)

    if not restored_frames:
        print(f"[WARN][FrameInsertion] All frames treated as tampered for {video_id}, skipping restoration.")
        return None, None

    h, w = restored_frames[0].shape[:2]
    out_dir = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored_frames")
    os.makedirs(out_dir, exist_ok=True)
    for out_idx, fr in enumerate(restored_frames):
        out_path = os.path.join(out_dir, f"frame_{out_idx:05d}.png")
        cv2.imwrite(out_path, cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))

    fps = entry.get("fps", 25.0) or 25.0
    out_video_path = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored.avi")
    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
    for fr in restored_frames:
        writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
    writer.release()

    loc_dir = os.path.join(LOCALIZATION_ROOT, category)
    os.makedirs(loc_dir, exist_ok=True)
    loc_path = os.path.join(loc_dir, f"{video_id}_localization.json")

    loc_data = {
        "video_id": video_id,
        "category": category,
        "dropped_frame_indices": sorted(tampered_idxs),
        "kept_original_indices": kept_indices
    }
    with open(loc_path, "w", encoding="utf-8") as f:
        json.dump(loc_data, f, indent=2)

    print(f"[FrameInsertion] Restored video: {out_video_path}")
    print(f"[FrameInsertion] Localization JSON: {loc_path}")
    return out_video_path, loc_path


def restore_flip(entry, ann, flip_mode="horizontal"):
    video_id = entry["video_id"]
    category = entry["category"]

    frames, frame_files = load_frames_for_entry(entry)
    total_frames = len(frames)

    # Annotation format:
    # { "flip_range": { "start": 61, "end": 80 }, "original_frame_count": 100, ... }
    flip_range = ann.get("flip_range", {})
    start = int(flip_range.get("start", 0))
    end   = int(flip_range.get("end", -1))

    # make sure indices are in [0, total_frames)
    start = max(0, start)
    end   = min(total_frames - 1, end) if end >= 0 else total_frames - 1

    tampered_idxs = set(range(start, end + 1)) if end >= start else set()

    restored_frames = []
    for i in range(total_frames):
        fr = frames[i]
        if i in tampered_idxs:
            if flip_mode == "horizontal":
                fr = cv2.flip(fr, 1)
            elif flip_mode == "vertical":
                fr = cv2.flip(fr, 0)
        restored_frames.append(fr)

    h, w = restored_frames[0].shape[:2]
    out_dir = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored_frames")
    os.makedirs(out_dir, exist_ok=True)
    for idx, fr in enumerate(restored_frames):
        out_path = os.path.join(out_dir, f"frame_{idx:05d}.png")
        cv2.imwrite(out_path, cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))

    fps = entry.get("fps", 25.0) or 25.0
    out_video_path = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored.avi")
    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
    for fr in restored_frames:
        writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
    writer.release()

    loc_dir = os.path.join(LOCALIZATION_ROOT, category)
    os.makedirs(loc_dir, exist_ok=True)
    loc_path = os.path.join(loc_dir, f"{video_id}_localization.json")

    loc_data = {
        "video_id": video_id,
        "category": category,
        "tampered_indices": sorted(tampered_idxs),
        "operation": flip_mode
    }
    with open(loc_path, "w", encoding="utf-8") as f:
        json.dump(loc_data, f, indent=2)

    print(f"[Flip-{flip_mode}] Restored video: {out_video_path}")
    print(f"[Flip-{flip_mode}] Localization JSON: {loc_path}")
    return out_video_path, loc_path


def restore_zoom(entry, ann):
    video_id = entry["video_id"]
    category = entry["category"]
    assert "zooming" in category, f"Expected zooming_frames, got {category}"

    frames, frame_files = load_frames_for_entry(entry)
    total_frames = len(frames)

    # Annotation:
    # { "zoom_range": { "start": 56, "end": 75 }, "factor": 1.216, ... }
    zoom_range = ann.get("zoom_range", {})
    start = int(zoom_range.get("start", 0))
    end   = int(zoom_range.get("end", -1))

    start = max(0, start)
    end   = min(total_frames - 1, end) if end >= 0 else total_frames - 1
    tampered_idxs = set(range(start, end + 1)) if end >= start else set()

    restored_frames = []
    for i in range(total_frames):
        fr = frames[i]
        if i in tampered_idxs:
            # Simple "restoration": replace zoomed frame with nearest non-zoom neighbor
            j = i - 1
            while j >= 0 and j in tampered_idxs:
                j -= 1
            if j >= 0:
                fr = frames[j]
            else:
                k = i + 1
                while k < total_frames and k in tampered_idxs:
                    k += 1
                if k < total_frames:
                    fr = frames[k]
        restored_frames.append(fr)

    h, w = restored_frames[0].shape[:2]
    out_dir = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored_frames")
    os.makedirs(out_dir, exist_ok=True)
    for idx, fr in enumerate(restored_frames):
        out_path = os.path.join(out_dir, f"frame_{idx:05d}.png")
        cv2.imwrite(out_path, cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))

    fps = entry.get("fps", 25.0) or 25.0
    out_video_path = os.path.join(RESTORED_ROOT, category, f"{video_id}_restored.avi")
    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
    for fr in restored_frames:
        writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
    writer.release()

    loc_dir = os.path.join(LOCALIZATION_ROOT, category)
    os.makedirs(loc_dir, exist_ok=True)
    loc_path = os.path.join(loc_dir, f"{video_id}_localization.json")

    loc_data = {
        "video_id": video_id,
        "category": category,
        "tampered_indices": sorted(tampered_idxs),
        "restoration_strategy": "neighbor_frame_replacement"
    }
    with open(loc_path, "w", encoding="utf-8") as f:
        json.dump(loc_data, f, indent=2)

    print(f"[Zooming] Restored video: {out_video_path}")
    print(f"[Zooming] Localization JSON: {loc_path}")
    return out_video_path, loc_path


### Dispatcher for localization + restoration

In [None]:
# dispatcher for localization + restoration

TAMPERED_CATEGORIES = [
    "Forgery_Object_deletion",
    "Forgery_frame_insertion",
    "Forgery_horizontal_flipping",
    "Forgery_vertical_flipping",
    "Forgery_zooming_frames",
]

def localize_and_restore_entry(entry):
    """
    Run localization + restoration for a single manifest entry
    (assuming its category is one of the forgery types).
    """
    category = entry["category"]
    video_id = entry["video_id"]

    if category not in TAMPERED_CATEGORIES:
        print(f"[INFO] Skipping unsupported/non-tampered category: {category} ({video_id})")
        return None, None

    ann_norm = load_and_normalize_annotation(entry)

    if category == "Forgery_Object_deletion":
        return restore_object_deletion(entry, ann_norm)
    elif category == "Forgery_frame_insertion":
        return restore_frame_insertion(entry, ann_norm)
    elif category == "Forgery_horizontal_flipping":
        return restore_flip(entry, ann_norm, flip_mode="horizontal")
    elif category == "Forgery_vertical_flipping":
        return restore_flip(entry, ann_norm, flip_mode="vertical")
    elif category == "Forgery_zooming_frames":
        return restore_zoom(entry, ann_norm)
    else:
        print(f"[WARN] Unknown category {category} for video {video_id}")
        return None, None


### Classifier model (EfficientNet-B2 + LSTM) and label maps

In [None]:
# classifier model definition (EfficientNet-B2 + LSTM) + load checkpoint

DROPOUT = 0.4  # should match training

def build_label_map(manifest_path):
    cats = set()
    with open(manifest_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            cats.add(obj["category"])
    cats = sorted(list(cats))
    cat2idx = {c: i for i, c in enumerate(cats)}
    idx2cat = {i: c for c, i in cat2idx.items()}
    return cat2idx, idx2cat

cat2idx, idx2cat = build_label_map(TRAIN_MANIFEST)
num_classes = len(cat2idx)
print("Label map:", cat2idx)

class EfficientNetB2VideoClassifier(nn.Module):
    def __init__(self, num_classes, pretrained=True, dropout=DROPOUT):
        super().__init__()
        # backbone
        if pretrained:
            ef = tvmodels.efficientnet_b2(
                weights=tvmodels.EfficientNet_B2_Weights.IMAGENET1K_V1
            )
        else:
            ef = tvmodels.efficientnet_b2(weights=None)

        # remove classifier head, keep feature extractor
        self.backbone = nn.Sequential(*list(ef.children())[:-1])  # up to global pool
        if hasattr(ef, "classifier"):
            feat_dim = ef.classifier[1].in_features
        else:
            feat_dim = 1408  # default for EfficientNet-B2

        self.feat_dim = feat_dim
        # Temporal encoder (like you used in training)
        self.rnn = nn.LSTM(
            input_size=self.feat_dim,
            hidden_size=512 // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.backbone(x)               # (B*T, feat_dim, 1, 1)
        feats = feats.view(B, T, self.feat_dim)  # (B, T, feat_dim)
        rnn_out, _ = self.rnn(feats)           # (B, T, 512)
        pooled = rnn_out.mean(dim=1)           # (B, 512)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)       # (B, num_classes)
        return logits

model = EfficientNetB2VideoClassifier(num_classes=num_classes, pretrained=False, dropout=DROPOUT).to(DEVICE)

# Use your best B2 checkpoint
CKPT_PATH = os.path.join(PROJECT_PATH, "models", "video_classifier_best_efficientnet6_b2.pth")
assert os.path.exists(CKPT_PATH), f"Checkpoint not found: {CKPT_PATH}"

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
print(f"Loaded classifier checkpoint from epoch {ckpt.get('epoch', 'NA')}:", CKPT_PATH)


Label map: {'Forgery_Object_deletion': 0, 'Forgery_frame_insertion': 1, 'Forgery_horizontal_flipping': 2, 'Forgery_vertical_flipping': 3, 'Forgery_zooming_frames': 4, 'Original': 5}
Loaded classifier checkpoint from epoch 13: /content/drive/MyDrive/D2R Model/models/video_classifier_best_efficientnet6_b2.pth


### Build datasets + loaders for detection

In [None]:
# datasets + dataloaders for detection

IMG_SIZE = (224, 224)   # must match training for EfficientNet-B2
CLIP_LEN = 16           # must match training

def frame_transform(arr: np.ndarray) -> np.ndarray:
    """Resize numpy RGB HWC -> IMG_SIZE (H,W) for classification."""
    h, w = IMG_SIZE
    return cv2.resize(arr, (w, h), interpolation=cv2.INTER_LINEAR)

BATCH_SIZE = 8      # you used 8 in B2 training; ok for inference
NUM_WORKERS = 4

train_ds = dataset_from_manifest(
    TRAIN_MANIFEST,
    load_frames=True,
    clip_len=CLIP_LEN,
    stride=1,
    video_fallback=True,
    frame_transform=frame_transform
)
val_ds = dataset_from_manifest(
    VAL_MANIFEST,
    load_frames=True,
    clip_len=CLIP_LEN,
    stride=1,
    video_fallback=True,
    frame_transform=frame_transform
)
test_ds = dataset_from_manifest(
    TEST_MANIFEST,
    load_frames=True,
    clip_len=CLIP_LEN,
    stride=1,
    video_fallback=True,
    frame_transform=frame_transform
)

print("Dataset sizes -> train:", len(train_ds), "val:", len(val_ds), "test:", len(test_ds))

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_video_batch,
    pin_memory=(DEVICE == "cuda")
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_video_batch,
    pin_memory=(DEVICE == "cuda")
)
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_video_batch,
    pin_memory=(DEVICE == "cuda")
)


Dataset sizes -> train: 1260 val: 270 test: 270


### Prediction helper (predict_loader)

In [None]:
# prediction helper for detection

from contextlib import nullcontext

@torch.no_grad()
def predict_loader(model, loader, device):
    model.eval()
    preds_dict = {}
    trues_dict = {}

    for batch in tqdm(loader, desc="Predicting"):
        if isinstance(batch, list):
            continue

        frames = batch["frames"].to(device).float()
        # normalize with ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 1, 3, 1, 1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 1, 3, 1, 1)
        frames = (frames - mean) / std

        cats = [s["category"] for s in batch["raw"]]
        vids = batch["video_ids"]

        if device == "cuda":
            ctx = torch.amp.autocast(device_type="cuda", enabled=True)
        else:
            ctx = nullcontext()

        with ctx:
            logits = model(frames)
            preds_idx = logits.argmax(dim=1).cpu().numpy().tolist()

        for vid, true_cat, p_idx in zip(vids, cats, preds_idx):
            pred_cat = idx2cat[p_idx]
            preds_dict[vid] = pred_cat
            trues_dict[vid] = true_cat

    return preds_dict, trues_dict

# Example: quick prediction on TEST set
test_preds, test_trues = predict_loader(model, test_loader, DEVICE)
print("Test videos predicted:", len(test_preds))


Predicting: 100%|██████████| 34/34 [15:35<00:00, 27.51s/it]

Test videos predicted: 199





### Full D2R loop (Detection → Localization → Restoration)

In [None]:
# helpers for raw annotation + building AVI videos for GUI

import json
import numpy as np
import cv2
import os

def load_raw_annotation(entry):
    """
    Load the raw JSON annotation as-is (no normalize_annotation).
    Returns {} if file missing or error.
    """
    ann_path = entry.get("annotation_path")
    if not ann_path or not os.path.exists(ann_path):
        print(f"[ANN][WARN] No annotation file for video_id={entry.get('video_id')}")
        return {}
    try:
        with open(ann_path, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception as e:
        print(f"[ANN][WARN] Failed to load annotation for {entry.get('video_id')}: {e}")
        return {}


def build_avi_from_frames(entry, out_dir):
    """
    Build an AVI video from frames_dir for visualization in the GUI.

    Returns:
        out_video_path (str)
    """
    video_id = entry["video_id"]
    frames, frame_files = load_frames_for_entry(entry)
    n_frames = len(frames)
    if n_frames == 0:
        raise RuntimeError(f"No frames for video_id={video_id}")

    h, w = frames[0].shape[:2]
    fps = entry.get("fps", 25.0) or 25.0

    os.makedirs(out_dir, exist_ok=True)
    out_video_path = os.path.join(out_dir, f"{video_id}_tampered.avi")

    # If already exists, you can skip re-writing to save time
    if os.path.exists(out_video_path):
        return out_video_path

    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
    for fr in frames:
        writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
    writer.release()

    return out_video_path


def build_clip_for_classifier(entry, img_size, clip_len):
    """
    Build a (1, T, C, H, W) clip tensor from frames_dir for classification,
    matching the EfficientNet/LSTM or R3D training pipeline.

    Uses:
      - CLIP_LEN = clip_len
      - IMG_SIZE = img_size (H, W)
      - ImageNet mean/std normalization
    """
    frames, frame_files = load_frames_for_entry(entry)
    total = len(frames)
    if total == 0:
        raise RuntimeError(f"No frames for video_id={entry['video_id']}")

    H, W = img_size
    # choose evenly spaced indices
    if total >= clip_len:
        idxs = np.linspace(0, total - 1, clip_len, dtype=int)
    else:
        # repeat frames if not enough
        idxs = np.array([i % total for i in range(clip_len)], dtype=int)

    clip_list = []
    for i in idxs:
        fr = frames[i]  # RGB
        fr_resized = cv2.resize(fr, (W, H), interpolation=cv2.INTER_LINEAR)  # (H,W,3)
        clip_list.append(fr_resized)

    clip_np = np.stack(clip_list, axis=0)   # (T, H, W, 3)
    clip_np = clip_np.astype(np.float32) / 255.0
    clip_np = np.transpose(clip_np, (0, 3, 1, 2))  # (T, C, H, W)

    clip = torch.from_numpy(clip_np).unsqueeze(0).to(DEVICE)  # (1, T, C, H, W)

    # normalize with ImageNet stats
    mean = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 1, 3, 1, 1)
    clip = (clip - mean) / std

    return clip  # (1, T, C, H, W)


In [None]:
# Styled Gradio GUI – D2R (Detect to Restore) Pipeline

!pip install -q gradio

import gradio as gr
import os

# Make sure manifest index exists
test_manifest_index = index_manifest_by_vid(TEST_MANIFEST)
print("Loaded", len(test_manifest_index), "test videos into manifest index.")

# Folder to store GUI preview AVIs built from frames
GUI_VIDEO_DIR = os.path.join(PROJECT_PATH, "gui_videos")
os.makedirs(GUI_VIDEO_DIR, exist_ok=True)

# --- Lightweight CSS theme: light sky-blue / orange, black text ---
custom_css = """
body {
  background: #f5f9ff;
  color: #000000;
  font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
}

/* Main app container */
.d2r-container {
  max-width: 1200px;
  margin: 0 auto !important;
}

/* Title block */
.d2r-title {
  background: linear-gradient(135deg, #e3f2fd, #fff3e0);
  border-radius: 16px;
  padding: 16px 24px;
  border: 1px solid #e0e0e0;
  text-align: center;
  margin-bottom: 16px;
}

.d2r-title h1 {
  margin: 0;
  font-size: 26px;
  font-weight: 700;
  color: #000000;
}

.d2r-title p {
  margin: 6px 0 0 0;
  font-size: 14px;
  color: #000000;
}

/* Section blocks */
.d2r-section {
  background: #ffffff;
  border-radius: 14px;
  padding: 14px 16px;
  border: 1px solid #e0e0e0;
  margin-bottom: 12px;
}

/* Dropdown + button row */
.d2r-controls-row {
  gap: 12px;
}

/* Button styling */
button {
  border-radius: 999px !important;
  font-weight: 600 !important;
}

/* Video panels */
.d2r-video {
  background: #fafafa;
  border-radius: 12px;
  border: 1px solid #e0e0e0;
}
"""


def run_d2r_single(video_id: str):
    """
    Backend for GUI:
      - Build input AVI from frames.
      - Classify using current model (EfficientNet-B0 + LSTM).
      - If video is tampered AND classifier is correct:
            → run appropriate restoration (object deletion, insertion, flip, zoom).
      - If original or misclassified:
            → no restoration applied.
      - Returns: (input_avi_path, restored_avi_path_or_None, log_message)
    """
    if video_id not in test_manifest_index:
        return None, None, f"[ERROR] video_id {video_id} not found in TEST_MANIFEST."

    entry = test_manifest_index[video_id]
    true_cat = entry["category"]

    # 1) Build input AVI from frames_dir (won't recreate if already exists)
    try:
        input_avi = build_avi_from_frames(entry, GUI_VIDEO_DIR)
    except Exception as e:
        return None, None, f"[ERROR] Failed to build input video for {video_id}: {e}"

    # 2) Run classifier on this video (on-the-fly)
    model.eval()
    with torch.no_grad():
        clip = build_clip_for_classifier(entry, IMG_SIZE, CLIP_LEN)  # (1, T, C, H, W)
        logits = model(clip)
        pred_idx = int(logits.argmax(dim=1).item())
        pred_cat = idx2cat[pred_idx]

    msg_head = f"Video ID: {video_id}\nGround Truth: {true_cat}\nPrediction: {pred_cat}\n"

    # --- Case 1: Ground truth is Original ---
    if true_cat == "Original":
        if pred_cat == "Original":
            msg = (
                msg_head
                + "\nResult: Correctly identified as original. No tampering detected → no restoration applied."
            )
        else:
            msg = (
                msg_head
                + "\nResult: False positive (classified as tampered, but GT is original). "
                  "Restoration is skipped."
            )
        return input_avi, None, msg

    # --- Case 2: Tampered but misclassified ---
    if pred_cat != true_cat:
        msg = (
            msg_head
            + "\nResult: Misclassified tampered video → restoration skipped "
              "(to avoid applying the wrong operation)."
        )
        return input_avi, None, msg

    # --- Case 3: Tampered AND correctly classified → Run restoration ---
    ann_raw = load_raw_annotation(entry)
    if not ann_raw:
        msg = (
            msg_head
            + "\nResult: Annotation missing or invalid → cannot localize tampering, "
              "so restoration is not applied."
        )
        return input_avi, None, msg

    restored_avi_path = None
    loc_path = None

    # Use TRUE category for restoration logic
    if true_cat == "Forgery_Object_deletion":
        restored_avi_path, loc_path = restore_object_deletion(entry, ann_raw)
    elif true_cat == "Forgery_frame_insertion":
        restored_avi_path, loc_path = restore_frame_insertion(entry, ann_raw)
    elif true_cat == "Forgery_horizontal_flipping":
        restored_avi_path, loc_path = restore_flip(entry, ann_raw, flip_mode="horizontal")
    elif true_cat == "Forgery_vertical_flipping":
        restored_avi_path, loc_path = restore_flip(entry, ann_raw, flip_mode="vertical")
    elif true_cat == "Forgery_zooming_frames":
        restored_avi_path, loc_path = restore_zoom(entry, ann_raw)
    else:
        msg = (
            msg_head
            + "\nResult: Category not supported for restoration in this demo."
        )
        return input_avi, None, msg

    if restored_avi_path is None or not os.path.exists(restored_avi_path):
        msg = (
            msg_head
            + "\nResult: Restoration function ran, but no output AVI was produced "
              "(see notebook logs for details)."
        )
        return input_avi, None, msg

    msg = (
        msg_head
        + "\nResult: Correctly classified tampered video.\n"
          "Localization + restoration pipeline applied.\n\n"
          f"Restored AVI saved at:\n  {restored_avi_path}\n"
          f"Localization JSON saved at:\n  {loc_path}"
    )

    # For Gradio we can directly return the AVI path
    return input_avi, restored_avi_path, msg


# --- Build dropdown of available test video_ids ---
test_video_ids = sorted(test_manifest_index.keys())

with gr.Blocks(
    title="D2R (Detect to Restore) Pipeline",
    css=custom_css
) as demo:

    with gr.Column(elem_classes=["d2r-container"]):

        # Title section
        gr.Markdown(
            """
<div class="d2r-title d2r-section">
  <h1>D2R (Detect to Restore) Pipeline</h1>
  <p><b>Classify tampered videos, localize manipulated regions, and restore them when possible.</b></p>
</div>
            """,
        )

        # Controls: choose video + button
        with gr.Row(elem_classes=["d2r-section", "d2r-controls-row"]):
            vid_dd = gr.Dropdown(
                choices=test_video_ids,
                label="Select a test video_id",
                value=test_video_ids[0] if test_video_ids else None,
            )
            run_btn = gr.Button("Run D2R", scale=0)

        # Videos: input and restored side-by-side
        with gr.Row(elem_classes=["d2r-section"]):
            input_video_comp = gr.Video(
                label="Input (Original / Tampered)",
                elem_classes=["d2r-video"],
            )
            restored_video_comp = gr.Video(
                label="Restored Video (only if correctly classified & tampered)",
                elem_classes=["d2r-video"],
            )

        # Log / explanation
        logs = gr.Textbox(
            label="Info / Logs",
            lines=10,
            interactive=False,
            elem_classes=["d2r-section"],
        )

        run_btn.click(
            fn=run_d2r_single,
            inputs=vid_dd,
            outputs=[input_video_comp, restored_video_comp, logs],
        )

demo.launch(share=False)


Loaded 199 test videos into manifest index.


  with gr.Blocks(


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

