In [None]:
# Import necessary libraries
import os
import glob
import random
import math

import torch
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.models import vgg19


# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# %% [code]
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, crop_size=96, scale=4, is_train=True):
        """
        hr_dir: Directory containing high resolution images.
        lr_dir: Directory containing low resolution images (e.g., bicubic downscaled images).
        crop_size: Size of the HR crop (for training, a random crop will be taken).
        scale: The downscaling factor (e.g., 4 for x4 super-resolution).
        is_train: Flag indicating training or validation mode.
        """
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_image_files = sorted(glob.glob(os.path.join(hr_dir, '*.png')) +
                                     glob.glob(os.path.join(hr_dir, '*.jpg')))
        self.lr_image_files = sorted(glob.glob(os.path.join(lr_dir, '*.png')) +
                                     glob.glob(os.path.join(lr_dir, '*.jpg')))
        self.crop_size = crop_size
        self.scale = scale
        self.is_train = is_train

        # Define transforms for HR and LR images (initially using ToTensor only)
        self.hr_transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.lr_transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.hr_image_files)

    def __getitem__(self, idx):
        # Load the HR and LR images
        hr_image = Image.open(self.hr_image_files[idx]).convert("RGB")
        lr_image = Image.open(self.lr_image_files[idx]).convert("RGB")
        
        # For training, perform a random crop on the HR image and derive the corresponding LR crop
        if self.is_train:
            # Random crop on HR
            hr_width, hr_height = hr_image.size
            if hr_width < self.crop_size or hr_height < self.crop_size:
                hr_image = hr_image.resize((max(hr_width, self.crop_size), max(hr_height, self.crop_size)), Image.BICUBIC)
                lr_image = lr_image.resize((max(lr_image.size), max(lr_image.size)), Image.BICUBIC)
                hr_width, hr_height = hr_image.size

            # Coordinates for crop
            left = random.randint(0, hr_width - self.crop_size)
            top = random.randint(0, hr_height - self.crop_size)
            hr_crop = hr_image.crop((left, top, left + self.crop_size, top + self.crop_size))
            
            # Derive LR crop by resizing the HR crop to the LR size
            lr_crop = hr_crop.resize((self.crop_size // self.scale, self.crop_size // self.scale), Image.BICUBIC)
            
            hr_image = hr_crop
            lr_image = lr_crop
        else:
            # For validation: center crop if needed (or keep full image)
            pass  # You can add center cropping here if desired

        # Apply tensor transforms
        hr_tensor = self.hr_transform(hr_image)
        lr_tensor = self.lr_transform(lr_image)

        return {"lr": lr_tensor, "hr": hr_tensor}

# Example file paths on Kaggle:
# Modify these paths according to your Kaggle dataset structure.
TRAIN_HR_DIR = "/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_train_HR"
TRAIN_LR_DIR = "/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_train_LR_bicubic_X4/X4"

# Create a dataset instance and data loader
train_dataset = DIV2KDataset(hr_dir=TRAIN_HR_DIR, lr_dir=TRAIN_LR_DIR, crop_size=96, scale=4, is_train=True)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

print("Number of training images:", len(train_dataset))


In [None]:
# %% [code]
def show_sample_pair(dataset, idx=0):
    sample = dataset[idx]
    lr_img = transforms.ToPILImage()(sample["lr"])
    hr_img = transforms.ToPILImage()(sample["hr"])
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    axes[0].imshow(lr_img)
    axes[0].set_title("LR Image")
    axes[0].axis("off")
    
    axes[1].imshow(hr_img)
    axes[1].set_title("HR Image")
    axes[1].axis("off")
    
    plt.show()

# Display a sample pair from the training set
show_sample_pair(train_dataset, idx=10)

# Generator model

In [None]:
# %% [code]
import torch
import torch.nn as nn
from torchvision import transforms

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, num_residual_blocks=16, upscale_factor=4):
        super(Generator, self).__init__()
        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks
        res_blocks = []
        for _ in range(num_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)
        
        # Mid convolution layer after residual blocks
        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # Upsampling layers using PixelShuffle – for upscale factor 4, use 2 steps of x2 upscaling
        upsample_layers = []
        num_upsamples = upscale_factor // 2
        for _ in range(num_upsamples):
            upsample_layers += [
                nn.Conv2d(64, 256, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.ReLU(inplace=True)
            ]
        self.upsample = nn.Sequential(*upsample_layers)
        
        # Final output convolution
        self.final = nn.Conv2d(64, 3, kernel_size=9, padding=4)
    
    def forward(self, x):
        initial_out = self.initial(x)
        res_out = self.res_blocks(initial_out)
        mid = self.mid_conv(res_out)
        combined = initial_out + mid  # Skip connection
        upsampled = self.upsample(combined)
        output = self.final(upsampled)
        return output

# Instantiate the generator and move it to device:
G = Generator(num_residual_blocks=8, upscale_factor=4).to(device)  # using 8 blocks for faster training
print(G)

# Generator model

In [None]:
# %% [code]
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # A sequence of convolutional layers with increasing feature channels
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # Fully connected layers for binary classification
        self.fc = nn.Sequential(
            nn.Linear(512 * 6 * 6, 1024),  # assuming input patch size of 96x96 leads to 6x6 feature maps
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )
    
    def forward(self, x):
        batch_size = x.size(0)
        features = self.net(x)
        features = features.view(batch_size, -1)
        out = self.fc(features)
        return out

D = Discriminator().to(device)
print(D)

In [None]:
from torchvision.models import vgg19
import torch.nn.functional as F

class VGGFeatureExtractor(nn.Module):
    def __init__(self, layer='features.35', use_cuda=True):
        super().__init__()
        vgg = vgg19(pretrained=True).eval()
        self.features = nn.Sequential(*list(vgg.features.children())[:36])  # up through conv5_4
        for p in self.features.parameters(): p.requires_grad = False
        if use_cuda: self.features = self.features.to(device)

    def forward(self, img):
        # expects img in [0,1], normalize to VGG’s ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1,3,1,1)
        img_norm = (img - mean) / std
        return self.features(img_norm)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Apply to G and D
G.apply(weights_init)
D.apply(weights_init)


In [None]:
# Pixel-wise (content) loss
pixel_loss_fn       = nn.MSELoss()

# Adversarial loss
adv_loss_fn         = nn.BCEWithLogitsLoss()

# Loss weights
lambda_pixel        = 1.0
lambda_perceptual   = 0.01
lambda_adv          = 0.001   # start smaller for stability

In [None]:
optimizer_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.9, 0.999))

