# Single Image Super Resolution using CNN and GAN
Minor Project



## 1. Imports and Environment Setup


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import cv2
import numpy as np
import matplotlib.pyplot as plt

from torchvision.models import vgg19
from torchvision import transforms

## 2. Dataset Preparation (LRâ€“HR Image Pairs)


In [10]:
class SRDataset(Dataset):
    def __init__(self, image_paths, scale=4):
        self.image_paths = image_paths
        self.scale = scale

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Read HR image
        hr = cv2.imread(self.image_paths[idx])
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB)

        # FORCE SAME SIZE (VERY IMPORTANT)
        hr = cv2.resize(hr, (256, 256))

        h, w, _ = hr.shape

        # Create LR image
        lr = cv2.resize(hr, (w // self.scale, h // self.scale))
        lr = cv2.resize(lr, (w, h), interpolation=cv2.INTER_CUBIC)

        # Normalize
        hr = hr / 255.0
        lr = lr / 255.0

        # To tensor
        hr = torch.tensor(hr).permute(2, 0, 1).float()
        lr = torch.tensor(lr).permute(2, 0, 1).float()

        return lr, hr


## 3. SRCNN Model (Baseline CNN)


In [11]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x
def psnr(pred, target):
    mse = nn.functional.mse_loss(pred, target)
    if mse == 0:
        return 100
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# ---- DEFINE IMAGE PATHS (THIS IS THE ANSWER TO YOUR CONFUSION) ----
image_paths = [
    r"C:\Users\KIIT0001\Videos\Minor Proj 6th sem\images\hr1.jpg",
    r"C:\Users\KIIT0001\Videos\Minor Proj 6th sem\images\hr2.jpg",
    r"C:\Users\KIIT0001\Videos\Minor Proj 6th sem\images\hr3.jpg",
    r"C:\Users\KIIT0001\Videos\Minor Proj 6th sem\images\hr4.jpg",
    r"C:\Users\KIIT0001\Videos\Minor Proj 6th sem\images\hr5.jpg"
    
]

dataset = SRDataset(image_paths, scale=4)
loader = DataLoader(dataset, batch_size=1, shuffle=True)



# ---- MODEL ----
srcnn = SRCNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(srcnn.parameters(), lr=1e-4)

epochs = 5
for epoch in range(epochs):
    epoch_loss = 0.0
    for lr, hr in loader:
        sr = srcnn(lr)
        loss = criterion(sr, hr)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"SRCNN Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(loader):.6f}")
# ---- END OF CODE ----


SRCNN Epoch 1/5, Loss: 0.143988
SRCNN Epoch 2/5, Loss: 0.117418
SRCNN Epoch 3/5, Loss: 0.086303
SRCNN Epoch 4/5, Loss: 0.055813
SRCNN Epoch 5/5, Loss: 0.032425


## 4. SRGAN Model (Generative Super Resolution)


In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4),
            nn.PReLU()
        )
        self.residuals = nn.Sequential(*[ResidualBlock(64) for _ in range(5)])
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, 3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(64, 3, 9, padding=4)
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.residuals(x)
        return self.upsample(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features.children())[:36])
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)

In [14]:
G = Generator()
D = Discriminator()

content_loss = nn.MSELoss()
adversarial_loss = nn.BCELoss()

vgg = vgg19(pretrained=True).features[:36].eval()
for p in vgg.parameters():
    p.requires_grad = False

opt_G = optim.Adam(G.parameters(), lr=1e-4)
opt_D = optim.Adam(D.parameters(), lr=1e-4)

for epoch in range(3):
    for lr, hr in loader:
        fake_hr = G(lr)

        real_pred = D(hr)
        fake_pred = D(fake_hr.detach())

        d_loss = adversarial_loss(real_pred, torch.ones_like(real_pred)) + \
                 adversarial_loss(fake_pred, torch.zeros_like(fake_pred))

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        fake_pred = D(fake_hr)
        hr_resized = torch.nn.functional.interpolate(
        hr, size=fake_hr.shape[2:], mode="bilinear", align_corners=False
        )
        perceptual = content_loss(vgg(fake_hr), vgg(hr_resized))


        g_loss = content_loss(fake_hr, hr_resized) + \
                 1e-3 * adversarial_loss(fake_pred, torch.ones_like(fake_pred)) + \
                 0.006 * perceptual

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    print(f"SRGAN Epoch {epoch+1}")
    




SRGAN Epoch 1
SRGAN Epoch 2
SRGAN Epoch 3


## 5. Results and Comparison


In [15]:
import os
os.makedirs("results", exist_ok=True)

srcnn.eval()
G.eval()

for i in range(len(dataset)):
    lr, hr = dataset[i]

    with torch.no_grad():
        srcnn_sr = srcnn(lr.unsqueeze(0)).squeeze(0)
        srgan_sr = G(lr.unsqueeze(0)).squeeze(0)

    fig, axs = plt.subplots(1, 4, figsize=(12,4))
    titles = ["Low-Res", "SRCNN", "SRGAN", "High-Res"]
    images = [
        lr.permute(1,2,0),
        srcnn_sr.permute(1,2,0),   
        srgan_sr.permute(1,2,0),
        hr.permute(1,2,0)
    ]

    for j in range(4):
        axs[j].imshow(images[j])
        axs[j].set_title(titles[j])
        axs[j].axis("off")

    plt.savefig(f"results/result_{i+1}.png")
    plt.close()
 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08338928..0.43917587].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08302008..0.5069683].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.09568811..0.50984436].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.115445755..0.53162223].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07253943..0.39773285].
