<a href="https://colab.research.google.com/github/Mahmoud-Khaled-Nasr/SRGAN/blob/master/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SRGAN, Single Image Super Resolution 

This notebook is an implementation for the paper "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" using Pytorch framework and trained on Google Colaboratory.

## SRGAN
The SRGAN as any GAN network is made of two networks: 

*   **discriminator**: It is reponsible for discrminating between real HR (high resolution) images and the SR (super resolution) images generated by the generator 
*   **generator**: It is responsible for upscaling a LR (low resolution) image and convince the discriminator that it is a real HR image

## Notebook structure 


1.   Loading the images using custom Pytorch Dataset loader to pair each LR image with the right HR image in the data loader
2.   Defining the hyperparameters
3.   Create the Data loaders
4.   Define the model architecture of both descriminator and generator
5.   Define the loss functions for the descriminator and generator
6.   Training code




In [0]:
#@title Hyperparameters
###############################
# Checking for GPU
###############################

import torch
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')


###############################
# Network Parameter
###############################
BATCH_SIZE = 16
EPOCH_NUM = 40
###############################
# Discriminator
###############################
DISCRIMINATOR_FINAL_FEATURE_MAP_SIZE = 10
###############################
# Generator 
###############################
# Number of the residual blocks in the generator
RESIDUAL_BLOCKS = 16
# Number of upsampling blocks in the generator.
# Each block upscale the previous block by a factor of 2
UPSAMPLING_BLOCKS = 2
###############################
# Optimizers
###############################
lr = 0.001

#Data Loading
##Custom Image Dataset Loader
`ImageDataset` is a class that implements Pytorch `Dataset` class to load HR and LR images and pair them correctly.

The constructor has the following parameters:
* `LR_root_dir`: The base directory for all the LR images
* `HR_root_dir`: The base directory for all the HR images
* `HR_transform`: Pytorch transform, The HR images transfromation 
* `LR_transform`: Pytorch transform, The LR images transfromation 

`__getitem__` returns a pair of (LR image, HR image)

In [0]:
#@title Custom Image Dataset loader
from torch.utils.data import Dataset

class ImageDataset(Dataset):

    def __init__(self, LR_root_dir, HR_root_dir, HR_transform, LR_transform):
        
      self.LR_root_dir = LR_root_dir
      self.HR_root_dir = HR_root_dir
      
      self.LR_transform = LR_transform
      self.HR_transform = HR_transform
      
      import os
      self.images_list = os.listdir(LR_root_dir) 
      self.length = len(self.images_list)

    def __len__(self):
      return self.length

    
    def pil_loader(self, path):
      # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
      from PIL import Image
      with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
    def __getitem__(self, idx):
      import os
      # Read the images
      img_name = os.path.join(self.LR_root_dir, self.images_list[idx])
      LR_image = self.pil_loader(img_name)
      img_name = os.path.join(self.HR_root_dir, self.images_list[idx])
      HR_image = self.pil_loader(img_name)
      
      # Apply the transformation to the images
      LR_image = self.LR_transform(LR_image)
      HR_image = self.HR_transform(HR_image)

      # Return pair of images (LR, HR)
      return LR_image, HR_image

##Data loaders
Create two data loaders `train_loader` and `validation_loader`

In [0]:
#@title Create The DataLoaders

from torchvision import transforms, datasets
import torch

train_HR = "train/HR"
train_LR = "train/LR"
validation_HR = "valid/HR"
validation_LR = "valid/LR"

HR_image_size = (256, 256)
LR_image_size = (64, 64)

train_LR_transform = transforms.Compose([
    transforms.Resize(LR_image_size),
    transforms.ToTensor()
])

