In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import time
from tqdm import tqdm
import zipfile
import gdown
from collections import defaultdict

# Metrics
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
import pandas as pd

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


In [None]:
# 2. DOWNLOAD LOL-V2 DATASET
# ============================================================================

print("\n" + "="*70)
print("üì• DOWNLOADING LOL-V2 DATASET")
print("="*70)

data_dir = Path('./data')
data_dir.mkdir(exist_ok=True)

# Google Drive link (contains both Real and Synthetic)
file_id = '1dzuLCk9_gE2bFF222n3-7GVUlSVHpMYC'
zip_path = data_dir / 'LOL-v2.zip'
extract_path = data_dir / 'LOL-v2'

if not extract_path.exists():
    print("üì• Downloading from Google Drive...")
    url = f'https://drive.google.com/uc?id={file_id}'
    gdown.download(url, str(zip_path), quiet=False, fuzzy=True)

    print("üì¶ Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    zip_path.unlink()
    print("‚úÖ Dataset ready!")
else:
    print("‚úÖ Dataset already exists!")

# Find Real and Synthetic folders
real_path = None
synthetic_path = None

for path in extract_path.rglob('*'):
    if path.is_dir():
        if 'real' in path.name.lower() and 'captured' in path.name.lower():
            real_path = path
        elif path.name.lower() == 'synthetic':
            synthetic_path = path

print(f"\nüìÅ Dataset structure:")
print(f"   Real: {real_path}")
print(f"   Synthetic: {synthetic_path}")

In [None]:
# 3. DATASET CLASS & DATA LOADING
# ============================================================================


def extract_id(filename):
    m = re.search(r'(\d+)', filename)
    return m.group(1) if m else None


class LOLv2Dataset(Dataset):
 
    def __init__(
        self,
        data_paths,
        split='train',
        img_size=256,
        verbose=True
    ):
        self.img_size = img_size
        self.split = split.lower()
        self.pairs = []

        assert self.split in ['train', 'test']

        if verbose:
            print(f"\nüìä LOADING {self.split.upper()} DATASET")

        for data_path in data_paths:
            data_path = Path(data_path)
            dataset_name = data_path.name

            split_dir = data_path / ('Train' if self.split == 'train' else 'Test')
            low_dir = split_dir / 'Low'
            normal_dir = split_dir / 'Normal'

            if not low_dir.exists() or not normal_dir.exists():
                if verbose:
                    print(f"‚ö†Ô∏è  Skip {dataset_name} ({self.split}) ‚Äì missing Low/Normal")
                continue

            # Collect files
            low_imgs = list(low_dir.glob('*.png')) + list(low_dir.glob('*.jpg'))
            normal_imgs = list(normal_dir.glob('*.png')) + list(normal_dir.glob('*.jpg'))

            # Build Normal dict by ID
            normal_dict = {}
            for p in normal_imgs:
                key = extract_id(p.name)
                if key is not None:
                    normal_dict[key] = p

            count_before = len(self.pairs)

            for low_img in low_imgs:
                key = extract_id(low_img.name)
                if key is None:
                    continue

                if key in normal_dict:
                    self.pairs.append((
                        str(low_img),
                        str(normal_dict[key])
                    ))

            count_after = len(self.pairs)

            if verbose:
                print(
                    f"   - {dataset_name:<15}: "
                    f"{count_after - count_before} pairs"
                )

        if verbose:
            print(f"‚úÖ TOTAL {self.split.upper()} PAIRS: {len(self.pairs)}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        low_path, normal_path = self.pairs[idx]

        # Load images
        low_img = cv2.imread(low_path)
        normal_img = cv2.imread(normal_path)

        if low_img is None or normal_img is None:
            raise RuntimeError(f"Failed to load image pair:\n{low_path}\n{normal_path}")

        # BGR -> RGB
        low_img = cv2.cvtColor(low_img, cv2.COLOR_BGR2RGB)
        normal_img = cv2.cvtColor(normal_img, cv2.COLOR_BGR2RGB)

        # Resize
        low_img = cv2.resize(low_img, (self.img_size, self.img_size))
        normal_img = cv2.resize(normal_img, (self.img_size, self.img_size))

        # Normalize
        low_img = low_img.astype(np.float32) / 255.0
        normal_img = normal_img.astype(np.float32) / 255.0

        # HWC -> CHW
        low_img = torch.from_numpy(low_img).permute(2, 0, 1)
        normal_img = torch.from_numpy(normal_img).permute(2, 0, 1)

        return low_img, normal_img

# Create combined dataset (Real + Synthetic)
print("\n" + "="*70)
print("üìä LOADING DATASET: Real + Synthetic Combined")
print("="*70)

train_dataset = LOLv2Dataset(
    data_paths=[real_path, synthetic_path],
    split='train',
    img_size=256
)

test_dataset = LOLv2Dataset(
    data_paths=[real_path, synthetic_path],
    split='test',
    img_size=256
)

print(f"\n‚úÖ Total pairs: {len(train_dataset) + len(test_dataset)}")

In [None]:
# 4. ZERO-DCE MODEL ARCHITECTURE
# ============================================================================

class DCENet(nn.Module):
    """Zero-DCE Network"""

    def __init__(self, num_iterations=8):
        super(DCENet, self).__init__()
        self.num_iterations = num_iterations

        # Encoder
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, padding=1)

        # Decoder
        self.conv5 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv7 = nn.Conv2d(32, 3 * num_iterations, 3, padding=1)

        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()

    def forward(self, x):
        # Encoder
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))
        x3 = self.relu(self.conv3(x2))
        x4 = self.relu(self.conv4(x3))

        # Decoder
        x5 = self.relu(self.conv5(x4))
        x6 = self.relu(self.conv6(x5))
        A = self.tanh(self.conv7(x6))

        # Apply curve iteratively
        enhanced = x
        for i in range(self.num_iterations):
            curve_params = A[:, i*3:(i+1)*3, :, :]
            enhanced = enhanced + curve_params * enhanced * (1 - enhanced)

        return enhanced, A

