<a href="https://colab.research.google.com/github/Monson2002/IT585-Advanced_ML_Project/blob/main/Diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget -O sony.zip "https://storage.googleapis.com/isl-datasets/SID/Sony.zip"

--2025-05-05 18:01:55--  https://storage.googleapis.com/isl-datasets/SID/Sony.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.101.207, 142.250.141.207, 142.251.2.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.101.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26926662016 (25G) [application/zip]
Saving to: ‘sony.zip’


2025-05-05 18:03:10 (342 MB/s) - ‘sony.zip’ saved [26926662016/26926662016]



In [None]:
!pip install torch torchvision rawpy numpy

Collecting rawpy
  Downloading rawpy-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x8

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# # The following code will only execute
# # successfully when compression is complete

# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("monsonrejiverghese/aml-sld")

# print("Path to dataset files:", path)

In [None]:
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d monsonrejiverghese/aml-sld
# Example: !kaggle datasets download -d zynicide/wine-reviews

Dataset URL: https://www.kaggle.com/datasets/monsonrejiverghese/aml-sld
License(s): Community Data License Agreement - Permissive - Version 1.0
404 Client Error: Not Found for url: https://www.kaggle.com/api/v1/datasets/download/monsonrejiverghese/aml-sld?raw=false


In [None]:
import os
import numpy as np
import rawpy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import random
from torch.amp import GradScaler, autocast
from google.colab import drive
import zipfile
import io

# Mount Google Drive
drive.mount('/content/drive')

# Set PyTorch CUDA memory management to reduce fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Clear GPU memory
torch.cuda.empty_cache()

