<a href="https://colab.research.google.com/github/Anguschow237/hybrid-zero-dce-low-light-enhancement/blob/main/hybrid_zero_dce_low_light_image_enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Low-Light Image Enhancement: U-Net vs. ZeroDCE Comparison

### **Author:** Chow Tsz Hin  
### **Date:** 20th December 2025  

This project benchmarks four deep learning models for low-light image enhancement:
- Supervised U-Net with Charbonnier loss
- Unsupervised ZeroDCE (original)
- Hybrid ZeroDCE (original loss + Charbonnier)
- Supervised ZeroDCE (Charbonnier only)

Key experiments include loss function comparisons and hybrid approaches to improve unsupervised performance.

**Content Sections** (collapsible for easy navigation):

1. Setting up the environment  
2. Supervised Loss Function Comparisons  
3. Unsupervised ZeroDCE Baseline
4. Modifying ZeroDCE for supervised/hybrid training  
5. Comparison of the four models (quantitative + qualitative)  
6. Final analysis and insights

## üìå 1. Setting up the environment & Import the LOL Dataset

In [None]:
  # Run this first
  !pip install -q piqa lpips kornia gdown

  import torch
  import torch.nn as nn
  import torch.nn.functional as F
  import torchvision.transforms as T
  from torch.utils.data import Dataset, DataLoader
  from pathlib import Path
  from PIL import Image
  import matplotlib.pyplot as plt
  import numpy as np
  from tqdm import tqdm
  import random
  import os

  torch.manual_seed(42)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/drive/MyDrive/lol_dataset /content/

In [None]:
from pathlib import Path

low_dir  = Path("/content/lol_dataset/our485/low")
high_dir = Path("/content/lol_dataset/our485/high")

print("First 20 files in low/ : ", sorted(low_dir.glob("*.png"))[:20])
print("First 20 files in high/: ", sorted(high_dir.glob("*.png"))[:20])

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

# Paths after manual unzip (exact structure from Kaggle)
low_path = "/content/lol_dataset/our485/low/75.png"
high_path = "/content/lol_dataset/our485/high/75.png"

low_img = Image.open(low_path)
high_img = Image.open(high_path)

plt.figure(figsize=(14, 7))
plt.subplot(1, 2, 1)
plt.title("Low-Light Input (Very Dark)", fontsize=18, color="red")
plt.imshow(low_img)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Ground Truth (Normal Light)", fontsize=18, color="green")
plt.imshow(high_img)
plt.axis("off")

plt.tight_layout()
plt.show()

print("Success! Dataset loaded ‚Äì see the dark vs. bright pair above.")

In [None]:
import torch, torch.nn as nn, torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

class LOLDataset(Dataset):
    def __init__(self, root="/content/lol_dataset", train=True, crop_size=256):
        self.train = train
        self.crop_size = crop_size
        low_dir  = Path(root) / "our485" / "low"
        high_dir = Path(root) / "our485" / "high"
        self.lows  = sorted(low_dir.glob("*.png"))
        self.highs = sorted(high_dir.glob("*.png"))

    def __len__(self): return len(self.lows)

    def __getitem__(self, idx):
        low  = Image.open(self.lows[idx]).convert("RGB")
        high = Image.open(self.highs[idx]).convert("RGB")

        if self.train:
            i, j, h, w = T.RandomCrop.get_params(low, (self.crop_size, self.crop_size))
            low  = low.crop((j, i, j+w, i+h))
            high = high.crop((j, i, j+w, i+h))
            if random.random() > 0.5:
                low  = low.transpose(Image.FLIP_LEFT_RIGHT)
                high = high.transpose(Image.FLIP_LEFT_RIGHT)

        return T.ToTensor()(low), T.ToTensor()(high)

# üìå 2. Supervised Loss Function Comparisons

I started with the U-Net architecture, as it is well-suited for pixel-level image-to-image tasks like low-light enhancement due to its encoder-decoder structure and skip connections that preserve fine details.

To determine the most effective loss function for generating bright, natural-looking enhanced images that closely match the ground truth, I systematically compared four common reconstruction losses:

- **L1 Loss**: Simple pixel-wise absolute difference; produced decent results but often led to slight blurring and inconsistent brightness.
- **SSIM Loss**: Focuses on structural similarity; tended to collapse to overly dark outputs, failing to adequately boost illumination.
- **L1 + SSIM Combination**: Aimed to balance pixel accuracy and perceptual quality, but training was unstable, with SSIM dominating and resulting in darkened images.
- **Charbonnier Loss**: A robust differentiable variant of L1 (handles outliers better); achieved the best stability, highest PSNR/SSIM scores, and visually superior results with accurate brightness, color fidelity, and detail preservation.

Charbonnier loss was selected as the optimal choice for the supervised U-Net baseline.

In [None]:
from piqa import SSIM

class L1Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enhanced, target):
        return F.l1_loss(enhanced, target)

class SSIMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ssim = SSIM()

    def forward(self, enhanced, target):
        return 1 - self.ssim(enhanced, target)

class ComboLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = L1Loss()
        self.ssim_loss = SSIMLoss()

    def forward(self, enhanced, target):
        return 0.8 * self.l1(enhanced, target) + 0.2 * self.ssim_loss(enhanced, target)

class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps

    def forward(self, enhanced, target):
        diff = enhanced - target
        return torch.mean(torch.sqrt(diff * diff + self.eps * self.eps))

# List of losses to experiment with
# Explicitly move ALL models to the correct device to avoid errors
losses_to_test = {
    'L1': L1Loss().to(device),
    'SSIM': SSIMLoss().to(device),
    'Combo': ComboLoss().to(device),
    'Charbonnier': CharbonnierLoss().to(device)
}

In [None]:
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# These settings fully utilize the A100 (40‚Äì80 GB) or V100 (32 GB) you now have
MODEL_DEPTH  = 6
BATCH_SIZE   = 32
CROP_SIZE    = 320
EPOCHS       = 50
LR           = 2e-4
SHOW_EVERY   = 5
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

# Optional but highly recommended: Mixed Precision (cuts memory ~50%, speeds up ~2√ó)
# Updated to use torch.amp to avoid deprecation warnings

# =============================================================================
# üß† MODEL ARCHITECTURE (The "Brain")
# This is the U-Net structure required to run the training.
# It is named 'UNet' here as the model variable name.
# =============================================================================
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, depth=6):
        super(UNet, self).__init__()
        self.depth = depth
        self.enc_layers = nn.ModuleList()
        self.dec_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()
        self.pool = nn.MaxPool2d(2)

        # Encoder
        ch = 32
        for _ in range(depth):
            self.enc_layers.append(self.conv_block(in_channels, ch))
            in_channels = ch
            ch *= 2

        # Bottleneck
        self.bottleneck = self.conv_block(in_channels, ch)

        # Decoder
        for _ in range(depth):
            self.up_layers.append(nn.ConvTranspose2d(ch, ch//2, kernel_size=2, stride=2))
            ch //= 2
            self.dec_layers.append(self.conv_block(ch*2, ch))

        self.final = nn.Conv2d(32, out_channels, kernel_size=1)

    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        skips = []
        # Encoder
        for layer in self.enc_layers:
            x = layer(x)
            skips.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        # Decoder
        for i in range(self.depth):
            x = self.up_layers[i](x)
            # Handle cropping if sizes don't match exactly due to odd dimensions
            if x.shape != skips[-(i+1)].shape:
                x = T.Resize(skips[-(i+1)].shape[2:])(x)
            x = torch.cat([x, skips[-(i+1)]], dim=1)
            x = self.dec_layers[i](x)

        # === FIX: Apply Sigmoid to force output to [0, 1] range ===
        return torch.sigmoid(self.final(x))

# =============================================================================
# TRAINING FUNCTION
# =============================================================================
def train_model(loss_name, criterion,
                epochs=EPOCHS, batch_size=BATCH_SIZE,
                crop_size=CROP_SIZE, lr=LR, model_depth=MODEL_DEPTH, show_every=SHOW_EVERY):

    train_dataset = LOLDataset(train=True, crop_size=crop_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True, prefetch_factor=2)

    # We instantiate the model here
    model = UNet(depth=model_depth).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    # Updated AMP Scaler
    scaler = torch.amp.GradScaler('cuda')

    losses = []
    for epoch in range(1, epochs+1):
        model.train()
        epoch_loss = 0.0

        for low, high in train_loader:
            low, high = low.to(device), high.to(device)
            optimizer.zero_grad()

            # Updated autocast context
            with torch.amp.autocast('cuda'):
                enhanced = model(low)
                loss = criterion(enhanced, high)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # stable training
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)

        # Only print and show results every 'show_every' epochs
        if epoch % show_every == 0:
            print(f"Epoch {epoch:2d}/{epochs} ({loss_name}) ‚Üí Loss: {avg_loss:.6f}")

            model.eval()
            with torch.no_grad():
                idx = random.randint(0, len(all_low_paths)-1)
                low_img = Image.open(all_low_paths[idx]).convert("RGB")
                high_img = Image.open(all_high_paths[idx]).convert("RGB") # Load Ground Truth

                input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)
                with torch.amp.autocast('cuda'):
                    enhanced_tensor = model(input_tensor).clamp(0,1).squeeze(0).cpu()
                enhanced_img = T.ToPILImage()(enhanced_tensor)

                plt.figure(figsize=(15,5))
                plt.subplot(1,3,1); plt.title("Low-light Input"); plt.imshow(low_img); plt.axis('off')
                plt.subplot(1,3,2); plt.title("Ground Truth");    plt.imshow(high_img); plt.axis('off')
                plt.subplot(1,3,3); plt.title(f"Enhanced ‚Äì {loss_name}"); plt.imshow(enhanced_img); plt.axis('off')
                plt.tight_layout()
                plt.show()
            model.train()

    # Save
    save_name = f"zero_dce_{loss_name.lower()}_pro.pth"
    torch.save(model.state_dict(), save_name)
    print(f"Model saved: {save_name}")

    plt.figure(figsize=(10,4))
    plt.plot(losses); plt.title(f"Training Loss ({loss_name})"); plt.grid(); plt.show()

    return model, losses