# Normailze the HR images between [-1, 1]
train_HR_transform = transforms.Compose([
    transforms.Resize(HR_image_size),
    transforms.ToTensor()
    #transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

validation_LR_transform = transforms.Compose([
    transforms.Resize(LR_image_size),
    transforms.ToTensor()
])

# Normailze the HR images between [-1, 1]
validation_HR_transform = transforms.Compose([
    transforms.Resize(HR_image_size),
    transforms.ToTensor()
    #transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = ImageDataset(HR_root_dir=train_HR, LR_root_dir=train_LR, HR_transform=train_HR_transform, LR_transform=train_LR_transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

#dataset = ImageDataset(HR_root_dir=validation_HR, LR_root_dir=validation_LR, HR_transform=validation_HR_transform, LR_transform=validation_LR_transform)
#validation_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


##Display Sample of Training Images
Below function `imshow` is used to reshape some given images and converts them to NumPy images so that they can be displayed by `plt`. This should display two grids containing a batch of image data (LR, HR)

In [0]:
import matplotlib.pyplot as plt
import torchvision
import torch
import numpy as np

# helper imshow function
def imshow(img):
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
    

# get some images from X
dataiter = iter(train_loader)
# the "_" is a placeholder for no labels
LR, HR = dataiter.next()

# show images
fig = plt.figure(figsize=(12, 8))
# Display the LR images
imshow(torchvision.utils.make_grid(LR))

In [0]:
fig = plt.figure(figsize=(16, 12))
# Display the HR images
imshow(torchvision.utils.make_grid(HR))

#Model Architecture
![Model Architecture](https://drive.google.com/uc?export=view&id=1qjzJAeJ16fFcQcEfsjbEzr4aAWcLCY3p)

## Generator
The generator is a fully convolutional network so it can accept any image size and the output is an upscaled image with a factor that depends on the number of upscaling blocks. Each upscale block provides upscale factor of 2.

The generator components are:
* `ResidualBlock`
* `UpscaleBlock`
* tanh activation function

### Residual Blocks
Based on the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) residual blocks have proven very effective aganist the degradation  problems due to deep network architectures.

Each block consists of:


1.   Convolution layer with $kernel = 3$, $stride = 1$ and $padding = 1$
2.   Batch normalization
3.   Parametric Relu
4.   Convolution layer with $kernel = 3$, $stride = 1$ and $padding = 1$
5.   Batch normalization 

Let the result of the previous layers is $y$ for input $x$, so the output of the block is $R = x + y$

### Upscale Blocks
Instead of the normal de-convolution layer we are using pixel shuffle.

Each block consists of:


1.   Convolution layer with $out\ channels = in\ channels * 2^2$ where 2 is the upscale factor
2.   Pixel shuffle layer
3.   Parametric Relu as activation function





In [0]:
#@title Generator Model
from torch import nn

class ResidualBlock(nn.Module):
  
  def __init__(self, channels, kernel_size=3, stride=1, padding=1):
    super().__init__()
    self.conv1 = nn.Conv2d(channels, channels, kernel_size, stride=stride, padding=padding)
    self.bn1 = nn.BatchNorm2d(channels)
    self.prelu = nn.PReLU()
    self.conv2 = nn.Conv2d(channels, channels, kernel_size, stride=stride, padding=padding)
    self.bn2 = nn.BatchNorm2d(channels)
    
  def forward(self, x):
    result = self.bn2(self.conv2(self.prelu(self.bn1(self.conv1(x)))))
    assert result.shape == x.shape
    return result + x


class UpscaleBlock(nn.Module):
  
  def __init__(self, in_channels, upscale_factor):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, in_channels * (upscale_factor ** 2), kernel_size=3, stride=1, padding=1)
    self.ps = nn.PixelShuffle(upscale_factor)
    self.prelu = nn.PReLU()
    
  def forward(self, x):
    return self.prelu(self.ps(self.conv(x)))
    
  
class Generator(nn.Module):
  
  def __init__(self, residual_block_number, upscaling_block_number):
    super().__init__()
    
    self.residual_block_channels = 64
    self.input_channels = 3
    self.scale_factor = 2
    
    self.input_block = nn.Sequential(
        nn.Conv2d(self.input_channels, self.residual_block_channels, kernel_size=9, stride=1, padding=4),
        nn.PReLU()
    ) 
    
    self.residual_blocks = nn.ModuleList()
    for _ in range(residual_block_number):
      self.residual_blocks.append(ResidualBlock(self.residual_block_channels))
    
    self.residual_output = nn.Sequential(
        nn.Conv2d(self.residual_block_channels, self.residual_block_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(self.residual_block_channels)
    )
    
    self.upscaling_blocks = nn.ModuleList()
    for _ in range(upscaling_block_number):
      self.upscaling_blocks.append(UpscaleBlock(self.residual_block_channels, self.scale_factor))
    
    self.output_layer = nn.Conv2d(self.residual_block_channels, 3, kernel_size=9, stride=1, padding=4)
    
    self.tanh = nn.Tanh()
    
  def forward(self, x):
    input_block_output = self.input_block(x)
    result = input_block_output
    for block in self.residual_blocks:
      result = block(result)
    
    result = self.residual_output(result) + input_block_output
    
    for block in self.upscaling_blocks:
      result = block(result)
      
    return self.tanh(self.output_layer(result))
    

##Discriminator
The discriminator consists of

1.   Series of convolution layers with leaky relu as the activation function
2.   Fully connected network followed by a sigmoid activation function to calculate the probability of the image being a real of fake




In [0]:
#@title Discriminator Model
from torch import nn


class ConvBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride):
      super().__init__()
      alpha = 0.2
      
      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride)
      self.bn = nn.BatchNorm2d(out_channels)
      self.lrelu = nn.LeakyReLU(alpha)
    
    def forward(self, x):
      result = self.conv(x)
      result = self.bn(result)
      return self.lrelu(result)

    
class Flatten(nn.Module):
  
  def __init__(self):
    super().__init__()
    
  def forward(self, x):
    return x.view(x.shape[0], -1)
  
class Discriminator(nn.Module):  
  
  def __init__(self, final_feature_map_size):
    super().__init__()
    alpha = 0.2
    assert final_feature_map_size > 0
    
    self.input_block = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, stride=1),
        nn.LeakyReLU(alpha)
    )
    
    self.blocks = nn.Sequential(
        ConvBlock(64, 64, 2),
        ConvBlock(64, 128, 1),
        ConvBlock(128, 128, 2),
        ConvBlock(128, 256, 1),
        ConvBlock(256, 256, 2),
        ConvBlock(256, 512, 1),
        ConvBlock(512, 512, 2),
    )
    
    img_size = final_feature_map_size
    dense_block_input_size = 512 * img_size * img_size
    
    self.output_block = nn.Sequential(
        nn.AdaptiveAvgPool2d(img_size),
        Flatten(),
        nn.Linear(dense_block_input_size, 1024),
        nn.LeakyReLU(alpha),
        nn.Linear(1024, 1)
    )
    
  def forward(self, x):
    assert x.shape[2] >= 64 and x.shape[3] >= 64
    return self.output_block(self.blocks(self.input_block(x)))