# Optional: decay LR every 50 epochs
from torch.optim.lr_scheduler import StepLR
scheduler_G = StepLR(optimizer_G, step_size=50, gamma=0.5)
scheduler_D = StepLR(optimizer_D, step_size=50, gamma=0.5)


In [None]:
pretrain_epochs = 50
for epoch in range(pretrain_epochs):
    G.train()
    epoch_loss = 0
    for batch in train_dataloader:
        lr = batch['lr'].to(device)
        hr = batch['hr'].to(device)
        
        optimizer_G.zero_grad()
        sr = G(lr)
        loss = pixel_loss_fn(sr, hr)
        loss.backward()
        optimizer_G.step()
        
        epoch_loss += loss.item()
    print(f"[Pretrain {epoch+1}/{pretrain_epochs}] MSE loss: {epoch_loss/len(train_dataloader):.4f}")

# Save the pretrained weights
torch.save(G.state_dict(), "generator_pretrained.pth")


In [None]:
adv_epochs = 200
content_extractor = VGGFeatureExtractor(use_cuda=torch.cuda.is_available())

for epoch in range(adv_epochs):
    G.train(); D.train()
    d_loss_running = 0
    g_loss_running = 0
    
    for batch in train_dataloader:
        lr = batch['lr'].to(device)
        hr = batch['hr'].to(device)
        bs = lr.size(0)
        
        # ——— Train Discriminator ———
        D.zero_grad()
        
        # Real
        real_out = D(hr)
        real_labels = torch.ones(bs, 1, device=device)
        loss_real = adv_loss_fn(real_out, real_labels)
        
        # Fake (from frozen G)
        with torch.no_grad():
            fake_hr = G(lr)
        fake_out = D(fake_hr)
        fake_labels = torch.zeros(bs, 1, device=device)
        loss_fake = adv_loss_fn(fake_out, fake_labels)
        
        d_loss = 0.5 * (loss_real + loss_fake)
        d_loss.backward()
        optimizer_D.step()
        
        # ——— Train Generator ———
        G.zero_grad()
        
        fake_hr = G(lr)
        pred_fake = D(fake_hr)
        adv_loss = adv_loss_fn(pred_fake, real_labels)
        
        perc_loss = F.l1_loss(
            content_extractor(fake_hr),
            content_extractor(hr)
        )
        pix_loss = pixel_loss_fn(fake_hr, hr)
        
        g_loss = (lambda_adv * adv_loss
                  + lambda_perceptual * perc_loss
                  + lambda_pixel * pix_loss)
        g_loss.backward()
        optimizer_G.step()
        
        d_loss_running += d_loss.item()
        g_loss_running += g_loss.item()
    
    # Step schedulers
    scheduler_G.step()
    scheduler_D.step()
    
    print(f"Epoch {epoch+1}/{adv_epochs} — D_loss: {d_loss_running/len(train_dataloader):.4f}, "
          f"G_loss: {g_loss_running/len(train_dataloader):.4f}")
    
    # Save checkpoints every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(G.state_dict(), f"generator_epoch{epoch+1}.pth")
        torch.save(D.state_dict(), f"discriminator_epoch{epoch+1}.pth")


In [None]:
import glob
import torch
import shutil
from skimage.metrics import peak_signal_noise_ratio

best_psnr = 0.0
best_ckpt = None