In [None]:
# Define full paths once (run this once)
all_low_paths  = sorted((Path("/content/lol_dataset") / "our485" / "low").glob("*.png"))
all_high_paths = sorted((Path("/content/lol_dataset") / "our485" / "high").glob("*.png"))

# Train all four losses
trained_models = {}
all_train_losses = {}

for loss_name, criterion in losses_to_test.items():
    print(f"\n" + "="*60)
    print(f"STARTING TRAINING WITH {loss_name.upper()} LOSS")
    print("="*60)
    model, losses = train_model(
        loss_name=loss_name,
        criterion=criterion,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        crop_size=CROP_SIZE,
        lr=1e-4,
        model_depth=MODEL_DEPTH,
        show_every=SHOW_EVERY
    )
    trained_models[loss_name] = model
    all_train_losses[loss_name] = losses

## Evaluation Function and Comparison

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import pandas as pd

def evaluate_models(models_dict):
    test_dataset = LOLDataset(train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    results = {name: {'PSNR': [], 'SSIM': []} for name in models_dict}

    for low, high in test_loader:
        low, high = low.to(device), high.to(device)
        high_cpu = high.squeeze(0).cpu().numpy().transpose(1,2,0)  # For PSNR

        for name, model in models_dict.items():
            model.eval()
            with torch.no_grad():
                enhanced = model(low).clamp(0,1).squeeze(0)
                enhanced_cpu = enhanced.cpu().numpy().transpose(1,2,0)

                psnr_val = psnr(high_cpu, enhanced_cpu, data_range=1.0)
                # Ensure SSIM module is on the same device as tensors
                ssim_val = SSIM().to(device)(enhanced.unsqueeze(0), high).item()

                results[name]['PSNR'].append(psnr_val)
                results[name]['SSIM'].append(ssim_val)

    # Average metrics
    avg_results = {name: {'Avg PSNR': np.mean(vals['PSNR']), 'Avg SSIM': np.mean(vals['SSIM'])} for name, vals in results.items()}

    # Display as table
    df = pd.DataFrame(avg_results).T
    print("Evaluation Results on Test Set:")
    display(df)

    # Find best loss
    best_psnr = df['Avg PSNR'].idxmax()
    best_ssim = df['Avg SSIM'].idxmax()
    print(f"Best by PSNR: {best_psnr}")
    print(f"Best by SSIM: {best_ssim}")

    return df

# Run evaluation
eval_df = evaluate_models(trained_models)

In [None]:
# Pick a random test image and show enhancements from all models
test_low_paths = sorted((Path("/content/lol_dataset/eval15/low")).glob("*.png"))
test_high_paths = sorted((Path("/content/lol_dataset/eval15/high")).glob("*.png"))

if len(test_low_paths) == 0:
    print("‚ö†Ô∏è Error: No images found in /content/lol_dataset/eval15/low")
    print("Please check if the dataset was copied correctly. You might need to re-run the dataset setup cell.")
else:
    # Use dynamic length to avoid IndexError
    idx = random.randint(0, len(test_low_paths) - 1)
    print(f"Testing on image index: {idx} / {len(test_low_paths)-1}")

    low_img = Image.open(test_low_paths[idx]).convert("RGB")
    high_img = Image.open(test_high_paths[idx]).convert("RGB")
    input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)

    plt.figure(figsize=(20, 5))
    plt.subplot(1, len(trained_models)+2, 1); plt.title("Low-light Input"); plt.imshow(low_img); plt.axis('off')
    plt.subplot(1, len(trained_models)+2, 2); plt.title("Ground Truth"); plt.imshow(high_img); plt.axis('off')

    for i, (name, model) in enumerate(trained_models.items(), start=3):
        model.eval()
        with torch.no_grad():
            # Ensure output is clamped to [0,1] to avoid display errors
            enhanced = model(input_tensor).clamp(0,1).squeeze(0).cpu()
            enhanced_img = T.ToPILImage()(enhanced)
        plt.subplot(1, len(trained_models)+2, i); plt.title(f"Enhanced ({name})"); plt.imshow(enhanced_img); plt.axis('off')

    plt.suptitle("Comparison on a Test Image")
    plt.show()

## Quantitative Evaluation of Supervised Loss Functions (Test Set)

Direct comparison of training/validation loss values is not meaningful, as each loss optimizes a different objective.

Instead, we use standard reference-based metrics to fairly assess reconstruction quality:

- **PSNR** (Peak Signal-to-Noise Ratio): Higher ‚Üí better pixel-level accuracy  
- **SSIM** (Structural Similarity Index): Higher (closer to 1.0) ‚Üí better perceptual quality

### Average Results Across the Test Set

| Loss Function       | Avg PSNR   | Avg SSIM   |
|---------------------|------------|------------|
| L1                  | 19.594     | 0.766      |
| SSIM                | 10.783     | 0.504      |
| L1 + SSIM Combo     | 10.752     | 0.530      |
| Charbonnier         | 19.534     | **0.769**  |

### Key Insights
- **L1** and **Charbonnier** clearly outperform the others, achieving ~19.5 dB PSNR ‚Äî strong results for low-light enhancement on the LOL dataset.
- **Charbonnier** delivers the highest average SSIM (0.769) and, in approximately 70% of test images (by manual inspection), outperforms L1 in **both PSNR and SSIM** simultaneously.
- **SSIM-only** and **Combo** suffer severe performance degradation, confirming their tendency to collapse to dark outputs when used without a strong pixel-level intensity term.

### Conclusion
While L1 achieves a slightly higher average PSNR, **Charbonnier loss** provides the best overall performance. It consistently produces more natural-looking results with superior perceptual quality (highest SSIM) and wins both metrics on the majority of test images. Therefore, Charbonnier loss was selected as the primary supervised loss for the U-Net baseline in subsequent comparisons with ZeroDCE variants.

# üìå 3. Unsupervised ZeroDCE Baseline

The ZeroDCE model was implemented based on the original architecture and non-reference loss from the official repository.

Minor modifications were made for better training monitoring (progress logging and visualization).

Hyperparameters were adopted from the original paper/reference implementation, as extensive testing showed they provided optimal convergence for this unsupervised setting on the LOL dataset.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#import pytorch_colors as colors
import numpy as np