#Loss
The goal of this paper is super resolution with an eye on the perceptual properties of the final image. It must be photo-realistic as much as possible.

##Generator Loss
The paper is using the following loss
$$Loss = Perception\ Loss + 0.001 * Advesrial\ Loss$$
but through the training and experimentation the training converged much faster when adding the mean square error of the images' pixel values to the equation
$$Loss = Image\ Loss + Perception\ Loss + 0.001 * Advesrial\ Loss$$

###Advesrial Loss
This measures the ability of the generator to fool the discriminator. Calculated by BCE (Binary Cross Entrophy).

###Image Loss
This measures how much the two images are away from each other in terms of intensity. Calculated as MSE (mean square error).

###Perception Loss
This measures the difference in high frequency features and perceptual properties of the SR image to make it photo-realistic as much as possible. Using **VGG19** network to extract feature maps for each of HR and SR images we can compare the feature maps using MSE and enhance the generator based on it.


In [0]:
#@title Generator Loss { form-width: "150px" }
from torch import nn
from torchvision import models

class PerceptualLoss(nn.Module):
  
  def __init__(self):
    super().__init__()
    model = models.vgg19(pretrained=True)
    model.eval()
    
    fifth_conv_layer_index = 26
    features = model.features
    self.feature_map_extractor = nn.Sequential(*list(model.features)[:fifth_conv_layer_index+1])
    self.feature_map_extractor.eval()
    for param in self.feature_map_extractor.parameters():
      param.requires_grad = False
      
    self.mse = nn.MSELoss()
       
    
  def forward(self, real_image, generated_image):
    assert real_image.shape == generated_image.shape
    
    loss = self.mse(self.feature_map_extractor(generated_image), self.feature_map_extractor(real_image))
    
    return loss
  
  
class GeneratorLoss(nn.Module):
  
  def __init__(self):
    super().__init__()
    self.perceptual_loss = PerceptualLoss()
    self.discrimenator_loss = nn.BCEWithLogitsLoss()
    self.image_loss = nn.MSELoss()
    
  def forward(self, real_imges, generated_images, output_labels, target_labels):
    self.perc_loss = self.perceptual_loss(real_imges, generated_images)
    self.adv_loss = self.discrimenator_loss(output_labels, target_labels)
    self.img_loss = self.image_loss(generated_images, real_imges)
    
    return self.img_loss + self.perc_loss + 0.001 * self.adv_loss
    
    

##Discriminator Loss
From the point of view of the discriminator loss It is considered a simple classification problem. The loss used is BCE.

