### SRGAN

This notebook implements SRGAN model along with training and validation data creation.

In [4]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Super-resolution/SRGAN

Mounted at /content/drive
/content/drive/MyDrive/Super-resolution/SRGAN


In [5]:
"""
Import Library
"""
from torch import nn
import h5py
import numpy as np
import glob
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import torch
from tqdm import tqdm
from collections import namedtuple
import copy
import math
from torch.autograd import Variable
import pandas as pd
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
import random
from torch import nn

In [6]:
"""
SRGAN model
"""
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(32)
        self.block3 = ResidualBlock(32)
        self.block4 = ResidualBlock(32)
        self.block5 = ResidualBlock(32)
        self.block6 = ResidualBlock(32)
        self.block7 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32)
        )
        block8 = [UpsampleBLock(32, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(32, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

In [4]:
# """
# Dataset feeding
# """
# class CustomDataset(Dataset):
#     def __init__(self, h5_file):
#         super(CustomDataset, self).__init__()
#         self.h5_file = h5_file

#     def __getitem__(self, idx):
#         with h5py.File(self.h5_file, 'r') as f:
#             return f['lr'][idx], f['hr'][idx]

#     def __len__(self):
#         with h5py.File(self.h5_file, 'r') as f:
#             return len(f['lr'])


In [7]:
"""
Loss Functions
"""
from torchvision.models.vgg import vgg16

# TV loss is optional but implemented in paper
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        # use VGG16 for loss calculation
        vgg = vgg16(pretrained=True, progress=False)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss



# Custom dataloader

In [8]:
# Custom dataset class to load images
class CustomDataset(Dataset):
    def __init__(self, lr_image_path, hr_image_path,num, transform=None):
        """
        Custom dataset to load low-resolution (LR) and high-resolution (HR) images.

        :param lr_image_path: Path to low-resolution images
        :param hr_image_path: Path to high-resolution images
        :param transform: Optional transformation to apply to images
        """
        self.lr_image_path = lr_image_path
        self.hr_image_path = hr_image_path
        self.transform = transform

        # Get list of image file paths
        self.lr_image_list = glob.glob(lr_image_path)[:num]
        self.hr_image_list = glob.glob(hr_image_path)[:num]

        # Shuffle the lists (optional)
        random.shuffle(self.lr_image_list)
        random.shuffle(self.hr_image_list)

        # Ensure both lists have the same length (minimum of the two lengths)
        self.num_images = min(len(self.lr_image_list), len(self.hr_image_list))

    def __getitem__(self, idx):
        # Open the images (convert them to RGB)
        lr_image = Image.open(self.lr_image_list[idx]).convert('RGB')
        hr_image = Image.open(self.hr_image_list[idx]).convert('RGB')

        # Convert images to numpy arrays
        lr_image = np.array(lr_image).astype(np.float32)
        hr_image = np.array(hr_image).astype(np.float32)

        # Transpose to match PyTorch image format (C, H, W)
        lr_image = np.transpose(lr_image, axes=[2, 0, 1])  # Convert to C, H, W
        hr_image = np.transpose(hr_image, axes=[2, 0, 1])  # Convert to C, H, W

        # Normalize the image to [0, 1] range
        lr_image /= 255.0
        hr_image /= 255.0

        # Apply any transformations if provided
        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image

    def __len__(self):
        return self.num_images

In [9]:
# Create DataLoader for training and evaluation datasets
def create_dataloader(lr_image_path, hr_image_path,num, batch_size=1):
    dataset = CustomDataset(lr_image_path, hr_image_path,num)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=True)
    return dataloader

In [10]:
# Set paths to the directories containing LR and HR images
lr_train_dir = 'train/images_stage5/*.png'  # Path for low-resolution training images
hr_train_dir = 'train/images_stage3/*.png'  # Path for high-resolution training images

# Set paths to the directories containing LR and HR images for validation set
lr_valid_dir = 'valid/images_stage5/*.png'  # Path for low-resolution validation images
hr_valid_dir = 'valid/images_stage3/*.png'  # Path for high-resolution validation images

In [11]:
train_dataset = CustomDataset(lr_train_dir, hr_train_dir, num=1000)

# Create DataLoader for training
train_dataloader = create_dataloader(lr_train_dir, hr_train_dir,1000, batch_size=4)

eval_dataset = CustomDataset(lr_valid_dir, hr_valid_dir, num=100)

# Create DataLoader for evaluation
eval_dataloader = create_dataloader(lr_valid_dir, hr_valid_dir, 100,batch_size=4)

In [12]:
"""
Setup network parameter
"""
upscale_factor = 4
num_epoch = 1  #20

torch.manual_seed(123)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
"""
Setup network
"""
netG = Generator(upscale_factor)
netD = Discriminator()
generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.to(device)
    netD.to(device)
    generator_criterion.to(device)

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


In [12]:
"""
Util function to measure error
"""
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

"""
Calculate PSNR
"""
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


In [13]:
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': []}
best_weights = copy.deepcopy(netG.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(1, num_epoch + 1):

    epoch_losses = AverageMeter()
    netG.train()
    netD.train()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % 1)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epoch))

        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        # training
        netG.train()
        netD.train()

        for data in train_dataloader:
            inputs, labels = data

            g_update_first = True
            batch_size = inputs.size(0)
            running_results['batch_sizes'] += batch_size

            # Update D network
            real_img = Variable(labels).to(device, dtype=torch.float)
            z = Variable(inputs).to(device, dtype=torch.float)

            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)

            # Update G network
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

            epoch_losses.update(g_loss.item(), len(inputs))

            optimizerD.step()
            optimizerG.step()

            # Loss for current batch
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            t.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epoch, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))
            t.update(len(inputs))

        torch.save(netG.state_dict(), 'weight_srgan/netG_epoch_%d.pth' % epoch)
        torch.save(netD.state_dict(), 'weight_srgan/netD_epoch_%d.pth' % epoch)

        # validation
        netG.eval()
        epoch_psnr = AverageMeter()

        with torch.no_grad():
            val_images = []
            for data in eval_dataloader:
                inputs, labels = data
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.float)

                preds = netG(inputs)

                epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

            print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

            if epoch_psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = epoch_psnr.avg
                best_weights = copy.deepcopy(netG.state_dict())