In [None]:
# 5. LOSS FUNCTIONS
# ============================================================================

class CharbonnierLoss(nn.Module):
    def __init__(self, epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon = epsilon
    
    def forward(self, pred, target):
        diff = pred - target
        loss = torch.mean(torch.sqrt(diff * diff + self.epsilon * self.epsilon))
        return loss

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = vgg19(pretrained=True).features
        
        self.layers = nn.ModuleList([
            vgg[:4], vgg[:9], vgg[:18], vgg[:27]
        ])
        
        for layer in self.layers:
            for param in layer.parameters():
                param.requires_grad = False
        
        self.layers = self.layers.to(device)
    
    def forward(self, pred, target):
        loss = 0
        for layer in self.layers:
            pred_feat = layer(pred)
            target_feat = layer(target)
            loss += F.mse_loss(pred_feat, target_feat)
        return loss

class ColorConstancyLoss(nn.Module):
    def __init__(self):
        super(ColorConstancyLoss, self).__init__()
    
    def forward(self, enhanced):
        mean_rgb = torch.mean(enhanced, dim=(2, 3), keepdim=True)
        mr, mg, mb = mean_rgb[:, 0:1], mean_rgb[:, 1:2], mean_rgb[:, 2:3]
        
        d_rg = torch.pow(mr - mg, 2)
        d_rb = torch.pow(mr - mb, 2)
        d_gb = torch.pow(mg - mb, 2)
        
        loss = torch.sqrt(torch.pow(d_rg, 2) + torch.pow(d_rb, 2) + torch.pow(d_gb, 2))
        return torch.mean(loss)

class ExposureControlLoss(nn.Module):
    def __init__(self, patch_size=16, mean_val=0.6):
        super(ExposureControlLoss, self).__init__()
        self.patch_size = patch_size
        self.mean_val = mean_val
        self.pool = nn.AvgPool2d(patch_size)
    
    def forward(self, enhanced):
        enhanced_gray = 0.299 * enhanced[:, 0] + 0.587 * enhanced[:, 1] + 0.114 * enhanced[:, 2]
        enhanced_gray = enhanced_gray.unsqueeze(1)
        mean = self.pool(enhanced_gray)
        return torch.mean(torch.pow(mean - self.mean_val, 2))

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.charbonnier = CharbonnierLoss()
        self.perceptual = PerceptualLoss()
        self.color = ColorConstancyLoss()
        self.exposure = ExposureControlLoss()
    
    def forward(self, pred, target):
        losses = {}
        
        losses['charbonnier'] = self.charbonnier(pred, target)
        losses['perceptual'] = self.perceptual(pred, target)
        losses['color'] = self.color(pred)
        losses['exposure'] = self.exposure(pred)
        
        total_loss = losses['charbonnier'] + \
                     0.05 * losses['perceptual'] + \
                     0.5 * losses['color'] + \
                     0.1 * losses['exposure']
        
        losses['total'] = total_loss
        return total_loss, losses


# Loss configurations
LOSS_CONFIGS = {
    'l1': 'l1',
    'charbonnier': 'charbonnier',
    'charbonnier_perceptual': 'charbonnier_perceptual',
    'charbonnier_perceptual_ssim': 'charbonnier_perceptual_ssim',
    'charbonnier_perceptual_color_exposure': 'charbonnier_perceptual_color_exposure'
}

print(f"\nüéØ Loss Configurations:")
for i, name in enumerate(LOSS_CONFIGS.keys(), 1):
    print(f"   {i}. {name}")

In [None]:
# 6. METRICS CALCULATION
# ============================================================================

class MetricsCalculator:
    """Calculate PSNR, SSIM, and LPIPS metrics"""

    def __init__(self):
        self.lpips_model = lpips.LPIPS(net='alex').to(device)
        self.lpips_model.eval()

    def calculate_psnr(self, pred, target):
        pred_np = pred.cpu().numpy().transpose(1, 2, 0)
        target_np = target.cpu().numpy().transpose(1, 2, 0)
        return psnr(target_np, pred_np, data_range=1.0)

    def calculate_ssim(self, pred, target):
        pred_np = pred.cpu().numpy().transpose(1, 2, 0)
        target_np = target.cpu().numpy().transpose(1, 2, 0)
        return ssim(target_np, pred_np, data_range=1.0, channel_axis=2)

    def calculate_lpips(self, pred, target):
        with torch.no_grad():
            pred_norm = pred * 2 - 1
            target_norm = target * 2 - 1
            return self.lpips_model(pred_norm, target_norm).item()

    def calculate_all(self, pred, target):
        return {
            'psnr': self.calculate_psnr(pred, target),
            'ssim': self.calculate_ssim(pred, target),
            'lpips': self.calculate_lpips(pred.unsqueeze(0), target.unsqueeze(0))
        }


In [None]:
# 7. TRAINING FUNCTION
# ============================================================================

class Trainer:
    """Training pipeline"""

    def __init__(self, model, train_loader, test_loader, loss_fn,
                 optimizer, device, save_dir, experiment_name):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device
        self.save_dir = Path(save_dir)
        self.experiment_name = experiment_name
        self.metrics_calc = MetricsCalculator()

        self.save_dir.mkdir(exist_ok=True, parents=True)

        self.history = defaultdict(list)
        self.best_psnr = 0

    def train_epoch(self):
        self.model.train()
        epoch_losses = defaultdict(float)

        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (low, normal) in enumerate(pbar):
            try:
                low, normal = low.to(self.device), normal.to(self.device)

                self.optimizer.zero_grad()
                enhanced, _ = self.model(low)
                loss, loss_dict = self.loss_fn(enhanced, normal)

                loss.backward()
                self.optimizer.step()

                for k, v in loss_dict.items():
                    epoch_losses[k] += v.item()

                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(f"\n‚ö†Ô∏è  OOM at batch {batch_idx}. Clearing cache...")
                    torch.cuda.empty_cache()
                    continue
                else:
                    raise e

        for k in epoch_losses:
            epoch_losses[k] /= len(self.train_loader)

        return dict(epoch_losses)

    def evaluate(self):
        self.model.eval()
        metrics = defaultdict(list)

        with torch.no_grad():
            for low, normal in tqdm(self.test_loader, desc='Evaluating'):
                low, normal = low.to(self.device), normal.to(self.device)

                enhanced, _ = self.model(low)

                for i in range(enhanced.shape[0]):
                    m = self.metrics_calc.calculate_all(enhanced[i], normal[i])
                    for k, v in m.items():
                        metrics[k].append(v)

        avg_metrics = {k: np.mean(v) for k, v in metrics.items()}
        return avg_metrics

    def save_checkpoint(self, epoch, metrics, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': metrics,
            'history': dict(self.history)
        }

        if (epoch + 1) % 20 == 0:
            path = self.save_dir / f'{self.experiment_name}_epoch_{epoch+1}.pth'
            torch.save(checkpoint, path)
            print(f"üíæ Saved checkpoint: {path}")

        if is_best:
            path = self.save_dir / f'{self.experiment_name}_best.pth'
            torch.save(checkpoint, path)
            print(f"üèÜ Saved best model: {path}")

    def train(self, num_epochs):
        print(f"\nüöÄ Starting training: {self.experiment_name}")
        print(f"   Epochs: {num_epochs}")

        for epoch in range(num_epochs):
            print(f"\nüìÖ Epoch {epoch+1}/{num_epochs}")

            train_losses = self.train_epoch()
            test_metrics = self.evaluate()

            print(f"\nüìä Results:")
            print(f"   Train Loss: {train_losses['total']:.4f}")
            print(f"   Test PSNR: {test_metrics['psnr']:.2f} dB")
            print(f"   Test SSIM: {test_metrics['ssim']:.4f}")
            print(f"   Test LPIPS: {test_metrics['lpips']:.4f}")

            self.history['train_loss'].append(train_losses['total'])
            for k, v in test_metrics.items():
                self.history[f'test_{k}'].append(v)

            is_best = test_metrics['psnr'] > self.best_psnr
            if is_best:
                self.best_psnr = test_metrics['psnr']

            self.save_checkpoint(epoch, test_metrics, is_best)
            torch.cuda.empty_cache()

        print(f"\n‚úÖ Training completed! Best PSNR: {self.best_psnr:.2f} dB")
        return self.history


In [None]:
# 8. EXPERIMENT RUNNER
# ============================================================================

def run_experiment(train_dataset, test_dataset, loss_config,
                   batch_size=16, num_epochs=100, lr=1e-4):
    """Run single experiment"""

    experiment_name = f"real_synthetic_{loss_config}"
    print(f"\n{'='*70}")
    print(f"üî¨ EXPERIMENT: {experiment_name}")
    print(f"{'='*70}")

    # Create data loaders
    try:
        train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                 shuffle=True, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size,
                                shuffle=False, num_workers=2, pin_memory=True)
    except RuntimeError as e:
        if 'out of memory' in str(e):
            print(f"‚ö†Ô∏è  OOM with batch_size={batch_size}, reducing to {batch_size//2}")
            batch_size = batch_size // 2
            train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                    shuffle=True, num_workers=2, pin_memory=True)
            test_loader = DataLoader(test_dataset, batch_size=batch_size,
                                   shuffle=False, num_workers=2, pin_memory=True)

    model = DCENet(num_iterations=8).to(device)
    loss_fn = CombinedLoss(loss_type=LOSS_CONFIGS[loss_config])
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    save_dir = './checkpoints'
    trainer = Trainer(model, train_loader, test_loader, loss_fn,
                     optimizer, device, save_dir, experiment_name)

    history = trainer.train(num_epochs)
    final_metrics = trainer.evaluate()

    return history, final_metrics, model

