# SRGAN

This notebook is a simple implementation of the SRGAN super resolution architecture. Further Improvements need to be done. 

### Key takeaways from the paper

- Ability of Mean Squared Error, Peak Signal to Noise Ratio to capture 'perceptually relevant'(like texture for instance) differences like high texture details is limited as both are based on pixel wise image differences. 

- Super resolved image may not be as photo realistic as the original image. 

- A Super Resolution Generative Adversarial Network is proposed, in which a deep residual network(ResNet) with skip connections. 

- A novel loss function is also proposed which uses high level feature maps of the VGG network. 

- Low resolution images were obtained by applying a Gaussian filter to the High Resolution Image, proceeded by a downsampling with a factor of r. 

- The ParametricReLU is used as the activation function. 

- The perceptual loss function was defined as a sum of content loss and adversarial loss. Weightage of the loss function is $1 : \frac{1}{1000}$(content loss : adversarial loss)

- Perceptual Loss > can use many humans to eval images, a far practical method is to use a pre trained network that has been trained on millions of images.

#### Things to read up on:
- MSE, PSNR 
- ResNet, skip connections
- Loss functions
- feature maps, vgg network

## Dependencies

In [1]:
import torch
import math
from os import listdir
import numpy as np
from torch.autograd import Variable
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from os.path import join
from torch import nn, optim
from torchvision.models.vgg import vgg16
from tqdm import tqdm
import os
from time import time
import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as transforms
from scipy.ndimage import rotate
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x26f03e78fd0>

In [2]:
UPSCALE_FACTOR = 4
CROP_SIZE = 88

In [3]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Dataset

This section contains code for the dataloader needed for the data in the folder. This is a custom implementation of the `TrainDatasetFromFolder` class, which inherits from the original class in the `pytorch` library. The dataloader returns a HR(High Resolution, which is the image at its original resolution) image and an LR image(which has been resized to introduce image degradation).

The dataset we used in this notebook is the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) dataset, which contains 1000 images of 2K resolution(800 training images, 100 validation images, 100 testing images)

In [4]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

In [5]:
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

In [6]:
def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])

In [7]:
def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])

In [8]:
def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

In [9]:
class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)

In [10]:
train_set = TrainDatasetFromFolder("DIV2K_train_HR", crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
trainloader = DataLoader(train_set, batch_size=64, num_workers=0, shuffle=True)

# Model Implementation

![SRGAN Model Architecture](srgan.png)

## Generator Architecture

The first major block consists of a sequence of residual blocks, after the input layer. 

### What is a Residual Block?

A Residual block consists of the following components:
- A series of convolutional layers that are responsible for extracting features from the input data. 
- Batch normalisation is applied after each convolutional layer. This helps in stabilising and accelerating training by normalising the input to the following layer. 
- Activation functions are applied to introduce non linearity in the network, which allows the network to learn complex and non linear patterns in the data. 
- Skip Connection is essentially the sum of the input of the block and the output of the last batch normalisation layer, allows gradients to flow more freely throught the network. 

> Note: Read up on the Vanishing Gradient Problem

In [11]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super(ResidualBlock, self).__init__()
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(channels)
    self.prelu = nn.PReLU()
    self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(channels)
  def forward(self, x):
    residual = self.conv1(x)
    residual = self.bn1(residual)
    residual = self.prelu(residual)
    residual = self.conv2(residual)
    residual = self.bn2(residual)
    return x + residual

The next block in the generator is the upscaling block. The upsaming block is comprises of an initial convolution layer, which feeds into a pixel shuffle layer, and then lastly fed into a ParametricReLU activation function. 

In [12]:
class UpsampleBlock(nn.Module):
  def __init__(self, in_channels, up_scale):
    super(UpsampleBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2,
                          kernel_size=3, padding=1)
    self.pixel_shuffle = nn.PixelShuffle(up_scale)
    self.prelu = nn.PReLU()
  def forward(self, x):
    x = self.conv(x)
    x = self.pixel_shuffle(x)
    x = self.prelu(x)
    return x

In [13]:
class Generator(nn.Module):
  def __init__(self, scale_factor):
    super(Generator, self).__init__()
    upsample_block_num = int(math.log(scale_factor, 2))

    self.block1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=9, padding=4),
        nn.PReLU()
    )

    self.block2 = ResidualBlock(64)
    self.block3 = ResidualBlock(64)
    self.block4 = ResidualBlock(64)
    self.block5 = ResidualBlock(64)
    self.block6 = ResidualBlock(64)
    self.block7 = nn.Sequential(
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64)
    )
    block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
    block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
    self.block8 = nn.Sequential(*block8)
  def forward(self, x):
    block1 = self.block1(x)
    block2 = self.block2(block1)
    block3 = self.block3(block2)
    block4 = self.block4(block3)
    block5 = self.block5(block4)
    block6 = self.block6(block5)
    block7 = self.block7(block6)
    block8 = self.block8(block1 + block7)
    return (torch.tanh(block8) + 1) / 2