# Confirmed: This is the ZeroDCE loss component code from GitHub.
class ZeroDCE_Unsupervised_Model(nn.Module):

	def __init__(self):
		super(ZeroDCE_Unsupervised_Model, self).__init__()

		self.relu = nn.ReLU(inplace=True)

		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)

		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)



	def forward(self, x):

		x1 = self.relu(self.e_conv1(x))
		# p1 = self.maxpool(x1)
		x2 = self.relu(self.e_conv2(x1))
		# p2 = self.maxpool(x2)
		x3 = self.relu(self.e_conv3(x2))
		# p3 = self.maxpool(x3)
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		# x5 = self.upsample(x5)
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
		x = x + r6*(torch.pow(x,2)-x)
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

print("ZeroDCE_Unsupervised_Model class defined with adjusted weights for better quality.")


In [None]:
# Confirmed: These are the helper loss modules (L_color, L_spa, L_exp, L_TV, Sa_Loss, perception_loss)
# copied directly from the ZeroDCE GitHub implementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import numpy as np


class L_color(nn.Module):

    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x ):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)


        return k.mean() # Added .mean() to ensure scalar output


class L_spa(nn.Module):

    def __init__(self):
        super(L_spa, self).__init__()
        # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0) # Changed .cuda() to .to(device)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0) # Changed .cuda() to .to(device)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0) # Changed .cuda() to .to(device)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).to(device).unsqueeze(0).unsqueeze(0) # Changed .cuda() to .to(device)
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)
    def forward(self, org , enhance ):
        b,c,h,w = org.shape

        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)

        org_pool =  self.pool(org_mean)
        enhance_pool = self.pool(enhance_mean)

        weight_diff =torch.max(torch.FloatTensor([1]).to(device) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).to(device),torch.FloatTensor([0]).to(device)),torch.FloatTensor([0.5]).to(device)) # Changed .cuda() to .to(device)
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).to(device)) ,enhance_pool-org_pool) # Changed .cuda() to .to(device)


        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)

        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)

        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)
        # E = 25*(D_left + D_right + D_up +D_down)

        return E.mean() # Added .mean() to return a scalar
class L_exp(nn.Module):

    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
    def forward(self, x ):

        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)

        d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).to(device),2)) # Changed .cuda() to .to(device)
        return d

class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

class Sa_Loss(nn.Module):
    def __init__(self):
        super(Sa_Loss, self).__init__()
        # print(1)
    def forward(self, x ):
        # self.grad = np.ones(x.shape,dtype=np.float32)
        b,c,h,w = x.shape
        # x_de = x.cpu().detach().numpy()
        r,g,b = torch.split(x , 1, dim=1)
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Dr = r-mr
        Dg = g-mg
        Db = b-mb
        k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
        # print(k)


        k = torch.mean(k)
        return k

class perception_loss(nn.Module):
    def __init__(self):
        super(perception_loss, self).__init__()
        features = vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential()
        self.to_relu_2_2 = nn.Sequential()
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()

        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])

        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h
        # out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
        return h_relu_4_3

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import time
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from piqa import SSIM
from skimage.metrics import peak_signal_noise_ratio as psnr
import pandas as pd
import matplotlib.pyplot as plt
import random

# =============================================================================
# HYPERPARAMETERS (Exactly from the GitHub code you provided)
# =============================================================================
lowlight_images_path = "data/train_data/" # (We replace this with LOLDataset)
lr = 0.0001 # CHANGED from 0.00005 to 0.0001
weight_decay = 0.0001
grad_clip_norm = 0.1
num_epochs = 200 # CHANGED from 100 to 200
train_batch_size = 8
val_batch_size = 4
num_workers = 4
display_iter = 10
snapshot_iter = 10
snapshots_folder = "snapshots/"
load_pretrain = False
pretrain_dir = "snapshots/Epoch99.pth"

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# =============================================================================
# TRAINING LOGIC
# =============================================================================
def train_github_style():
    # 1. Setup Device
    # os.environ['CUDA_VISIBLE_DEVICES']='0' # handled by Colab runtime

    # 2. Model Setup
    # DCE_net = model.enhance_net_nopool().cuda()
    DCE_net = ZeroDCE_Unsupervised_Model().to(device)
    DCE_net.apply(weights_init)

    # if config.load_pretrain == True:
    #     DCE_net.load_state_dict(torch.load(config.pretrain_dir))

    # 3. Data Loader Setup
    # train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)
    train_dataset = LOLDataset(train=True, crop_size=256) # Assuming 256 crop for 8 batch size fit
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)

    # Validation loader for monitoring (not in original train.py but needed for us to see progress)
    val_dataset = LOLDataset(train=False)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # 4. Loss Components Setup
    L_color_module = L_color().to(device)
    L_spa_module = L_spa().to(device)
    L_exp_module = L_exp(16, 0.6).to(device)
    L_TV_module = L_TV().to(device)
    # Myloss.L_color(), Myloss.L_spa(), etc.

    # 5. Optimizer Setup
    optimizer = torch.optim.Adam(DCE_net.parameters(), lr=lr, weight_decay=weight_decay)

    DCE_net.train()

    # Stats tracking
    history = {'epoch': [], 'loss': [], 'psnr': [], 'ssim': []}
    cal_ssim = SSIM().to(device)

    print("Starting training with GitHub parameters...")
    print(f"Epochs: {num_epochs}, Batch: {train_batch_size}, LR: {lr}")

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for iteration, (img_lowlight, _) in enumerate(train_loader): # LOLDataset returns (low, high)
            img_lowlight = img_lowlight.to(device)

            # Forward pass
            enhanced_image_1, enhanced_image, A = DCE_net(img_lowlight)

            # Loss Calculation (Exact weights from GitHub)
            Loss_TV = 200 * L_TV_module(A)
            loss_spa = torch.mean(L_spa_module(img_lowlight, enhanced_image))
            loss_col = 5 * torch.mean(L_color_module(enhanced_image))
            loss_exp = 10 * torch.mean(L_exp_module(enhanced_image))

            loss = Loss_TV + loss_spa + loss_col + loss_exp

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(DCE_net.parameters(), grad_clip_norm)
            optimizer.step()

            epoch_loss += loss.item()

            # Display iter logic from original code (simplified for notebook)
            if ((iteration+1) % display_iter) == 0:
                # print("Loss at iteration", iteration+1, ":", loss.item())
                pass

        # End of Epoch Evaluation
        if ((epoch+1) % snapshot_iter) == 0 or epoch == 0:
            avg_loss = epoch_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}")

            # Validation
            DCE_net.eval()
            psnr_val_list, ssim_val_list = [], []
            with torch.no_grad():
                for v_low, v_high in val_loader:
                    v_low, v_high = v_low.to(device), v_high.to(device)
                    _, v_enhanced, _ = DCE_net(v_low)
                    v_enhanced = v_enhanced.clamp(0,1)
                    ssim_val_list.append(cal_ssim(v_enhanced, v_high).item())
                    v_enhanced_np = v_enhanced.squeeze(0).cpu().numpy().transpose(1,2,0)
                    v_high_np = v_high.squeeze(0).cpu().numpy().transpose(1,2,0)
                    psnr_val_list.append(psnr(v_high_np, v_enhanced_np, data_range=1.0))

            avg_psnr = sum(psnr_val_list) / len(psnr_val_list)
            avg_ssim = sum(ssim_val_list) / len(ssim_val_list)
            history['epoch'].append(epoch+1)
            history['loss'].append(avg_loss)
            history['psnr'].append(avg_psnr)
            history['ssim'].append(avg_ssim)

            print(f"Epoch {epoch+1} Test PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

            # Visualize one random image from validation set
            with torch.no_grad():
                rand_idx = random.randint(0, len(val_dataset) - 1)
                v_low, v_high = val_dataset[rand_idx]

                # Add batch dimension and move to device
                v_low = v_low.unsqueeze(0).to(device)

                _, v_enhanced, _ = DCE_net(v_low)
                v_enhanced = v_enhanced.clamp(0,1).cpu()

                plt.figure(figsize=(15,5))
                plt.subplot(1,3,1); plt.imshow(T.ToPILImage()(v_low.squeeze(0))); plt.title('Input')
                plt.subplot(1,3,2); plt.imshow(T.ToPILImage()(v_enhanced.squeeze(0))); plt.title('Enhanced')
                plt.subplot(1,3,3); plt.imshow(T.ToPILImage()(v_high)); plt.title('Ground Truth')
                plt.show()

            DCE_net.train()

            # Snapshot saving
            # if not os.path.exists(snapshots_folder):
            #     os.mkdir(snapshots_folder)
            # torch.save(DCE_net.state_dict(), snapshots_folder + "Epoch" + str(epoch) + '.pth')

    # Plotting Training History
    plt.figure(figsize=(18, 5))

    plt.subplot(1, 3, 1)
    plt.plot(history['epoch'], history['loss'], label='Loss', color='red')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history['epoch'], history['psnr'], label='PSNR', color='green')
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(history['epoch'], history['ssim'], label='SSIM', color='blue')
    plt.title('Validation SSIM')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Return trained model and history
    return DCE_net, history

