In [None]:
!pip install lpips
!pip install torchmetrics
!pip install colormath
!pip install torch_fidelity

In [None]:
import torch
print("CUDA Available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Current device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))


In [None]:
import os
import cv2
import gc
import math
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid,save_image
from torch.cuda.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
from google.colab import files
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torch_fidelity
from torch_fidelity import calculate_metrics
# LPIPS, FID, and color metrics
import lpips
from torchmetrics.image.fid import FrechetInceptionDistance
from colormath.color_diff import delta_e_cie2000
from colormath.color_objects import LabColor, sRGBColor


# CUDA configuration
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)



In [None]:
class NoiseScheduler:
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=2e-2):
        self.num_timesteps = num_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

    def add_noise(self, original, noise, timesteps):
        # Remove .to(original.device) since tensors are already on device
        sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
        sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[timesteps]

        # Expand dimensions for broadcasting
        sqrt_alpha_prod = sqrt_alpha_prod[:, None, None, None]
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None, None, None]

        return sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise


In [None]:
class CosineNoiseScheduler(NoiseScheduler):
    def __init__(self, num_timesteps=1000, s=0.008, device='cuda'):
        super().__init__(num_timesteps)
        self.s = s
        self.device = device
        self._build_schedule()

    def _build_schedule(self):
        steps = torch.arange(self.num_timesteps + 1, device=self.device)
        f_t = torch.cos((steps / self.num_timesteps + self.s) / (1 + self.s) * math.pi * 0.5) ** 2
        self.alphas_cumprod = f_t / f_t[0]
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

    def remove_noise(self, x_t, predicted_noise, timesteps):
        """Reverse diffusion process to estimate x_0 from x_t"""
        sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps][:, None, None, None]
        sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[timesteps][:, None, None, None]

        x_0 = (x_t - sqrt_one_minus_alpha_prod * predicted_noise) / sqrt_alpha_prod
        return x_0.clamp(-1, 1)  # Ensure valid pixel range


In [None]:
class TimestepEmbedding(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim * 4)
        )

    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.proj(emb)


In [None]:
class TimestepResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_embed_dim):
        super().__init__()
        self.time_proj = nn.Linear(time_embed_dim, out_channels)  # (256 -> out_channels)
        self.block = nn.Sequential(
            nn.GroupNorm(1, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(1, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        self.time_proj = nn.Linear(time_embed_dim, out_channels)
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, inputs):
        x, temb = inputs
        h = self.block(x)
        h = h + self.time_proj(temb)[:, :, None, None]
        return h + self.skip(x), temb  # Pass through timestep embedding



In [None]:
# class ConditionalUNet(nn.Module):
#     def __init__(self, in_channels=4, base_channels=64, time_embed_dim=64):
#         super().__init__()
#         self.time_embed = TimestepEmbedding(time_embed_dim)

#         self.down1 = nn.Sequential(
#             TimestepResidualBlock(in_channels, base_channels, time_embed_dim),
#             TimestepResidualBlock(base_channels, base_channels, time_embed_dim)
#         )

#     def forward(self, x, cond, t):
#         temb = self.time_embed(t)


In [None]:
class TimestepUNet(nn.Module):
    def __init__(self, in_channels=4, base_channels=64, time_dim=64):
        super().__init__()
        self.time_dim = time_dim
        self.time_embed = nn.Sequential(
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim * 4)
        )

        # Downsample blocks
        self.down1 = nn.ModuleList([
            TimestepResidualBlock(in_channels, base_channels, time_dim*4),
            TimestepResidualBlock(base_channels, base_channels, time_dim*4)
        ])
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = nn.ModuleList([
            TimestepResidualBlock(base_channels, base_channels*2, time_dim*4),
            TimestepResidualBlock(base_channels*2, base_channels*2, time_dim*4)
        ])
        self.pool2 = nn.MaxPool2d(2)

        # Middle blocks
        self.mid = nn.ModuleList([
            TimestepResidualBlock(base_channels*2, base_channels*4, time_dim*4),
            TimestepResidualBlock(base_channels*4, base_channels*4, time_dim*4)
        ])

        # Upsample blocks
        self.up2 = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='nearest'),
            TimestepResidualBlock(base_channels*4 + base_channels*2, base_channels*2, time_dim*4)
        ])

        self.up1 = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='nearest'),
            TimestepResidualBlock(base_channels*2 + base_channels, base_channels, time_dim*4)
        ])

        self.out = nn.Conv2d(base_channels, 3, kernel_size=1)

    def forward(self, x, cond, t):
        temb = self.time_embed(get_timestep_embedding(t, self.time_dim))
        x = torch.cat([x, cond], dim=1)

        # Downsample path
        # Down1 blocks
        for block in self.down1:
            x, _ = block((x, temb))
        d1 = x  # Save before pooling
        x = self.pool1(x)

        # Down2 blocks
        for block in self.down2:
            x, _ = block((x, temb))
        d2 = x  # Save before pooling
        x = self.pool2(x)

        # Middle path
        for block in self.mid:
            x, _ = block((x, temb))

        # Upsample path
        # Up2 blocks
        x = self.up2[0](x)
        x = torch.cat([x, d2], dim=1)  # Use pre-pool d2 (96 channels)
        for block in self.up2[1:]:
            x, _ = block((x, temb))

        # Up1 blocks
        x = self.up1[0](x)
        x = torch.cat([x, d1], dim=1)  # Use pre-pool d1 (48 channels)
        for block in self.up1[1:]:
            x, _ = block((x, temb))

        return self.out(x)