In [None]:
# 9. RUN ALL EXPERIMENTS
# ============================================================================

print("\n" + "üéØ"*35)
print("STARTING EXPERIMENTS: Real + Synthetic Combined")
print("üéØ"*35)

all_histories = {}
all_metrics = {}
all_models = {}

for loss_config in LOSS_CONFIGS.keys():
    try:
        history, metrics, model = run_experiment(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            loss_config=loss_config,
            batch_size=16,
            num_epochs=100,
            lr=1e-4
        )

        exp_name = f"real_synthetic_{loss_config}"
        all_histories[exp_name] = history
        all_metrics[exp_name] = metrics
        all_models[loss_config] = model

        # Save results
        Path('./results').mkdir(exist_ok=True)
        results_file = f'./results/{exp_name}_results.json'
        with open(results_file, 'w') as f:
            json.dump({
                'history': {k: [float(v) for v in vals] for k, vals in history.items()},
                'final_metrics': {k: float(v) for k, v in metrics.items()}
            }, f, indent=2)

        print(f"‚úÖ {exp_name} completed!")

        del model
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"‚ùå Error in {loss_config}: {e}")
        continue

In [None]:
# 10. VISUALIZATION - BEST MODELS COMPARISON
# ============================================================================

print("\n" + "="*70)
print("üì∏ CREATING BEST MODELS COMPARISON")
print("="*70)

