# Import libraries

In [None]:
!pip install adjustText
!pip install torchmetrics[image]
!pip install pytorch-msssim torchvision

In [None]:
import os
import math
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from adjustText import adjust_text

from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import StepLR

from torchvision import transforms, models
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchvision.models import inception_v3

from pytorch_msssim import ssim as ssim_loss

from torchmetrics.image.fid import FrechetInceptionDistance as FID
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.functional import peak_signal_noise_ratio as psnr_fn
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn

from tqdm import tqdm
from torch.cuda.amp import autocast

In [None]:
base_path = "/root/nfs/hmj/ImP/MyTest/BUDA-Net/datasets/gopro_deblur"
blur_dir = os.path.join(base_path, "blur", "images")
sharp_dir = os.path.join(base_path, "sharp", "images")

In [None]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

# Dataset Preview

In [None]:
image_exts = (".png", ".jpg", ".jpeg", ".bmp")
image_pairs = []

for filename in os.listdir(blur_dir):
    if not filename.lower().endswith(image_exts):
        continue

    blur_path = os.path.join(blur_dir, filename)
    sharp_path = os.path.join(sharp_dir, filename)

    if os.path.isfile(blur_path) and os.path.isfile(sharp_path):
        image_pairs.append((blur_path, sharp_path))

print(f"Total image pairs found: {len(image_pairs)}")

In [None]:
blur_path, sharp_path = image_pairs[0]

blur_img = Image.open(blur_path).convert("RGB")
sharp_img = Image.open(sharp_path).convert("RGB")

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("Blurry Image")
plt.imshow(blur_img)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Sharp Image")
plt.imshow(sharp_img)
plt.axis("off")

plt.tight_layout()
plt.show()

# Dataset setup

In [None]:
class BlurSharpDataset(Dataset):
    def __init__(self, image_pairs, transform=None):
        self.image_pairs = image_pairs
        self.transform = transform

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        blur_path, sharp_path = self.image_pairs[idx]
        blur = Image.open(blur_path).convert("RGB")
        sharp = Image.open(sharp_path).convert("RGB")

        if self.transform:
            blur = self.transform(blur)
            sharp = self.transform(sharp)

        return blur, sharp

# 我们的粗糙数据集划分

In [None]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF


class GoProDataset(Dataset):
    """
    GoPro dataset with structure:
    root/gopro
      ├── blur/image/*.png
      └── sharp/image/*.png

    Train / test split is done by index.
    """
    def __init__(
        self,
        root,
        split="train",
        crop_size=256,
        training=True,
        train_ratio=0.8,
        extensions=(".png", ".jpg", ".jpeg")
    ):
        super().__init__()

        self.blur_dir = os.path.join(root, "blur", "images")
        self.sharp_dir = os.path.join(root, "sharp", "images")

        assert os.path.isdir(self.blur_dir), f"Not found: {self.blur_dir}"
        assert os.path.isdir(self.sharp_dir), f"Not found: {self.sharp_dir}"

        names = [
            f for f in os.listdir(self.blur_dir)
            if f.lower().endswith(extensions)
        ]
        names.sort()

        # ---- train / test split ----
        split_idx = int(len(names) * train_ratio)
        if split == "train":
            self.names = names[:split_idx]
        elif split == "test":
            self.names = names[split_idx:]
        else:
            raise ValueError("split must be 'train' or 'test'")

        self.crop_size = crop_size
        self.training = training

        print(
            f"[GoProDataset] split={split}, "
            f"samples={len(self.names)}"
        )

    def __len__(self):
        return len(self.names)

    def _random_crop(self, img1, img2):
        w, h = img1.size
        cs = self.crop_size

        if w < cs or h < cs:
            img1 = TF.resize(img1, (cs, cs))
            img2 = TF.resize(img2, (cs, cs))
            return img1, img2

        x = random.randint(0, w - cs)
        y = random.randint(0, h - cs)
        img1 = TF.crop(img1, y, x, cs, cs)
        img2 = TF.crop(img2, y, x, cs, cs)
        return img1, img2

    def __getitem__(self, idx):
        name = self.names[idx]

        blur = Image.open(os.path.join(self.blur_dir, name)).convert("RGB")
        sharp = Image.open(os.path.join(self.sharp_dir, name)).convert("RGB")

        if self.training:
            blur, sharp = self._random_crop(blur, sharp)

            if random.random() < 0.5:
                blur = TF.hflip(blur)
                sharp = TF.hflip(sharp)
            if random.random() < 0.5:
                blur = TF.vflip(blur)
                sharp = TF.vflip(sharp)

        blur = TF.to_tensor(blur)
        sharp = TF.to_tensor(sharp)
        return blur, sharp


In [None]:
transform = transforms.Compose([
    #transforms.RandomHorizontalFlip(p=0.5),
    #transforms.RandomRotation(degrees=5),
    #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

In [None]:
def to_gray_np(img):
    img_np = np.array(img.convert("RGB")) / 255.0
    gray = np.dot(img_np[..., :3], [0.299, 0.587, 0.114])
    return gray

filtered_pairs = []

# for blur_path, sharp_path in tqdm(image_pairs, desc='Computing SSIM'):
#     blur_img = Image.open(blur_path)
#     sharp_img = Image.open(sharp_path)

#     blur_gray = to_gray_np(blur_img)
#     sharp_gray = to_gray_np(sharp_img)

#     ssim_score = ssim(blur_gray, sharp_gray, data_range=1.0)

#     if ssim_score <= 0.8:
#         filtered_pairs.append((blur_path, sharp_path))

# print(f"Total original pairs: {len(image_pairs)}")
# print(f"Filtered (SSIM <= 0.8): {len(filtered_pairs)}")

In [None]:
# full_dataset = BlurSharpDataset(image_pairs=filtered_pairs, transform=transform)

# train_size = int(0.8 * len(full_dataset))
# val_size = len(full_dataset) - train_size

# train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

train_loader = DataLoader(
        GoProDataset("/root/nfs/hmj/ImP/MyTest/BUDA-Net/datasets/gopro_deblur", "train",256, True),
        batch_size=256,
        shuffle=True,
        # num_workers=args.num_workers,
        drop_last=True
    )

val_loader = DataLoader(
        GoProDataset("/root/nfs/hmj/ImP/MyTest/BUDA-Net/datasets/gopro_deblur", "test", 256, False),
        batch_size=1, shuffle=False, 
        # num_workers=2
    )

# CNN Architecture

In [None]:
class DeblurCNN(nn.Module):
    def __init__(self):
        super(DeblurCNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# UNet Architecture

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_dropout=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if use_dropout:
            layers.insert(3, nn.Dropout(0.2))

        self.conv_block = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv_block(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, use_dropout=False):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, use_dropout)
        )

    def forward(self, x):
        return self.encoder(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, use_dropout=False):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding = 1)
        self.conv = ConvBlock(in_channels, out_channels, use_dropout)

    def forward(self, x1, x2):
        x1 = self.upconv(x1)

        if x1.size() != x2.size():
            x1 = F.interpolate(x1, size=x2.size()[2:], mode="bilinear", align_corners=False)

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, use_dropout=False):
        super().__init__()

        self.in_conv = ConvBlock(in_channels, 64, use_dropout)

        self.enc1 = Encoder(64, 128, use_dropout)
        self.enc2 = Encoder(128, 256, use_dropout)
        self.enc3 = Encoder(256, 512, use_dropout)
        self.enc4 = Encoder(512, 1024, use_dropout)

        self.dec1 = Decoder(1024, 512, use_dropout)
        self.dec2 = Decoder(512, 256, use_dropout)
        self.dec3 = Decoder(256, 128, use_dropout)
        self.dec4 = Decoder(128, 64, use_dropout)

        self.out_conv = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=1),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.in_conv(x)
        
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        x = self.dec1(x5, x4)
        x = self.dec2(x, x3)
        x = self.dec3(x, x2)
        x = self.dec4(x, x1)

        return self.out_conv(x)