In [0]:
#@title Discriminator Loss { form-width: "150px" }
from torch import nn

class DiscriminatorLoss(nn.Module):
  
  def __init__(self):
    super().__init__()
    self.loss_critrion = nn.BCEWithLogitsLoss()
    
  def forward(self, output_labels, target_labels):
    return self.loss_critrion(output_labels, target_labels)
    

In [0]:
#@title Save training and models state
def save_state():
  import datetime
  import os

  state = {
      'epoch': epoch,
      'discriminator_state_dict': D.state_dict(),
      'generator_state_dict': G.state_dict(),
      'training_results': training_results,
      'DISCRIMINATOR_FINAL_FEATURE_MAP_SIZE': DISCRIMINATOR_FINAL_FEATURE_MAP_SIZE,
      'RESIDUAL_BLOCKS': RESIDUAL_BLOCKS,
      'UPSAMPLING_BLOCKS': UPSAMPLING_BLOCKS
  }

  file_name = 'model ' + str(datetime.datetime.now()) + '.pth'
  file_path = os.path.join('/content/drive/My Drive/Models/', file_name)
  torch.save(state, file_path)
  return file_path
  

In [0]:
#@title Load training and models state
def load_state(file_name):
  import os
  import torch
  
  saved_file_src = '/content/drive/My Drive/Models/'
  file_path = os.path.join(saved_file_src, file_name)
  if os.path.isfile(file_path):
    return torch.load(file_path)
  else:
    return None

In [0]:
#@title Models' Creation 
import torch.optim as optim

D = Discriminator(DISCRIMINATOR_FINAL_FEATURE_MAP_SIZE)
G = Generator(RESIDUAL_BLOCKS, UPSAMPLING_BLOCKS)


D_loss = DiscriminatorLoss()
G_loss = GeneratorLoss()


# Create optimizers for the discriminator and generator
d_optimizer = optim.SGD(D.parameters(), lr)
g_optimizer = optim.Adam(G.parameters(), lr)

###############################
# Load training state if exists
###############################
file_name = 'mmodel 2019-03-16 16:42:36.967231.pth'
state = load_state(file_name)

old_state_exists = state is not None

if old_state_exists:
  print('loading old state from', file_name)
  G.load_state_dict(state['generator_state_dict'])
  D.load_state_dict(state['discriminator_state_dict'])
else:
  print("starting from the beginning")

  
if train_on_gpu:
  D, G = D.cuda(), G.cuda()
  D_loss, G_loss = D_loss.cuda(), G_loss.cuda()

In [0]:
!nvidia-smi

#Training
The training process was powered by Google Colab with 40 epoches for training.

##Optimizations 


*   Discriminator's training barch has either real or fake images not both
*   The training process isn't symmetric, the generator or the discriminator may be trained more than the other depending on the accuracy of the discriminator
*   Smoothing of the labels (real, fake) is applied
*   Tanh activation function is applied to the generator




In [0]:
#@title New Train Loop { form-width: "150px" }
import random

# Training
INTERLEAV_TRAINING_LIMIT = -1
# For logging the losses
EPOCH_LOG_INTERVAL = 1
BATCH_LOG_INTERVAL = 5
SAVE_MODEL_INTERVAL = 2


sigmoid = nn.Sigmoid()


G_LOSS = "G_LOSS"
G_ADV_LOSS = "G_ADV_LOSS"
G_PERC_LOSS = "G_PERC_LOSS"
G_IMG_LOSS = "G_IMG_LOSS"
G_TRAINING_ITERATIONS = "G_TRAINING_ITERATIONS"
D_REAL_LOSS = "D_REAL_LOSS"
D_FAKE_LOSS = "D_FAKE_LOSS"
D_REAL_TRAINING_ITERATIONS = "D_REAL_TRAINING_ITERATIONS"
D_FAKE_TRAINING_ITERATIONS = "D_FAKE_TRAINING_ITERATIONS"
D_CORRECT_PREDICTIONS = "D_CORRECT_PREDICTIONS"
CURRENT_TRAINED_IMAGES = "CURRENT_TRAINED_IMAGES"
D_ACC = "D_ACC"


if old_state_exists:
  training_results = state['training_results']
  START_EPOCH = state['epoch']
else:
  training_results = {
      G_LOSS: [], G_ADV_LOSS: [], G_PERC_LOSS: [], G_IMG_LOSS: [], G_TRAINING_ITERATIONS: [],
      D_REAL_LOSS: [], D_FAKE_LOSS: [], D_REAL_TRAINING_ITERATIONS: [], D_FAKE_TRAINING_ITERATIONS: [],
      D_ACC: []
  }
  START_EPOCH = 1

