In [None]:
!git clone https://github.com/Li-Chongyi/Zero-DCE_extension.git

Cloning into 'Zero-DCE_extension'...
remote: Enumerating objects: 2140, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 2140 (delta 0), reused 2 (delta 0), pack-reused 2136 (from 1)[K
Receiving objects: 100% (2140/2140), 100.92 MiB | 27.67 MiB/s, done.
Resolving deltas: 100% (15/15), done.


In [None]:
!pip install Pytorch
!pip install opencv
!pip install torchvision
!pip install cuda

Collecting Pytorch
  Downloading pytorch-1.0.2.tar.gz (689 bytes)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: Pytorch
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for Pytorch (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for Pytorch[0m[31m
[0m[?25h  Running setup.py clean for Pytorch
Failed to build Pytorch
[31mERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (Pytorch)[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement opencv (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for opencv[0m[31m
[31mERROR: Could not find a version that satisfi

In [None]:
import os
import sys
import time
import glob

import cv2
import numpy as np
import torch
import torchvision

sys.path.append(os.path.abspath("/content/Zero-DCE_extension/Zero-DCE++"))
import model as zdce_model


def load_model(weight_path: str, scale_factor: int = 12, device: str = None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    net = zdce_model.enhance_net_nopool(scale_factor).to(device)
    state = torch.load(weight_path, map_location=device)
    net.load_state_dict(state)
    net.eval()
    print(f"Model loaded to {device} from {weight_path}")
    return net, device


def _crop_to_multiple(img_np: np.ndarray, factor: int):
    h, w = img_np.shape[:2]
    h2 = (h // factor) * factor
    w2 = (w // factor) * factor
    if h2 == h and w2 == w:
        return img_np
    return img_np[0:h2, 0:w2, :]


def _tensor_to_bgr_uint8(tensor: torch.Tensor):
    if tensor.dim() == 4 and tensor.shape[0] == 1:
        tensor = tensor[0]
    t = tensor.detach().cpu().clamp(0.0, 1.0)
    arr = (t.numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
    bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
    return bgr


def enhance_video(
    input_video_path: str,
    output_video_path: str,
    model_weights: str,
    scale_factor: int = 12,
    sample_every_n: int = 1,
    codec: str = "mp4v",
    max_frames: int | None = None,
):
    model, device = load_model(model_weights, scale_factor=scale_factor)

    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open input video: {input_video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)

    cropped_h = (orig_h // scale_factor) * scale_factor
    cropped_w = (orig_w // scale_factor) * scale_factor
    if cropped_h == 0 or cropped_w == 0:
        raise RuntimeError("Video resolution too small for the chosen scale_factor.")

    out_dir = os.path.dirname(output_video_path)
    os.makedirs(out_dir, exist_ok=True)
    fourcc = cv2.VideoWriter_fourcc(*codec)
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (cropped_w, cropped_h))

    frame_idx = 0
    processed = 0
    total_time = 0.0

    print(f"Input video: {input_video_path} ({orig_w}x{orig_h}, fps={fps}, frames={total_frames})")
    print(f"Output video: {output_video_path} ({cropped_w}x{cropped_h}), sample_every_n={sample_every_n}")

    try:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            if sample_every_n > 1 and (frame_idx % sample_every_n) != 0:
                frame_idx += 1
                continue

            if max_frames is not None and processed >= max_frames:
                break

            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            frame_rgb = _crop_to_multiple(frame_rgb, scale_factor)

            img_np = frame_rgb.astype(np.float32) / 255.0
            img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)

            start = time.time()
            with torch.no_grad():
                enhanced_tensor, *_ = model(img_t)
            elapsed = time.time() - start
            total_time += elapsed

            out_frame = _tensor_to_bgr_uint8(enhanced_tensor)
            if (out_frame.shape[1], out_frame.shape[0]) != (cropped_w, cropped_h):
                out_frame = cv2.resize(out_frame, (cropped_w, cropped_h), interpolation=cv2.INTER_LINEAR)

            out.write(out_frame)

            processed += 1
            frame_idx += 1
            if processed % 50 == 0:
                print(f"Processed {processed} frames (frame_idx={frame_idx}) — last frame time {elapsed:.3f}s")

    except KeyboardInterrupt:
        print("Interrupted by user. Finalizing...")

    finally:
        cap.release()
        out.release()

    print(f"Done. Processed {processed} frames. Total inference time: {total_time:.4f}s")
    if processed:
        print(f"Average time/frame: {total_time / processed:.4f}s")

    return {
        "output_path": output_video_path,
        "processed_frames": processed,
        "total_inference_time": total_time,
        "avg_time_per_frame": (total_time / processed) if processed else 0.0,
    }


if __name__ == "__main__":
    input_video = "/content/test/low_light_video.mp4"
    output_video = "/content/results_Zero_DCE++/enhanced_video.mp4"
    weights = "/content/Zero-DCE_extension/Zero-DCE++/snapshots_Zero_DCE++/Epoch99.pth"

    result = enhance_video(
        input_video_path=input_video,
        output_video_path=output_video,
        model_weights=weights,
        scale_factor=12,
        sample_every_n=1,
        codec="mp4v",
        max_frames=None,
    )
    print(result)


✅ Model loaded to cuda from /content/Zero-DCE_extension/Zero-DCE++/snapshots_Zero_DCE++/Epoch99.pth
Input video: /content/output_1.mp4 (1920x1080, fps=22.808764940239044, frames=229)
Output video: /content/enhanced_video.mp4 (1920x1080), sample_every_n=1
Processed 50 frames (frame_idx=50) — last frame time 0.002s
Processed 100 frames (frame_idx=100) — last frame time 0.002s
Processed 150 frames (frame_idx=150) — last frame time 0.002s
Processed 200 frames (frame_idx=200) — last frame time 0.002s
Done. Processed 229 frames. Total inference time: 0.5645s
Average time/frame: 0.0025s
{'output_path': '/content/enhanced_video.mp4', 'processed_frames': 229, 'total_inference_time': 0.5644688606262207, 'avg_time_per_frame': 0.0024649295223852434}