# ResNet Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class ResNet(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_resnet_blocks=9, final_activation='tanh'):
        super(ResNet, self).__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling Layers
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features *= 2

        # Residual Blocks
        for _ in range(num_resnet_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling layers
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, 
                                   stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features //= 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7)
        ]

        # Final activation
        if final_activation == 'tanh':
            model += [nn.Tanh()]
        elif final_activation == 'sigmoid':
            model += [nn.Sigmoid()]
        elif final_activation != 'none':
            raise ValueError("Final_activation must be 'tanh', 'sigmoid', or 'none'")

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Perceptual Loss

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features[:16]
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.eval().to(device)

    def forward(self, input, target):
        # input and target: [B, 3, H, W], normalized to [0, 1]
        return F.l1_loss(self.vgg(input), self.vgg(target))

In [None]:
imagenet_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

def preprocess_for_vgg(tensor):
    return imagenet_norm((tensor * 0.5) + 0.5) # from [-1, 1] -> [0, 1] -> normalized

In [None]:
mse_loss = nn.MSELoss()
percep_loss = PerceptualLoss()

In [None]:
def total_loss(pred, target):
    pred_vgg = preprocess_for_vgg(pred)
    target_vgg = preprocess_for_vgg(target)
    return 0.2 * mse_loss(pred, target) + 0.8 * percep_loss(pred_vgg, target_vgg)

# CNN & UNet Execution

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    total = 0.0
    with torch.no_grad():
        for blur_imgs, sharp_imgs in dataloader:
            blur_imgs = blur_imgs.to(device)
            sharp_imgs = sharp_imgs.to(device)
            outputs = model(blur_imgs)

            loss = total_loss(outputs, sharp_imgs) # Custom loss
            total += loss.item()
    return total / len(dataloader)


In [None]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs):
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(num_epochs):
        model.train()  # ✅ 每个 epoch 都切回 train
        running_loss = 0.0

        for blur_imgs, sharp_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            blur_imgs = blur_imgs.to(device, non_blocking=True)
            sharp_imgs = sharp_imgs.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast():  # ✅ AMP
                outputs = model(blur_imgs)
                loss = total_loss(outputs, sharp_imgs)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        avg_val_loss = evaluate_model(model, val_loader)

        print(
            f"Epoch [{epoch+1}/{num_epochs}] "
            f"- Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
        )


# def train_model(model, train_loader, val_loader, optimizer, num_epochs):
#     model.train()
#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         for blur_imgs, sharp_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
#             blur_imgs = blur_imgs.to(device)
#             sharp_imgs = sharp_imgs.to(device)

#             outputs = model(blur_imgs)
#             loss = total_loss(outputs, sharp_imgs) # Custom loss

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()

#         avg_train_loss = running_loss / len(train_loader)
#         avg_val_loss = evaluate_model(model, val_loader)

#         print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

In [None]:
model_cnn = DeblurCNN().to(device)
optimizer_cnn = torch.optim.Adam(model_cnn.parameters(), lr=1e-4)

train_model(model_cnn, train_loader, val_loader, optimizer_cnn, num_epochs=100)

In [None]:
torch.save(model_cnn.state_dict(), "CNN.pth")

In [None]:
model_unet = UNet().to(device)
optimizer_unet = torch.optim.Adam(model_unet.parameters(), lr=1e-4)

train_model(model_unet, train_loader, val_loader, optimizer_unet, num_epochs=100)

In [None]:
torch.save(model_unet.state_dict(), "UNet.pth")

# CNN GAN Architecture

In [None]:
class CNNGenerator(nn.Module):
    def __init__(self, use_dropout=False):
        super(CNNGenerator, self).__init__()
        dropout = lambda: nn.Dropout(0.5) if use_dropout else nn.Identity()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            dropout(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            dropout(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            dropout()
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# Discriminator is the same as UNet GAN (PatchGAN)

# Loss is the same as UNet GAN (Adversarial(BCE), L1, Perceptual)

# UNet GAN Architecture

In [None]:
# Generator (U-Net)
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, use_dropout=False):
        super().__init__()

        self.in_conv = ConvBlock(in_channels, 64, use_dropout)
        self.enc1 = Encoder(64, 128, use_dropout)
        self.enc2 = Encoder(128, 256, use_dropout)
        self.enc3 = Encoder(256, 512, use_dropout)
        self.enc4 = Encoder(512, 1024, use_dropout)

        self.dec1 = Decoder(1024, 512, use_dropout)
        self.dec2 = Decoder(512, 256, use_dropout)
        self.dec3 = Decoder(256, 128, use_dropout)
        self.dec4 = Decoder(128, 64, use_dropout)

        self.out_conv = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=1),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        x = self.dec1(x5, x4)
        x = self.dec2(x, x3)
        x = self.dec3(x, x2)
        x = self.dec4(x, x1)

        return self.out_conv(x)

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, use_dropout=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        if use_dropout:
            layers.append(nn.Dropout(0.3))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

In [None]:
# Discriminator (PatchGAN)
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, use_dropout=False):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer1 = DiscriminatorBlock(64, 128, stride=2, use_dropout=use_dropout)
        self.layer2 = DiscriminatorBlock(128, 256, stride=2, use_dropout=use_dropout)
        self.layer3 = DiscriminatorBlock(256, 512, stride=1, use_dropout=use_dropout)
        self.final = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return self.final(x)