train_on_fake = True

for epoch in range(START_EPOCH, EPOCH_NUM):
   
  running_results = {
      G_LOSS: 0, G_ADV_LOSS: 0, G_PERC_LOSS: 0, G_IMG_LOSS: 0, G_TRAINING_ITERATIONS: 0,
      D_REAL_LOSS: 0, D_FAKE_LOSS: 0, D_REAL_TRAINING_ITERATIONS: 0, D_FAKE_TRAINING_ITERATIONS: 0,
      D_CORRECT_PREDICTIONS: 0, 
      CURRENT_TRAINED_IMAGES: 0
  }
                    
  D.train()
  G.train()
  
  for batch_id, (LR_images, HR_images) in enumerate(train_loader):
    
    if train_on_gpu:
      HR_images, LR_images = HR_images.cuda(), LR_images.cuda()
    
    
    ###############################
    # Choose which netwrok to train
    ###############################
    
    assert running_results[D_CORRECT_PREDICTIONS] <= running_results[CURRENT_TRAINED_IMAGES]
    
    try:
      acc = running_results[D_CORRECT_PREDICTIONS] / running_results[CURRENT_TRAINED_IMAGES]
    except:
      acc = 0.5
    
    
    g_train = acc > 0.3
    d_train = acc < 0.85
    
      
    
    ###############################
    # Train the Generator
    ###############################
    
    if g_train:
    
      g_optimizer.zero_grad()

      generated_image = G(LR_images)
      D_fake_output = D(generated_image)

      # The target is to make the discriminator belive that all the images are real
      g_loss = G_loss(HR_images, generated_image, D_fake_output, torch.ones_like(D_fake_output) * 0.9)

      g_loss.backward()
      g_optimizer.step()
      
      running_results[G_LOSS] += g_loss.item() * BATCH_SIZE
      running_results[G_ADV_LOSS] += G_loss.adv_loss.item() * BATCH_SIZE
      running_results[G_PERC_LOSS] += G_loss.perc_loss.item() * BATCH_SIZE
      running_results[G_IMG_LOSS] += G_loss.img_loss.item() * BATCH_SIZE
      running_results[G_TRAINING_ITERATIONS] += 1
      running_results[CURRENT_TRAINED_IMAGES] += BATCH_SIZE
      running_results[D_CORRECT_PREDICTIONS] += (sigmoid(D_fake_output).cpu().detach().numpy()<=0.5).sum()
      
    ###############################
    # Train the discriminator
    ###############################
    
    if d_train:
      
      d_optimizer.zero_grad()
      # If random number > 0.5 train on fake data else train on real
      
      if train_on_fake:
        generated_image = G(LR_images)
        D_fake_output = D(generated_image.detach())
        # The goal is to make the discriminator get the fake images right with smooth factor
        target = torch.zeros_like(D_fake_output) + 0.1
        d_fake_loss = D_loss(D_fake_output, target)
        d_fake_loss.backward()
        
        running_results[D_FAKE_LOSS] += d_fake_loss.item() * BATCH_SIZE
        running_results[D_FAKE_TRAINING_ITERATIONS] += 1
        running_results[D_CORRECT_PREDICTIONS] += (sigmoid(D_fake_output).cpu().detach().numpy()<=0.5).sum()
      else:
        D_real_output = D(HR_images)  
        # The goal is to make the discriminator get the real images right with smooth factor
        target = torch.ones_like(D_real_output) * 0.9
        d_real_loss = D_loss(D_real_output, target)
        d_real_loss.backward()
        
        running_results[D_REAL_LOSS] += d_real_loss.item() * BATCH_SIZE
        running_results[D_REAL_TRAINING_ITERATIONS] += 1
        running_results[D_CORRECT_PREDICTIONS] += (sigmoid(D_real_output).cpu().detach().numpy()>0.5).sum()

      train_on_fake = not train_on_fake
      d_optimizer.step()
      running_results[CURRENT_TRAINED_IMAGES] += BATCH_SIZE
    
    ###############################
    # Logging
    ###############################
      
    total_d_iterations = running_results[D_REAL_TRAINING_ITERATIONS] + running_results[D_FAKE_TRAINING_ITERATIONS]
    total_d_loss = running_results[D_REAL_LOSS] + running_results[D_FAKE_LOSS]
    
    g_images = running_results[G_TRAINING_ITERATIONS] * BATCH_SIZE + 1 
    d_real_images = running_results[D_REAL_TRAINING_ITERATIONS] * BATCH_SIZE + 1
    d_fake_images = (running_results[D_FAKE_TRAINING_ITERATIONS] * BATCH_SIZE + 1)
    
    if batch_id % BATCH_LOG_INTERVAL == 0:
      print('[%d/%d/%d] Acc_D: %.4f Corr_D :%d Used_IMG_D: %d Loss_D: %.4f R_Loss_D: %.4f F_Loss_D: %.4f Loss_G: %.4f Adv_G: %.4f Perc_G: %.4f Img_G: %.4f D_Train: %d G_Train: %d' % (
          batch_id,
          epoch,
          EPOCH_NUM,
          
          acc,
          running_results[D_CORRECT_PREDICTIONS],
          running_results[CURRENT_TRAINED_IMAGES],
          
          total_d_loss / (total_d_iterations * BATCH_SIZE),
          running_results[D_REAL_LOSS] / d_real_images,
          running_results[D_FAKE_LOSS] / d_fake_images,
          
          running_results[G_LOSS] / g_images,
          running_results[G_ADV_LOSS] / g_images,
          running_results[G_PERC_LOSS] / g_images,
          running_results[G_IMG_LOSS] / g_images,
          
          total_d_iterations,
          running_results[G_TRAINING_ITERATIONS]
      ))

  

  if epoch % EPOCH_LOG_INTERVAL == 0:
    
    g_images = running_results[G_TRAINING_ITERATIONS] * BATCH_SIZE + 1 
    d_real_images = running_results[D_REAL_TRAINING_ITERATIONS] * BATCH_SIZE + 1
    d_fake_images = (running_results[D_FAKE_TRAINING_ITERATIONS] * BATCH_SIZE + 1)
    
    training_results[G_LOSS].append(running_results[G_LOSS] / g_images)
    training_results[G_ADV_LOSS].append(running_results[G_ADV_LOSS] / g_images)
    training_results[G_PERC_LOSS].append(running_results[G_PERC_LOSS] / g_images)
    training_results[G_IMG_LOSS].append(running_results[G_IMG_LOSS] / g_images)
    training_results[G_TRAINING_ITERATIONS].append(running_results[G_TRAINING_ITERATIONS])
    training_results[D_REAL_LOSS].append(running_results[D_REAL_LOSS] / d_real_images)
    training_results[D_FAKE_LOSS].append(running_results[D_FAKE_LOSS] / d_fake_images)
    training_results[D_REAL_TRAINING_ITERATIONS].append(running_results[D_REAL_TRAINING_ITERATIONS])
    training_results[D_FAKE_TRAINING_ITERATIONS].append(running_results[D_FAKE_TRAINING_ITERATIONS])
    training_results[D_ACC].append(running_results[D_CORRECT_PREDICTIONS] / running_results[CURRENT_TRAINED_IMAGES] + 1)
           
  if epoch % SAVE_MODEL_INTERVAL == 0:
    print("saving model state", save_state())
    