# Run the training
unsupervised_model, unsupervised_history = train_github_style()

# Save model to global dict for comparison later
trained_models['ZeroDCE_Unsupervised'] = unsupervised_model


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from piqa import SSIM
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

def evaluate_and_visualize_model(model, history, model_name):
    print(f"\n{'='*60}")
    print(f"EVALUATING AND VISUALIZING: {model_name}")
    print(f"{'='*60}")

    # 1. Evaluate on test set
    test_dataset = LOLDataset(train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    psnr_values = []
    ssim_values = []

    model.eval()
    cal_ssim = SSIM().to(device) # Ensure SSIM module is on the correct device
    with torch.no_grad():
        for i, (low, high) in enumerate(test_loader):
            low, high = low.to(device), high.to(device)

            # ZeroDCE_Unsupervised_Model returns (enhanced_image_1, enhanced_image, A)
            _, enhanced, _ = model(low)
            enhanced = enhanced.clamp(0, 1)

            # For PSNR calculation
            high_np = high.squeeze(0).cpu().numpy().transpose(1, 2, 0)
            enhanced_np = enhanced.squeeze(0).cpu().numpy().transpose(1, 2, 0)

            psnr_values.append(psnr(high_np, enhanced_np, data_range=1.0))
            ssim_values.append(cal_ssim(enhanced, high).item())

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)

    print(f"Average PSNR on Test Set: {avg_psnr:.2f}")
    print(f"Average SSIM on Test Set: {avg_ssim:.4f}")

    # 2. Re-display training history plots
    if history and history['epoch']:
        plt.figure(figsize=(18, 5))

        plt.subplot(1, 3, 1)
        plt.plot(history['epoch'], history['loss'], label='Loss', color='red')
        plt.title(f'{model_name} - Training Loss')
        plt.xlabel('Epoch')
        plt.grid(True)
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.plot(history['epoch'], history['psnr'], label='PSNR', color='green')
        plt.title(f'{model_name} - Validation PSNR')
        plt.xlabel('Epoch')
        plt.grid(True)
        plt.legend()

        plt.subplot(1, 3, 3)
        plt.plot(history['epoch'], history['ssim'], label='SSIM', color='blue')
        plt.title(f'{model_name} - Validation SSIM')
        plt.xlabel('Epoch')
        plt.grid(True)
        plt.legend()

        plt.tight_layout()
        plt.show()
    else:
        print("No training history to display.")

    # 3. Show a sample enhanced image
    print(f"\nDisplaying a sample enhanced image for {model_name}...")
    test_low_paths = sorted((Path("/content/lol_dataset/eval15/low")).glob("*.png"))
    test_high_paths = sorted((Path("/content/lol_dataset/eval15/high")).glob("*.png"))

    if len(test_low_paths) == 0:
        print("‚ö†Ô∏è Error: No images found in /content/lol_dataset/eval15/low")
        print("Please check if the dataset was copied correctly.")
        return

    idx = random.randint(0, len(test_low_paths) - 1)
    low_img_path = test_low_paths[idx]
    high_img_path = test_high_paths[idx]

    low_img = Image.open(low_img_path).convert("RGB")
    high_img = Image.open(high_img_path).convert("RGB")
    input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)

    with torch.no_grad():
        _, enhanced_tensor, _ = model(input_tensor)
        enhanced_tensor = enhanced_tensor.clamp(0, 1).squeeze(0).cpu()
    enhanced_img = T.ToPILImage()(enhanced_tensor)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1); plt.title("Low-light Input"); plt.imshow(low_img); plt.axis('off')
    plt.subplot(1, 3, 2); plt.title("Ground Truth"); plt.imshow(high_img); plt.axis('off')
    plt.subplot(1, 3, 3); plt.title(f"Enhanced ({model_name})"); plt.imshow(enhanced_img); plt.axis('off')
    plt.suptitle("Sample Image Comparison")
    plt.show()

# Call the function for the newly trained unsupervised model
evaluate_and_visualize_model(
    unsupervised_model,
    unsupervised_history,
    'ZeroDCE_Unsupervised' # Updated model name for clarity
)

## üìå 3. Unsupervised ZeroDCE Baseline

The original ZeroDCE model relies on a **non-reference (unsupervised) loss** that does not require paired ground-truth images. Instead, it uses carefully designed priors (spatial consistency, exposure control, illumination smoothness) to guide curve-based enhancement.

### Training Observations
- Loss decreased rapidly in the first 25 epochs, accompanied by significant gains in PSNR and SSIM.
- This correlation confirms that the unsupervised objective effectively aligns with perceptual quality improvements, even without direct supervision.
- Final test-set performance typically reached **PSNR > 15 dB** and **SSIM > 0.45**, which is respectable for a fully unsupervised approach on the challenging LOL dataset.

### Key Insights
While visual outputs often appear plausible (good brightness recovery, reduced noise), quantitative metrics (PSNR/SSIM) remain lower than supervised methods. This is expected: supervised models directly optimize toward ground-truth pixels, whereas ZeroDCE relies solely on hand-crafted priors without access to high-light references during training.

This unsupervised baseline highlights the strengths of ZeroDCE's lightweight curve estimation while establishing a reference point for the subsequent hybrid experiments.

## Initial Comparison: Supervised vs. Unsupervised Approaches

In [None]:
# =============================================================================
# EVALUATION AND VISUAL COMPARISON OF TWO MODELS
# =============================================================================

import torch
import numpy as np
import pandas as pd
from skimage.metrics import peak_signal_noise_ratio as psnr
from piqa import SSIM
import os

# 1. Load the UNet model with Charbonnier loss
print("\n" + "="*60)
print("LOADING UNet MODEL WITH CHARBONNIER LOSS")
print("="*60)

# Create a clean dictionary JUST for this comparison
comparison_models = {}

# Ensure the UNet class is defined
charbonnier_model = UNet(depth=MODEL_DEPTH).to(device)
charbonnier_model_path = "zero_dce_charbonnier_pro.pth"

# Load UNet_Charbonnier
if os.path.exists(charbonnier_model_path):
    charbonnier_model.load_state_dict(torch.load(charbonnier_model_path))
    charbonnier_model.eval()
    # Add to our new comparison dictionary
    comparison_models['UNet_Charbonnier'] = charbonnier_model
    print(f"Loaded UNet_Charbonnier model from {charbonnier_model_path}")
else:
    print(f"‚ö†Ô∏è Error: {charbonnier_model_path} not found.")

# 2. Add ZeroDCE_Unsupervised
# Check trained_models first, then fallback to local variable 'unsupervised_model'
if 'trained_models' in globals() and 'ZeroDCE_Unsupervised' in trained_models:
    comparison_models['ZeroDCE_Unsupervised'] = trained_models['ZeroDCE_Unsupervised']
    print("Added ZeroDCE_Unsupervised from trained_models dictionary.")
elif 'unsupervised_model' in globals():
    comparison_models['ZeroDCE_Unsupervised'] = unsupervised_model
    print("Added ZeroDCE_Unsupervised from 'unsupervised_model' variable.")
else:
    print("‚ö†Ô∏è Warning: 'ZeroDCE_Unsupervised' model not found. Please train it first.")

print("\n" + "="*60)
print("COMPARING UNet_Charbonnier AND ZeroDCE_Unsupervised MODELS")
print("="*60)

