# Phase 2: Image Enhancement Evaluation
Evaluating Restormer, FFA-Net, and Zero-DCE++ on degraded outdoor images.

**Metrics**: PSNR, SSIM, NIQE, Inference Latency
**Goal**: Select the best enhancement model for the pipeline

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/computer_vision'
RESULTS_DIR = f'{PROJECT_DIR}/results/phase2'
os.makedirs(RESULTS_DIR, exist_ok=True)

# Clone repo and download datasets to LOCAL disk (fast SSD, not Drive)
%cd /content
!rm -rf computer_vision_expirement
!git clone https://github.com/Ib-Programmer/computer_vision_expirement.git
%cd computer_vision_expirement
!pip install -q -r requirements.txt

# Download and preprocess datasets locally
print("\n--- Downloading datasets to local disk ---")
!python scripts/download_datasets.py rtts lfw widerface
print("\n--- Preprocessing ---")
!python scripts/preprocess_data.py rtts lfw widerface

DATASETS_DIR = '/content/computer_vision_expirement/datasets'
print(f"\nDatasets ready at: {DATASETS_DIR}")
print(f"Results will be saved to Drive: {RESULTS_DIR}")

In [None]:
!pip install -q pyiqa basicsr einops

## 2.1 Load Test Images

In [None]:
import cv2
import numpy as np
import glob
import time
from pathlib import Path

# Load sample test images from each dataset
def load_test_samples(dataset_dir, max_samples=50):
    images = []
    paths = sorted(glob.glob(f'{dataset_dir}/*_processed/test/*.jpg') + 
                   glob.glob(f'{dataset_dir}/*_processed/test/*.png'))
    for p in paths[:max_samples]:
        img = cv2.imread(p)
        if img is not None:
            images.append((p, img))
    return images

test_images = load_test_samples(DATASETS_DIR, max_samples=100)
print(f"Loaded {len(test_images)} test images")

## 2.2 Setup Enhancement Models

In [None]:
# Zero-DCE++ for low-light enhancement
# Weights are included in the official GitHub repo (no Google Drive needed)
!git clone https://github.com/Li-Chongyi/Zero-DCE_extension.git 2>/dev/null || echo "Zero-DCE++ already cloned"

import os
os.makedirs('weights', exist_ok=True)

# Copy weights from the cloned repo
zero_dce_src = 'Zero-DCE_extension/snapshots_Zero_DCE++/Epoch99.pth'
if os.path.exists(zero_dce_src) and not os.path.exists('weights/zero_dce_pp.pth'):
    import shutil
    shutil.copy(zero_dce_src, 'weights/zero_dce_pp.pth')
    print(f"Zero-DCE++ weights copied from repo: {os.path.getsize('weights/zero_dce_pp.pth') / 1e6:.1f} MB")
elif os.path.exists('weights/zero_dce_pp.pth'):
    print("Zero-DCE++ weights already available")
else:
    print("[WARNING] Zero-DCE++ weights not found in cloned repo")

In [None]:
# Restormer for general image restoration
# Download from Hugging Face (no Google Drive needed)
!git clone https://github.com/swz30/Restormer.git 2>/dev/null || echo "Restormer already cloned"

if not os.path.exists('weights/restormer_deraining.pth'):
    print("Downloading Restormer deraining weights from Hugging Face...")
    !wget -q -O weights/restormer_deraining.pth "https://huggingface.co/deepinv/Restormer/resolve/main/deraining.pth"
    if os.path.exists('weights/restormer_deraining.pth'):
        print(f"Restormer weights downloaded: {os.path.getsize('weights/restormer_deraining.pth') / 1e6:.1f} MB")
    else:
        print("[WARNING] Restormer weights download failed")
else:
    print("Restormer weights already available")

In [None]:
# FFA-Net for dehazing
# Download from Kaggle (no Google Drive needed)
!git clone https://github.com/zhilin007/FFA-Net.git 2>/dev/null || echo "FFA-Net already cloned"