In [None]:
class GANLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()
        self.l1 = nn.L1Loss()
        self.percep = PerceptualLoss()

    def discriminator_loss(self, real_preds, fake_preds, smoothing=0.1):
        real_labels = torch.ones_like(real_preds) * (1.0 - smoothing)
        fake_labels = torch.zeros_like(fake_preds) * 0.2
        loss_real = self.bce(real_preds, fake_labels)
        loss_fake = self.bce(fake_preds, real_labels)
        return 0.5 * (loss_real + loss_fake)

    def generator_loss(self, fake_preds, fake_imgs, real_imgs, 
                       adv_weight=0.001, pixel_weight=0.5, perceptual_weight=0.5):
        real_labels = torch.ones_like(fake_preds)
        adv_loss = self.bce(fake_preds, real_labels)
        l1_loss = self.l1(fake_imgs, real_imgs)
        perceptual_loss = self.percep(preprocess_for_vgg(fake_imgs), preprocess_for_vgg(real_imgs))
        return adv_weight * adv_loss + pixel_weight * l1_loss + perceptual_weight * perceptual_loss

# ResNet GAN Architecture

In [None]:
# Generator (ResNet)
class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, num_resnet_blocks=9, use_dropout=False):
        super(ResNetGenerator, self).__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features *= 2

        # Residual Blocks
        for _ in range(num_resnet_blocks):
            block = [ResidualBlock(in_features)]
            if use_dropout:
                block.append(nn.Dropout(0.5))
            model += block

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, 
                                  padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features //= 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
# Discriminator is the same as UNet GAN (PatchGAN)

# Loss is the same as UNet GAN (Adversarial(BCE), L1, Perceptual)

# UNet GAN Execution

In [None]:
def train_gan(generator, discriminator, train_loader, val_loader, optimizer_G, optimizer_D, loss_fn, device, num_epochs):
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()

        running_d_loss = 0.0
        running_g_loss = 0.0

        for real_blur, real_sharp in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Train"):
            real_blur, real_sharp = real_blur.to(device), real_sharp.to(device)
            fake_sharp = generator(real_blur)

            real_pair = torch.cat((real_blur, real_sharp), dim=1)
            fake_pair = torch.cat((real_blur, fake_sharp.detach()), dim=1)
            d_real = discriminator(real_pair)
            d_fake = discriminator(fake_pair)
            d_loss = loss_fn.discriminator_loss(d_real, d_fake)

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            fake_pair = torch.cat((real_blur, fake_sharp), dim=1)
            d_fake_pred = discriminator(fake_pair)
            g_loss = loss_fn.generator_loss(d_fake_pred, fake_sharp, real_sharp)

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            running_d_loss += d_loss.item()
            running_g_loss += g_loss.item()

        avg_d_loss = running_d_loss / len(train_loader)
        avg_g_loss = running_g_loss / len(train_loader)

        generator.eval()
        discriminator.eval()
        val_g_loss = 0.0

        with torch.no_grad():
            for val_blur, val_sharp in val_loader:
                val_blur, val_sharp = val_blur.to(device), val_sharp.to(device)
                val_fake = generator(val_blur)
                val_fake_pair = torch.cat((val_blur, val_fake), dim=1)
                d_fake_pred_val = discriminator(val_fake_pair)
                val_g_loss += loss_fn.generator_loss(d_fake_pred_val, val_fake, val_sharp).item()

        avg_val_g_loss = val_g_loss / len(val_loader)
        print(f"Epoch {epoch+1}: D Loss = {avg_d_loss:.4f}, G Loss = {avg_g_loss:.4f}, Val G Loss = {avg_val_g_loss:.4f}")

In [None]:
loss_fn = GANLoss()

generator_unet = UNetGenerator(use_dropout=False).to(device)
discriminator_patch = PatchDiscriminator(use_dropout=True).to(device)

optimizer_G = torch.optim.Adam(generator_unet.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator_patch.parameters(), lr=1e-4, betas=(0.5, 0.999))

train_gan(generator_unet, discriminator_patch, train_loader, val_loader, optimizer_G, optimizer_D, loss_fn, device, num_epochs=100)

In [None]:
torch.save(generator_unet.state_dict(), "GAN_UNet_Generator.pth")
torch.save(discriminator_patch.state_dict(), "GAN_Patch_Discriminator.pth")

# Load models weights

In [None]:
model_cnn = DeblurCNN().to(device)
model_unet = UNet().to(device)
model_resnet = ResNet().to(device)

generator_cnn = CNNGenerator(use_dropout=False).to(device)
generator_unet = UNetGenerator(use_dropout=False).to(device)
generator_resnet = ResNetGenerator(use_dropout=False).to(device)

In [None]:
# CNN
model_cnn.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/CNN.pth", device))
# UNet
model_unet.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/UNet.pth", device))
# ResNet
model_resnet.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/ResNet.pth", device))

# GAN CNN
generator_cnn.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/GAN_CNN_Generator.pth", device))
# GAN UNet
generator_unet.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/GAN_UNet_Generator.pth", device))
# GAN ResNet
generator_resnet.load_state_dict(torch.load("/kaggle/input/deblur-image-phase-1/pytorch/prototype/1/GAN_ResNet_Generator.pth", device))

# Models evaluation

In [None]:
def denorm(t):
    return (t * 0.5) + 0.5

def to_gray(img):
    return np.dot(img[...,:3], [0.299, 0.587, 0.114])

In [None]:
def model_evaluation(model, dataloader, device, max_batches=None, verbose=False):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    count = 0

    with torch.no_grad():
        for i, (blur_img, sharp_img) in enumerate(dataloader):
            if max_batches is not None and i >= max_batches:
                break

            input_tensor = blur_img.to(device)
            target_tensor = sharp_img.to(device)

            output = model(input_tensor).cpu()
            target_tensor = target_tensor.cpu()

            for pred, target in zip(output, target_tensor):
                pred_np = denorm(pred).clamp(0, 1).permute(1, 2, 0).numpy()
                target_np = denorm(target).clamp(0, 1).permute(1, 2, 0).numpy()

                pred_gray = to_gray(pred_np)
                target_gray = to_gray(target_np)

                psnr_val = psnr(target_np, pred_np, data_range=1.0)
                ssim_val = ssim(target_gray, pred_gray, data_range=1.0)

                total_psnr += psnr_val
                total_ssim += ssim_val
                count += 1

                if verbose:
                    print(f"Sample {count}: PSNR={psnr_val:.2f}, SSIM={ssim_val:.4f}")

    if count == 0:
        raise ValueError("No samples were evaluated")

    avg_psnr = total_psnr / count if count != 0 else 0
    avg_ssim = total_ssim / count if count != 0 else 0

    return avg_psnr, avg_ssim