In [14]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, padding=1),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(512, 1024, kernel_size=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(1024, 1, kernel_size=1)
    )
  def forward(self, x):
    batch_size=x.size()[0]
    return torch.sigmoid(self.net(x).view(batch_size))

## Loss Functions

The paper proposed a perceptual loss function, which consists of an adversarial loss and a content loss. The content loss is motivated by the perceptual similarity insteam of pixel space similarity.  

In [15]:
class TVLoss(nn.Module):
  def __init__(self, tv_loss_weight=1):
    super(TVLoss, self).__init__()
    self.tv_loss_weight=tv_loss_weight
  def forward(self, x):
    batch_size=x.size()[0]
    h_x = x.size()[2]
    w_x = x.size()[3]

    count_h = self.tensor_size(x[:, :, 1:, :])
    count_w = self.tensor_size(x[:, :, :, 1:])

    h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :h_x - 1, :], 2).sum()
    w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :w_x - 1], 2).sum()
    return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

  @staticmethod # Must add this
  def tensor_size(t):
    return t.size()[1] * t.size()[2] * t.size()[3]

In [16]:
class GeneratorLoss(nn.Module):
  def __init__(self):
    super(GeneratorLoss, self).__init__()
    vgg = vgg16(pretrained=True)
    loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
    for param in loss_network.parameters():
      param.requires_grad = False
    self.loss_network = loss_network
    self.mse_loss = nn.MSELoss()
    self.tv_loss = TVLoss()
  def forward(self, out_labels, out_images, target_images):
    adversial_loss = torch.mean(1 - out_labels)
    perception_loss = self.mse_loss(out_images, target_images)
    image_loss = self.mse_loss(out_images, target_images)
    tv_loss = self.tv_loss(out_images)
    return image_loss + 0.001 * adversial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

In [17]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Standard device selectoin
device

device(type='cuda')

In [18]:
netG = Generator(UPSCALE_FACTOR)
netD = Discriminator()

In [19]:
generator_criterion = GeneratorLoss()



In [20]:
generator_criterion = generator_criterion.to(device)
netG = netG.to(device)
netD = netD.to(device)

In [21]:
optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)

In [22]:
results = {
    "d_loss":[],
    "g_loss":[],
    "d_score": [],
    "g_score": []
}

In [23]:
N_EPOCHS = 150 # 150 is good enough for our model. gives decent enough results