# Select 3 fixed test samples for comparison
sample_indices = [0, len(test_dataset)//2, len(test_dataset)-1]

# Create figure: 3 rows (samples) x 7 columns (Low + 5 models + Normal)
fig, axes = plt.subplots(3, 7, figsize=(21, 9))

# Column headers
col_headers = ['Low Input', 'L1 Loss', 'Charbonnier',
               'Char+Percep', 'Char+Percep+SSIM',
               'Char+Percep+Color+Exp', 'Ground Truth']

for col, header in enumerate(col_headers):
    axes[0, col].set_title(header, fontsize=10, fontweight='bold')

# Load all models
loaded_models = {}
for loss_name in LOSS_CONFIGS.keys():
    try:
        checkpoint_path = f'./checkpoints/real_synthetic_{loss_name}_best.pth'
        checkpoint = torch.load(checkpoint_path)

        model = DCENet(num_iterations=8).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        loaded_models[loss_name] = model
        print(f"‚úÖ Loaded {loss_name} model")
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not load {loss_name}: {e}")

# Generate comparisons for 3 samples
with torch.no_grad():
    for row, idx in enumerate(sample_indices):
        low, normal = test_dataset[idx]

        # Display low input
        low_np = low.numpy().transpose(1, 2, 0)
        axes[row, 0].imshow(np.clip(low_np, 0, 1))
        axes[row, 0].axis('off')

        # Display enhanced versions from each model
        low_batch = low.unsqueeze(0).to(device)

        for col, loss_name in enumerate(LOSS_CONFIGS.keys(), 1):
            if loss_name in loaded_models:
                enhanced, _ = loaded_models[loss_name](low_batch)
                enhanced_np = enhanced.squeeze().cpu().numpy().transpose(1, 2, 0)
                axes[row, col].imshow(np.clip(enhanced_np, 0, 1))
            else:
                axes[row, col].text(0.5, 0.5, 'N/A', ha='center', va='center')
            axes[row, col].axis('off')

        # Display ground truth
        normal_np = normal.numpy().transpose(1, 2, 0)
        axes[row, 6].imshow(np.clip(normal_np, 0, 1))
        axes[row, 6].axis('off')

plt.suptitle('Best Models Comparison (Same 3 Test Samples)',
             fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('./results/best_models_comparison.png', dpi=150, bbox_inches='tight')
plt.close()

print("‚úÖ Saved comparison: ./results/best_models_comparison.png")

In [None]:
# 11. TRAINING CURVES & METRICS COMPARISON
# ============================================================================

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

metrics = ['train_loss', 'test_psnr', 'test_ssim', 'test_lpips']
titles = ['Training Loss', 'Test PSNR (dB)', 'Test SSIM', 'Test LPIPS']

for idx, (metric, title) in enumerate(zip(metrics, titles)):
    ax = axes[idx // 2, idx % 2]

    for exp_name, history in all_histories.items():
        if metric in history:
            label = exp_name.replace('real_synthetic_', '')
            ax.plot(history[metric], label=label, linewidth=2)

    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./results/training_curves.png', dpi=150, bbox_inches='tight')
plt.close()
print("üìà Saved training curves: ./results/training_curves.png")

# Metrics comparison bar chart
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

loss_names = [k.replace('real_synthetic_', '') for k in all_metrics.keys()]
psnr_vals = [all_metrics[k]['psnr'] for k in all_metrics.keys()]
ssim_vals = [all_metrics[k]['ssim'] for k in all_metrics.keys()]
lpips_vals = [all_metrics[k]['lpips'] for k in all_metrics.keys()]

# PSNR
axes[0].bar(range(len(loss_names)), psnr_vals, color='skyblue')
axes[0].set_xticks(range(len(loss_names)))
axes[0].set_xticklabels(loss_names, rotation=45, ha='right')
axes[0].set_ylabel('PSNR (dB)')
axes[0].set_title('PSNR Comparison')
axes[0].grid(axis='y', alpha=0.3)
for i, v in enumerate(psnr_vals):
    axes[0].text(i, v, f'{v:.2f}', ha='center', va='bottom', fontsize=9)

# SSIM
axes[1].bar(range(len(loss_names)), ssim_vals, color='lightgreen')
axes[1].set_xticks(range(len(loss_names)))
axes[1].set_xticklabels(loss_names, rotation=45, ha='right')
axes[1].set_ylabel('SSIM')
axes[1].set_title('SSIM Comparison')
axes[1].grid(axis='y', alpha=0.3)
for i, v in enumerate(ssim_vals):
    axes[1].text(i, v, f'{v:.4f}', ha='center', va='bottom', fontsize=9)

# LPIPS
axes[2].bar(range(len(loss_names)), lpips_vals, color='salmon')
axes[2].set_xticks(range(len(loss_names)))
axes[2].set_xticklabels(loss_names, rotation=45, ha='right')
axes[2].set_ylabel('LPIPS')
axes[2].set_title('LPIPS Comparison (lower is better)')
axes[2].grid(axis='y', alpha=0.3)
for i, v in enumerate(lpips_vals):
    axes[2].text(i, v, f'{v:.4f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('./results/metrics_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("üìä Saved metrics comparison: ./results/metrics_comparison.png")

In [None]:
# 12. FINAL SUMMARY
# ============================================================================

print("\n" + "="*70)
print("üìä FINAL RESULTS SUMMARY")
print("="*70)

summary_data = []
for exp_name, metrics in all_metrics.items():
    summary_data.append({
        'Loss Function': exp_name.replace('real_synthetic_', ''),
        'PSNR (dB)': f"{metrics['psnr']:.2f}",
        'SSIM': f"{metrics['ssim']:.4f}",
        'LPIPS': f"{metrics['lpips']:.4f}"
    })

df_summary = pd.DataFrame(summary_data)
print("\n" + df_summary.to_string(index=False))

df_summary.to_csv('./results/final_summary.csv', index=False)
print("\nüíæ Saved summary: ./results/final_summary.csv")

# Best models
print("\n" + "="*70)
print("üèÜ BEST PERFORMING MODELS")
print("="*70)

best_psnr = max(all_metrics.items(), key=lambda x: x[1]['psnr'])
best_ssim = max(all_metrics.items(), key=lambda x: x[1]['ssim'])
best_lpips = min(all_metrics.items(), key=lambda x: x[1]['lpips'])

print(f"\nü•á Best PSNR: {best_psnr[0].replace('real_synthetic_', '')}")
print(f"   PSNR: {best_psnr[1]['psnr']:.2f} dB")
print(f"   SSIM: {best_psnr[1]['ssim']:.4f}")
print(f"   LPIPS: {best_psnr[1]['lpips']:.4f}")

print(f"\nü•á Best SSIM: {best_ssim[0].replace('real_synthetic_', '')}")
print(f"   PSNR: {best_ssim[1]['psnr']:.2f} dB")
print(f"   SSIM: {best_ssim[1]['ssim']:.4f}")
print(f"   LPIPS: {best_ssim[1]['lpips']:.4f}")

print(f"\nü•á Best LPIPS: {best_lpips[0].replace('real_synthetic_', '')}")
print(f"   PSNR: {best_lpips[1]['psnr']:.2f} dB")
print(f"   SSIM: {best_lpips[1]['ssim']:.4f}")
print(f"   LPIPS: {best_lpips[1]['lpips']:.4f}")

print("\n" + "="*70)
print("‚úÖ ALL EXPERIMENTS COMPLETED!")
print("="*70)

print(f"""
üìÅ Results saved in:
   - Checkpoints: ./checkpoints/
   - Visualizations: ./results/
   - Best models comparison: ./results/best_models_comparison.png
   - Training curves: ./results/training_curves.png
   - Metrics comparison: ./results/metrics_comparison.png
   - Summary: ./results/final_summary.csv

üéâ Training pipeline completed successfully! üéâ
""")