In [None]:
psnr_cnn, ssim_cnn = model_evaluation(model_cnn, val_loader, device, max_batches=None)
psnr_unet, ssim_unet = model_evaluation(model_unet, val_loader, device, max_batches=None)
psnr_resnet, ssim_resnet = model_evaluation(model_resnet, val_loader, device, max_batches=None)

psnr_gan_cnn, ssim_gan_cnn = model_evaluation(generator_cnn, val_loader, device, max_batches=None)
psnr_gan_unet, ssim_gan_unet = model_evaluation(generator_unet, val_loader, device, max_batches=None)
psnr_gan_resnet, ssim_gan_resnet = model_evaluation(generator_resnet, val_loader, device, max_batches=None)

print(f"{'Model':<15}{'PSNR (dB)':>12}{'SSIM':>12}")
print("-" * 39)

print(f"{'CNN':<15}{psnr_cnn:>12.2f}{ssim_cnn:>12.4f}")
print(f"{'UNet':<15}{psnr_unet:>12.2f}{ssim_unet:>12.4f}")
print(f"{'ResNet':<15}{psnr_resnet:>12.2f}{ssim_resnet:>12.4f}")

print(f"{'GAN CNN':<15}{psnr_gan_cnn:>12.2f}{ssim_gan_cnn:>12.4f}")
print(f"{'GAN UNet':<15}{psnr_gan_unet:>12.2f}{ssim_gan_unet:>12.4f}")
print(f"{'GAN ResNet':<15}{psnr_gan_resnet:>12.2f}{ssim_gan_resnet:>12.4f}")

In [None]:
model_scores = {
    "CNN": {"psnr": psnr_cnn, "ssim": ssim_cnn},
    "UNet": {"psnr": psnr_unet, "ssim": ssim_unet},
    "ResNet": {"psnr": psnr_resnet, "ssim": ssim_resnet},
    "GAN_CNN": {"psnr": psnr_gan_cnn, "ssim": ssim_gan_cnn},
    "GAN_UNet": {"psnr": psnr_gan_unet, "ssim": ssim_gan_unet},
    "GAN_ResNet": {"psnr": psnr_gan_resnet, "ssim": ssim_gan_resnet}
}

colors = cm.tab10(np.linspace(0, 1, len(model_scores)))

plt.figure(figsize=(8, 6))
texts = []

for (model, scores), color in zip(model_scores.items(), colors):
    plt.scatter(scores["psnr"], scores["ssim"], color=color, s=100)
    texts.append(
        plt.text(scores["psnr"], scores["ssim"], model, fontsize=11)
    )

adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.5))

plt.xlabel("PSNR (dB)")
plt.ylabel("SSIM")
plt.title("Model Comparison: PSNR vs SSIM")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
def get_visuals(index, model):
    blur_img, sharp_img = full_dataset[index]
    with torch.no_grad():
        input_tensor = blur_img.unsqueeze(0).to(device)
        output_tensor = model(input_tensor).squeeze(0).cpu()

    blur_np = denorm(blur_img).permute(1, 2, 0).clamp(0, 1).numpy()
    pred_np = denorm(output_tensor).permute(1, 2, 0).clamp(0, 1).numpy()
    sharp_np = denorm(sharp_img).permute(1, 2, 0).clamp(0, 1).numpy()

    return blur_np, pred_np, sharp_np

In [None]:
def evaluate_ssim(models, dataloader, device, amount):
    model_ssim_scores = {model_name: [] for model_name in models.keys()}

    for model_name, model in models.items():
        model.eval()

        ssim_scores = []
        with torch.no_grad():
            for i in range(amount):
                blur_np, pred_np, sharp_np = get_visuals(i, model)

                pred_gray = to_gray(pred_np)
                sharp_gray = to_gray(sharp_np)

                score = ssim(pred_gray, sharp_gray, data_range=1.0)
                ssim_scores.append(score)

        model_ssim_scores[model_name] = ssim_scores

    return model_ssim_scores

In [None]:
models = {
    "CNN": model_cnn,
    "UNet": model_unet,
    "ResNet": model_resnet, 
    "GAN_CNN": generator_cnn, 
    "GAN_UNet": generator_unet, 
    "GAN_ResNet": generator_resnet
}

model_ssim_scores = evaluate_ssim(models, val_loader, device, amount=100)

# Flatten data
df = pd.DataFrame([
    {"Model": model, "SSIM": ssim}
    for model, scores in model_ssim_scores.items()
    for ssim in scores
])

plt.figure(figsize=(10, 6))
sns.violinplot(x="Model", y="SSIM", data=df, inner="box", palette="Set2")
plt.title("SSIM Distribution per Model")
plt.grid(True)
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()


In [None]:
def visualize_best_worst(models, dataloader, device, max_batches):
    model_best_worst_scores = {}

    for model_name, model in models.items():
        model.eval()

        ssim_scores = []

        with torch.no_grad():
            for i in range(max_batches):
                blur_np, pred_np, sharp_np = get_visuals(i, model)

                pred_gray = to_gray(pred_np)
                sharp_gray = to_gray(sharp_np)

                score = ssim(pred_gray, sharp_gray, data_range=1.0)
                ssim_scores.append(score)

        worst_index = int(np.argmin(ssim_scores))
        best_index = int(np.argmax(ssim_scores))

        worst_blur, worst_pred, worst_sharp = get_visuals(worst_index, model)
        best_blur, best_pred, best_sharp = get_visuals(best_index, model)

        model_best_worst_scores[model_name] = {
            'worst': (worst_blur, worst_pred, worst_sharp, ssim_scores[worst_index]), 
            'best': (best_blur, best_pred, best_sharp, ssim_scores[best_index])
        }

    n_models = len(models)
    plt.figure(figsize=(12, 4 * n_models))

    for i, (model_name, scores) in enumerate(model_best_worst_scores.items()):
        worst_blur, worst_pred, worst_sharp, worst_ssim = scores['worst']
        best_blur, best_pred, best_sharp, best_ssim = scores['best']

        row_offset = i * 2 * 3

        # Worst Case
        plt.subplot(n_models * 2, 3, row_offset + 1)
        plt.title(f"{model_name} Worst Input/nSSIM: {worst_ssim:.4f})")
        plt.imshow(worst_blur)
        plt.axis("off")

        plt.subplot(n_models * 2, 3, row_offset + 2)
        plt.title(f"Prediction")
        plt.imshow(worst_pred)
        plt.axis("off")

        plt.subplot(n_models * 2, 3, row_offset + 3)
        plt.title(f"Ground Truth")
        plt.imshow(worst_sharp)
        plt.axis("off")

        # Best Case
        plt.subplot(n_models * 2, 3, row_offset + 4)
        plt.title(f"{model_name} Best Input/nSSIM: {best_ssim:.4f})")
        plt.imshow(best_blur)
        plt.axis("off")
        
        plt.subplot(n_models * 2, 3, row_offset + 5)
        plt.title(f"Prediction")
        plt.imshow(best_pred)
        plt.axis("off")
        
        plt.subplot(n_models * 2, 3, row_offset + 6)
        plt.title(f"Ground Truth")
        plt.imshow(best_sharp)
        plt.axis("off")


    plt.tight_layout()
    plt.show()