[1/1] Loss_D: 0.2191 Loss_G: 0.0353 D(x): 0.9082 D(G(z)): 0.1273: 100%|██████████| 1000/1000 [12:51<00:00,  1.30it/s]

eval psnr: 15.01





In [1]:
# """
# Evaluate the model with test set
# """
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# model = Generator(upscale_factor).to(device)
# state_dict = model.state_dict()
# for n, p in torch.load('weight_srgan/netG_epoch_1.pth', map_location=lambda storage, loc: storage).items():
#     if n in state_dict.keys():
#         state_dict[n].copy_(p)
#     else:
#         raise KeyError(n)

# model.eval()


# Loading pretrained weights

In [16]:
# Define paths to the saved weights
generator_path = 'weight_srgan/netG_epoch_1.pth'  # Replace with the desired epoch
discriminator_path = 'weight_srgan/netD_epoch_1.pth'

# Load model weights
netG.load_state_dict(torch.load(generator_path,map_location=torch.device('cpu') ))
netD.load_state_dict(torch.load(discriminator_path,map_location=torch.device('cpu') ))

  netG.load_state_dict(torch.load(generator_path,map_location=torch.device('cpu') ))
  netD.load_state_dict(torch.load(discriminator_path,map_location=torch.device('cpu') ))


<All keys matched successfully>

In [1]:
# results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': []}
# best_weights = copy.deepcopy(netG.state_dict())
# best_epoch = 0
# best_psnr = 0.0

# for epoch in range(2, num_epoch + 2):

#     epoch_losses = AverageMeter()
#     netG.train()
#     netD.train()

#     with tqdm(total=(len(train_dataset) - len(train_dataset) % 1)) as t:
#         t.set_description('epoch: {}/{}'.format(epoch, num_epoch))

#         running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

#         # training
#         netG.train()
#         netD.train()

#         for data in train_dataloader:
#             inputs, labels = data

#             g_update_first = True
#             batch_size = inputs.size(0)
#             running_results['batch_sizes'] += batch_size

#             # Update D network
#             real_img = Variable(labels).to(device, dtype=torch.float)
#             z = Variable(inputs).to(device, dtype=torch.float)

#             fake_img = netG(z)

#             netD.zero_grad()
#             real_out = netD(real_img).mean()
#             fake_out = netD(fake_img).mean()
#             d_loss = 1 - real_out + fake_out
#             d_loss.backward(retain_graph=True)

#             # Update G network
#             netG.zero_grad()
#             g_loss = generator_criterion(fake_out, fake_img, real_img)
#             g_loss.backward()

#             epoch_losses.update(g_loss.item(), len(inputs))

#             optimizerD.step()
#             optimizerG.step()

#             # Loss for current batch
#             running_results['g_loss'] += g_loss.item() * batch_size
#             running_results['d_loss'] += d_loss.item() * batch_size
#             running_results['d_score'] += real_out.item() * batch_size
#             running_results['g_score'] += fake_out.item() * batch_size

#             t.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
#                 epoch, num_epoch, running_results['d_loss'] / running_results['batch_sizes'],
#                 running_results['g_loss'] / running_results['batch_sizes'],
#                 running_results['d_score'] / running_results['batch_sizes'],
#                 running_results['g_score'] / running_results['batch_sizes']))
#             t.update(len(inputs))

#         torch.save(netG.state_dict(), 'weight_srgan/netG_epoch_%d.pth' % epoch)
#         torch.save(netD.state_dict(), 'weight_srgan/netD_epoch_%d.pth' % epoch)

#         # validation
#         netG.eval()
#         epoch_psnr = AverageMeter()

#         with torch.no_grad():
#             val_images = []
#             for data in eval_dataloader:
#                 inputs, labels = data
#                 inputs = inputs.to(device, dtype=torch.float)
#                 labels = labels.to(device, dtype=torch.float)

#                 preds = netG(inputs)

#                 epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

#             print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

#             if epoch_psnr.avg > best_psnr:
#                 best_epoch = epoch
#                 best_psnr = epoch_psnr.avg
#                 best_weights = copy.deepcopy(netG.state_dict())