# Loop through all saved generator checkpoints
for ckpt_path in sorted(glob.glob("generator_epoch*.pth")):
    # Load weights
    G.load_state_dict(torch.load(ckpt_path, map_location=device))
    G.to(device).eval()
    
    # Compute PSNR over the entire validation set
    psnrs = []
    with torch.no_grad():
        for batch in valid_loader:
            lr = batch['lr'].to(device)
            hr = batch['hr'].to(device)
            
            # Forward pass
            sr = G(lr).clamp(0,1).cpu()
            hr_cpu = hr.cpu()
            
            # Convert to H×W×C numpy arrays
            sr_np = sr.squeeze(0).permute(1,2,0).numpy()
            hr_np = hr_cpu.squeeze(0).permute(1,2,0).numpy()
            
            # Measure PSNR
            psnrs.append(peak_signal_noise_ratio(hr_np, sr_np, data_range=1.0))
    
    avg_psnr = sum(psnrs) / len(psnrs)
    print(f"{ckpt_path:20s} → PSNR = {avg_psnr:.2f} dB")
    
    if avg_psnr > best_psnr:
        best_psnr = avg_psnr
        best_ckpt  = ckpt_path

print(f"\n🏆 Best checkpoint: {best_ckpt} with PSNR = {best_psnr:.2f} dB")

# Copy the best model to a stable filename
shutil.copy(best_ckpt, "generator_best.pth")
print("✅ Copied to generator_best.pth")


In [None]:
from skimage.metrics import peak_signal_noise_ratio
import torch.nn.functional as F

bic_psnrs = []
with torch.no_grad():
    for batch in valid_loader:
        lr = batch['lr']    # [1,3,h,w] ∈ [0,1]
        hr = batch['hr']    # [1,3,H,W] ∈ [0,1]
        # Bicubic upsample
        bic = F.interpolate(lr, size=hr.shape[-2:], mode='bicubic', align_corners=False)
        # Convert to numpy
        bic_np = bic.squeeze(0).permute(1,2,0).cpu().numpy()
        hr_np  = hr.squeeze(0).permute(1,2,0).cpu().numpy()
        bic_psnrs.append(peak_signal_noise_ratio(hr_np, bic_np, data_range=1.0))

print(f"🔹 Bicubic baseline PSNR: {sum(bic_psnrs)/len(bic_psnrs):.2f} dB")

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np
from skimage.filters import sobel

# 1) Load model on CPU
device = torch.device('cpu')
G = Generator(num_residual_blocks=8, upscale_factor=4)
G.load_state_dict(torch.load("generator_best.pth", map_location=device))
G.to(device).eval()

# 2) Grab one sample
idx = 1
sample = test_ds[idx]
lr  = sample['lr'].unsqueeze(0).to(device)   # [1,3,h,w]
hr  = sample['hr'].unsqueeze(0).to(device)   # [1,3,H,W]

# 3) Inference + bicubic baseline (both at HR size)
with torch.no_grad():
    sr  = G(lr).clamp(0,1)
bic = F.interpolate(lr, size=hr.shape[-2:], mode='bicubic', align_corners=False)

# 4) Convert to H×W×C numpy arrays
def to_np(x):
    return x.squeeze(0).permute(1,2,0).cpu().numpy()

bic_np, sr_np, hr_np = map(to_np, (bic, sr, hr))

# 5) Center‑crop up to 64×64
H, W, _ = hr_np.shape
patch = min(64, H, W)
half  = patch//2
ch, cw = H//2, W//2
y0, y1 = ch-half, ch+half
x0, x1 = cw-half, cw+half

bic_c = bic_np[y0:y1, x0:x1]
sr_c  = sr_np[y0:y1,  x0:x1]
hr_c  = hr_np[y0:y1,  x0:x1]

# 6) Difference heatmap
diff = np.abs(sr_c - bic_c).mean(axis=2)

# 7) Edge maps
edges_sr  = sobel(sr_c.mean(axis=2))
edges_bic = sobel(bic_c.mean(axis=2))

# 8) Plot 2×3 grid
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

axes[0,0].imshow(bic_c, interpolation='nearest')
axes[0,0].set_title("Bicubic Crop")
axes[0,0].axis('off')

axes[0,1].imshow(sr_c, interpolation='nearest')
axes[0,1].set_title("SRGAN Crop")
axes[0,1].axis('off')

axes[0,2].imshow(hr_c, interpolation='nearest')
axes[0,2].set_title("Ground‑Truth HR")
axes[0,2].axis('off')

axes[1,0].imshow(diff, cmap='magma')
axes[1,0].set_title("|SR – Bicubic|")
axes[1,0].axis('off')

axes[1,1].imshow(sr_c, interpolation='nearest')
axes[1,1].contour(edges_sr,  colors='r', linewidths=1)
axes[1,1].set_title("SRGAN Edges (red)")
axes[1,1].axis('off')

axes[1,2].imshow(bic_c, interpolation='nearest')
axes[1,2].contour(edges_bic, colors='b', linewidths=1)
axes[1,2].set_title("Bicubic Edges (blue)")
axes[1,2].axis('off')

plt.tight_layout()
plt.show()
