In [None]:
%%capture

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
%%capture

!pip install lpips

In [None]:
%%capture
!pip install piq

In [None]:
import torch
import os
import time
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
import numpy as np
import lpips
import torch.nn.functional as F
from piq import SSIMLoss, MultiScaleSSIMLoss
from torchmetrics.functional import peak_signal_noise_ratio as psnr

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# define custom dataset class

class ImageRestorationDataset(Dataset):

  def __init__(self, corrupted_dir, clean_dir, transform=None):
    self.corrupted_dir = corrupted_dir
    self.clean_dir = clean_dir
    self.transform = transform

    # get sorted list of image file names from the corrupted image folder
    self.filenames = sorted(os.listdir(corrupted_dir))

  def __len__(self):
    return len(self.filenames)

  def __getitem__(self, idx):

    # get file paths
    corrupted_path = os.path.join(self.corrupted_dir, self.filenames[idx])
    clean_path = os.path.join(self.clean_dir, self.filenames[idx])


    # open images and convert to RGB format
    corrupted_image = Image.open(corrupted_path).convert("RGB")
    clean_image = Image.open(clean_path).convert("RGB")

    if self.transform:
      corrupted_image = self.transform(corrupted_image)
      clean_image = self.transform(clean_image)

    return corrupted_image, clean_image

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])

val_clean_dir = '/kaggle/input/clearvision-image-dataset/val_clean_images'
val_corrupted_dir = '/kaggle/input/clearvision-image-dataset/val_corrupted_images'

# Validation dataset
val_dataset = ImageRestorationDataset(val_corrupted_dir, val_clean_dir, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(x + self.block(x))

In [None]:

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super(UNetGenerator, self).__init__()

        # Encoder with residuals
        self.down1 = nn.Sequential(self._contract_block(in_channels, features, use_batchnorm=False), ResidualBlock(features))
        self.down2 = nn.Sequential(self._contract_block(features, features*2), ResidualBlock(features*2))
        self.down3 = nn.Sequential(self._contract_block(features*2, features*4), ResidualBlock(features*4))
        self.down4 = nn.Sequential(self._contract_block(features*4, features*8), ResidualBlock(features*8))

        # Decoder with residuals
        self.up1 = nn.Sequential(self._expand_block(features*8, features*4), ResidualBlock(features*4))
        self.up2 = nn.Sequential(self._expand_block(features*8, features*2), ResidualBlock(features*2))
        self.up3 = nn.Sequential(self._expand_block(features*4, features), ResidualBlock(features))

        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _contract_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_batchnorm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        if use_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def _expand_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u2 = self.up2(torch.cat([u1, d3], dim=1))
        u3 = self.up3(torch.cat([u2, d2], dim=1))

        output = self.final(torch.cat([u3, d1], dim=1))
        return output

In [None]:
generator = UNetGenerator().to(device)
generator.load_state_dict(torch.load('/kaggle/input/gan-checkpoints/final_generator_best.pth', map_location=device))

<All keys matched successfully>

In [None]:
# losses

# Loss functions and optimizers
adversarial_loss = torch.nn.BCEWithLogitsLoss()
pixelwise_loss   = torch.nn.L1Loss()

# Perceptual alternatives
ssim_loss        = SSIMLoss(data_range=2.0).to(device)         # outputs are in [-1,1]
ms_ssim_loss     = MultiScaleSSIMLoss(data_range=2.0).to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
lpips_model.eval()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
scaler_G = torch.cuda.amp.GradScaler()

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 202MB/s] 


Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/alex.pth


  scaler_G = torch.cuda.amp.GradScaler()


In [None]:
gen_dir = '/kaggle/working/generated'
os.makedirs(gen_dir, exist_ok=True)

In [None]:
def rescale_to_01(x):
    return (x + 1) / 2

In [None]:
val_ssim_c    = 0.0
val_ms_ssim_c = 0.0
val_psnr_c    = 0.0
val_lpips_c   = 0.0

