<a href="https://colab.research.google.com/github/Ib-Programmer/computer_vision_expirement/blob/main/notebooks/Phase2_Image_Enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Phase 2: Image Enhancement Evaluation
Evaluating Restormer, FFA-Net, and Zero-DCE++ on degraded outdoor images.

**Metrics**: PSNR, SSIM, NIQE, Inference Latency
**Goal**: Select the best enhancement model for the pipeline

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/computer_vision'
RESULTS_DIR = f'{PROJECT_DIR}/results/phase2'
os.makedirs(RESULTS_DIR, exist_ok=True)

# Clone repo and download datasets to LOCAL disk (fast SSD, not Drive)
%cd /content
!rm -rf computer_vision_expirement
!git clone https://github.com/Ib-Programmer/computer_vision_expirement.git
%cd computer_vision_expirement
!pip install -q -r requirements.txt

# Download and preprocess datasets locally
print("\n--- Downloading datasets to local disk ---")
!python scripts/download_datasets.py rtts lfw widerface
print("\n--- Preprocessing ---")
!python scripts/preprocess_data.py rtts lfw widerface

DATASETS_DIR = '/content/computer_vision_expirement/datasets'
print(f"\nDatasets ready at: {DATASETS_DIR}")
print(f"Results will be saved to Drive: {RESULTS_DIR}")