In [0]:
#@title Validation and Results Reporting
import os
import math

if not os.path.exists("/content/pytorch_ssim"):
  !git clone https://github.com/Po-Hsun-Su/pytorch-ssim.git
  !mv /content/pytorch-ssim /content/pytorch_ssim

from pytorch_ssim import pytorch_ssim
  

EPOCH_MSE = 'EPOCH_MSE'
EPOCH_SSIM = 'EPOCH_SSIM'

valing_results = {EPOCH_MSE: 0, EPOCH_SSIM: 0}
dataset_size = len(validation_loader.dataset)

for batch_id, (LR_images, HR_images) in enumerate(validation_loader):
  
  if train_on_gpu:
    LR_images = LR_images.cuda()
    HR_images = HR_images.cuda()
  
  SR = G(LR_images)

  valing_results[EPOCH_MSE] += ((SR - HR_images) ** 2).data.mean() * BATCH_SIZE
  valing_results[EPOCH_SSIM] += pytorch_ssim.ssim(SR, HR_images).data.mean() * BATCH_SIZE
  
total_mse_loss = valing_results[EPOCH_MSE] / dataset_size
total_ssim_loss = valing_results[EPOCH_SSIM] / dataset_size
psnr = 10 * math.log10(1 / total_mse_loss)

print("MSE: %.4f SSIM: %.4f PSNR: %.4f" %(
    total_mse_loss,
    total_ssim_loss,
    psnr
))