In [24]:
def saveCheckpoint(model, optimizer, filename):
    print('===> saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint,filename)

In [25]:
def loadCheckpoint(checkpoint_file, model, optimizer, lr): 
    print('===> load checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    for param_group in optimizer.param_groups:      # needed or model will have learning rate of older checkpoint 
        param_group['lr'] = lr

In [26]:
t0 = time()
for epoch in range(1, N_EPOCHS + 1):
  train_bar = tqdm(trainloader)
  running_results = {'batch_sizes':0, 'd_loss':0,
                     "g_loss":0, "d_score":0, "g_score":0}

  netG.train()
  netD.train()
  for data, target in train_bar:
    g_update_first = True
    batch_size = data.size(0)
    running_results['batch_sizes'] += batch_size

    real_img = Variable(target)
    real_img = real_img.to(device)
    z = Variable(data)
    z = z.to(device)

    ## Update Discriminator ##
    fake_img = netG(z)
    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph = True)
    optimizerD.step()

    ## Now update Generator
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()

    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    optimizerG.step()

    running_results['g_loss'] += g_loss.item() * batch_size
    running_results['d_loss'] += d_loss.item() * batch_size
    running_results['d_score'] += real_out.item() * batch_size
    running_results['g_score'] += real_out.item() * batch_size

    ## Updating the progress bar
    train_bar.set_description(desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f" % (
        epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
        running_results['g_loss'] / running_results['batch_sizes'],
        running_results['d_score'] / running_results['batch_sizes'],
        running_results['g_score'] / running_results['batch_sizes']
    ))
  if epoch%50==0:
        saveCheckpoint(netG, optimizerG, filename='checkpoints/generator_checkpoint.pth')
        print(f'Saved Generator Checkpoint')
        saveCheckpoint(netD, optimizerD, filename='checkpoints/discriminator_checkpoint.pth')
        print(f'Saved Discriminator Checkpoint')  
  netG.eval()
print(f'Total Training Time: {(t0 - time())/3600} Hours')

