In [1]:
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)


[INFO] Using device: cuda


In [3]:
def read_image(path: Path) -> np.ndarray:
    img = cv2.imread(str(path))
    if img is None:
        raise ValueError(f"Failed to read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def to_tensor(img: np.ndarray) -> torch.Tensor:
    img = img.astype(np.float32) / 255.0
    return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)

def to_numpy(t: torch.Tensor) -> np.ndarray:
    img = t.squeeze(0).permute(1, 2, 0).cpu().numpy()
    return (img * 255.0).clip(0, 255).astype(np.uint8)


In [4]:
def compute_brightness(img: np.ndarray) -> float:
    """
    Returns mean brightness in [0,1]
    """
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    v = hsv[:, :, 2] / 255.0
    return float(v.mean())


In [5]:
def needs_low_light_enhancement(
    img: np.ndarray,
    threshold: float = 0.35
) -> Tuple[bool, float]:
    brightness = compute_brightness(img)
    return brightness < threshold, brightness


In [6]:
class ZeroDCE(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv5 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv7 = nn.Conv2d(32, 24, 3, padding=1)

    def forward(self, x):
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))
        x3 = self.relu(self.conv3(x2))
        x4 = self.relu(self.conv4(x3))
        x5 = self.relu(self.conv5(x4))
        x6 = self.relu(self.conv6(x5))
        curves = torch.tanh(self.conv7(x6))
        return curves


In [7]:
def apply_curves(img, curves, n_iters=8):
    enhanced = img
    for i in range(n_iters):
        r = curves[:, i*3:(i+1)*3, :, :]
        enhanced = enhanced + r * (enhanced**2 - enhanced)
    return enhanced


In [8]:
MODEL_PATH = Path("../outputs/models/low_light/best_model.pth")

low_light_model = ZeroDCE().to(DEVICE)
low_light_model.eval()

if MODEL_PATH.exists():
    state = torch.load(MODEL_PATH, map_location=DEVICE)
    low_light_model.load_state_dict(state["model_state"])
    print("✅ Loaded pretrained low-light model")
else:
    print("⚠️ No pretrained model found. Using random weights.")


⚠️ No pretrained model found. Using random weights.


In [9]:
@torch.no_grad()
def enhance_low_light(frame: np.ndarray) -> np.ndarray:
    inp = to_tensor(frame).to(DEVICE)
    curves = low_light_model(inp)
    enhanced = apply_curves(inp, curves)
    return to_numpy(enhanced)


In [10]:
def visualize_low_light(img_path: Path):
    img = read_image(img_path)
    enhanced = enhance_low_light(img)

    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.title("Original (Low-Light)")
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.imshow(enhanced)
    plt.title("Enhanced")
    plt.axis("off")

    plt.show()


In [11]:
def low_light_enhancement_step(frame: np.ndarray):
    enhance_flag, brightness = needs_low_light_enhancement(frame)

    if enhance_flag:
        frame = enhance_low_light(frame)

    return {
        "brightness": brightness,
        "enhanced": enhance_flag,
        "frame": frame
    }


In [12]:
def full_enhancement_pipeline(
    frame: np.ndarray,
    blur_flag: bool
):
    if blur_flag:
        frame = deblur_frame(frame)  # from Notebook 03

    res = low_light_enhancement_step(frame)
    return res["frame"]