val_ssim_g    = 0.0
val_ms_ssim_g = 0.0
val_psnr_g    = 0.0
val_lpips_g   = 0.0

total_inference_time = 0.0
total_images = 0

print("Running GAN on corrupted images...")

input_filenames = sorted(os.listdir(val_corrupted_dir))

with torch.no_grad():
    for batch_idx, (corrupted, clean) in enumerate(tqdm(val_dataloader)):
        if corrupted is None or clean is None:
            continue
        corrupted = corrupted.to(device).float()
        clean = clean.to(device).float()

        start_time = time.time()
        fake = generator(corrupted).float()
        end_time = time.time()

        batch_time = (end_time - start_time)
        total_inference_time += batch_time
        total_images += corrupted.size(0)

        fake_res      = rescale_to_01(fake)
        clean_res     = rescale_to_01(clean)
        corrupted_res = rescale_to_01(corrupted)

        val_ssim_c      += (1 - ssim_loss(corrupted_res, clean_res)).item()
        val_ms_ssim_c   += (1 - ms_ssim_loss(corrupted_res, clean_res)).item()
        val_psnr_c      += psnr(corrupted_res, clean_res).item()
        val_lpips_c     += lpips_model(corrupted_res * 2 - 1, clean_res * 2 - 1).mean().item()

        val_ssim_g      += (1 - ssim_loss(fake_res, clean_res)).item()
        val_ms_ssim_g   += (1 - ms_ssim_loss(fake_res, clean_res)).item()
        val_psnr_g      += psnr(fake_res, clean_res).item()
        val_lpips_g     += lpips_model(fake_res * 2 - 1, clean_res * 2 - 1).mean().item()

        for i in range(fake_res.size(0)):
            filename = input_filenames[batch_idx * val_dataloader.batch_size + i]
            img_to_save = TF.to_pil_image(fake_res[i].cpu())
            img_to_save.save(os.path.join(gen_dir, filename))


val_ssim_c    /= len(val_dataloader)
val_ms_ssim_c /= len(val_dataloader)
val_psnr_c    /= len(val_dataloader)
val_lpips_c   /= len(val_dataloader)

val_ssim_g    /= len(val_dataloader)
val_ms_ssim_g /= len(val_dataloader)
val_psnr_g    /= len(val_dataloader)
val_lpips_g   /= len(val_dataloader)

avg_latency_per_image = total_inference_time / total_images
avg_latency_per_batch = total_inference_time / len(val_dataloader)

print(f"Before GAN → SSIM: {val_ssim_c:.4f}, MS-SSIM: {val_ms_ssim_c:.4f}, PSNR: {val_psnr_c:.2f}, LPIPS: {val_lpips_c:.4f}")
print(f"After GAN → SSIM: {val_ssim_g:.4f}, MS-SSIM: {val_ms_ssim_g:.4f}, PSNR: {val_psnr_g:.2f}, LPIPS: {val_lpips_g:.4f}")
print(f"Avg inference latency → {avg_latency_per_image*1000:.2f} ms/image | {avg_latency_per_batch:.4f} s/batch")

Running GAN on corrupted images...


100%|██████████| 321/321 [02:23<00:00,  2.23it/s]

Before GAN → SSIM: 0.8076, MS-SSIM: 0.9185, PSNR: 22.09, LPIPS: 0.4203
After GAN → SSIM: 0.9235, MS-SSIM: 0.9828, PSNR: 28.69, LPIPS: 0.0979
Avg inference latency → 0.18 ms/image | 0.0059 s/batch





In [None]:
import shutil
from IPython.display import FileLink

# Replace with your output folder path
output_folder = "/kaggle/working/generated"

# Create a zip archive of the output folder
shutil.make_archive("generated_output_1", "zip", output_folder)
print("Output folder zipped as generated_output.zip")

# Display a clickable download link
display(FileLink("generated_output_1.zip"))


Output folder zipped as generated_output.zip