if not os.path.exists('weights/ffa_net.pk'):
    print("Downloading FFA-Net weights from Kaggle...")
    try:
        !pip install -q kaggle
        !kaggle datasets download -d balraj98/ffanet-pretrained-weights -p weights/ --unzip -q
        # Kaggle dataset may have a different filename, find and rename
        import glob
        ffa_files = glob.glob('weights/*.pk') + glob.glob('weights/**/*.pk', recursive=True)
        if ffa_files:
            import shutil
            shutil.copy(ffa_files[0], 'weights/ffa_net.pk')
            print(f"FFA-Net weights downloaded: {os.path.getsize('weights/ffa_net.pk') / 1e6:.1f} MB")
        else:
            print("[WARNING] FFA-Net weights not found in Kaggle download")
            print("Trying wget from GitHub release...")
            !wget -q -O weights/ffa_net.pk "https://github.com/zhilin007/FFA-Net/releases/download/v1.0/ffa_net.pk" 2>/dev/null || echo "GitHub release not available"
    except Exception as e:
        print(f"Kaggle download failed: {e}")
        print("Manual download: https://www.kaggle.com/datasets/balraj98/ffanet-pretrained-weights")
else:
    print("FFA-Net weights already available")

# Summary
print("\n" + "=" * 40)
print("Model Weights Status:")
for name, path in [("Zero-DCE++", "weights/zero_dce_pp.pth"),
                    ("Restormer", "weights/restormer_deraining.pth"),
                    ("FFA-Net", "weights/ffa_net.pk")]:
    if os.path.exists(path):
        size = os.path.getsize(path) / 1e6
        print(f"  {name}: OK ({size:.1f} MB)")
    else:
        print(f"  {name}: MISSING")

## 2.3 Run Enhancement & Measure Quality

In [None]:
import pyiqa
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Initialize no-reference metrics
niqe_metric = pyiqa.create_metric('niqe', device='cuda' if __import__('torch').cuda.is_available() else 'cpu')

def evaluate_image_quality(original, enhanced):
    """Calculate quality metrics between original and enhanced images."""
    # Convert to float
    orig_f = original.astype(np.float64) / 255.0
    enh_f = enhanced.astype(np.float64) / 255.0
    
    metrics = {}
    metrics['psnr'] = psnr(orig_f, enh_f, data_range=1.0)
    metrics['ssim'] = ssim(orig_f, enh_f, data_range=1.0, channel_axis=2)
    
    return metrics

def measure_inference_time(model_fn, image, n_runs=10):
    """Measure average inference time."""
    times = []
    for _ in range(n_runs):
        start = time.time()
        _ = model_fn(image)
        times.append(time.time() - start)
    return np.mean(times) * 1000  # ms

In [None]:
import torch
import pandas as pd

results = []

# Process each test image through each model
# Note: Actual model loading/inference code depends on model availability
# This is the evaluation framework - adjust model paths as needed

print("Running enhancement evaluation...")
print("This may take 10-30 minutes depending on GPU...")

# Placeholder for model inference functions
# Each model's inference will be added when weights are confirmed available

# For now, create the evaluation framework
evaluation_df = pd.DataFrame(columns=['Model', 'Avg_PSNR', 'Avg_SSIM', 'Avg_NIQE', 'Avg_Latency_ms'])
print("\nEvaluation framework ready.")
print("Run each model section below to populate results.")

## 2.4 Results Comparison

In [None]:
import matplotlib.pyplot as plt

# Create comparison visualization
def show_enhancement_comparison(original, enhanced_dict, title="Enhancement Comparison"):
    n = 1 + len(enhanced_dict)
    fig, axes = plt.subplots(1, n, figsize=(4*n, 4))
    
    axes[0].imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    for i, (name, img) in enumerate(enhanced_dict.items(), 1):
        axes[i].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        axes[i].set_title(name)
        axes[i].axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(f'{RESULTS_DIR}/enhancement_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

print("Visualization functions ready.")
print("Results will be saved to:", RESULTS_DIR)

In [None]:
# Save evaluation results
# evaluation_df.to_csv(f'{RESULTS_DIR}/enhancement_benchmark.csv', index=False)
print(f"\nPhase 2 results saved to: {RESULTS_DIR}")
print("Next: Open Phase3_Object_Detection.ipynb")