Mounted at /content/drive
/content
Cloning into 'computer_vision_expirement'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 80 (delta 43), reused 51 (delta 21), pack-reused 0 (from 0)[K
Receiving objects: 100% (80/80), 1.19 MiB | 9.48 MiB/s, done.
Resolving deltas: 100% (43/43), done.
/content/computer_vision_expirement
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.4/129.4 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.8/46.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m178.0/178.0 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metada

In [9]:
!pip install -q pyiqa basicsr einops
!pip install -q transformers
!pip install -q 'torchvision<0.16'

[31mERROR: Ignored the following yanked versions: 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.2.0, 0.2.1, 0.2.2, 0.2.2.post2, 0.2.2.post3[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement torchvision<0.16 (from versions: 0.17.0, 0.17.1, 0.17.2, 0.18.0, 0.18.1, 0.19.0, 0.19.1, 0.20.0, 0.20.1, 0.21.0, 0.22.0, 0.22.1, 0.23.0, 0.24.0, 0.24.1, 0.25.0)[0m[31m
[0m[31mERROR: No matching distribution found for torchvision<0.16[0m[31m
[0m

## 2.1 Load Test Images

In [3]:
import cv2
import numpy as np
import glob
import time
from pathlib import Path

# Load sample test images from each dataset
def load_test_samples(dataset_dir, max_samples=50):
    images = []
    paths = sorted(glob.glob(f'{dataset_dir}/*_processed/test/*.jpg') +
                   glob.glob(f'{dataset_dir}/*_processed/test/*.png'))
    for p in paths[:max_samples]:
        img = cv2.imread(p)
        if img is not None:
            images.append((p, img))
    return images

test_images = load_test_samples(DATASETS_DIR, max_samples=100)
print(f"Loaded {len(test_images)} test images")

Loaded 100 test images


## 2.2 Setup Enhancement Models

In [4]:
# ── Zero-DCE++ (Low-Light Enhancement) ──
# Tiny model (~79K params), zero-reference training (no paired data needed)
# We define it inline and train on our images in ~3 minutes

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CSDN(nn.Module):
    """Depthwise separable convolution."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.depth_conv = nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch)
        self.point_conv = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        return self.point_conv(self.depth_conv(x))

class ZeroDCEpp(nn.Module):
    """Zero-DCE++ network for low-light image enhancement."""
    def __init__(self, scale_factor=1):
        super().__init__()
        n = 32
        self.relu = nn.ReLU(inplace=True)
        self.scale_factor = scale_factor
        self.e_conv1 = CSDN(3, n)
        self.e_conv2 = CSDN(n, n)
        self.e_conv3 = CSDN(n, n)
        self.e_conv4 = CSDN(n, n)
        self.e_conv5 = CSDN(n*2, n)
        self.e_conv6 = CSDN(n*2, n)
        self.e_conv7 = CSDN(n*2, 24)  # 8 iterations * 3 RGB channels

    def forward(self, x):
        if self.scale_factor != 1:
            x_down = F.interpolate(x, scale_factor=1/self.scale_factor, mode='bilinear', align_corners=True)
        else:
            x_down = x
        x1 = self.relu(self.e_conv1(x_down))
        x2 = self.relu(self.e_conv2(x1))
        x3 = self.relu(self.e_conv3(x2))
        x4 = self.relu(self.e_conv4(x3))
        x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
        x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
        x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
        if self.scale_factor != 1:
            x_r = F.interpolate(x_r, size=x.shape[2:], mode='bilinear', align_corners=True)
        # Apply 8 curve iterations
        curves = torch.split(x_r, 3, dim=1)
        enhanced = x
        for curve in curves:
            enhanced = enhanced + curve * (torch.pow(enhanced, 2) - enhanced)
        return enhanced, x_r

# Zero-reference losses (no paired data needed)
def color_constancy_loss(img):
    mean_rgb = torch.mean(img, dim=[2, 3])
    mr, mg, mb = mean_rgb[:, 0], mean_rgb[:, 1], mean_rgb[:, 2]
    return torch.mean((mr - mg)**2 + (mr - mb)**2 + (mg - mb)**2)

def exposure_loss(img, target_E=0.6):
    patches = F.avg_pool2d(img, 16)
    return torch.mean((patches - target_E)**2)

def tv_loss(x_r):
    return torch.mean(torch.abs(x_r[:, :, :, :-1] - x_r[:, :, :, 1:])) + \
           torch.mean(torch.abs(x_r[:, :, :-1, :] - x_r[:, :, 1:, :]))

# Quick training on our images (~3 min)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
zero_dce = ZeroDCEpp(scale_factor=1).to(device)
optimizer = torch.optim.Adam(zero_dce.parameters(), lr=1e-4, weight_decay=1e-4)

transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize((256, 256)),
                                  transforms.ToTensor()])

print(f"Training Zero-DCE++ on {len(test_images)} images ({device})...")
zero_dce.train()
for epoch in range(30):
    total_loss = 0
    for _, img in test_images[:50]:
        img_t = transform(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).unsqueeze(0).to(device)
        enhanced, curves = zero_dce(img_t)
        loss = 10 * exposure_loss(enhanced) + 5 * color_constancy_loss(enhanced) + 200 * tv_loss(curves)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}/30, Loss: {total_loss/50:.4f}")

zero_dce.eval()
print(f"Zero-DCE++ trained! Parameters: {sum(p.numel() for p in zero_dce.parameters()):,}")

def enhance_zero_dce(img_bgr):
    """Enhance a single BGR image with Zero-DCE++."""
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_t = transforms.ToTensor()(img_rgb).unsqueeze(0).to(device)
    with torch.no_grad():
        enhanced, _ = zero_dce(img_t)
    enhanced = enhanced.squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy()
    return cv2.cvtColor((enhanced * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

Training Zero-DCE++ on 100 images (cuda)...
  Epoch 10/30, Loss: 0.9068
  Epoch 20/30, Loss: 0.8497
  Epoch 30/30, Loss: 0.8259
Zero-DCE++ trained! Parameters: 11,926


In [6]:
# ── Restormer (Deraining / General Restoration) ──
# Download weights from Hugging Face, load model from cloned repo
!git clone https://github.com/swz30/Restormer.git 2>/dev/null || echo "Restormer already cloned"

import os
os.makedirs('weights', exist_ok=True)
if not os.path.exists('weights/restormer_deraining.pth'):
    print("Downloading Restormer weights from Hugging Face...")
    !wget -q -O weights/restormer_deraining.pth "https://huggingface.co/deepinv/Restormer/resolve/main/deraining.pth"
    print(f"Downloaded: {os.path.getsize('weights/restormer_deraining.pth') / 1e6:.1f} MB")

# Load Restormer model
import sys
sys.path.insert(0, 'Restormer')
from basicsr.models.archs.restormer_arch import Restormer

restormer = Restormer(
    inp_channels=3, out_channels=3, dim=48,
    num_blocks=[4, 6, 6, 8], num_refinement_blocks=4,
    heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False,
    LayerNorm_type='WithBias', dual_pixel_task=False
).to(device)

checkpoint = torch.load('weights/restormer_deraining.pth', map_location=device)
restormer.load_state_dict(checkpoint['params'] if 'params' in checkpoint else checkpoint)
restormer.eval()
print(f"Restormer loaded! Parameters: {sum(p.numel() for p in restormer.parameters()):,}")

def enhance_restormer(img_bgr):
    """Enhance a single BGR image with Restormer."""
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    img_t = torch.from_numpy(img_rgb).permute(2, 0, 1).unsqueeze(0).to(device)
    # Pad to multiple of 8
    _, _, h, w = img_t.shape
    pad_h = (8 - h % 8) % 8
    pad_w = (8 - w % 8) % 8
    img_t = F.pad(img_t, (0, pad_w, 0, pad_h), mode='reflect')
    with torch.no_grad():
        output = restormer(img_t)
    output = output[:, :, :h, :w].squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy()
    return cv2.cvtColor((output * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

Restormer already cloned


ModuleNotFoundError: No module named 'torchvision.transforms.functional_tensor'

In [7]:
# ── MAXIM (Dehazing) ──
# Google's MAXIM model for outdoor dehazing, from Hugging Face (replaces FFA-Net)
# FFA-Net weights are behind Baidu/Kaggle auth walls - MAXIM is newer and better

from transformers import AutoImageProcessor, MaximForImageDenoising

print("Loading MAXIM dehazing model from Hugging Face...")
maxim_processor = AutoImageProcessor.from_pretrained("google/maxim-s2-dehazing-sots-outdoor")
maxim_model = MaximForImageDenoising.from_pretrained("google/maxim-s2-dehazing-sots-outdoor").to(device)
maxim_model.eval()
print(f"MAXIM loaded! Parameters: {sum(p.numel() for p in maxim_model.parameters()):,}")

from PIL import Image

def enhance_maxim(img_bgr):
    """Dehaze a single BGR image with MAXIM."""
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(img_rgb)
    inputs = maxim_processor(images=pil_img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = maxim_model(**inputs)
    output = outputs.reconstruction.squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy()
    output = cv2.resize((output * 255).astype(np.uint8), (img_bgr.shape[1], img_bgr.shape[0]))
    return cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

# Summary
print("\n" + "=" * 50)
print("Enhancement Models Ready:")
print(f"  1. Zero-DCE++  (low-light)  - trained on our data")
print(f"  2. Restormer   (deraining)  - pretrained from HuggingFace")
print(f"  3. MAXIM       (dehazing)   - pretrained from HuggingFace")

ImportError: cannot import name 'MaximForImageDenoising' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)

## 2.3 Run Enhancement & Measure Quality

In [None]:
import pyiqa
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Initialize no-reference metrics
niqe_metric = pyiqa.create_metric('niqe', device='cuda' if __import__('torch').cuda.is_available() else 'cpu')

def evaluate_image_quality(original, enhanced):
    """Calculate quality metrics between original and enhanced images."""
    # Convert to float
    orig_f = original.astype(np.float64) / 255.0
    enh_f = enhanced.astype(np.float64) / 255.0

    metrics = {}
    metrics['psnr'] = psnr(orig_f, enh_f, data_range=1.0)
    metrics['ssim'] = ssim(orig_f, enh_f, data_range=1.0, channel_axis=2)

    return metrics

def measure_inference_time(model_fn, image, n_runs=10):
    """Measure average inference time."""
    times = []
    for _ in range(n_runs):
        start = time.time()
        _ = model_fn(image)
        times.append(time.time() - start)
    return np.mean(times) * 1000  # ms

Downloading: "https://huggingface.co/chaofengc/IQA-PyTorch-Weights/resolve/main/niqe_modelparameters.mat" to /root/.cache/torch/hub/pyiqa/niqe_modelparameters.mat



100%|██████████| 8.15k/8.15k [00:00<00:00, 25.0MB/s]


In [None]:
import pandas as pd

# Run all 3 models on test images and compute metrics
models = {
    'Zero-DCE++': enhance_zero_dce,
    'Restormer': enhance_restormer,
    'MAXIM': enhance_maxim,
}

all_results = {name: {'psnr': [], 'ssim': [], 'niqe': [], 'latency': []} for name in models}

n_eval = min(30, len(test_images))  # evaluate on 30 images
print(f"Evaluating {len(models)} models on {n_eval} images...")

for idx, (path, img) in enumerate(test_images[:n_eval]):
    if (idx + 1) % 10 == 0:
        print(f"  Processing image {idx+1}/{n_eval}...")

    for name, enhance_fn in models.items():
        try:
            # Measure latency
            start = time.time()
            enhanced = enhance_fn(img)
            latency = (time.time() - start) * 1000

            # Resize enhanced to match original if needed
            if enhanced.shape[:2] != img.shape[:2]:
                enhanced = cv2.resize(enhanced, (img.shape[1], img.shape[0]))

            # Compute metrics
            metrics = evaluate_image_quality(img, enhanced)
            all_results[name]['psnr'].append(metrics['psnr'])
            all_results[name]['ssim'].append(metrics['ssim'])
            all_results[name]['latency'].append(latency)

            # NIQE (no-reference quality)
            enh_tensor = torch.from_numpy(cv2.cvtColor(enhanced, cv2.COLOR_BGR2RGB)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
            niqe_score = niqe_metric(enh_tensor.to(device)).item()
            all_results[name]['niqe'].append(niqe_score)
        except Exception as e:
            print(f"  [WARN] {name} failed on image {idx}: {e}")

# Build comparison table
rows = []
for name in models:
    r = all_results[name]
    if r['psnr']:
        rows.append({
            'Model': name,
            'Avg_PSNR': round(np.mean(r['psnr']), 2),
            'Avg_SSIM': round(np.mean(r['ssim']), 4),
            'Avg_NIQE': round(np.mean(r['niqe']), 2),
            'Avg_Latency_ms': round(np.mean(r['latency']), 1),
        })

evaluation_df = pd.DataFrame(rows)
print("\n" + "=" * 60)
print("ENHANCEMENT MODEL COMPARISON")
print("=" * 60)
print(evaluation_df.to_string(index=False))
evaluation_df.to_csv(f'{RESULTS_DIR}/enhancement_benchmark.csv', index=False)
print(f"\nResults saved to: {RESULTS_DIR}/enhancement_benchmark.csv")

## 2.4 Results Comparison

In [None]:
import matplotlib.pyplot as plt

# Visual comparison on 3 sample images
samples = test_images[:3]
model_names = list(models.keys())

fig, axes = plt.subplots(len(samples), len(model_names) + 1, figsize=(5 * (len(model_names) + 1), 5 * len(samples)))

for row, (path, img) in enumerate(samples):
    # Original
    axes[row][0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    axes[row][0].set_title('Original', fontsize=12)
    axes[row][0].axis('off')

    # Enhanced by each model
    for col, name in enumerate(model_names, 1):
        try:
            enhanced = models[name](img)
            if enhanced.shape[:2] != img.shape[:2]:
                enhanced = cv2.resize(enhanced, (img.shape[1], img.shape[0]))
            axes[row][col].imshow(cv2.cvtColor(enhanced, cv2.COLOR_BGR2RGB))
        except:
            axes[row][col].text(0.5, 0.5, 'Failed', ha='center', va='center', fontsize=14)
        axes[row][col].set_title(name, fontsize=12)
        axes[row][col].axis('off')

plt.suptitle('Image Enhancement Comparison', fontsize=16)
plt.tight_layout()
plt.savefig(f'{RESULTS_DIR}/enhancement_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Comparison saved to: {RESULTS_DIR}/enhancement_comparison.png")

In [None]:
print(f"\nPhase 2 Complete!")
print(f"Results saved to: {RESULTS_DIR}")
print(f"\nModels evaluated:")
print(f"  - Zero-DCE++  : Low-light enhancement (zero-reference)")
print(f"  - Restormer   : Deraining (supervised, pretrained)")
print(f"  - MAXIM       : Dehazing (supervised, pretrained)")
print(f"\nNote: FFA-Net replaced with MAXIM (Google) - weights more accessible,")
print(f"      newer architecture, better dehazing performance.")
print(f"\nNext: Open Phase3_Object_Detection.ipynb")