In [None]:
class CIFAR10ColorGrayUpscaled(Dataset):
    def __init__(self, root='./data', train=True, upscale_size=128):
        self.upscale_size = upscale_size
        self.transform = transforms.Compose([
            transforms.Resize((upscale_size, upscale_size)),
            transforms.ToTensor()
        ])
        self.gray_transform = transforms.Grayscale(num_output_channels=1)

        self.dataset = datasets.CIFAR10(
            root=root,
            train=train,
            download=True,
            transform=self.transform
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        color_img, _ = self.dataset[idx]
        pil_image = Image.fromarray(self.dataset.data[idx])
        gray_img = self.transform(self.gray_transform(pil_image))

        return {'color': color_img, 'gray': gray_img}

# Memory-Optimized DataLoader
def get_cifar10_loader(batch_size=64, upscale_size=128, train=True):
    dataset = CIFAR10ColorGrayUpscaled(train=train, upscale_size=upscale_size)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )


class CombinedLoss(nn.Module):
    def __init__(self, lpips_weight=1.0, deltae_weight=1.0, mse_weight=1.0, device='cuda'):
        super().__init__()
        self.lpips_fn = lpips.LPIPS(net='vgg').to(device)
        self.lpips_weight = lpips_weight
        self.deltae_weight = deltae_weight
        self.mse_weight = mse_weight

    def forward(self, pred, target):
        # Assumes input shape is (B, 3, H, W), normalized [-1, 1] for LPIPS
        loss = 0.0

        # MSE
        if self.mse_weight > 0:
            mse_loss = F.mse_loss(pred, target)
            loss += self.mse_weight * mse_loss

        # LPIPS
        if self.lpips_weight > 0:
            lpips_val = self.lpips_fn(pred, target).mean()
            loss += self.lpips_weight * lpips_val

        # ΔE using LAB color difference
        if self.deltae_weight > 0:
            delta_e_loss = self.batch_delta_e_loss(pred, target)
            loss += self.deltae_weight * delta_e_loss

        return loss

    def batch_delta_e_loss(self, pred, target):
        """
        pred and target: (B, 3, H, W) in range [0, 1] for ΔE
        """
        pred_img = (pred.clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy())  # B,H,W,3
        target_img = (target.clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy())

        batch_size = pred_img.shape[0]
        delta_e_total = 0.0

        for b in range(batch_size):
            delta_e_img = self.compute_delta_e(pred_img[b], target_img[b])
            delta_e_total += np.mean(delta_e_img)

        return torch.tensor(delta_e_total / batch_size, device=pred.device)

    def compute_delta_e(self, pred_img, target_img):
        # Ensure inputs are torch tensors before converting
        if isinstance(pred_img, torch.Tensor):
            pred_np = pred_img.permute(1, 2, 0).detach().cpu().numpy()  # [H, W, 3]
        else:
            pred_np = pred_img

        if isinstance(target_img, torch.Tensor):
            target_np = target_img.permute(1, 2, 0).detach().cpu().numpy()
        else:
            target_np = target_img

        pred_np = np.clip(pred_np, 0.0, 1.0)
        target_np = np.clip(target_np, 0.0, 1.0)

        pred_uint8 = (pred_np * 255).astype(np.uint8)
        target_uint8 = (target_np * 255).astype(np.uint8)

        pred_lab = cv2.cvtColor(pred_uint8, cv2.COLOR_RGB2LAB).astype(np.float32)
        target_lab = cv2.cvtColor(target_uint8, cv2.COLOR_RGB2LAB).astype(np.float32)

        delta_e = np.linalg.norm(pred_lab - target_lab, axis=-1)  # [H, W]
        return delta_e