In [None]:
models = {
    'CNN': model_cnn,
    'UNet': model_unet,
    'ResNet': model_resnet,
    'GAN_CNN': generator_cnn,
    'GAN_UNet': generator_unet,
    'GAN_ResNet': generator_resnet
}

visualize_best_worst(models, val_loader, device, max_batches = 10)

In [None]:
def select_random_sample(dataloader):
    idx = random.randint(0, len(dataloader.dataset) - 1)
    blur_img, sharp_img = dataloader.dataset[idx]
    return blur_img, sharp_img

In [None]:
def visualize_predictions(models, dataloader, device, model_names):
    blur_img, sharp_img = select_random_sample(dataloader)
    blur_img_tensor = blur_img.unsqueeze(0).to(device)

    predictions = []
    for model in models:
        model.eval()
        with torch.no_grad():
            output = model(blur_img_tensor)
            pred_np = denorm(output.squeeze(0).cpu()).permute(1, 2, 0).clamp(0, 1).numpy()
            predictions.append(pred_np)

    blur_np = denorm(blur_img).permute(1, 2, 0).clamp(0, 1).numpy()
    sharp_np = denorm(sharp_img).permute(1, 2, 0).clamp(0, 1).numpy()

    images = [blur_np, sharp_np] + predictions
    titles = ["Blurred Input", "Ground Truth"] + [f"{name} Prediction" for name in model_names]

    n_images = len(images)
    cols = 4
    rows = (n_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))

    for idx, ax in enumerate(axes.flat):
        if idx < n_images:
            ax.imshow(images[idx])
            ax.set_title(titles[idx])
        ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
models = [model_cnn, model_unet, model_resnet, generator_cnn, generator_unet, generator_resnet]
model_names = ["CNN", "UNet", "ResNet", "GAN_CNN", "GAN_UNet", "GAN_ResNet"]
visualize_predictions(models, val_loader, device, model_names)

# END OF PHASE 1
# ========================

# Dataset setup change for Diffusion

In [None]:
class CropPatchDataset(Dataset):
    def __init__(self, image_pairs, patch_size=256):
        self.image_pairs = image_pairs
        self.patch_size = patch_size
        self.n_patches = 9 # 3x3 Grid

        self.transform = transforms.Compose([
            transforms.Resize((patch_size, patch_size)), 
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.image_pairs) * self.n_patches

    def __getitem__(self, idx):
        img_idx = idx // self.n_patches
        patch_idx = idx % self.n_patches

        blur_path, sharp_path = self.image_pairs[img_idx]
        blur = Image.open(blur_path).convert("RGB")
        sharp = Image.open(sharp_path).convert("RGB")

        w, h = blur.size
        grid_size = int(self.n_patches ** 0.5)
        pw, ph = w // grid_size, h // grid_size

        i = patch_idx // grid_size
        j = patch_idx % grid_size
        left = j * pw
        upper = i * ph

        blur_crop = blur.crop((left, upper, left + pw, upper + ph))
        sharp_crop = sharp.crop((left, upper, left + pw, upper + ph))

        return {
            "condition": self.transform(blur_crop), 
            "target": self.transform(sharp_crop)
        }

In [None]:
class FullImageDataset(Dataset):
    def __init__(self, image_pairs):
        self.image_pairs = image_pairs
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        blur_path, sharp_path = self.image_pairs[idx]
        blur = Image.open(blur_path).convert("RGB")
        sharp = Image.open(sharp_path).convert("RGB")

        return {
            "condition": self.transform(blur),
            "target": self.transform(sharp)
        }

In [None]:
def to_gray_np(img):
    img_np = np.array(img.convert("RGB")) / 255.0
    gray = np.dot(img_np[..., :3], [0.299, 0.587, 0.114])
    return gray

filtered_pairs = []

for blur_path, sharp_path in tqdm(image_pairs, desc='Computing SSIM'):
    blur_img = Image.open(blur_path)
    sharp_img = Image.open(sharp_path)

    blur_gray = to_gray_np(blur_img)
    sharp_gray = to_gray_np(sharp_img)

    ssim_score = ssim(blur_gray, sharp_gray, data_range=1.0)

    if ssim_score <= 0.8:
        filtered_pairs.append((blur_path, sharp_path))

In [None]:
# === Cropped Patches Image Dataset ===
patch_dataset = CropPatchDataset(image_pairs=filtered_pairs)
train_size = int(0.8 * len(patch_dataset))
val_size = len(patch_dataset) - train_size
train_dataset, val_dataset = random_split(patch_dataset, [train_size, val_size])

BATCH_SIZE = 2
NUM_WORKERS = 2
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                       num_workers=NUM_WORKERS, pin_memory=True)

# === Dataset Statistics ===
print(f"Original filtered pairs: {len(filtered_pairs)}")
print(f"Total patches in dataset: {len(patch_dataset)}")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

# Noise Scheduler

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alpha_bar = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi / 2) ** 2
    alpha_bar = alpha_bar / alpha_bar[0]
    betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

In [None]:
timesteps = 1000

betas = cosine_beta_schedule(timesteps).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]])
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# Forward process

In [None]:
def q_sample(x_start, t, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    sqrt_alpha_bar = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alpha_bar = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise

In [None]:
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None].float() * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb  # [B, embedding_dim]

