# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from torchvision import models
import os

torch.set_float32_matmul_precision('medium')


# Data Module

## Custom Dataset

In [35]:
class LoLDataset(Dataset):
    def __init__(self, dark_dir, bright_dir, transform):
        self.dark_dir = dark_dir
        self.bright_dir = bright_dir
        self.transform = transform
        
        # Read and sort both folders independently
        self.dark_images = sorted([f for f in os.listdir(dark_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.bright_images = sorted([f for f in os.listdir(bright_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        
        # Quick safety print if there's a mismatch in folder sizes
        if len(self.dark_images) != len(self.bright_images):
            print(f"WARNING: Found {len(self.dark_images)} dark but {len(self.bright_images)} bright in {dark_dir.split('/')[-2]}")

    def __len__(self): 
        # Safely return the smaller count so it never asks for an index that doesn't exist
        return min(len(self.dark_images), len(self.bright_images))

    def __getitem__(self, idx):
        # Uses the specific filename found in each respective folder
        dark_path = os.path.join(self.dark_dir, self.dark_images[idx])
        bright_path = os.path.join(self.bright_dir, self.bright_images[idx])
        
        dark_img = Image.open(dark_path).convert("RGB")
        bright_img = Image.open(bright_path).convert("RGB")
        
        return self.transform(dark_img), self.transform(bright_img)

## Lightning DataModule

In [36]:
class LoLDataModule(pl.LightningDataModule):
    def __init__(self, base_path, batch_size=8):
        super().__init__()
        # Pointing to the specific subfolders from your image
        self.train_low = os.path.join(base_path, "Real_captured/Train/Low")
        self.train_high = os.path.join(base_path, "Real_captured/Train/Normal")
        self.val_low = os.path.join(base_path, "Real_captured/Test/Low")
        self.val_high = os.path.join(base_path, "Real_captured/Test/Normal")
        
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def setup(self, stage=None):
        # Training set 
        self.train_ds = LoLDataset(self.train_low, self.train_high, self.transform)
        # Validation set
        self.val_ds = LoLDataset(self.val_low, self.val_high, self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True)

In [37]:
# --- TEST ---
def test_datamodule():
    print("Testing DataModule...")
    dm = LoLDataModule(base_path="/home/sanjeet/ai_workspace/Image Enhancement/lol v2 dataset") 
    dm.setup()
    
    # Check Training Set
    train_loader = dm.train_dataloader()
    dark_batch, bright_batch = next(iter(train_loader))
    
    print(f"Total training pairs found: {len(dm.train_ds)}")
    print(f"Total validation pairs found: {len(dm.val_ds)}")
    print(f"Batch Shape (Dark): {dark_batch.shape}")
    print(f"Batch Shape (Bright): {bright_batch.shape}")
    
    # Check if pixels are normalized (should be between 0 and 1)
    print(f"Pixel Range: {dark_batch.min():.2f} to {dark_batch.max():.2f}")
    print("DataModule is ready to feed the model.\n")

test_datamodule()

Testing DataModule...
Total training pairs found: 689
Total validation pairs found: 100
Batch Shape (Dark): torch.Size([8, 3, 224, 224])
Batch Shape (Bright): torch.Size([8, 3, 224, 224])
Pixel Range: 0.00 to 0.60
DataModule is ready to feed the model.



# Architecture

In [38]:
class NanoLILY(nn.Module):
    def __init__(self):
        super(NanoLILY, self).__init__()

        # --- BRANCH A: THE ARTIST ---
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(0.2)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.dec1 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.final_spatial = nn.Conv2d(32, 3, kernel_size=3, padding=1)

        # --- BRANCH B: THE SIEVE ---
        # Initialize with SMALL values (0.01) so it doesn't overwhelm the image early on
        self.freq_mask = nn.Parameter(torch.full((1, 3, 224, 113), 0.01))

        # --- THE FUSION ---
        self.fusion = nn.Sequential(
            nn.Conv2d(6, 3, kernel_size=1),
            nn.Tanh() # Constrains the residual to a stable range [-1, 1]
        )

    def forward(self, x):
        identity = x
        
        # Spatial Pass
        s1 = self.enc1(x)
        s2 = self.enc2(s1)
        up1 = self.dec1(s2)
        spatial_out = self.final_spatial(up1 + s1) 

        # Frequency Pass
        x_fft = torch.fft.rfft2(x)
        x_fft_filtered = x_fft * self.freq_mask
        freq_out = torch.fft.irfft2(x_fft_filtered, s=(224, 224))

        # Fusion
        combined = torch.cat([spatial_out, freq_out], dim=1)
        residual = self.fusion(combined)

        return identity + residual

In [39]:
# --- TEST ---
model = NanoLILY()
sample_input = torch.randn(1, 3, 224, 224)
output = model(sample_input)
print(f"Success! Output Shape: {output.shape}")

Success! Output Shape: torch.Size([1, 3, 224, 224])


# The Perceptual Loss

In [40]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # We only need the feature extraction layers
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        # Slicing up to layer 18 (ReLU 4_2)
        self.slice = nn.Sequential(*list(vgg.children())[:18]).eval()
        # Freezing parameters so we don't train VGG
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, output, target):
        output_feat = self.slice(output)
        target_feat = self.slice(target)
        return F.mse_loss(output_feat, target_feat)

# The Training System

In [41]:
class NanoLILYSystem(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)

# Evaluation

In [None]:
# --- 1. DEFINE THE MISSING LOSS FUNCTIONS ---
def compute_charbonnier_loss(output, target, eps=1e-3):
    diff = output - target
    return torch.mean(torch.sqrt(diff * diff + eps * eps))

def compute_color_loss(output, target):
    return torch.mean(1.0 - F.cosine_similarity(output, target, dim=1))

def compute_tv_loss(x):
    tv_h = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]))
    tv_w = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]))
    return tv_h + tv_w