# --- Custom Evaluation Loop (Self-Contained) ---
if not comparison_models:
    print("‚ö†Ô∏è Error: No models available for comparison.")
else:
    # Setup data
    test_dataset = LOLDataset(train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    results = {name: {'PSNR': [], 'SSIM': []} for name in comparison_models}
    cal_ssim = SSIM().to(device)

    print(f"Evaluating {len(comparison_models)} models on test set...")

    for low, high in test_loader:
        low, high = low.to(device), high.to(device)
        high_cpu = high.squeeze(0).cpu().numpy().transpose(1,2,0)

        for name, model in comparison_models.items():
            model.eval()
            with torch.no_grad():
                # Get output
                output = model(low)

                # Handle tuple output for ZeroDCE vs Tensor for UNet
                if isinstance(output, tuple):
                    _, enhanced_out, _ = output
                    enhanced = enhanced_out.clamp(0,1).squeeze(0)
                else:
                    enhanced = output.clamp(0,1).squeeze(0)

                # Metrics
                enhanced_cpu = enhanced.cpu().numpy().transpose(1,2,0)
                psnr_val = psnr(high_cpu, enhanced_cpu, data_range=1.0)
                ssim_val = cal_ssim(enhanced.unsqueeze(0), high).item()

                results[name]['PSNR'].append(psnr_val)
                results[name]['SSIM'].append(ssim_val)

    # Average metrics
    avg_results = {name: {'Avg PSNR': np.mean(vals['PSNR']), 'Avg SSIM': np.mean(vals['SSIM'])} for name, vals in results.items()}
    df = pd.DataFrame(avg_results).T
    print("\nEvaluation Results:")
    display(df)

    best_psnr = df['Avg PSNR'].idxmax()
    best_ssim = df['Avg SSIM'].idxmax()
    print(f"Best by PSNR: {best_psnr}")
    print(f"Best by SSIM: {best_ssim}")

    # --- Visual Comparison ---
    print("\n" + "="*60)
    print("VISUAL COMPARISON OF MODELS")
    print("="*60)

    test_low_paths_eval15 = sorted((Path("/content/lol_dataset/eval15/low")).glob("*.png"))
    test_high_paths_eval15 = sorted((Path("/content/lol_dataset/eval15/high")).glob("*.png"))

    if len(test_low_paths_eval15) == 0:
        print("‚ö†Ô∏è Error: No images found in /content/lol_dataset/eval15/low.")
    else:
        idx = random.randint(0, len(test_low_paths_eval15) - 1)
        print(f"Testing on image index: {idx} / {len(test_low_paths_eval15)-1}")

        low_img = Image.open(test_low_paths_eval15[idx]).convert("RGB")
        high_img = Image.open(test_high_paths_eval15[idx]).convert("RGB")
        input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)

        plt.figure(figsize=(20, 5))
        total_plots = len(comparison_models) + 2

        plt.subplot(1, total_plots, 1); plt.title("Low-light Input"); plt.imshow(low_img); plt.axis('off')
        plt.subplot(1, total_plots, 2); plt.title("Ground Truth"); plt.imshow(high_img); plt.axis('off')

        plot_idx = 3
        for name, model in comparison_models.items():
            model.eval()
            with torch.no_grad():
                output = model(input_tensor)
                if isinstance(output, tuple):
                     _, enhanced_out, _ = output
                     enhanced = enhanced_out.clamp(0,1).squeeze(0).cpu()
                else:
                     enhanced = output.clamp(0,1).squeeze(0).cpu()

                enhanced_img = T.ToPILImage()(enhanced)
                plt.subplot(1, total_plots, plot_idx)
                plt.title(f"Enhanced ({name})")
                plt.imshow(enhanced_img)
                plt.axis('off')
                plot_idx += 1

        plt.suptitle("Comparison on a Test Image (eval15 Dataset)")
        plt.tight_layout()
        plt.show()

## Analysis of the two models:

The supervised U-Net (with Charbonnier loss) leverages paired ground-truth images, enabling direct pixel-level optimization. This results in superior reconstruction accuracy, natural brightness recovery, and high PSNR/SSIM scores.

In contrast, the original ZeroDCE operates fully unsupervised, relying solely on non-reference priors (spatial consistency, exposure control, illumination smoothness). While effective at brightening images without paired data, it exhibits lower quantitative performance and occasional color shifts or residual noise.

This gap motivated the next phase: exploring whether incorporating supervised signals into the ZeroDCE framework could bridge the performance difference while retaining its architectural advantages.

# üìå 4. Hybrid and Supervised ZeroDCE Variants

To investigate the impact of supervision on the ZeroDCE architecture, two variants were developed:

- **Hybrid ZeroDCE (Dual Loss)**:  
  Combines the original unsupervised ZeroDCE loss with Charbonnier loss. This approach retains perceptual priors while adding direct reconstruction guidance.

- **Supervised ZeroDCE (Charbonnier Only)**:  
  Replaces the ZeroDCE loss entirely with Charbonnier loss, converting the model into a fully supervised learner.

These experiments aim to isolate the contributions of architecture vs. training objective and explore hybrid strategies for improved stability and quality.

# MODEL 1: Hybrid_ZeroDCE_Two_Losses

This model uses ZeroDCE architecutre with ZeroDCE loss and Charbonnier loss

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from piqa import SSIM
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

# =============================================================================
# üìå 4. Modify the original unsupervised ZeroDCE model to make it supervised (TUNING)
# =============================================================================

# 1. Define a new class named ZeroDCE_Supervised_Architecture
class ZeroDCE_Supervised_Architecture(ZeroDCE_Unsupervised_Model):
    def __init__(self):
        super(ZeroDCE_Supervised_Architecture, self).__init__()

    def forward(self, x):
        _, enhanced_image, _ = super().forward(x) # Call parent's forward, take second return
        return enhanced_image

# Define the missing L_color_supervised (adapted from your commented version to be supervised)
class L_color_supervised(nn.Module):
    def __init__(self):
        super(L_color_supervised, self).__init__()

    def forward(self, x, y):  # x: enhanced, y: ground truth
        # Compute color constancy for enhanced (as in original)
        r, g, b = torch.split(x, 1, dim=1)
        mean_rgb_x = torch.mean(x, [2, 3], keepdim=True)
        mr_x, mg_x, mb_x = torch.split(mean_rgb_x, 1, dim=1)
        Dr_x = r - mr_x
        Dg_x = g - mg_x
        Db_x = b - mb_x
        k_x = torch.pow(torch.pow(Dr_x, 2) + torch.pow(Dg_x, 2) + torch.pow(Db_x, 2) + 1e-8, 0.5)
        k_x = torch.mean(k_x)

        # Compute the same for ground truth (for supervision)
        r_y, g_y, b_y = torch.split(y, 1, dim=1)
        mean_rgb_y = torch.mean(y, [2, 3], keepdim=True)
        mr_y, mg_y, mb_y = torch.split(mean_rgb_y, 1, dim=1)
        Dr_y = r_y - mr_y
        Dg_y = g_y - mg_y
        Db_y = b_y - mb_y
        k_y = torch.pow(torch.pow(Dr_y, 2) + torch.pow(Db_y, 2) + torch.pow(Dg_y, 2) + 1e-8, 0.5)
        k_y = torch.mean(k_y)

        # Supervised: Encourage enhanced to match ground truth's color constancy (absolute diff)
        return torch.abs(k_x - k_y)