# Extract SID dataset from Google Drive without saving zip locally
def extract_zip_from_drive(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted dataset to {extract_to}")

zip_path = '/content/drive/MyDrive/Sony.zip'  # Adjust path as needed
extract_to = '/content/aml-sld'
os.makedirs(extract_to, exist_ok=True)
extract_zip_from_drive(zip_path, extract_to)

# Custom Dataset for SID (×100 amplification only)
class SIDDataset(Dataset):
    def __init__(self, txt_file, data_root='/content/aml-sld', patch_size=256, target_amplification=100, is_train=True):
        self.data_root = data_root
        self.patch_size = patch_size
        self.target_amplification = target_amplification
        self.is_train = is_train
        self.pairs = []

        # Load pairs from the .txt file
        with open(txt_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                short_path, long_path, iso, fstop = line.strip().split()
                short_exp = float(short_path.split('_')[-1].replace('s.ARW', ''))
                long_exp = float(long_path.split('_')[-1].replace('s.ARW', ''))
                if short_exp > 0:
                    amplification = long_exp / short_exp
                    if 90 <= amplification <= 110:
                        self.pairs.append((short_path, long_path, short_exp, long_exp))

        # Split into train and test (90% train, 10% test)
        self.pairs = sorted(self.pairs, key=lambda x: x[0])
        split_idx = int(0.9 * len(self.pairs))
        if is_train:
            self.pairs = self.pairs[:split_idx]
        else:
            self.pairs = self.pairs[split_idx:]

        # If training, oversample ×100 pairs
        if self.is_train:
            x100_pairs = self.pairs[:]
            while len(self.pairs) < 1000:
                self.pairs.extend(random.choices(x100_pairs, k=len(x100_pairs)))

        print(f"{'Training' if is_train else 'Testing'} dataset: Found {len(self.pairs)} pairs with amplification ratio ~{target_amplification}")

        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=90),
        ]) if is_train else None

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        short_path, long_path, short_exp, long_exp = self.pairs[idx]

        # Load raw images
        short_raw = rawpy.imread(os.path.join(self.data_root, short_path))
        long_raw = rawpy.imread(os.path.join(self.data_root, long_path))

        short_bayer = short_raw.raw_image_visible.astype(np.float32) - 512
        long_bayer = long_raw.raw_image_visible.astype(np.float32) - 512

        short_bayer = np.clip(short_bayer / (4095 - 512), 0, 1)
        long_bayer = np.clip(long_bayer / (4095 - 512), 0, 1)

        h, w = short_bayer.shape
        short_packed = np.zeros((h//2, w//2, 4), dtype=np.float32)
        long_packed = np.zeros((h//2, w//2, 4), dtype=np.float32)

        short_packed[..., 0] = short_bayer[0::2, 0::2]
        short_packed[..., 1] = short_bayer[0::2, 1::2]
        short_packed[..., 2] = short_bayer[1::2, 0::2]
        short_packed[..., 3] = short_bayer[1::2, 1::2]
        long_packed[..., 0] = long_bayer[0::2, 0::2]
        long_packed[..., 1] = long_bayer[0::2, 1::2]
        long_packed[..., 2] = long_bayer[1::2, 0::2]
        long_packed[..., 3] = long_bayer[1::2, 1::2]

        h, w, _ = short_packed.shape
        if self.is_train:
            i = np.random.randint(0, h - self.patch_size + 1)
            j = np.random.randint(0, w - self.patch_size + 1)
        else:
            i = (h - self.patch_size) // 2
            j = (w - self.patch_size) // 2
        short_patch = short_packed[i:i+self.patch_size, j:j+self.patch_size, :]
        long_patch = long_packed[i:i+self.patch_size, j:j+self.patch_size, :]

        short_patch = torch.from_numpy(short_patch).permute(2, 0, 1)
        long_patch = torch.from_numpy(long_patch).permute(2, 0, 1)

        if self.transform:
            stacked = torch.stack([short_patch, long_patch], dim=0)
            stacked = self.transform(stacked)
            short_patch, long_patch = stacked[0], stacked[1]

        return short_patch, long_patch, short_exp, long_exp

# Perceptual Loss using VGG16
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.MSELoss()

    def forward(self, x, y):
        x_rgb = x[:, :3, :, :].mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
        y_rgb = y[:, :3, :, :].mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
        x_vgg = self.vgg(x_rgb)
        y_vgg = self.vgg(y_rgb)
        return self.criterion(x_vgg, y_vgg)

# NAFNet Architecture
class SimpleGate(nn.Module):
    def __init__(self):
        super(SimpleGate, self).__init__()

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class SimplifiedChannelAttention(nn.Module):
    def __init__(self, channel):
        super(SimplifiedChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv2d(channel, channel, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y)
        y = self.sigmoid(y)
        return x * y

class NAFBlock(nn.Module):
    def __init__(self, c, drop_path_rate=0.0):
        super(NAFBlock, self).__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.conv2 = nn.Conv2d(c, c * 2, 3, padding=1)
        self.conv3 = nn.Conv2d(c, c, 3, padding=1)
        self.norm1 = nn.LayerNorm(c)
        self.norm2 = nn.LayerNorm(c * 2)
        self.sca = SimplifiedChannelAttention(c)
        self.sg = SimpleGate()
        self.drop_path = nn.Identity() if drop_path_rate == 0.0 else nn.Dropout(drop_path_rate)

        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='relu')
        if self.conv1.bias is not None:
            nn.init.constant_(self.conv1.bias, 0)
        if self.conv2.bias is not None:
            nn.init.constant_(self.conv2.bias, 0)
        if self.conv3.bias is not None:
            nn.init.constant_(self.conv3.bias, 0)

    def forward(self, x):
        residual = x
        x = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x = self.conv1(x)
        x = self.sca(x)
        x = self.conv2(x)
        x = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x = self.sg(x)
        x = self.conv3(x)
        x = self.drop_path(x) + residual
        return x

class NAFNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, width=32, enc_blocks=[2, 2, 2, 2], dec_blocks=[2, 2, 2, 2], middle_blocks=2, drop_path_rate=0.0):
        super(NAFNet, self).__init__()
        self.width = width

        self.conv_in = nn.Conv2d(in_channels, width, 3, padding=1)

        self.enc1 = nn.ModuleList([NAFBlock(width, drop_path_rate) for _ in range(enc_blocks[0])])
        self.down1 = nn.Conv2d(width, width * 2, 3, stride=2, padding=1)
        self.enc2 = nn.ModuleList([NAFBlock(width * 2, drop_path_rate) for _ in range(enc_blocks[1])])
        self.down2 = nn.Conv2d(width * 2, width * 4, 3, stride=2, padding=1)
        self.enc3 = nn.ModuleList([NAFBlock(width * 4, drop_path_rate) for _ in range(enc_blocks[2])])
        self.down3 = nn.Conv2d(width * 4, width * 8, 3, stride=2, padding=1)
        self.enc4 = nn.ModuleList([NAFBlock(width * 8, drop_path_rate) for _ in range(enc_blocks[3])])
        self.down4 = nn.Conv2d(width * 8, width * 16, 3, stride=2, padding=1)

        self.middle = nn.ModuleList([NAFBlock(width * 16, drop_path_rate) for _ in range(middle_blocks)])

        self.up1 = nn.ConvTranspose2d(width * 16, width * 8, 4, stride=2, padding=1)
        self.dec1 = nn.ModuleList([NAFBlock(width * 8, drop_path_rate) for _ in range(dec_blocks[0])])
        self.up2 = nn.ConvTranspose2d(width * 8, width * 4, 4, stride=2, padding=1)
        self.dec2 = nn.ModuleList([NAFBlock(width * 4, drop_path_rate) for _ in range(dec_blocks[1])])
        self.up3 = nn.ConvTranspose2d(width * 4, width * 2, 4, stride=2, padding=1)
        self.dec3 = nn.ModuleList([NAFBlock(width * 2, drop_path_rate) for _ in range(dec_blocks[2])])
        self.up4 = nn.ConvTranspose2d(width * 2, width, 4, stride=2, padding=1)
        self.dec4 = nn.ModuleList([NAFBlock(width, drop_path_rate) for _ in range(dec_blocks[3])])

        self.conv_out = nn.Conv2d(width, out_channels, 3, padding=1)

        nn.init.kaiming_normal_(self.conv_in.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv_out.weight, mode='fan_in', nonlinearity='relu')
        if self.conv_in.bias is not None:
            nn.init.constant_(self.conv_in.bias, 0)
        if self.conv_out.bias is not None:
            nn.init.constant_(self.conv_out.bias, 0)

    def forward(self, x):
        x = self.conv_in(x)
        e1 = x
        for block in self.enc1:
            e1 = block(e1)
        e2 = self.down1(e1)
        for block in self.enc2:
            e2 = block(e2)
        e3 = self.down2(e2)
        for block in self.enc3:
            e3 = block(e3)
        e4 = self.down3(e3)
        for block in self.enc4:
            e4 = block(e4)
        m = self.down4(e4)
        for block in self.middle:
            m = block(m)
        d1 = self.up1(m)
        d1 = d1 + e4
        for block in self.dec1:
            d1 = block(d1)
        d2 = self.up2(d1)
        d2 = d2 + e3
        for block in self.dec2:
            d2 = block(d2)
        d3 = self.up3(d2)
        d3 = d3 + e2
        for block in self.dec3:
            d3 = block(d3)
        d4 = self.up4(d3)
        d4 = d4 + e1
        for block in self.dec4:
            d4 = block(d4)
        out = self.conv_out(d4)
        return out

# Diffusion Model Training Function
def train_diffusion_model():
    data_root = "/content/aml-sld"
    train_txt = os.path.join(data_root, "Sony_train_list.txt")

    train_dataset = SIDDataset(train_txt, data_root=data_root, patch_size=256, target_amplification=100, is_train=True)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        raise RuntimeError("GPU not available")
    print(f"Using device: {device}")

    model = NAFNet(in_channels=4, out_channels=4, width=32, enc_blocks=[2, 2, 2, 2], dec_blocks=[2, 2, 2, 2], middle_blocks=2, drop_path_rate=0.0).to(device)
    criterion_l1 = nn.L1Loss()
    criterion_perceptual = PerceptualLoss(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    scaler = GradScaler('cuda')

    # Diffusion parameters
    num_timesteps = 1000
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

    num_epochs = 100
    accum_steps = 8

    print("Starting diffusion model training...")
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()

        for i, (short_img, long_img, _, _) in enumerate(train_loader):
            short_img = short_img.to(device, non_blocking=True)
            long_img = long_img.to(device, non_blocking=True)

            # Sample random timesteps
            t = torch.randint(0, num_timesteps, (short_img.size(0),), device=device)

            # Add Gaussian noise
            noise = torch.randn_like(long_img)
            sqrt_alpha_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
            sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
            noisy_img = sqrt_alpha_t * long_img + sqrt_one_minus_alpha_t * noise

            with autocast('cuda'):
                predicted_noise = model(noisy_img)
                loss_l1 = criterion_l1(predicted_noise, noise)
                loss_perceptual = criterion_perceptual(predicted_noise, noise)
                loss = 0.9 * loss_l1 + 0.1 * loss_perceptual
                loss = loss / accum_steps

            running_loss += loss.item() * accum_steps

            scaler.scale(loss).backward()

            if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        scheduler.step()
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

        torch.cuda.empty_cache()

    torch.save(model.state_dict(), "/content/diffusion_model.pth")
    print("Training completed and model saved.")

if __name__ == "__main__":
    train_diffusion_model()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Extracted dataset to /content/aml-sld
Training dataset: Found 1056 pairs with amplification ratio ~100
Using device: cuda:0


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 160MB/s]


Starting diffusion model training...
Epoch [1/100], Loss: 1.7574
Epoch [2/100], Loss: 0.4736
Epoch [3/100], Loss: 0.2997
Epoch [4/100], Loss: 0.2551
Epoch [5/100], Loss: 0.2389
Epoch [6/100], Loss: 0.2226
Epoch [7/100], Loss: 0.2165
Epoch [8/100], Loss: 0.2156
Epoch [9/100], Loss: 0.2253
Epoch [10/100], Loss: 0.2044
Epoch [11/100], Loss: 0.2018
Epoch [12/100], Loss: 0.1900
Epoch [13/100], Loss: 0.1825
Epoch [14/100], Loss: 0.1740
Epoch [15/100], Loss: 0.1727
Epoch [16/100], Loss: 0.1611
Epoch [17/100], Loss: 0.1489
Epoch [18/100], Loss: 0.1434
Epoch [19/100], Loss: 0.1293
Epoch [20/100], Loss: 0.1269
Epoch [21/100], Loss: 0.1273
Epoch [22/100], Loss: 0.1188
Epoch [23/100], Loss: 0.1259
Epoch [24/100], Loss: 0.1122
Epoch [25/100], Loss: 0.1283
Epoch [26/100], Loss: 0.1239
Epoch [27/100], Loss: 0.1135
Epoch [28/100], Loss: 0.1114
Epoch [29/100], Loss: 0.1121
Epoch [30/100], Loss: 0.1101
Epoch [31/100], Loss: 0.1056
Epoch [32/100], Loss: 0.1054
Epoch [33/100], Loss: 0.1171
Epoch [34/100],

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.14.3 torchmetrics-1.7.1


In [None]:
import os
import numpy as np
import rawpy
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from google.colab import drive
import zipfile

# Test Function for Diffusion Model
def test_diffusion_model():
    data_root = "/content/aml-sld"
    test_txt = os.path.join(data_root, "Sony_test_list.txt")
    output_file = "/content/diffusion_test_results.txt"

    test_dataset = SIDDataset(test_txt, data_root=data_root, patch_size=256, target_amplification=100)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        raise RuntimeError("GPU not available")
    print(f"Using device: {device}")

    model = NAFNet(in_channels=4, out_channels=4, width=32, enc_blocks=[2, 2, 2, 2], dec_blocks=[2, 2, 2, 2], middle_blocks=2, drop_path_rate=0.0).to(device)
    model.load_state_dict(torch.load("/content/diffusion_model.pth", weights_only=True))
    model.eval()

    # Initialize metrics
    psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='vgg', normalize=True).to(device)

    # Diffusion parameters (reduced to 20 timesteps)
    num_timesteps = 20
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    alpha_bar_t = alphas_cumprod
    one_minus_alpha_bar_t = 1.0 - alphas_cumprod
    sqrt_one_minus_alpha_bar_t = torch.sqrt(one_minus_alpha_bar_t)

    total_psnr, total_ssim, total_lpips = 0.0, 0.0, 0.0
    num_samples = 0

    # Open text file to save results
    with open(output_file, 'w') as f:
        f.write("Diffusion Model Test Results\n")
        f.write("Sample | PSNR | SSIM | LPIPS\n")
        f.write("-" * 40 + "\n")

        print("Starting diffusion model testing...")
        with torch.no_grad():
            for i, (short_img, long_img, _, _) in enumerate(test_loader):
                short_img = short_img.to(device, non_blocking=True)
                long_img = long_img.to(device, non_blocking=True)

                # Initialize with pure noise (simulate x_T)
                x_t = torch.randn_like(long_img)

                # Reverse diffusion process
                for t in range(num_timesteps - 1, -1, -1):
                    t_tensor = torch.full((1,), t, device=device, dtype=torch.long)

                    # Predict noise
                    predicted_noise = model(x_t)

                    # Compute coefficients
                    alpha_t = alphas[t]
                    alpha_bar_t_val = alpha_bar_t[t]
                    sqrt_one_minus_alpha_bar_t_val = sqrt_one_minus_alpha_bar_t[t]

                    # Denoise step
                    noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
                    x_t = (1.0 / torch.sqrt(alpha_t)) * (x_t - ((1 - alpha_t) / sqrt_one_minus_alpha_bar_t_val) * predicted_noise) + torch.sqrt(1 - alpha_t) * noise
                    x_t = torch.clamp(x_t, 0, 1)

                # Final denoised image
                pred_x0 = x_t

                # Convert to RGB-like for LPIPS
                pred_x0_rgb = bayer_to_rgb(pred_x0)
                long_img_rgb = bayer_to_rgb(long_img)

                # Compute metrics
                psnr = psnr_metric(pred_x0, long_img)
                ssim = ssim_metric(pred_x0, long_img)
                lpips = lpips_metric(pred_x0_rgb, long_img_rgb)

                total_psnr += psnr.item()
                total_ssim += ssim.item()
                total_lpips += lpips.item()
                num_samples += 1

                # Write per-sample metrics to file and print
                result_line = f"{i+1:03d} | {psnr.item():.4f} | {ssim.item():.4f} | {lpips.item():.4f}\n"
                f.write(result_line)
                print(f"Sample [{i+1}/{len(test_loader)}], PSNR: {psnr:.4f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")

        # Compute and write average metrics
        avg_psnr = total_psnr / num_samples
        avg_ssim = total_ssim / num_samples
        avg_lpips = total_lpips / num_samples

        f.write("-" * 40 + "\n")
        f.write(f"Average PSNR: {avg_psnr:.4f}\n")
        f.write(f"Average SSIM: {avg_ssim:.4f}\n")
        f.write(f"Average LPIPS: {avg_lpips:.4f}\n")

    print(f"\nTest Results saved to {output_file}:")
    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average LPIPS: {avg_lpips:.4f}")

if __name__ == "__main__":
    test_diffusion_model()

Training dataset: Found 1141 pairs with amplification ratio ~100
Using device: cuda:0
Starting diffusion model testing...
Sample [1/1141], PSNR: 10.8590, SSIM: 0.0964, LPIPS: 0.7473
Sample [2/1141], PSNR: 9.4116, SSIM: 0.0522, LPIPS: 0.7390
Sample [3/1141], PSNR: 11.0093, SSIM: 0.0929, LPIPS: 0.7423
Sample [4/1141], PSNR: 8.9169, SSIM: 0.0280, LPIPS: 0.7824
Sample [5/1141], PSNR: 9.5660, SSIM: 0.0542, LPIPS: 0.7282
Sample [6/1141], PSNR: 8.9662, SSIM: 0.0310, LPIPS: 0.7749
Sample [7/1141], PSNR: 9.7783, SSIM: 0.0533, LPIPS: 0.7428
Sample [8/1141], PSNR: 9.1156, SSIM: 0.0350, LPIPS: 0.7763
Sample [9/1141], PSNR: 10.4303, SSIM: 0.0773, LPIPS: 0.7375
Sample [10/1141], PSNR: 11.5059, SSIM: 0.0781, LPIPS: 0.7426
Sample [11/1141], PSNR: 12.4158, SSIM: 0.1602, LPIPS: 0.7238
Sample [12/1141], PSNR: 10.7593, SSIM: 0.1140, LPIPS: 0.7843
Sample [13/1141], PSNR: 10.5527, SSIM: 0.0991, LPIPS: 0.7500
Sample [14/1141], PSNR: 9.7275, SSIM: 0.0646, LPIPS: 0.7566
Sample [15/1141], PSNR: 11.6017, SSIM: 0

KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import rawpy
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

# Clear GPU memory
torch.cuda.empty_cache()

# Custom Dataset for SID (×100 amplification, limited to 20 images)
class SIDDataset(Dataset):
    def __init__(self, txt_file, data_root='/content/aml-sld', patch_size=256, target_amplification=100, max_images=20):
        self.data_root = data_root
        self.patch_size = patch_size
        self.target_amplification = target_amplification
        self.max_images = max_images
        self.pairs = []

        # Load pairs from the .txt file
        with open(txt_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                short_path, long_path, iso, fstop = line.strip().split()
                short_exp = float(short_path.split('_')[-1].replace('s.ARW', ''))
                long_exp = float(long_path.split('_')[-1].replace('s.ARW', ''))
                if short_exp > 0:
                    amplification = long_exp / short_exp
                    if 90 <= amplification <= 110:
                        self.pairs.append((short_path, long_path, short_exp, long_exp))
                        if len(self.pairs) >= self.max_images:
                            break

        print(f"Test dataset: Selected {len(self.pairs)} pairs with amplification ratio ~{target_amplification}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        short_path, long_path, short_exp, long_exp = self.pairs[idx]

        # Load raw images
        short_raw = rawpy.imread(os.path.join(self.data_root, short_path))
        long_raw = rawpy.imread(os.path.join(self.data_root, long_path))

        short_bayer = short_raw.raw_image_visible.astype(np.float32) - 512
        long_bayer = long_raw.raw_image_visible.astype(np.float32) - 512

        short_bayer = np.clip(short_bayer / (4095 - 512), 0, 1)
        long_bayer = np.clip(long_bayer / (4095 - 512), 0, 1)

        h, w = short_bayer.shape
        short_packed = np.zeros((h//2, w//2, 4), dtype=np.float32)
        long_packed = np.zeros((h//2, w//2, 4), dtype=np.float32)

        short_packed[..., 0] = short_bayer[0::2, 0::2]
        short_packed[..., 1] = short_bayer[0::2, 1::2]
        short_packed[..., 2] = short_bayer[1::2, 0::2]
        short_packed[..., 3] = short_bayer[1::2, 1::2]
        long_packed[..., 0] = long_bayer[0::2, 0::2]
        long_packed[..., 1] = long_bayer[0::2, 1::2]
        long_packed[..., 2] = long_bayer[1::2, 0::2]
        long_packed[..., 3] = long_bayer[1::2, 1::2]

        # Center crop for testing
        h, w, _ = short_packed.shape
        i = (h - self.patch_size) // 2
        j = (w - self.patch_size) // 2
        short_patch = short_packed[i:i+self.patch_size, j:j+self.patch_size, :]
        long_patch = long_packed[i:i+self.patch_size, j:j+self.patch_size, :]

        short_patch = torch.from_numpy(short_patch).permute(2, 0, 1)
        long_patch = torch.from_numpy(long_patch).permute(2, 0, 1)

        return short_path, short_patch, long_patch, short_exp, long_exp

# Helper function to convert 4-channel Bayer-packed data to 3-channel RGB-like data
def bayer_to_rgb(bayer_tensor):
    r = bayer_tensor[:, 0:1, :, :]
    g = (bayer_tensor[:, 1:2, :, :] + bayer_tensor[:, 2:3, :, :]) / 2
    b = bayer_tensor[:, 3:4, :, :]
    rgb_tensor = torch.cat([r, g, b], dim=1)
    return rgb_tensor

# Test Function for Diffusion Model (20 images)
def test_diffusion_model():
    data_root = "/content/aml-sld"
    test_txt = os.path.join(data_root, "Sony_test_list.txt")
    metrics_file = "/content/test_metrics.txt"

    test_dataset = SIDDataset(test_txt, data_root=data_root, patch_size=256, target_amplification=100, max_images=20)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        raise RuntimeError("GPU not available")
    print(f"Using device: {device}")

    model = NAFNet(in_channels=4, out_channels=4, width=32, enc_blocks=[2, 2, 2, 2], dec_blocks=[2, 2, 2, 2], middle_blocks=2, drop_path_rate=0.0).to(device)
    model.load_state_dict(torch.load("/content/diffusion_model.pth", weights_only=True))
    model.eval()

    # Initialize metrics
    psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='vgg', normalize=True).to(device)

    # Diffusion parameters (same as training)
    num_timesteps = 1000
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    alpha_bar_t = alphas_cumprod
    one_minus_alpha_bar_t = 1.0 - alphas_cumprod
    sqrt_one_minus_alpha_bar_t = torch.sqrt(one_minus_alpha_bar_t)

    total_psnr, total_ssim, total_lpips = 0.0, 0.0, 0.0
    num_samples = 0
    metrics_list = []

    print("Starting diffusion model testing on 20 images...")
    with torch.no_grad():
        for i, (short_path, short_img, long_img, _, _) in enumerate(test_loader):
            short_img = short_img.to(device, non_blocking=True)
            long_img = long_img.to(device, non_blocking=True)
            short_path = short_path[0]  # Extract string from tuple

            # Initialize with pure noise (simulate x_T)
            x_t = torch.randn_like(long_img)

            # Reverse diffusion process
            for t in range(num_timesteps - 1, -1, -1):
                t_tensor = torch.full((1,), t, device=device, dtype=torch.long)

                # Predict noise
                predicted_noise = model(x_t)

                # Compute coefficients
                alpha_t = alphas[t]
                alpha_bar_t_val = alpha_bar_t[t]
                sqrt_one_minus_alpha_bar_t_val = sqrt_one_minus_alpha_bar_t[t]

                # Denoise step
                noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
                x_t = (1.0 / torch.sqrt(alpha_t)) * (x_t - ((1 - alpha_t) / sqrt_one_minus_alpha_bar_t_val) * predicted_noise) + torch.sqrt(1 - alpha_t) * noise
                x_t = torch.clamp(x_t, 0, 1)

            # Final denoised image
            pred_x0 = x_t

            # Convert to RGB-like for LPIPS
            pred_x0_rgb = bayer_to_rgb(pred_x0)
            long_img_rgb = bayer_to_rgb(long_img)

            # Compute metrics
            psnr = psnr_metric(pred_x0, long_img)
            ssim = ssim_metric(pred_x0, long_img)
            lpips = lpips_metric(pred_x0_rgb, long_img_rgb)

            total_psnr += psnr.item()
            total_ssim += ssim.item()
            total_lpips += lpips.item()
            num_samples += 1

            # Store metrics with image name
            metrics_list.append((short_path, psnr.item(), ssim.item(), lpips.item()))

            print(f"Sample [{i+1}/{len(test_loader)}], Image: {os.path.basename(short_path)}, PSNR: {psnr:.4f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")

    # Compute average metrics
    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    avg_lpips = total_lpips / num_samples

    # Save metrics to text file
    with open(metrics_file, 'w') as f:
        f.write("Diffusion Model Test Metrics (20 Images)\n")
        f.write("-" * 50 + "\n")
        f.write("Sample Metrics:\n")
        for short_path, psnr, ssim, lpips in metrics_list:
            f.write(f"Image: {os.path.basename(short_path)}\n")
            f.write(f"PSNR: {psnr:.4f}\n")
            f.write(f"SSIM: {ssim:.4f}\n")
            f.write(f"LPIPS: {lpips:.4f}\n")
            f.write("-" * 50 + "\n")
        f.write("Average Metrics:\n")
        f.write(f"Average PSNR: {avg_psnr:.4f}\n")
        f.write(f"Average SSIM: {avg_ssim:.4f}\n")
        f.write(f"Average LPIPS: {avg_lpips:.4f}\n")

    print(f"\nTest Results (saved to {metrics_file}):")
    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average LPIPS: {avg_lpips:.4f}")

if __name__ == "__main__":
    test_diffusion_model()

Test dataset: Selected 20 pairs with amplification ratio ~100
Using device: cuda:0
Starting diffusion model testing on 20 images...
Sample [1/20], Image: 10003_00_0.1s.ARW, PSNR: 5.5281, SSIM: 0.0635, LPIPS: 0.7489
Sample [2/20], Image: 10003_01_0.1s.ARW, PSNR: 5.5880, SSIM: 0.0633, LPIPS: 0.7553
Sample [3/20], Image: 10003_02_0.1s.ARW, PSNR: 5.5704, SSIM: 0.0630, LPIPS: 0.7557
Sample [4/20], Image: 10003_03_0.1s.ARW, PSNR: 5.5868, SSIM: 0.0630, LPIPS: 0.7526
Sample [5/20], Image: 10003_04_0.1s.ARW, PSNR: 5.5557, SSIM: 0.0630, LPIPS: 0.7478
Sample [6/20], Image: 10003_05_0.1s.ARW, PSNR: 5.5531, SSIM: 0.0623, LPIPS: 0.7474
Sample [7/20], Image: 10003_06_0.1s.ARW, PSNR: 5.5501, SSIM: 0.0629, LPIPS: 0.7494
Sample [8/20], Image: 10003_07_0.1s.ARW, PSNR: 5.5873, SSIM: 0.0639, LPIPS: 0.7521
Sample [9/20], Image: 10003_08_0.1s.ARW, PSNR: 5.5712, SSIM: 0.0633, LPIPS: 0.7529
Sample [10/20], Image: 10003_09_0.1s.ARW, PSNR: 5.5865, SSIM: 0.0647, LPIPS: 0.7531
Sample [11/20], Image: 10006_00_0.1s.