In [None]:
def train_diffusion_model(model, noise_scheduler, criterion, optimizer, num_epochs, device,accumulation_steps=4):
    # Initialize training components
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')

    # Data loaders (optimized for low VRAM)
    train_loader = get_cifar10_loader(batch_size=64)  # Reduced batch size
    val_loader = get_cifar10_loader(batch_size=64, train=False)

    for epoch in range(num_epochs):
        sys.stdout.write(f"\nStarting epoch {epoch+1}\n")
        sys.stdout.write(f"Training batches: {len(train_loader)}\n")
        sys.stdout.flush()

        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0
        running_lpips = 0.0

        optimizer.zero_grad()

        for i, batch in enumerate(train_loader):
            gray = batch['gray'].to(device, non_blocking=True)
            color = batch['color'].to(device, non_blocking=True)

            if i % 1000 == 0:
                print_gpu_mem()

            # Diffusion process
            t = torch.randint(0, noise_scheduler.num_timesteps, (color.size(0),), device=device)
            noise = torch.randn_like(color)
            noised_color = noise_scheduler.add_noise(color, noise, t)

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                pred_noise = model(noised_color, gray, t)
                loss = criterion(pred_noise, noise) / accumulation_steps

            scaler.scale(loss).backward()

            # Gradient accumulation
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            # Update running metrics
            running_loss += loss.item() * accumulation_steps

            # Calculate metrics on denoised version
            with torch.no_grad():
                denoised = noise_scheduler.remove_noise(noised_color, pred_noise, t)
                batch_psnr, batch_ssim = compute_metrics(denoised, color)
                batch_lpips = lpips_fn(denoised, color).mean().item()

            running_psnr += batch_psnr
            running_ssim += batch_ssim
            running_lpips += batch_lpips

            # Progress update
            sys.stdout.write(
                f"\r[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(train_loader)}] "
                f"Loss: {loss.item() * accumulation_steps:.4f} PSNR: {batch_psnr:.2f} "
                f"SSIM: {batch_ssim:.4f} LPIPS: {batch_lpips:.4f}"
            )
            sys.stdout.flush()

            if (i + 1) % 100 == 0:
                idx = random.randint(0, gray.size(0) - 1)
                show_images(gray, denoised, color, idx)

        # Epoch statistics
        avg_loss = running_loss / len(train_loader)
        avg_psnr = running_psnr / len(train_loader)
        avg_ssim = running_ssim / len(train_loader)
        avg_lpips = running_lpips / len(train_loader)

        sys.stdout.write('\n')
        sys.stdout.flush()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                gray = batch['gray'].to(device)
                color = batch['color'].to(device)


                t = torch.full((color.size(0),), noise_scheduler.num_timesteps//2, device=device)
                noise = torch.randn_like(color)
                noised_color = noise_scheduler.add_noise(color, noise, t)

                pred_noise = model(noised_color, gray, t)
                val_loss += criterion(pred_noise, noise).item()

                if i == 0:
                    denoised = noise_scheduler.remove_noise(noised_color, pred_noise, t)
                    show_images(gray, denoised, color, idx=0)

        val_loss /= len(val_loader)
        print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, epoch, path="best_model.pth")
            sys.stdout.write("Best model saved!\n")

        torch.cuda.empty_cache()


In [None]:
def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def compute_metrics(pred, target):
    # PSNR
    mse = torch.mean((pred - target) ** 2)
    psnr = 10 * torch.log10(1.0 / mse)

    # SSIM
    ssim = torchmetrics.functional.structural_similarity_index_measure(
        pred, target, data_range=1.0
    )
    return psnr.item(), ssim.item()

def show_images(gray, pred, color, idx=0):
    # Convert tensors to numpy arrays
    gray_img = gray[idx].cpu().permute(1, 2, 0).numpy()
    pred_img = pred[idx].detach().cpu().permute(1, 2, 0).numpy()
    color_img = color[idx].cpu().permute(1, 2, 0).numpy()

    # Plot images
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(gray_img.squeeze(), cmap='gray')
    ax[0].set_title('Input (Gray)')
    ax[1].imshow(pred_img)
    ax[1].set_title('Predicted')
    ax[2].imshow(color_img)
    ax[2].set_title('Ground Truth')
    plt.show()