# Diffusion UNet (Conditional UNet)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, out_channels)
        )
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if use_dropout:
            layers.insert(3, nn.Dropout(0.2))

        self.conv = nn.Sequential(*layers)

    def forward(self, x, t_emb):
        """
        x: feature map [B, C, H, W]
        t_emb: timestep embedding [B, time_emb_dim]
        """
        h = self.conv(x)
        # Add timestep embedding: reshape and broadcast
        t = self.time_mlp(t_emb)
        return h + t[:, :, None, None]

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = ConvBlock(in_channels, out_channels, time_emb_dim, use_dropout)

    def forward(self, x, t_emb):
        x = self.pool(x)
        return self.conv(x, t_emb)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.conv = ConvBlock(in_channels, out_channels, time_emb_dim, use_dropout)

    def forward(self, x1, x2, t_emb):
        x1 = self.upconv(x1)
        if x1.size() != x2.size():
            x1 = F.interpolate(x1, size=x2.size()[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, t_emb)

In [None]:
class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, condition_channels=3, out_channels=3, time_emb_dim, use_dropout=False):
        """
        in_channels: channels of noisy image
        condition_channels: channels of blurry image
        out_channels: output channels (predicting noise)
        """
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        total_in_channels = in_channels + condition_channels

        self.in_conv = ConvBlock(total_in_channels, 64, time_emb_dim, use_dropout)

        self.enc1 = Encoder(64, 128, time_emb_dim, use_dropout)
        self.enc2 = Encoder(128, 256, time_emb_dim, use_dropout)
        self.enc3 = Encoder(256, 512, time_emb_dim, use_dropout)
        self.enc4 = Encoder(512, 1024, time_emb_dim, use_dropout)

        self.dec1 = Decoder(1024, 512, time_emb_dim, use_dropout)
        self.dec2 = Decoder(512, 256, time_emb_dim, use_dropout)
        self.dec3 = Decoder(256, 128, time_emb_dim, use_dropout)
        self.dec4 = Decoder(128, 64, time_emb_dim, use_dropout)

        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x_noisy, t, condition):
        """
        x_noisy: [B, in_channels, H, W]
        t: [B] integer timesteps
        condition: [B, condition_channels, H, W] (blurred input)
        """
        x = torch.cat([x_noisy, condition], dim=1)

        t_emb = get_timestep_embedding(t, self.time_mlp[0].in_features)
        t_emb = self.time_mlp(t_emb)

        # Pass through UNet with timestep embedding
        x1 = self.in_conv(x, t_emb)
        x2 = self.enc1(x1, t_emb)
        x3 = self.enc2(x2, t_emb)
        x4 = self.enc3(x3, t_emb)
        x5 = self.enc4(x4, t_emb)

        x = self.dec1(x5, x4, t_emb)
        x = self.dec2(x, x3, t_emb)
        x = self.dec3(x, x2, t_emb)
        x = self.dec4(x, x1, t_emb)

        return self.out_conv(x)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)

        self.norm1 = nn.GroupNorm(min(8, in_channels), in_channels)
        self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)

        self.dropout = nn.Dropout(0.2) if use_dropout else nn.Identity()
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb):
      # First norm + activation + conv
      h = self.norm1(x)
      h = self.relu(h)
      h = self.conv1(h)

      # Second norm + activation + dropout + conv
      h = self.norm2(h)
      h = self.relu(h)
      h = self.dropout(h)
      h = self.conv2(h)

      # Add timestep embedding
      t = self.time_mlp(t_emb)
      h = h + t[:, :, None, None]

      # Residual connection
      return h + self.shortcut(x)

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels // 8, 1)  # Fixed to match query/key
        self.proj_out = nn.Conv2d(in_channels // 8, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.shape
        query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  # B x HW x C//8
        key = self.key(x).view(B, -1, H * W)  # B x C//8 x HW
        value = self.value(x).view(B, -1, H * W)  # B x C//8 x HW

        attn = torch.bmm(query, key)  # B x HW x HW
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(value, attn.permute(0, 2, 1))  # B x C//8 x HW
        out = out.view(B, C // 8, H, W)
        out = self.proj_out(out)
        out = self.gamma * out + x
        return out

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = ResidualBlock(in_channels, out_channels, time_emb_dim, use_dropout)

    def forward(self, x, t_emb):
        x = self.pool(x)
        return self.block(x, t_emb)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_dropout=False):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.block = ResidualBlock(in_channels, out_channels, time_emb_dim, use_dropout)

    def forward(self, x1, x2, t_emb):
        x1 = self.upconv(x1)
        if x1.size() != x2.size():
            x1 = F.interpolate(x1, size=x2.size()[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x2, x1], dim=1)
        return self.block(x, t_emb)

In [None]:
class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, condition_channels=3, out_channels=3, 
                 time_emb_dim=256, use_dropout=False):
        """
        in_channels: channels of noisy image
        condition_channels: channels of blurry image
        out_channels: output channels (predicting noise)
        """
        super().__init__()
        self.time_emb_dim = time_emb_dim
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        total_in_channels = in_channels + condition_channels

        self.in_conv = ResidualBlock(total_in_channels, 64, time_emb_dim, use_dropout)

        self.enc1 = Encoder(64, 128, time_emb_dim, use_dropout)
        self.enc2 = Encoder(128, 256, time_emb_dim, use_dropout)
        #self.attn_enc2 = SelfAttention(256) # Add attnetion at Enc2
        self.enc3 = Encoder(256, 512, time_emb_dim, use_dropout)
        self.attn_enc3 = SelfAttention(512) # # Add attnetion at Enc3
        self.enc4 = Encoder(512, 1024, time_emb_dim, use_dropout)
        self.attn_bottleneck = SelfAttention(1024) # Add attnetion at BottleNeck
      
        self.dec1 = Decoder(1024, 512, time_emb_dim, use_dropout)
        self.dec2 = Decoder(512, 256, time_emb_dim, use_dropout)
        self.attn_dec2 = SelfAttention(256) # Add attnetion at Dec2
        self.dec3 = Decoder(256, 128, time_emb_dim, use_dropout)
        self.attn_dec3 = SelfAttention(128) # Add attnetion at Dec3
        self.dec4 = Decoder(128, 64, time_emb_dim, use_dropout)

        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x_noisy, t, condition):
        """
        x_noisy: [B, in_channels, H, W]
        t: [B] integer timesteps
        condition: [B, condition_channels, H, W] (blurred input)
        """
        if condition is None:
            condition = torch.zeros_like(x_noisy) # Neutral placeholder
            
        x = torch.cat([x_noisy, condition], dim=1)

        t_emb = get_timestep_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        
        # Pass through UNet with timestep embedding
        x1 = self.in_conv(x, t_emb)
        x2 = self.enc1(x1, t_emb)
        x3 = self.enc2(x2, t_emb)
        #x3 = self.attn_enc2(x3) # Apply attention at Enc2
        x4 = self.enc3(x3, t_emb)
        x4 = self.attn_enc3(x4) # Apply attention at Enc3
        x5 = self.enc4(x4, t_emb)
        x5 = self.attn_bottleneck(x5) # Apply attention at BottleNeck
      
        x = self.dec1(x5, x4, t_emb)
        x = self.dec2(x, x3, t_emb)
        x = self.attn_dec2(x) # Apply attention at Dec2
        x = self.dec3(x, x2, t_emb)
        x = self.attn_dec3(x) # Apply attention at Dec3
        x = self.dec4(x, x1, t_emb)

        return self.out_conv(x)

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1.0 - self.decay) * param.data

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and name in self.backup:
                param.data = self.backup[name]
        self.backup = {}