[1/150] Loss_D: 0.7994 Loss_G: 0.0412 D(x): 0.5923 D(G(z)): 0.5923: 100%|██████████████| 13/13 [01:48<00:00,  8.35s/it]
[2/150] Loss_D: 0.5155 Loss_G: 0.0173 D(x): 0.7294 D(G(z)): 0.7294: 100%|██████████████| 13/13 [01:25<00:00,  6.56s/it]
[3/150] Loss_D: 0.3522 Loss_G: 0.0149 D(x): 0.8116 D(G(z)): 0.8116: 100%|██████████████| 13/13 [01:31<00:00,  7.06s/it]
[4/150] Loss_D: 0.2502 Loss_G: 0.0128 D(x): 0.8766 D(G(z)): 0.8766: 100%|██████████████| 13/13 [01:25<00:00,  6.61s/it]
[5/150] Loss_D: 0.0844 Loss_G: 0.0115 D(x): 0.9532 D(G(z)): 0.9532: 100%|██████████████| 13/13 [01:43<00:00,  7.98s/it]
[6/150] Loss_D: 0.0546 Loss_G: 0.0104 D(x): 0.9706 D(G(z)): 0.9706: 100%|██████████████| 13/13 [01:28<00:00,  6.77s/it]
[7/150] Loss_D: 0.6345 Loss_G: 0.0108 D(x): 0.5194 D(G(z)): 0.5194: 100%|██████████████| 13/13 [01:36<00:00,  7.44s/it]
[8/150] Loss_D: 0.7253 Loss_G: 0.0098 D(x): 0.4340 D(G(z)): 0.4340: 100%|██████████████| 13/13 [01:28<00:00,  6.82s/it]
[9/150] Loss_D: 0.6795 Loss_G: 0.0103 D(

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0013 Loss_G: 0.0053 D(x): 0.9993 D(G(z)): 0.9993:  15%|██▏           | 2/13 [00:12<01:12,  6.56s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0010 Loss_G: 0.0053 D(x): 0.9995 D(G(z)): 0.9995:  23%|███▏          | 3/13 [00:19<01:03,  6.34s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0016 Loss_G: 0.0051 D(x): 0.9989 D(G(z)): 0.9989:  31%|████▎         | 4/13 [00:25<00:58,  6.55s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0017 Loss_G: 0.0051 D(x): 0.9991 D(G(z)): 0.9991:  38%|█████▍        | 5/13 [00:32<00:51,  6.46s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0030 Loss_G: 0.0051 D(x): 0.9991 D(G(z)): 0.9991:  46%|██████▍       | 6/13 [00:41<00:52,  7.53s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0026 Loss_G: 0.0051 D(x): 0.9992 D(G(z)): 0.9992:  54%|███████▌      | 7/13 [00:47<00:42,  7.07s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0024 Loss_G: 0.0051 D(x): 0.9992 D(G(z)): 0.9992:  62%|████████▌     | 8/13 [00:55<00:35,  7.08s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0022 Loss_G: 0.0053 D(x): 0.9993 D(G(z)): 0.9993:  69%|█████████▋    | 9/13 [01:01<00:27,  6.99s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0021 Loss_G: 0.0054 D(x): 0.9993 D(G(z)): 0.9993:  77%|██████████   | 10/13 [01:10<00:22,  7.54s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0020 Loss_G: 0.0054 D(x): 0.9993 D(G(z)): 0.9993:  85%|███████████  | 11/13 [01:16<00:14,  7.07s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0019 Loss_G: 0.0054 D(x): 0.9993 D(G(z)): 0.9993:  92%|████████████ | 12/13 [01:25<00:07,  7.76s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[50/150] Loss_D: 0.0019 Loss_G: 0.0054 D(x): 0.9993 D(G(z)): 0.9993: 100%|█████████████| 13/13 [01:29<00:00,  6.86s/it]


===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[51/150] Loss_D: 0.0018 Loss_G: 0.0052 D(x): 0.9994 D(G(z)): 0.9994: 100%|█████████████| 13/13 [01:31<00:00,  7.07s/it]
[52/150] Loss_D: 0.0010 Loss_G: 0.0058 D(x): 0.9994 D(G(z)): 0.9994: 100%|█████████████| 13/13 [01:26<00:00,  6.66s/it]
[53/150] Loss_D: 0.0016 Loss_G: 0.0057 D(x): 0.9989 D(G(z)): 0.9989: 100%|█████████████| 13/13 [01:32<00:00,  7.13s/it]
[54/150] Loss_D: 0.0015 Loss_G: 0.0052 D(x): 0.9994 D(G(z)): 0.9994: 100%|█████████████| 13/13 [01:33<00:00,  7.18s/it]
[55/150] Loss_D: 0.0023 Loss_G: 0.0057 D(x): 0.9987 D(G(z)): 0.9987: 100%|█████████████| 13/13 [01:22<00:00,  6.37s/it]
[56/150] Loss_D: 0.0010 Loss_G: 0.0060 D(x): 0.9996 D(G(z)): 0.9996: 100%|█████████████| 13/13 [01:35<00:00,  7.34s/it]
[57/150] Loss_D: 0.0032 Loss_G: 0.0055 D(x): 0.9976 D(G(z)): 0.9976: 100%|█████████████| 13/13 [01:21<00:00,  6.29s/it]
[58/150] Loss_D: 0.6885 Loss_G: 0.0048 D(x): 0.8514 D(G(z)): 0.8514: 100%|█████████████| 13/13 [01:34<00:00,  7.26s/it]
[59/150] Loss_D: 0.8516 Loss_G: 0.0050 D

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0001 Loss_G: 0.0045 D(x): 1.0000 D(G(z)): 1.0000:  15%|██           | 2/13 [00:13<01:13,  6.66s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0007 Loss_G: 0.0047 D(x): 0.9995 D(G(z)): 0.9995:  23%|███          | 3/13 [00:18<01:02,  6.22s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0005 Loss_G: 0.0048 D(x): 0.9996 D(G(z)): 0.9996:  31%|████         | 4/13 [00:25<00:57,  6.42s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0005 Loss_G: 0.0047 D(x): 0.9997 D(G(z)): 0.9997:  38%|█████        | 5/13 [00:33<00:56,  7.07s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0048 D(x): 0.9997 D(G(z)): 0.9997:  46%|██████       | 6/13 [00:40<00:49,  7.01s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0050 D(x): 0.9998 D(G(z)): 0.9998:  54%|███████      | 7/13 [00:46<00:40,  6.68s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0003 Loss_G: 0.0050 D(x): 0.9998 D(G(z)): 0.9998:  62%|████████     | 8/13 [00:53<00:33,  6.77s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0003 Loss_G: 0.0050 D(x): 0.9998 D(G(z)): 0.9998:  69%|█████████    | 9/13 [00:59<00:26,  6.54s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0050 D(x): 0.9998 D(G(z)): 0.9998:  77%|█████████▏  | 10/13 [01:07<00:20,  6.79s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0049 D(x): 0.9998 D(G(z)): 0.9998:  85%|██████████▏ | 11/13 [01:13<00:13,  6.53s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0049 D(x): 0.9998 D(G(z)): 0.9998:  92%|███████████ | 12/13 [01:20<00:06,  6.70s/it]

===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[100/150] Loss_D: 0.0004 Loss_G: 0.0048 D(x): 0.9997 D(G(z)): 0.9997: 100%|████████████| 13/13 [01:24<00:00,  6.47s/it]


===> saving checkpoint
Saved Generator Checkpoint
===> saving checkpoint
Saved Discriminator Checkpoint


[101/150] Loss_D: 0.0003 Loss_G: 0.0054 D(x): 1.0000 D(G(z)): 1.0000: 100%|████████████| 13/13 [01:52<00:00,  8.65s/it]
[102/150] Loss_D: 0.0004 Loss_G: 0.0049 D(x): 0.9998 D(G(z)): 0.9998: 100%|████████████| 13/13 [01:31<00:00,  7.02s/it]
[103/150] Loss_D: 0.0001 Loss_G: 0.0048 D(x): 0.9999 D(G(z)): 0.9999: 100%|████████████| 13/13 [01:32<00:00,  7.08s/it]
[104/150] Loss_D: 0.0004 Loss_G: 0.0052 D(x): 0.9998 D(G(z)): 0.9998: 100%|████████████| 13/13 [01:34<00:00,  7.27s/it]
[105/150] Loss_D: 0.0001 Loss_G: 0.0048 D(x): 0.9999 D(G(z)): 0.9999: 100%|████████████| 13/13 [01:38<00:00,  7.57s/it]
[106/150] Loss_D: 0.0003 Loss_G: 0.0047 D(x): 0.9998 D(G(z)): 0.9998: 100%|████████████| 13/13 [04:37<00:00, 21.35s/it]
[107/150] Loss_D: 0.0003 Loss_G: 0.0048 D(x): 0.9998 D(G(z)): 0.9998: 100%|████████████| 13/13 [02:42<00:00, 12.53s/it]
[108/150] Loss_D: 0.0003 Loss_G: 0.0048 D(x): 0.9998 D(G(z)): 0.9998: 100%|████████████| 13/13 [02:39<00:00, 12.27s/it]
[109/150] Loss_D: 0.0004 Loss_G: 0.0048 

# An attempt at testing the model

In [27]:
def testModel(input_file, save_image=False):
    '''
    This function passes an image into the model, and retrieves its input. 
    Need to implement model saving and checkpointing. 
    '''
    # Read input image and convert to Tensor for feeding into model
    image = cv2.imread(input_file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Need to add this as OpenCV uses BGR instead of RGB
    transform = transforms.Compose([transforms.ToTensor()])
    in_tensor = transform(image)
    
    # Feed input tensor into model
    in_tensor = in_tensor.to(device) # Send input to device, can cause problems if input tensor is on CPU and Model on GPU!!
    out_tensor = netG(in_tensor.unsqueeze(0))
    
    # Final transformations, as for some reason, output is rotated
    display_tensor = out_tensor.to('cpu') # send output tensor back to CPU
    display_tensor = display_tensor.detach().numpy()
    display_tensor = np.transpose(display_tensor)
    display_tensor = np.squeeze(display_tensor)
    display_tensor = cv2.flip(display_tensor, 0)
    display_tensor = rotate(display_tensor, -90)
    
    if save_image: 
        f = (display_tensor * 255).astype(np.uint8) 
        out_file = input('What name do you want to save the image as: ')
        cv2.imwrite(f'{out_file}.png', t)
    
    return display_tensor

In [None]:
a = testModel('source/test_images/barbara_test.png')
i = cv2.imread('source/test_images/barbara_test.png')
i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
fig, ax = plt.subplots(1, 2, figsize=(12,8))
ax[0].imshow(i)
ax[0].set_title('Input Image')
ax[1].imshow(a)
ax[1].set_title('Output Image')