def print_gpu_mem():
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
        print(f"Reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")


In [None]:
lpips_fn = lpips.LPIPS(net='alex').to(device)
fid_metric = FrechetInceptionDistance(feature=2048).to(device)

In [None]:
# Metric Functions
def compute_metrics(pred, target):
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    psnr_batch, ssim_batch = [], []
    for i in range(pred_np.shape[0]):
        pred_img = np.transpose(pred_np[i], (1, 2, 0))
        target_img = np.transpose(target_np[i], (1, 2, 0))
        psnr_batch.append(psnr(target_img, pred_img, data_range=1.0))
        ssim_batch.append(ssim(target_img, pred_img, channel_axis=-1, data_range=1.0))
    return np.mean(psnr_batch), np.mean(ssim_batch)

# Visualization
def show_images(gray, pred, color, idx, title=""):
    fig, axs = plt.subplots(1, 3, figsize=(9, 3))
    axs[0].imshow(gray[idx].squeeze().cpu().numpy(), cmap='gray')
    axs[0].set_title(f"Grayscale {title}")
    axs[1].imshow(pred[idx].permute(1, 2, 0).detach().cpu().numpy())
    axs[1].set_title(f"Predicted Color {title}")
    axs[2].imshow(color[idx].permute(1, 2, 0).cpu().numpy())
    axs[2].set_title(f"Ground Truth {title}")
    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Save Checkpoint
def save_checkpoint(model, optimizer, epoch, path='best_model.pth'):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }, path)
    files.download(path)