# Loss functions

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features[:16]
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.eval().to(device)

    def forward(self, input, target):
        assert input.shape == target.shape, "Perceptual input/target shape mismatch"
        # input and target: [B, 3, H, W], normalized to [0, 1]
        return F.l1_loss(self.vgg(input), self.vgg(target))

In [None]:
def preprocess_for_vgg(tensor):
    # Convert from [-1, 1] to [0, 1]
    tensor = (tensor + 1) * 0.5
    tensor = torch.clamp(tensor, 0, 1)
    # Apply ImageNet normalization
    mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1)
    return (tensor - mean) / std

In [None]:
percep_loss_fn = PerceptualLoss()
mse_loss_fn = nn.MSELoss()
l1_loss_fn = nn.L1Loss()

# UNet Diffusion Execution

In [None]:
def train_diff(model, train_loader, optimizer, scheduler, scheduler_params,
              device, epochs, log_interval=10):
    model.to(device)
    betas = scheduler_params['betas']
    sqrt_alphas_cumprod = scheduler_params['sqrt_alphas_cumprod']
    sqrt_one_minus_alphas_cumprod = scheduler_params['sqrt_one_minus_alphas_cumprod']

    scaler = torch.cuda.amp.GradScaler()
    T = len(scheduler_params['betas'])

    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0

        loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)
        for step, batch in enumerate(loop, 1):
            with autocast():
                x_blur = batch['condition'].to(device)
                x_sharp = batch['target'].to(device)
    
                B = x_sharp.size(0)
                t = torch.randint(0, len(betas), (B,), device=device)
                noise = torch.randn_like(x_sharp)
                x_t = q_sample(x_sharp, t, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
    
                pred_noise = model(x_t, t, x_blur)
                
                mse_loss = mse_loss_fn(pred_noise, noise)
    
                # Recover x0
                alpha_cumprod_t = sqrt_alphas_cumprod[t] ** 2
                alpha_cumprod_t = alpha_cumprod_t.view(-1, 1, 1, 1)
                eps = 1e-8
                x0_pred = (x_t - torch.sqrt(1 - alpha_cumprod_t + eps) * pred_noise) / torch.sqrt(alpha_cumprod_t)
                x0_pred = x0_pred.clamp(-1.0, 1.0)
                
                # Preprocess for perceptual
                x0_pred_vgg = preprocess_for_vgg(x0_pred)
                x_sharp_vgg = preprocess_for_vgg(x_sharp)
        
                # Perceptual and L1 Loss
                percep = percep_loss_fn(x0_pred_vgg, x_sharp_vgg)
                l1 = l1_loss_fn(x0_pred, x_sharp)
    
                # Total Loss
                loss = mse_loss + 0.4 * l1 + 0.3 * percep
                
                optimizer.zero_grad()
                
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                scaler.step(optimizer)
                scaler.update()
                
                ema.update()

                epoch_loss += loss.item()
    
                if step % log_interval == 0:
                    loop.set_postfix(loss=loss.item())
                    
        scheduler.step()
        with torch.no_grad():
            t_log = torch.randint(0, T, (B,), device=device).long()
            noise = torch.randn_like(x_sharp)
            x_t_log = q_sample(x_sharp, t_log, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
            pred_noise_log = model(x_t_log, t_log, x_blur)
            print(f"[Step {step}] pred_noise mean/std: {pred_noise_log.mean().item():.4f} / {pred_noise_log.std().item():.4f}")

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch}/{epochs}] - Avg Loss: {avg_loss:.6f}")

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")

In [None]:
model_diffunet = ConditionalUNet(in_channels=3, condition_channels=3, out_channels=3, time_emb_dim=256).to(device)
ema = EMA(model_diffunet)
optimizer = torch.optim.AdamW(model_diffunet.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

scheduler_params = {'betas': betas, 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 
                    'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,}

count_parameters(model_diffunet)

print("\n[Phase 1] Training on patch dataset...")
train_diff(model=model_diffunet, train_loader=train_loader, optimizer=optimizer, 
           scheduler=scheduler, scheduler_params=scheduler_params, device=device, 
           epochs=20, log_interval=10)

In [None]:
torch.save({
    "model_state": model_diffunet.state_dict(), 
    "optimizer_state": optimizer.state_dict(), 
    "scheduler_state": scheduler.state_dict(), 
    "ema_shadow": ema.shadow,
}, "phase1_checkpoint.pth")
print("Checkpoint saved to 'phase1_checkpoint.pth'")

In [None]:
model_diffunet = ConditionalUNet(in_channels=3, condition_channels=3, out_channels=3, time_emb_dim=256).to(device)
ema = EMA(model_diffunet)
optimizer = torch.optim.AdamW(model_diffunet.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

scheduler_params = {'betas': betas, 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 
                    'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,}

count_parameters(model_diffunet)

checkpoint = torch.load("/kaggle/input/deblurring-image/pytorch/default/3/phase1_checkpoint.pth")

model_diffunet.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
scheduler.load_state_dict(checkpoint["scheduler_state"])
ema.shadow = checkpoint["ema_shadow"]

# Clean up before switching
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# === Full Image Dataset ===
full_image_dataset = FullImageDataset(filtered_pairs)
train_loader = DataLoader(full_image_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=NUM_WORKERS, pin_memory=True)

print("\n[Phase 2] Fine-tuning on full image dataset...")
train_diff(model=model_diffunet, train_loader=train_loader, optimizer=optimizer, 
           scheduler=scheduler, scheduler_params=scheduler_params, device=device, 
           epochs=15, log_interval=10)

In [None]:
torch.save({
    "model_state": model_diffunet.state_dict(), 
    "optimizer_state": optimizer.state_dict(), 
    "scheduler_state": scheduler.state_dict(), 
    "ema_shadow": ema.shadow,
}, "phase1_checkpoint.pth")
print("Checkpoint saved to 'phase2_checkpoint.pth'")

# Reverse Process

In [None]:
@torch.no_grad()
def sample_ddpm(model, condition, scheduler_params, device, num_steps):
    betas = scheduler_params['betas']
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]])

    eps = 1e-5
    sqrt_recip_alphas = torch.sqrt(1.0 / torch.clamp(alphas, min=eps))
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    posterior_variance = torch.clamp(posterior_variance, min=eps)

    B, C, H, W = condition.shape
    x_t = torch.randn((B, C, H, W), device=device)

    for t in reversed(range(num_steps)):
        t_batch = torch.full((B,), t, device=device, dtype=torch.long)

        pred_noise = model(x_t, t_batch, condition)

        rec_alpha = sqrt_recip_alphas[t]
        beta_t = betas[t]
        acp_t = alphas_cumprod[t]

        mean = rec_alpha * (x_t - (beta_t / torch.sqrt(1.0 - acp_t)) * pred_noise)

        if t > 0:
            noise = torch.randn_like(x_t)
            sigma = torch.sqrt(posterior_variance[t])
            x_t = mean + sigma * noise
        else:
            x_t = mean

    return x_t

In [None]:
@torch.no_grad()
def sample_ddim(model, condition, scheduler_params, device, num_steps, eta):
    betas = scheduler_params['betas']
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    T = len(betas)
    ddim_timesteps = torch.linspace(T - 1, 0, steps=num_steps, dtype=torch.float64).round().long().to(device)

    B, C, H, W = condition.shape
    x_t = torch.randn((B, C, H, W), device=device)

    for i in range(len(ddim_timesteps)):
        t = ddim_timesteps[i]
        t_batch = torch.full((B,), t, device=device, dtype=torch.long)

        pred_noise = model(x_t, t_batch, condition)
        alpha_cumprod_t = alphas_cumprod[t]

        x0_pred = (x_t - torch.sqrt(1 - alpha_cumprod_t) * pred_noise) / torch.sqrt(alpha_cumprod_t)
        x0_pred = x0_pred.clamp(-1.0, 1.0)
        
        if i == len(ddim_timesteps) - 1:
            x_t = x0_pred
        else:
            t_next = ddim_timesteps[i + 1]
            alpha_cumprod_next = alphas_cumprod[t_next]

            sigma = (
                eta * torch.sqrt(
                    (1 - alpha_cumprod_next) / (1 - alpha_cumprod_t)
                )
                * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_next)
            ).clamp(min=0)

            noise = torch.randn_like(x_t) if eta > 0 else 0.0

            eps = 1e-5
            sqrt_term = torch.sqrt(torch.clamp(1 - alpha_cumprod_next - sigma ** 2, min=eps))

            x_t = (
                torch.sqrt(alpha_cumprod_next) * x0_pred
                + sqrt_term * pred_noise
                + sigma * noise
            )

    return x_t

