In [1]:
# ---- srgan_training_eval.py ----
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm
import os
import requests, zipfile, io
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np
import glob

In [2]:
# --- Download DIV2K Dataset ---
def download_div2k(root_dir):
    os.makedirs(root_dir, exist_ok=True)
    url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
    print("Downloading DIV2K HR dataset...")
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall(root_dir)

    url_lr = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip"
    print("Downloading DIV2K LR dataset...")
    r = requests.get(url_lr)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall(root_dir)

In [12]:
from torchvision.transforms import Resize

In [13]:
# --- Custom Dataset Loader ---
class DIV2KDataset(Dataset):
    def __init__(self, root_dir):
        self.hr_dir = os.path.join(root_dir, 'DIV2K_train_HR')
        self.lr_dir = os.path.join(root_dir, 'DIV2K_train_LR_bicubic', 'X4')
        self.hr_files = sorted(glob.glob(os.path.join(self.hr_dir, '*.png')))
        self.lr_files = sorted(glob.glob(os.path.join(self.lr_dir, '*.png')))

        self.hr_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
        self.lr_transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_files[idx]).convert('RGB')
        lr_img = Image.open(self.lr_files[idx]).convert('RGB')
        return self.lr_transform(lr_img), self.hr_transform(hr_img)

In [14]:
# --- Generator ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

In [15]:
class GeneratorSR(nn.Module):
    def __init__(self, in_channels=3, num_residuals=16):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, 9, 1, 4),
            nn.PReLU()
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residuals)])
        self.mid = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64)
        )
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.final = nn.Conv2d(64, in_channels, 9, 1, 4)

    def forward(self, x):
        x1 = self.initial(x)
        x2 = self.res_blocks(x1)
        x3 = self.mid(x2)
        x = x1 + x3
        x = self.upsample(x)
        return self.final(x)

In [16]:
# --- Discriminator ---
class DiscriminatorSR(nn.Module):
    def __init__(self):
        super().__init__()
        def conv_block(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )

        self.model = nn.Sequential(
            conv_block(3, 64, 1),
            conv_block(64, 64, 2),
            conv_block(64, 128, 1),
            conv_block(128, 128, 2),
            conv_block(128, 256, 1),
            conv_block(256, 256, 2),
            conv_block(256, 512, 1),
            conv_block(512, 512, 2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.model(x)


In [17]:
# --- Perceptual Loss using VGG19 ---
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg19 = models.vgg19(pretrained=True).features
        self.slice = nn.Sequential(*[vgg19[i] for i in range(36)])
        for p in self.slice.parameters():
            p.requires_grad = False

    def forward(self, x):
        return self.slice(x)

# --- GAN Loss ---
def gan_loss(dis_real, dis_fake):
    real_loss = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_real))
    fake_loss = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_fake))
    return real_loss + fake_loss

In [27]:
# --- Training Function ---
def train(root_dir, epochs=10, batch_size=16, lr=1e-4, save_dir='results_srgan'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    download_div2k(root_dir)
    dataset = DIV2KDataset(root_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)


    generator = GeneratorSR().to(device)
    discriminator = DiscriminatorSR().to(device)
    feature_extractor = VGGFeatureExtractor().to(device)

    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.999))
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.9, 0.999))

    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(1, epochs + 1):
        g_loss_total, d_loss_total = 0.0, 0.0
        psnr_epoch, ssim_epoch = [], []

        loop = tqdm(dataloader, desc=f"Epoch [{epoch}/{epochs}]")
        for lr_imgs, hr_imgs in loop:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

            # Generator forward
            fake_hr = generator(lr_imgs)
            dis_fake = discriminator(fake_hr)
            dis_real = discriminator(hr_imgs)

            content_loss = F.mse_loss(feature_extractor(fake_hr), feature_extractor(hr_imgs))
            adv_loss = gan_loss(dis_real, dis_fake)
            pixel_loss = F.mse_loss(fake_hr, hr_imgs)
            g_loss = content_loss + 1e-3 * adv_loss + pixel_loss

            opt_g.zero_grad()
            g_loss.backward()
            opt_g.step()

            # Discriminator forward
            dis_real = discriminator(hr_imgs)
            dis_fake = discriminator(fake_hr.detach())
            d_loss = gan_loss(dis_real, dis_fake)

            opt_d.zero_grad()
            d_loss.backward()
            opt_d.step()

            g_loss_total += g_loss.item()
            d_loss_total += d_loss.item()

            # Metrics
            for i in range(hr_imgs.size(0)):
                gen_img = fake_hr[i].detach().cpu().numpy().transpose(1, 2, 0)
                gt_img = hr_imgs[i].detach().cpu().numpy().transpose(1, 2, 0)
                psnr_epoch.append(psnr(gt_img, gen_img, data_range=1.0))
                ssim_epoch.append(ssim(gt_img, gen_img, channel_axis=-1, data_range=1.0, win_size=5))

        avg_psnr = np.mean(psnr_epoch)
        avg_ssim = np.mean(ssim_epoch)
        print(f"\nEpoch {epoch}: G_Loss={g_loss_total:.4f}, D_Loss={d_loss_total:.4f}, PSNR={avg_psnr:.4f}, SSIM={avg_ssim:.4f}")

        save_image(fake_hr, os.path.join(save_dir, f"gen_epoch_{epoch}.png"))
        save_image(hr_imgs, os.path.join(save_dir, f"gt_epoch_{epoch}.png"))