# --- 2. SETUP DATA AND METRICS ---
dm = LoLDataModule(base_path="../dataset/lol_v2", batch_size=16)
dm.setup()
val_loader = dm.val_dataloader()

psnr_metric = PeakSignalNoiseRatio(data_range=1.0).cuda()
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()
vgg_loss_fn = VGGPerceptualLoss().cuda() 

# --- 3. LOAD MODEL ---
model_128k = NanoLILY() 


ckpt_path_128k = "../model/NanoLILY_best_weights.ckpt" 
checkpoint = torch.load(ckpt_path_128k, map_location="cpu")

# Handle standard checkpoint dictionary or raw weights
if 'state_dict' in checkpoint:
    state_dict = {k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items()}
    model_128k.load_state_dict(state_dict, strict=False)
else:
    model_128k.load_state_dict(checkpoint)

model_128k.cuda().eval()

# --- 4. EVALUATE ---
total_psnr, total_ssim, total_vgg, total_fft = 0.0, 0.0, 0.0, 0.0
total_char, total_color, total_tv = 0.0, 0.0, 0.0
batches = 0

print("Extracting full metrics for 128k model...")

with torch.no_grad():
    for dark, bright in val_loader:
        dark, bright = dark.cuda(), bright.cuda()
        
        # Forward pass
        output = model_128k(dark)
        
        # Calculate standard metrics
        total_psnr += psnr_metric(output, bright).item()
        
        ssim_val = ssim_metric(output, bright).item()
        total_ssim += ssim_val
        
        total_vgg += vgg_loss_fn(output, bright).item()
        
        # Calculate FFT
        out_fft = torch.fft.rfft2(output.float(), norm="ortho")
        tar_fft = torch.fft.rfft2(bright.float(), norm="ortho")
        total_fft += F.l1_loss(torch.abs(out_fft), torch.abs(tar_fft)).item()
        
        # Calculate the missing losses
        total_char += compute_charbonnier_loss(output, bright).item()
        total_color += compute_color_loss(output, bright).item()
        total_tv += compute_tv_loss(output).item()
        
        batches += 1

print("\n" + "="*40)
print("üèÜ 128k MODEL DIAGNOSTIC METRICS üèÜ")
print("="*40)
print(f"       val_psnr: {total_psnr / batches:.5f}")
print(f"       val_ssim: {total_ssim / batches:.5f}")
print(f"  val_char_loss: {total_char / batches:.5f}")
print(f"   val_vgg_loss: {total_vgg / batches:.5f}")
print(f"   val_fft_loss: {total_fft / batches:.5f}")
print(f"    val_tv_loss: {total_tv / batches:.5f}")
print(f" val_color_loss: {total_color / batches:.5f}")
print(f"  val_ssim_loss: {1.0 - (total_ssim / batches):.5f}")
print("="*40)

Extracting full metrics for 128k model...


  checkpoint = torch.load(ckpt_path_128k, map_location="cpu")



üèÜ 128k MODEL DIAGNOSTIC METRICS üèÜ
       val_psnr: 19.19457
       val_ssim: 0.82591
  val_char_loss: 0.09114
   val_vgg_loss: 0.66296
   val_fft_loss: 0.01738
    val_tv_loss: 0.05561
 val_color_loss: 0.01628
  val_ssim_loss: 0.17409