def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    Modified to automatically handle device placement
    """
    device = timesteps.device  # Get device from input tensor
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1, 0, 0))

    return emb



In [None]:
# Initialize components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TimestepUNet(base_channels=48, time_dim=64).to(device)
noise_scheduler = CosineNoiseScheduler(num_timesteps=1000, device=device)
criterion = CombinedLoss(lpips_weight=0.5, deltae_weight=2.0, mse_weight=0.1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Start training
train_diffusion_model(
    model=model,
    noise_scheduler=noise_scheduler,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=20,
    device=device,
    accumulation_steps=4
)

In [None]:
# Save final model
final_path = "/content/final_model.pth"
save_checkpoint(model, optimizer, 20, path=final_path)
sys.stdout.write("Final model saved!\n")

# Download to local machine
from google.colab import files
files.download(final_path)

In [None]:
def train_dual_diffusion_models(model_noise, model_rgb, noise_scheduler,
                                criterion_noise, criterion_rgb,
                                optimizer_noise, optimizer_rgb,
                                num_epochs, device, accumulation_steps=4):
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss_noise = float('inf')
    best_val_loss_rgb = float('inf')

    train_loader = get_cifar10_loader(batch_size=64)
    val_loader = get_cifar10_loader(batch_size=64, train=False)

    for epoch in range(num_epochs):
        sys.stdout.write(f"\nStarting epoch {epoch+1}\n")
        sys.stdout.flush()

        model_noise.train()
        model_rgb.train()

        running_loss_noise = 0.0
        running_loss_rgb = 0.0
        running_psnr = 0.0
        running_ssim = 0.0
        running_lpips = 0.0

        optimizer_noise.zero_grad()
        optimizer_rgb.zero_grad()

        for i, batch in enumerate(train_loader):
            gray = batch['gray'].to(device, non_blocking=True)
            color = batch['color'].to(device, non_blocking=True)

            if i % 1000 == 0:
                print_gpu_mem()

            t = torch.randint(0, noise_scheduler.num_timesteps, (color.size(0),), device=device)
            noise = torch.randn_like(color)
            noised_color = noise_scheduler.add_noise(color, noise, t)

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                pred_noise = model_noise(noised_color, gray, t)
                loss_noise = criterion_noise(pred_noise, noise) / accumulation_steps

                pred_noise_rgb = model_rgb(noised_color, gray, t)
                denoised_rgb = noise_scheduler.remove_noise(noised_color, pred_noise_rgb, t)
                loss_rgb = criterion_rgb(denoised_rgb, color) / accumulation_steps

            scaler.scale(loss_noise).backward()
            scaler.scale(loss_rgb).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer_noise)
                scaler.step(optimizer_rgb)
                scaler.update()
                optimizer_noise.zero_grad()
                optimizer_rgb.zero_grad()

            running_loss_noise += loss_noise.item() * accumulation_steps
            running_loss_rgb += loss_rgb.item() * accumulation_steps

            with torch.no_grad():
                denoised = noise_scheduler.remove_noise(noised_color, pred_noise, t)
                batch_psnr, batch_ssim = compute_metrics(denoised, color)
                batch_lpips = lpips_fn(denoised, color).mean().item()

            running_psnr += batch_psnr
            running_ssim += batch_ssim
            running_lpips += batch_lpips

            sys.stdout.write(
                f"\r[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(train_loader)}] "
                f"LossN: {loss_noise.item() * accumulation_steps:.4f} "
                f"LossRGB: {loss_rgb.item() * accumulation_steps:.4f} "
                f"PSNR: {batch_psnr:.2f} SSIM: {batch_ssim:.4f} LPIPS: {batch_lpips:.4f}"
            )
            sys.stdout.flush()

            if (i + 1) % 100 == 0:
                idx = random.randint(0, gray.size(0) - 1)
                show_images(gray, noise_scheduler.remove_noise(noised_color, pred_noise, t), color, idx, title="Noise Model")
                show_images(gray, noise_scheduler.remove_noise(noised_color, pred_noise_rgb, t), color, idx, title="RGB Model")

        avg_loss_noise = running_loss_noise / len(train_loader)
        avg_loss_rgb = running_loss_rgb / len(train_loader)
        avg_psnr = running_psnr / len(train_loader)
        avg_ssim = running_ssim / len(train_loader)
        avg_lpips = running_lpips / len(train_loader)

        sys.stdout.write('\n')
        sys.stdout.flush()

        # Validation
        model_noise.eval()
        model_rgb.eval()
        val_loss_noise = 0.0
        val_loss_rgb = 0.0

        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                gray = batch['gray'].to(device)
                color = batch['color'].to(device)
                t = torch.full((color.size(0),), noise_scheduler.num_timesteps//2, device=device)
                noise = torch.randn_like(color)
                noised_color = noise_scheduler.add_noise(color, noise, t)

                pred_noise = model_noise(noised_color, gray, t)
                val_loss_noise += criterion_noise(pred_noise, noise).item()
                denoised_noise = noise_scheduler.remove_noise(noised_color, pred_noise, t)

                pred_noise_rgb = model_rgb(noised_color, gray, t)
                denoised_rgb = noise_scheduler.remove_noise(noised_color, pred_noise_rgb, t)
                val_loss_rgb += criterion_rgb(denoised_rgb, color).item()

                if i == 0:
                    show_images(gray, denoised_noise, color, idx=0, title="Val Noise Model")
                    show_images(gray, denoised_rgb, color, idx=0, title="Val RGB Model")

        val_loss_noise /= len(val_loader)
        val_loss_rgb /= len(val_loader)
        print(f"[Epoch {epoch+1}] Val Loss (Noise): {val_loss_noise:.4f} | Val Loss (RGB): {val_loss_rgb:.4f}")

        if val_loss_noise < best_val_loss_noise:
            best_val_loss_noise = val_loss_noise
            save_checkpoint(model_noise, optimizer_noise, epoch, path="best_model_noise.pth")
            print("Saved best Noise Model")

        if val_loss_rgb < best_val_loss_rgb:
            best_val_loss_rgb = val_loss_rgb
            save_checkpoint(model_rgb, optimizer_rgb, epoch, path="best_model_rgb.pth")
            print("Saved best RGB Model")

        torch.cuda.empty_cache()


In [None]:
model_noise = TimestepUNet(base_channels=48, time_dim=64).to(device)
model_rgb = TimestepUNet(base_channels=48, time_dim=64).to(device)

noise_scheduler = CosineNoiseScheduler(num_timesteps=1000, device=device)

# LPIPS and DeltaE are used by both criterions
lpips_fn = lpips.LPIPS(net='alex').to(device).eval()

# Criterion for noise prediction
criterion_noise = CombinedLoss(lpips_weight=0.0, deltae_weight=0.0, mse_weight=1.0).to(device)

# Criterion for RGB prediction
criterion_rgb = CombinedLoss(lpips_weight=0.5, deltae_weight=2.0, mse_weight=0.1).to(device)

optimizer_noise = torch.optim.AdamW(model_noise.parameters(), lr=1e-4)
optimizer_rgb = torch.optim.AdamW(model_rgb.parameters(), lr=1e-4)

# Start dual-model training
train_dual_diffusion_models(
    model_noise=model_noise,
    model_rgb=model_rgb,
    noise_scheduler=noise_scheduler,
    criterion_noise=criterion_noise,
    criterion_rgb=criterion_rgb,
    optimizer_noise=optimizer_noise,
    optimizer_rgb=optimizer_rgb,
    num_epochs=20,
    device=device,
    accumulation_steps=4
)

In [None]:
from tqdm import tqdm
def save_generated_images(model, loader, device, save_dir, noise_scheduler, t_value=500):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    image_count = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Generating images for {save_dir}"):
            gray = batch['gray'].to(device)
            color = batch['color'].to(device)
            noise = torch.randn_like(color)
            t = torch.full((color.size(0),), t_value, device=device)

            noised_color = noise_scheduler.add_noise(color, noise, t)
            pred_noise = model(noised_color, gray, t)
            pred_color = noise_scheduler.remove_noise(noised_color, pred_noise, t)

            for img in pred_color:
                save_image(img, f"{save_dir}/{image_count:06}.png")
                image_count += 1

def validate_models(model_noise, model_rgb, val_loader, device):
    print("Saving generated images...")
    save_generated_images(model_noise, val_loader, device, "/content/gen_noise", noise_scheduler)
    save_generated_images(model_rgb, val_loader, device, "/content/gen_rgb", noise_scheduler)

    print("\nCalculating FID...")
    fid_noise = calculate_metrics(
        input1="/content/gen_noise",
        input2="/content/real_val",
        cuda=torch.cuda.is_available(),
        isc=False, fid=True, kid=True, ppl=False,
        verbose=False
    )
    fid_rgb = calculate_metrics(
        input1="/content/gen_rgb",
        input2="/content/real_val",
        cuda=torch.cuda.is_available(),
        isc=False, fid=True, kid=True, ppl=False,
        verbose=False
    )
    print("Noise Model FID:", fid_noise['frechet_inception_distance'])
    print("RGB Model FID:", fid_rgb['frechet_inception_distance'])
    print("Noise Model KID:", fid_noise['kernel_inception_distance_mean'])
    print("RGB Model KID:", fid_rgb['kernel_inception_distance_mean'])

    print("\nCalculating perceptual and pixel metrics...")
    lpips_fn = lpips.LPIPS(net='alex').to(device).eval()
    criterion = CombinedLoss().to(device)

    def calculate_other_metrics(model):
        model.eval()
        lpips_vals, psnr_vals, ssim_vals, deltae_vals = [], [], [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Metric Evaluation"):
                gray = batch['gray'].to(device)
                color = batch['color'].to(device)
                t = torch.full((color.size(0),), 500, device=device)
                noise = torch.randn_like(color)
                noised_color = noise_scheduler.add_noise(color, noise, t)
                pred_noise = model(noised_color, gray, t)
                pred_color = noise_scheduler.remove_noise(noised_color, pred_noise, t)

                lpips_vals.extend(lpips_fn(pred_color, color).cpu().numpy())
                psnr, ssim = compute_metrics(pred_color, color)
                psnr_vals.append(psnr)
                ssim_vals.append(ssim)
                deltae = criterion.batch_delta_e_loss(pred_color, color).cpu().numpy()
                deltae_vals.append(deltae)

        return {
            'LPIPS': np.mean(lpips_vals),
            'PSNR': np.mean(psnr_vals),
            'SSIM': np.mean(ssim_vals),
            'DeltaE': np.mean(deltae_vals),
        }

    metrics_noise = calculate_other_metrics(model_noise)
    metrics_rgb = calculate_other_metrics(model_rgb)

    print("\n--- Additional Metrics ---")
    for k, v in metrics_noise.items():
        print(f"Noise Model {k}: {v:.4f}")
    for k, v in metrics_rgb.items():
        print(f"RGB Model {k}: {v:.4f}")


model_noise = TimestepUNet(base_channels=48, time_dim=64).to(device)
model_rgb = TimestepUNet(base_channels=48, time_dim=64).to(device)

model_noise.load_state_dict(torch.load("/content/best_model_noise.pth")['model_state_dict'])
model_rgb.load_state_dict(torch.load("/content/best_model_rgb.pth")['model_state_dict'])

val_loader = get_cifar10_loader(batch_size=64, train=False)
save_generated_images(model_rgb, val_loader, device, "/content/real_val", noise_scheduler=CosineNoiseScheduler())

validate_models(model_noise, model_rgb, val_loader, device)