# Reuse CharbonnierLoss (already defined in your notebook)
class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (reused from your supervised U-Net section)"""
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        return torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))

# Hybrid Loss Class (as before, now with defined L_color_supervised and adjustable weights)
class ZeroDCELoss_Hybrid_Supervised(nn.Module):
    """Hybrid Supervised ZeroDCE Loss: Reuses original components + adds Charbonnier for direct supervision"""
    def __init__(self, recon_weight, percep_weight, spa_weight, tv_weight, col_weight, exp_weight, sa_weight):
        super(ZeroDCELoss_Hybrid_Supervised, self).__init__()
        # Store weights as instance attributes
        self.recon_weight = recon_weight
        self.percep_weight = percep_weight
        self.spa_weight = spa_weight
        self.tv_weight = tv_weight
        self.col_weight = col_weight
        self.exp_weight = exp_weight
        self.sa_weight = sa_weight

        # Reuse existing losses from your notebook
        self.L_color = L_color_supervised()  # Now defined above
        self.L_spatial = L_spa()             # Defined
        self.L_exp = L_exp(16, 0.6)          # Defined
        self.L_tv = L_TV()                   # Defined
        self.L_percep = perception_loss().to(device)  # Defined
        self.sa_loss = Sa_Loss()             # Defined
        self.charbonnier = CharbonnierLoss() # Reuse

    def forward(self, enhanced, high_img):
        # --- Reuse Original/Adapted ZeroDCE Terms (as priors) ---
        loss_percep = self.percep_weight * torch.mean(torch.abs(self.L_percep(enhanced) - self.L_percep(high_img)))
        loss_spa = self.spa_weight * self.L_spatial(high_img, enhanced)  # Supervised spatial
        loss_tv = self.tv_weight * self.L_tv(enhanced)
        loss_col = self.col_weight * self.L_color(enhanced, high_img)    # Now supervised
        loss_exp = self.exp_weight * self.L_exp(enhanced)                # Unsupervised exposure prior
        loss_sa = self.sa_weight * self.sa_loss(enhanced)                # Unsupervised saturation prior

        # --- New: Add Supervised Reconstruction Term (Charbonnier) ---
        loss_recon = self.recon_weight * self.charbonnier(enhanced, high_img)  # New weight for tuning

        # Total Hybrid Loss
        return loss_percep + loss_spa + loss_tv + loss_col + loss_exp + loss_sa + loss_recon

# =============================================================================
# TRAINING: Reuse most of the loop from Section 3, but with hybrid loss
# =============================================================================

# Define new hyperparameters for the hybrid loss (Modified for equal contribution)
recon_weight = 1.0 # Reduced from 100.0 to balance with priors
percep_weight = 0.28 # Scaled down from 1.0 (1.0 / 3.52)
spa_weight = 0.28    # Scaled down from 1.0 (1.0 / 3.52)
tv_weight = 0.28     # Scaled down from 1.0 (1.0 / 3.52)
col_weight = 0.14    # Scaled down from 0.5 (0.5 / 3.52)
exp_weight = 0.003   # Scaled down from 0.01 (0.01 / 3.52)
sa_weight = 0.003    # Scaled down from 0.01 (0.01 / 3.52)

# Adjust training hyperparameters for this model to match paper
LR = 0.0001 # CHANGED from 5e-5 to 0.0001
EPOCHS = 200
BATCH_SIZE = 8 # CHANGED from 16 to 8
CROP_SIZE = 256
SHOW_EVERY = 10

print("="*60)
print("STARTING TRAINING: SUPERVISED ZERODCE (HYBRID LOSS) - Balanced Contribution")
print(f"üîß CURRENT WEIGHTS: Recon={recon_weight}, Perceptual={percep_weight}, Spatial={spa_weight}, TV={tv_weight}, Color={col_weight}, Exposure={exp_weight}, SA={sa_weight}")
print(f"üîß TRAINING HYPERPARAMETERS: LR={LR}, EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}")
print("="*60)

# Reuse datasets/loaders from Section 3 (defined)
train_dataset = LOLDataset(train=True, crop_size=CROP_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True, prefetch_factor=2)

val_dataset = LOLDataset(train=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Instantiate the new hybrid criterion with the defined weights
hybrid_criterion = ZeroDCELoss_Hybrid_Supervised(recon_weight=recon_weight, percep_weight=percep_weight, spa_weight=spa_weight, tv_weight=tv_weight, col_weight=col_weight, exp_weight=exp_weight, sa_weight=sa_weight).to(device)

# Instantiate the new supervised ZeroDCE architecture
supervised_zerodce_model = ZeroDCE_Supervised_Architecture().to(device)

optimizer = torch.optim.AdamW(supervised_zerodce_model.parameters(), lr=LR, weight_decay=1e-5)
scaler = torch.amp.GradScaler('cuda')

# Reusing SSIM for validation metrics
cal_ssim_val = SSIM().to(device)

# Reuse tracking from Section 3
training_stats_supervised = {'epoch': [], 'loss': [], 'psnr': [], 'ssim': []}

for epoch in range(1, EPOCHS+1):
    supervised_zerodce_model.train()
    epoch_loss = 0.0

    for low, high in train_loader:
        low, high = low.to(device), high.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            enhanced = supervised_zerodce_model(low) # Now returns only the enhanced image
            loss = hybrid_criterion(enhanced, high)  # Now uses hybrid loss

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(supervised_zerodce_model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)

    if epoch % SHOW_EVERY == 0:
        supervised_zerodce_model.eval()
        psnr_list, ssim_list = [], []
        with torch.no_grad():
            for v_low, v_high in val_loader:
                v_low, v_high = v_low.to(device), v_high.to(device)
                v_enhanced = supervised_zerodce_model(v_low).clamp(0,1) # Model now returns only enhanced image
                ssim_list.append(cal_ssim_val(v_enhanced, v_high).item())
                v_enhanced_np = v_enhanced.squeeze(0).cpu().numpy().transpose(1,2,0)
                v_high_np = v_high.squeeze(0).cpu().numpy().transpose(1,2,0)
                psnr_list.append(psnr(v_high_np, v_enhanced_np, data_range=1.0))

        avg_psnr = sum(psnr_list) / len(psnr_list)
        avg_ssim = sum(ssim_list) / len(ssim_list)

        training_stats_supervised['epoch'].append(epoch)
        training_stats_supervised['loss'].append(avg_loss)
        training_stats_supervised['psnr'].append(avg_psnr)
        training_stats_supervised['ssim'].append(avg_ssim)

        print(f"Epoch {epoch:2d}/{EPOCHS} | Loss: {avg_loss:.4f} | Test PSNR: {avg_psnr:.2f} | Test SSIM: {avg_ssim:.4f}")

        # Reuse visualization from Section 3
        with torch.no_grad():
            idx = random.randint(0, len(all_low_paths)-1)
            low_img = Image.open(all_low_paths[idx]).convert("RGB")
            high_img = Image.open(all_high_paths[idx]).convert("RGB")
            input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)
            enhanced_tensor = supervised_zerodce_model(input_tensor).clamp(0,1).squeeze(0).cpu()
            enhanced_img = T.ToPILImage()(enhanced_tensor)

            plt.figure(figsize=(15,5))
            plt.subplot(1,3,1); plt.title("Input"); plt.imshow(low_img); plt.axis('off')
            plt.subplot(1,3,2); plt.title("Ground Truth"); plt.imshow(high_img); plt.axis('off')
            plt.subplot(1,3,3); plt.title(f"Epoch {epoch} (PSNR {avg_psnr:.2f})"); plt.imshow(enhanced_img); plt.axis('off')
            plt.show()

# Save the model (reuse dict from notebook)
trained_models['ZeroDCE_Supervised_Hybrid'] = supervised_zerodce_model

# New: Final Summary and Plots
if training_stats_supervised['psnr']:
    final_avg_psnr = sum(training_stats_supervised['psnr']) / len(training_stats_supervised['psnr'])
    final_avg_ssim = sum(training_stats_supervised['ssim']) / len(training_stats_supervised['ssim'])
    print(f"\nFinal Average PSNR: {final_avg_psnr:.2f}")
    print(f"Final Average SSIM: {final_avg_ssim:.4f}")

    # Plotting Training History
    plt.figure(figsize=(18, 5))

    plt.subplot(1, 3, 1)
    plt.plot(training_stats_supervised['epoch'], training_stats_supervised['loss'], label='Loss', color='red')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(training_stats_supervised['epoch'], training_stats_supervised['psnr'], label='PSNR', color='green')
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(training_stats_supervised['epoch'], training_stats_supervised['ssim'], label='SSIM', color='blue')
    plt.title('Validation SSIM')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

## Analysis of Hybrid ZeroDCE (Dual Loss: ZeroDCE + Charbonnier)

### Quantitative Metrics (Test Set)
- **Average PSNR**: 17.05  
- **Average SSIM**: 0.5552  

### Training Loss Observations
The loss curve shows steady convergence, starting high and declining smoothly over epochs. This indicates stable optimization and progressive improvement in the model's ability to enhance low-light images.

### Performance Insights
The hybrid model outperforms the pure unsupervised ZeroDCE, demonstrating that incorporating supervised reconstruction (via Charbonnier loss) enhances pixel accuracy and perceptual quality while retaining ZeroDCE's illumination priors.

However, it falls short of the supervised U-Net baseline, likely due to architectural differences: U-Net's encoder-decoder with skip connections excels at precise pixel-level transformations, whereas ZeroDCE's curve-based approach prioritizes global enhancement and naturalness over exact reconstruction. This trade-off makes the hybrid a balanced alternative for scenarios with limited paired data.

# MODEL 2: Hybrid_ZeroDCE_Only_Charbonnier

This model uses ZeroDCE architecture with only Charbonnier loss

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from piqa import SSIM
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

# =============================================================================
# üìå 4. Model 4: ZeroDCE Architecture with only Charbonnier Loss (Pure Supervised)
# =============================================================================

# Reuse ZeroDCE_Supervised_Architecture (already defined)
# class ZeroDCE_Supervised_Architecture(ZeroDCE_Unsupervised_Model):
#     def __init__(self):
#         super(ZeroDCE_Supervised_Architecture, self).__init__()
#
#     def forward(self, x):
#         _, enhanced_image, _ = super().forward(x)
#         return enhanced_image

# Reuse CharbonnierLoss (already defined)
# class CharbonnierLoss(nn.Module):
#     def __init__(self, eps=1e-3):
#         super(CharbonnierLoss, self).__init__()
#         self.eps = eps
#
#     def forward(self, x, y):
#         diff = x - y
#         return torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))

# Define the hyperparameters to match paper
LR = 0.0001 # CHANGED from 5e-5 to 0.0001
EPOCHS = 200
BATCH_SIZE = 8 # CHANGED from 16 to 8
CROP_SIZE = 256
SHOW_EVERY = 10

print("="*60)
print("STARTING TRAINING: MODEL 4 (ZeroDCE Arch + Pure Charbonnier Loss)")
print(f"üîß TRAINING HYPERPARAMETERS: LR={LR}, EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}")
print("="*60)

# Reuse datasets/loaders
train_dataset = LOLDataset(train=True, crop_size=CROP_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True, prefetch_factor=2)

val_dataset = LOLDataset(train=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Instantiate the supervised ZeroDCE architecture
supervised_zerodce_model_pure_charbonnier = ZeroDCE_Supervised_Architecture().to(device)

# The criterion is purely Charbonnier Loss
criterion_pure_charbonnier = CharbonnierLoss().to(device)

optimizer = torch.optim.AdamW(supervised_zerodce_model_pure_charbonnier.parameters(), lr=LR, weight_decay=1e-5)
scaler = torch.amp.GradScaler('cuda')

# Reusing SSIM for validation metrics
cal_ssim_val = SSIM().to(device)

# Reuse tracking from Section 3
training_stats_pure_charbonnier = {'epoch': [], 'loss': [], 'psnr': [], 'ssim': []}

for epoch in range(1, EPOCHS+1):
    supervised_zerodce_model_pure_charbonnier.train()
    epoch_loss = 0.0

    for low, high in train_loader:
        low, high = low.to(device), high.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            enhanced = supervised_zerodce_model_pure_charbonnier(low)
            loss = criterion_pure_charbonnier(enhanced, high)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(supervised_zerodce_model_pure_charbonnier.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)

    if epoch % SHOW_EVERY == 0:
        supervised_zerodce_model_pure_charbonnier.eval()
        psnr_list, ssim_list = [], []
        with torch.no_grad():
            for v_low, v_high in val_loader:
                v_low, v_high = v_low.to(device), v_high.to(device)
                v_enhanced = supervised_zerodce_model_pure_charbonnier(v_low).clamp(0,1)
                ssim_list.append(cal_ssim_val(v_enhanced, v_high).item())
                v_enhanced_np = v_enhanced.squeeze(0).cpu().numpy().transpose(1,2,0)
                v_high_np = v_high.squeeze(0).cpu().numpy().transpose(1,2,0)
                psnr_list.append(psnr(v_high_np, v_enhanced_np, data_range=1.0))

        avg_psnr = sum(psnr_list) / len(psnr_list)
        avg_ssim = sum(ssim_list) / len(ssim_list)

        training_stats_pure_charbonnier['epoch'].append(epoch)
        training_stats_pure_charbonnier['loss'].append(avg_loss)
        training_stats_pure_charbonnier['psnr'].append(avg_psnr)
        training_stats_pure_charbonnier['ssim'].append(avg_ssim)

        print(f"Epoch {epoch:2d}/{EPOCHS} | Loss: {avg_loss:.4f} | Test PSNR: {avg_psnr:.2f} | Test SSIM: {avg_ssim:.4f}")

        # Reuse visualization from Section 3
        with torch.no_grad():
            rand_idx = random.randint(0, len(val_dataset) - 1)
            low_img_sample, high_img_sample = val_dataset[rand_idx]
            input_tensor_sample = low_img_sample.unsqueeze(0).to(device)

            enhanced_tensor = supervised_zerodce_model_pure_charbonnier(input_tensor_sample).clamp(0,1).squeeze(0).cpu()
            enhanced_img = T.ToPILImage()(enhanced_tensor)

            plt.figure(figsize=(15,5))
            plt.subplot(1,3,1); plt.title("Input"); plt.imshow(T.ToPILImage()(low_img_sample)); plt.axis('off')
            plt.subplot(1,3,2); plt.title("Ground Truth"); plt.imshow(T.ToPILImage()(high_img_sample)); plt.axis('off')
            plt.subplot(1,3,3); plt.title(f"Epoch {epoch} (PSNR {avg_psnr:.2f})"); plt.imshow(enhanced_img); plt.axis('off')
            plt.show()

# Save the model
trained_models['ZeroDCE_Supervised_Pure_Charbonnier'] = supervised_zerodce_model_pure_charbonnier

# Final Summary and Plots
if training_stats_pure_charbonnier['psnr']:
    final_avg_psnr = sum(training_stats_pure_charbonnier['psnr']) / len(training_stats_pure_charbonnier['psnr'])
    final_avg_ssim = sum(training_stats_pure_charbonnier['ssim']) / len(training_stats_pure_charbonnier['ssim'])
    print(f"\nFinal Average PSNR: {final_avg_psnr:.2f}")
    print(f"Final Average SSIM: {final_avg_ssim:.4f}")

    # Plotting Training History
    plt.figure(figsize=(18, 5))

    plt.subplot(1, 3, 1)
    plt.plot(training_stats_pure_charbonnier['epoch'], training_stats_pure_charbonnier['loss'], label='Loss', color='red')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(training_stats_pure_charbonnier['epoch'], training_stats_pure_charbonnier['psnr'], label='PSNR', color='green')
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(training_stats_pure_charbonnier['epoch'], training_stats_pure_charbonnier['ssim'], label='SSIM', color='blue')
    plt.title('Validation SSIM')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

## Analysis of Hybrid ZeroDCE Variant 2: Supervised (Charbonnier Loss Only)

### Quantitative Metrics (Test Set)
- **Average PSNR**: 16.74  
- **Average SSIM**: 0.4993  

### Training Loss Observations
Similar to Variant 1, the loss curve generally decreases over epochs, indicating effective learning. However, PSNR and SSIM on the validation set exhibit more fluctuations compared to the dual-loss hybrid, suggesting reduced training stability without the original ZeroDCE priors.

### Performance Insights
This variant outperforms the pure unsupervised ZeroDCE, confirming that adding supervised reconstruction guidance (Charbonnier loss) improves accuracy and quality even on the ZeroDCE architecture.

However, it underperforms both the supervised U-Net baseline and the dual-loss hybrid (PSNR 17.05 / SSIM 0.5552 for Variant 1). The increased volatility in metrics implies that removing the ZeroDCE non-reference losses eliminates useful regularization, leading to less consistent convergence. This highlights the value of hybrid objectives: the original priors act as stabilizers, promoting smoother optimization and slightly better overall results by balancing perceptual naturalness with pixel-level fidelity.

# üìå 5. Comparison of the Four Models

This section evaluates the performance of the four trained models on the test set:
1. U-Net with Charbonnier loss (supervised)
2. ZeroDCE with original loss (unsupervised)
3. Hybrid ZeroDCE with dual loss (ZeroDCE + Charbonnier)
4. Hybrid ZeroDCE with Charbonnier loss only (supervised)

The code below computes quantitative metrics (PSNR/SSIM) and generates visual comparisons.

In [None]:
# Rename Model 1 and Model 4 for clarity in comparison
if 'ZeroDCE_Supervised_Hybrid' in trained_models:
    trained_models['Hybrid_ZeroDCE_Two_Losses'] = trained_models.pop('ZeroDCE_Supervised_Hybrid')
    print("Renamed 'ZeroDCE_Supervised_Hybrid' to 'Hybrid_ZeroDCE_Two_Losses'")

if 'ZeroDCE_Supervised_Pure_Charbonnier' in trained_models:
    trained_models['Hybrid_ZeroDCE_Only_Charbonnier'] = trained_models.pop('ZeroDCE_Supervised_Pure_Charbonnier')
    print("Renamed 'ZeroDCE_Supervised_Pure_Charbonnier' to 'Hybrid_ZeroDCE_Only_Charbonnier'")

# Ensure UNet_Charbonnier is loaded, as its state dict is saved separately.
# Re-load or ensure the trained UNet model is in trained_models for comparison
# (This logic is already robust in the comparison cell, but for completeness)
if 'UNet_Charbonnier' not in trained_models:
    # This part should ideally not run if the first supervised section was executed successfully
    charbonnier_model = UNet(depth=MODEL_DEPTH).to(device)
    charbonnier_model_path = "zero_dce_charbonnier_pro.pth"
    if os.path.exists(charbonnier_model_path):
        charbonnier_model.load_state_dict(torch.load(charbonnier_model_path))
        charbonnier_model.eval()
        trained_models['UNet_Charbonnier'] = charbonnier_model
        print("Ensured UNet_Charbonnier is in trained_models.")
    else:
        print(f"Warning: {charbonnier_model_path} not found. UNet_Charbonnier might be missing from comparison.")

print("Current models in trained_models for comparison:", list(trained_models.keys()))


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import random
import torch
import torchvision.transforms as T

# Assuming `df` is the DataFrame from the previous full evaluation cell
# Define the list of models to include in the final comparison
key_models_names = [
    'UNet_Charbonnier',
    'ZeroDCE_Unsupervised',
    'Hybrid_ZeroDCE_Two_Losses',
    'Hybrid_ZeroDCE_Only_Charbonnier'
]

# Filter the DataFrame to include only the key models
# Ensure 'df' exists; if not, you might need to re-run the full comparison cell first.
if 'df' in globals():
    final_comparison_df = df.loc[key_models_names]
    print("Final Evaluation Results of Key Models:")
    display(final_comparison_df)

    # Determine best models based on the filtered results
    best_psnr_key = final_comparison_df['Avg PSNR'].idxmax()
    best_ssim_key = final_comparison_df['Avg SSIM'].idxmax()

    print(f"Best by PSNR: {best_psnr_key}")
    print(f"Best by SSIM: {best_ssim_key}")
else:
    print("Warning: 'df' DataFrame not found. Please ensure the full comparison cell has been run.")

# --- Visual Comparison of Key Models ---
print("\n" + "="*60)
print("VISUAL COMPARISON OF KEY MODELS")
print("="*60)

# Ensure 'trained_models' global dictionary is available
if 'trained_models' not in globals():
    print("Error: 'trained_models' dictionary not found. Please ensure all models are trained and saved.")
else:
    # Filter trained_models to include only the key models
    key_trained_models = {name: trained_models[name] for name in key_models_names if name in trained_models}

    if not key_trained_models:
        print("Error: None of the key models were found in 'trained_models' for visual comparison.")
    else:
        test_low_paths_eval15 = sorted((Path("/content/lol_dataset/eval15/low")).glob("*.png"))
        test_high_paths_eval15 = sorted((Path("/content/lol_dataset/eval15/high")).glob("*.png"))

        if len(test_low_paths_eval15) == 0:
            print("‚ö†Ô∏è Error: No images found in /content/lol_dataset/eval15/low.")
        else:
            idx = random.randint(0, len(test_low_paths_eval15) - 1)
            print(f"Testing on image index: {idx} / {len(test_low_paths_eval15)-1}")

            low_img = Image.open(test_low_paths_eval15[idx]).convert("RGB")
            high_img = Image.open(test_high_paths_eval15[idx]).convert("RGB")
            input_tensor = T.ToTensor()(low_img).unsqueeze(0).to(device)

            plt.figure(figsize=(20, 5))
            total_plots = len(key_trained_models) + 2 # Low-light Input, Ground Truth, and each key model

            plt.subplot(1, total_plots, 1); plt.title("Low-light Input"); plt.imshow(low_img); plt.axis('off')
            plt.subplot(1, total_plots, 2); plt.title("Ground Truth"); plt.imshow(high_img); plt.axis('off')

            plot_idx = 3
            # Sort models alphabetically by name for consistent plotting order
            sorted_key_trained_models = dict(sorted(key_trained_models.items()))

            for name, model in sorted_key_trained_models.items():
                model.eval()
                with torch.no_grad():
                    output = model(input_tensor)
                    if isinstance(output, tuple):
                         _, enhanced_out, _ = output
                         enhanced = enhanced_out.clamp(0,1).squeeze(0).cpu()
                    else:
                         enhanced = output.clamp(0,1).squeeze(0).cpu()

                    enhanced_img = T.ToPILImage()(enhanced)
                    plt.subplot(1, total_plots, plot_idx)
                    plt.title(f"Enhanced ({name})")
                    plt.imshow(enhanced_img)
                    plt.axis('off')
                    plot_idx += 1

            plt.suptitle("Visual Comparison on a Test Image (eval15 Dataset)")
            plt.tight_layout()
            plt.show()

# üìå 6. Final Analysis of Model Performance

Based on quantitative metrics (PSNR and SSIM), training stability, and qualitative visual inspections, the four models exhibit distinct performance characteristics:

1. **U-Net + Charbonnier (Supervised)**:  
   Consistently achieves the highest PSNR (~19.5-20.2, SSIM ~0.76‚Äì0.78). Outputs closely match ground truth with accurate brightness, contrast, color fidelity, and detail recovery.  
   *Reason*: U-Net's encoder-decoder with skip connections excels at pixel-level reconstruction, optimized directly via supervised loss.

2. **ZeroDCE (Unsupervised)**:  
   Yields the lowest scores (PSNR ~14.6‚Äì14.7, SSIM ~0.47‚Äì0.49). Enhances visibility and contrast but introduces color shifts, noise, and artifacts.  
   *Reason*: Relies solely on non-reference priors without ground-truth guidance, limiting pixel accuracy despite perceptual improvements.

3. **Hybrid ZeroDCE (Dual Loss: ZeroDCE + Charbonnier)**:  
   Significantly improves over unsupervised ZeroDCE (PSNR ~17.0‚Äì17.3, SSIM ~0.55‚Äì0.57). Produces natural, bright outputs with reduced artifacts and stable training.  
   *Reason*: Supervised reconstruction enhances precision, while original priors provide regularization for better perceptual quality and convergence.

4. **Hybrid ZeroDCE (Charbonnier Only)**:  
   Performs comparably but slightly worse (PSNR ~16.7‚Äì17.0, SSIM ~0.49‚Äì0.51), with more metric fluctuations indicating reduced stability. Outputs are similar but less consistent in color and texture.  
   *Reason*: Fully supervised but lacks ZeroDCE priors, which act as stabilizers‚Äîhighlighting the value of hybrid regularization to avoid over-focusing on pixel errors.

### Overall Conclusion
Supervised approaches dominate in low-light enhancement, with U-Net + Charbonnier as the benchmark for accuracy. ZeroDCE shines in unsupervised scenarios but benefits greatly from hybridization, where the dual-loss variant offers the best balance of metrics, stability, and naturalness. These findings emphasize strategic loss design and architecture choice for optimal results.