In [18]:
# ===============================
# Standard library
# ===============================
import os
import shutil
import zipfile
import tarfile
import urllib.request
from pathlib import Path
from typing import List, Optional, Dict, Any

# ===============================
# Third-party
# ===============================
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
from tqdm import tqdm


In [19]:
# ===============================
# DEVICE
# ===============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {DEVICE}")

# ===============================
# PROJECT ROOT
# ===============================
PROJECT_ROOT = Path("..").resolve()

# ===============================
# DATA DIRECTORIES
# ===============================
DATA_ROOT = PROJECT_ROOT / "data"
DATASETS_DIR = DATA_ROOT / "datasets"
RAW_VIDEO_DIR = DATA_ROOT / "raw_videos"
FRAME_DIR = DATA_ROOT / "extracted_frames"

# ===============================
# OUTPUT DIRECTORIES
# ===============================
OUTPUT_DIR = PROJECT_ROOT / "outputs"
MODELS_DIR = OUTPUT_DIR / "models"
LOGS_DIR = OUTPUT_DIR / "logs"
PLOTS_DIR = OUTPUT_DIR / "plots"

# ===============================
# CREATE ALL DIRS SAFELY
# ===============================
ALL_DIRS = [
    DATA_ROOT,
    DATASETS_DIR,
    RAW_VIDEO_DIR,
    FRAME_DIR,
    OUTPUT_DIR,
    MODELS_DIR,
    LOGS_DIR,
    PLOTS_DIR
]

for d in ALL_DIRS:
    d.mkdir(parents=True, exist_ok=True)

print("[INFO] Directory structure initialized")


[INFO] Using device: cuda
[INFO] Directory structure initialized


In [20]:
def dataset_exists(
    dataset_dir: Path,
    required_subdirs: Optional[List[str]] = None,
    min_files: int = 1
) -> bool:
    if not dataset_dir.exists():
        return False

    if required_subdirs:
        for sub in required_subdirs:
            if not (dataset_dir / sub).exists():
                return False

    files = [f for f in dataset_dir.rglob("*") if f.is_file()]
    return len(files) >= min_files


In [21]:
def download_file(url: str, output_path: Path):
    """
    Cross-platform downloader (Windows / Linux / Mac).
    Downloads only if file does not exist.
    """
    if output_path.exists():
        print(f"[INFO] File already exists: {output_path.name}")
        return

    print(f"[DOWNLOAD] {url}")
    try:
        urllib.request.urlretrieve(url, output_path)
        print(f"[INFO] Downloaded to {output_path}")
    except Exception as e:
        raise RuntimeError(f"Failed to download {url}\nReason: {e}")


In [22]:
def extract_archive(archive_path: Path, extract_to: Path):
    print(f"[EXTRACT] {archive_path.name}")

    if archive_path.suffix == ".zip":
        with zipfile.ZipFile(archive_path, "r") as z:
            z.extractall(extract_to)

    elif archive_path.suffixes[-2:] == [".tar", ".gz"]:
        with tarfile.open(archive_path, "r:gz") as t:
            t.extractall(extract_to)

    else:
        raise ValueError("Unsupported archive format")

    print("[INFO] Extraction complete")


In [23]:
GOPRO_DIR = DATASETS_DIR / "GoPro"

if dataset_exists(GOPRO_DIR, ["train", "test"]):
    print("✅ GoPro dataset ready.")
else:
    print(
        "\n❗ GoPro dataset NOT found.\n"
        "ACTION REQUIRED:\n"
        "1. Download from: https://seungjunnah.github.io/Datasets/gopro\n"
        f"2. Extract into: {GOPRO_DIR}\n"
        "Expected structure:\n"
        "GoPro/train/blur, GoPro/train/sharp, GoPro/test/blur, GoPro/test/sharp"
    )



❗ GoPro dataset NOT found.
ACTION REQUIRED:
1. Download from: https://seungjunnah.github.io/Datasets/gopro
2. Extract into: C:\Users\Manas Mehta\Desktop\PROJECTS\AIDTM\data\datasets\GoPro
Expected structure:
GoPro/train/blur, GoPro/train/sharp, GoPro/test/blur, GoPro/test/sharp


In [24]:
REALBLUR_DIR = DATASETS_DIR / "RealBlur"

if dataset_exists(REALBLUR_DIR):
    print("✅ RealBlur dataset ready.")
else:
    print(
        "\n❗ RealBlur dataset NOT found.\n"
        "Download from: https://cg.postech.ac.kr/research/realblur/\n"
        f"Extract into: {REALBLUR_DIR}"
    )



❗ RealBlur dataset NOT found.
Download from: https://cg.postech.ac.kr/research/realblur/
Extract into: C:\Users\Manas Mehta\Desktop\PROJECTS\AIDTM\data\datasets\RealBlur


In [25]:
LOL_DIR = DATASETS_DIR / "LOL"

if dataset_exists(LOL_DIR, ["our485", "eval15"]):
    print("✅ LOL dataset ready.")