In [28]:
train(root_dir="data", epochs=20, batch_size=16)


Downloading DIV2K HR dataset...
Downloading DIV2K LR dataset...


Epoch [1/20]: 100%|██████████| 50/50 [01:42<00:00,  2.05s/it]



Epoch 1: G_Loss=15.1875, D_Loss=36.5347, PSNR=15.6927, SSIM=0.1707


Epoch [2/20]: 100%|██████████| 50/50 [01:40<00:00,  2.01s/it]



Epoch 2: G_Loss=13.6784, D_Loss=7.6783, PSNR=18.3875, SSIM=0.2918


Epoch [3/20]: 100%|██████████| 50/50 [01:42<00:00,  2.05s/it]



Epoch 3: G_Loss=13.2451, D_Loss=3.2596, PSNR=18.6784, SSIM=0.3196


Epoch [4/20]: 100%|██████████| 50/50 [01:41<00:00,  2.03s/it]



Epoch 4: G_Loss=12.6648, D_Loss=0.7558, PSNR=18.8250, SSIM=0.3429


Epoch [5/20]: 100%|██████████| 50/50 [01:41<00:00,  2.04s/it]



Epoch 5: G_Loss=11.8520, D_Loss=0.7488, PSNR=19.3278, SSIM=0.3736


Epoch [6/20]: 100%|██████████| 50/50 [01:40<00:00,  2.02s/it]



Epoch 6: G_Loss=11.2667, D_Loss=7.7763, PSNR=19.6989, SSIM=0.4061


Epoch [7/20]: 100%|██████████| 50/50 [01:40<00:00,  2.01s/it]



Epoch 7: G_Loss=10.6405, D_Loss=0.7042, PSNR=20.1317, SSIM=0.4367


Epoch [8/20]: 100%|██████████| 50/50 [01:40<00:00,  2.02s/it]



Epoch 8: G_Loss=10.2551, D_Loss=0.6451, PSNR=20.2980, SSIM=0.4540


Epoch [9/20]: 100%|██████████| 50/50 [01:41<00:00,  2.02s/it]



Epoch 9: G_Loss=9.8603, D_Loss=0.3111, PSNR=20.5785, SSIM=0.4655


Epoch [10/20]: 100%|██████████| 50/50 [01:40<00:00,  2.02s/it]



Epoch 10: G_Loss=9.6089, D_Loss=0.3936, PSNR=20.5206, SSIM=0.4714


Epoch [11/20]: 100%|██████████| 50/50 [01:42<00:00,  2.05s/it]



Epoch 11: G_Loss=9.2974, D_Loss=1.3469, PSNR=20.7908, SSIM=0.4733


Epoch [12/20]: 100%|██████████| 50/50 [01:42<00:00,  2.06s/it]



Epoch 12: G_Loss=9.0673, D_Loss=0.2538, PSNR=20.9319, SSIM=0.4749


Epoch [13/20]: 100%|██████████| 50/50 [01:43<00:00,  2.06s/it]



Epoch 13: G_Loss=8.8900, D_Loss=0.1263, PSNR=20.8052, SSIM=0.4782


Epoch [14/20]: 100%|██████████| 50/50 [01:41<00:00,  2.04s/it]



Epoch 14: G_Loss=8.6925, D_Loss=0.1081, PSNR=20.8754, SSIM=0.4783


Epoch [15/20]: 100%|██████████| 50/50 [01:40<00:00,  2.01s/it]



Epoch 15: G_Loss=8.4603, D_Loss=0.0675, PSNR=20.9903, SSIM=0.4793


Epoch [16/20]: 100%|██████████| 50/50 [01:41<00:00,  2.02s/it]



Epoch 16: G_Loss=8.2916, D_Loss=0.0990, PSNR=20.8745, SSIM=0.4778


Epoch [17/20]: 100%|██████████| 50/50 [01:40<00:00,  2.01s/it]



Epoch 17: G_Loss=8.0993, D_Loss=0.2380, PSNR=21.0113, SSIM=0.4767


Epoch [18/20]: 100%|██████████| 50/50 [01:41<00:00,  2.03s/it]



Epoch 18: G_Loss=7.9657, D_Loss=0.1030, PSNR=20.8988, SSIM=0.4752


Epoch [19/20]: 100%|██████████| 50/50 [01:41<00:00,  2.03s/it]



Epoch 19: G_Loss=7.7901, D_Loss=0.0513, PSNR=20.9662, SSIM=0.4755


Epoch [20/20]: 100%|██████████| 50/50 [01:40<00:00,  2.01s/it]



Epoch 20: G_Loss=7.6448, D_Loss=0.0417, PSNR=21.0083, SSIM=0.4739