# Diffusion Model Evaluation

In [None]:
def denorm(t):
    return (t.clamp(-1, 1) + 1) / 2

val_loader_full = DataLoader(FullImageDataset(filtered_pairs), batch_size=2, shuffle=False)

In [None]:
def model_evaluation(model, val_loader, scheduler_params, device, num_steps, 
                     eta, model_name="Model"):
    model.eval()
    psnr_list, ssim_list, lpips_list = [], [], []

    lpips_module = LPIPS(net_type='alex').to(device)
    fid_module = FID(feature=2048).to(device)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Evaluating {model_name}"):
            x_blur = batch['condition'].to(device)
            x_sharp = batch['target'].to(device)

            # Generate deblurred output
            x_fake = sample_ddim(model, x_blur, scheduler_params, device, num_steps, eta)

            # Add predictions and GT to FID module
            fid_module.update(denorm(x_fake).mul(255).byte(), real=False)
            fid_module.update(denorm(x_sharp).mul(255).byte(), real=True)

            for b in range(x_blur.size(0)):
                pred = denorm(x_fake[b].unsqueeze(0))
                target = denorm(x_sharp[b].unsqueeze(0))

                psnr_value = psnr_fn(pred, target, data_range=1.0).item()
                ssim_value = ssim_fn(pred, target, data_range=1.0).item()
                lpips_value = lpips_module(pred, target).item()

                psnr_list.append(psnr_value)
                ssim_list.append(ssim_value)
                lpips_list.append(lpips_value)

    psnr_avg = np.mean(psnr_list)
    ssim_avg = np.mean(ssim_list)
    lpips_avg = np.mean(lpips_list)
    fid_value = fid_module.compute().item()

    print(f"{'Model':<15}{'PSNR (dB)':>12}{'SSIM':>12}{'LPIPS':>12}{'FID':>12}")
    print("-" * 63)
    print(f"{model_name:<15}{psnr_avg:>12.2f}{ssim_avg:>12.4f}{lpips_avg:>12.4f}{fid_value:>12.4f}")

    return {
        "PSNR": psnr_avg,
        "SSIM": ssim_avg,
        "LPIPS": lpips_avg,
        "FID": fid_value
    }

In [None]:
ema.apply_shadow()

metrics_unet = model_evaluation(
    model=model_diffunet,
    val_loader=val_loader_full ,
    scheduler_params=scheduler_params,
    device=device,
    num_steps=10, 
    eta=0.3, 
    model_name="UNet Diffusion"
)

ema.restore()

In [None]:
def visualize_diffusion(model, val_loader, scheduler_params, device, num_steps, eta, batch):
    model.eval()
    #batch = next(iter(val_loader))
    x_blur = batch['condition'].to(device)
    x_sharp = batch['target'].to(device)

    # Generate deblurred output
    with torch.no_grad():
        x_fake = sample_ddim(model, x_blur, scheduler_params, device, num_steps, eta)

    # Denormalize to [0, 1]
    blur_vis = denorm(x_blur).cpu()
    fake_vis = denorm(x_fake).cpu()
    sharp_vis = denorm(x_sharp).cpu()

    B = blur_vis.size(0)

    # Visualize first few examples in batch
    n_show = min(B, 4)

    fig, axs = plt.subplots(n_show, 3, figsize=(12, 4 * n_show))

    if n_show == 1:
        axs = [axs]

    for i in range(n_show):
        blur_img = np.transpose(blur_vis[i].numpy(), (1, 2, 0))
        fake_img = np.transpose(fake_vis[i].numpy(), (1, 2, 0))
        sharp_img = np.transpose(sharp_vis[i].numpy(), (1, 2, 0))

        axs[i][0].imshow(blur_img)
        axs[i][0].set_title(f"Blurry Input [{i}]")
        axs[i][0].axis("off")

        axs[i][1].imshow(fake_img)
        axs[i][1].set_title(f"Deblurred Output [{i}]")
        axs[i][1].axis("off")

        axs[i][2].imshow(sharp_img)
        axs[i][2].set_title(f"Ground Truth [{i}]")
        axs[i][2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
ema.apply_shadow()

val_iter = iter(val_loader)
batch1 = next(val_iter)
visualize_diffusion(
    model=model_diffunet, 
    val_loader=val_loader_full,
    scheduler_params=scheduler_params,
    device=device,
    num_steps=10, 
    eta=0.3, 
    batch=batch1
)

ema.restore()