else:
    print(
        "\n❗ LOL dataset NOT found.\n\n"
        "ACTION REQUIRED:\n"
        "1. Download the LOL dataset manually from one of these sources:\n"
        "   - https://daooshee.github.io/BMVC2018website/\n"
        "   - https://github.com/daooshee/Low-light-image-enhancement\n\n"
        "2. Extract it into:\n"
        f"   {LOL_DIR}\n\n"
        "Expected structure:\n"
        "LOL/\n"
        " ├── our485/\n"
        " │    ├── low/\n"
        " │    └── high/\n"
        " └── eval15/\n"
        "      ├── low/\n"
        "      └── high/\n"
    )



❗ LOL dataset NOT found.

ACTION REQUIRED:
1. Download the LOL dataset manually from one of these sources:
   - https://daooshee.github.io/BMVC2018website/
   - https://github.com/daooshee/Low-light-image-enhancement

2. Extract it into:
   C:\Users\Manas Mehta\Desktop\PROJECTS\AIDTM\data\datasets\LOL

Expected structure:
LOL/
 ├── our485/
 │    ├── low/
 │    └── high/
 └── eval15/
      ├── low/
      └── high/



In [26]:
ICDAR_DIR = DATASETS_DIR / "ICDAR2015"

if dataset_exists(ICDAR_DIR):
    print("✅ ICDAR 2015 dataset ready.")
else:
    print(
        "\n❗ ICDAR 2015 NOT found.\n"
        "Register & download from: https://rrc.cvc.uab.es/?ch=4\n"
        f"Extract into: {ICDAR_DIR}"
    )



❗ ICDAR 2015 NOT found.
Register & download from: https://rrc.cvc.uab.es/?ch=4
Extract into: C:\Users\Manas Mehta\Desktop\PROJECTS\AIDTM\data\datasets\ICDAR2015


In [27]:
RAILSEM_DIR = DATASETS_DIR / "RailSem19"

if dataset_exists(RAILSEM_DIR):
    print("✅ RailSem19 dataset ready.")
else:
    print(
        "\n❗ RailSem19 NOT found.\n"
        "Download from: https://www.railsense.org/datasets/railsem19\n"
        f"Extract into: {RAILSEM_DIR}"
    )



❗ RailSem19 NOT found.
Download from: https://www.railsense.org/datasets/railsem19
Extract into: C:\Users\Manas Mehta\Desktop\PROJECTS\AIDTM\data\datasets\RailSem19


In [28]:
def read_image(path: Path) -> np.ndarray:
    img = cv2.imread(str(path))
    if img is None:
        raise ValueError(f"Failed to read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def normalize_img(img: np.ndarray) -> np.ndarray:
    return img.astype(np.float32) / 255.0


def denormalize_img(img: np.ndarray) -> np.ndarray:
    return (img * 255.0).clip(0, 255).astype(np.uint8)


In [29]:
def show_random_image(root: Path):
    images = list(root.rglob("*.jpg")) + list(root.rglob("*.png"))
    if not images:
        print("[WARN] No images found.")
        return

    img_path = np.random.choice(images)
    img = read_image(img_path)

    plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.title(img_path.name)
    plt.axis("off")
    plt.show()


In [30]:
def extract_frames_from_video(
    video_path: Path,
    output_dir: Path,
    frame_interval: int = 1,
    max_frames: Optional[int] = None
):
    cap = cv2.VideoCapture(str(video_path))
    output_dir.mkdir(parents=True, exist_ok=True)

    frame_id = 0
    saved = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_id % frame_interval == 0:
            out_path = output_dir / f"frame_{frame_id:06d}.jpg"
            cv2.imwrite(str(out_path), frame)
            saved += 1

            if max_frames and saved >= max_frames:
                break

        frame_id += 1

    cap.release()
    print(f"[INFO] Extracted {saved} frames")


In [31]:
def save_epoch_model(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    save_dir: Path,
    metrics: Optional[Dict[str, Any]] = None
):
    save_dir.mkdir(parents=True, exist_ok=True)
    path = save_dir / f"epoch_{epoch:03d}.pth"

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": loss,
        "metrics": metrics
    }, path)


In [32]:
def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[Any],
    epoch: int,
    best_metric: float,
    save_dir: Path
):
    path = save_dir / "checkpoint.pth"

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict() if scheduler else None,
        "best_metric": best_metric
    }, path)


In [33]:
def save_best_model(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    metric: float,
    save_dir: Path
):
    path = save_dir / "best_model.pth"

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_metric": metric
    }, path)


In [34]:
def load_checkpoint_if_exists(
    model,
    optimizer,
    scheduler,
    checkpoint_path: Path
):
    if not checkpoint_path.exists():
        return model, optimizer, scheduler, 0, None

    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])

    if scheduler and checkpoint["scheduler_state"]:
        scheduler.load_state_dict(checkpoint["scheduler_state"])

    start_epoch = checkpoint["epoch"] + 1
    best_metric = checkpoint["best_metric"]

    print(f"[INFO] Resuming from epoch {start_epoch}")
    return model, optimizer, scheduler, start_epoch, best_metric
