In [None]:
import torch
import torch.nn as nn
import torchvision
import math

import glob
from pandas.core.common import flatten
from PIL import Image
import os
import json
import random

import torchvision.transforms.functional as FT
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

##Models

###Convolutional Block

In [None]:
class ConvolutionalBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, BatchNorm=False, activation=None):
    super(ConvolutionalBlock, self).__init__()

    if activation is not None:
      activation = activation.lower()
      assert activation in {'prelu', 'leakyrelu', 'tanh'}

    layers = []

    layers.append(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                  padding=kernel_size // 2)
    )

    if BatchNorm is True:
      layers.append(nn.BatchNorm2d(out_channels))

    if activation == 'prelu':
      layers.append(nn.PReLU())
    elif activation == 'leakyrelu':
      layers.append(nn.LeakyReLU(0.2))
    elif activation == 'tanh':
      layers.append(nn.Tanh())

    self.conv = nn.Sequential(*layers)

  def forward(self, x):
    output = self.conv(x)

    return output

### Sub-Pixel Convolutional Block

In [None]:
class SubPixelConvolutionalBlock(nn.Module):
  def __init__(self, n_channels=64, kernel_size=3, scaling_factor=2):
    super(SubPixelConvolutionalBlock, self).__init__()

    self.conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels*(scaling_factor**2),
                          kernel_size=kernel_size, padding=kernel_size // 2)
    self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
    self.prelu = nn.PReLU()

  def forward(self, x):
    output = self.conv(x)
    output = self.pixel_shuffle(output)
    output = self.prelu(output)

    return output

###Residual Block

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, n_channels=64, kernel_size=3):
    super(ResidualBlock, self).__init__()

    self.conv_block1 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=3,
                                          BatchNorm=True, activation='prelu')

    self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=3,
                                          BatchNorm=True)

  def forward(self, x):
    identity = x

    output = self.conv_block1(x)
    output = self.conv_block2(x)
    output = output + identity

    return output

###SRResNet

In [None]:
class SRResNet(nn.Module):
  def __init__(self, large_kernel_size=3, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
    super(SRResNet, self).__init__()

    scaling_factor = int(scaling_factor)
    assert scaling_factor in {2, 4, 8}, 'The scaling factor must be 2, 4, 8.'

    self.conv1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
                                    BatchNorm=False, activation='prelu')

    self.res_blocks = nn.Sequential(
        *[ResidualBlock(n_channels=n_channels, kernel_size=small_kernel_size) for i in range(n_blocks)]
    )

    self.conv2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=large_kernel_size,
                                    BatchNorm=True, activation=None)

    n_subpixel_blocks = int(math.log2(scaling_factor))
    self.subpixel_blocks = nn.Sequential(
        *[SubPixelConvolutionalBlock(n_channels=n_channels, kernel_size=small_kernel_size, scaling_factor=2)
        for i in range(n_subpixel_blocks)]
    )

    self.conv3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
                                    BatchNorm=False, activation='Tanh')

  def forward(self, lr_img):
    output = self.conv1(lr_img)

    residual = output

    output = self.res_blocks(output)
    output = self.conv2(output)

    output = output + residual

    output = self.subpixel_blocks(output)
    sr_img = self.conv3(output)

    return sr_img

###Generator

In [None]:
class Generator(nn.Module):
  def __init__(self, large_kernel_size=3, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
    super(Generator, self).__init__()

    self.generator = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                              n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)

  def init_with_srresnet(self, srresnet_checkpoint):
    srresnet = torch.load(srresnet_checkpoint)
    self.generator.load_state_dict(srresnet.state_dict())

    print('\nLoaded weights from pre-trained SRResNet.\n')

  def forward(self, lr_img):
    sr_img = self.generator(lr_img)

    return sr_img

###Discrimiator

In [None]:
class Discriminator(nn.Module):
  def __init__(self, in_channels=3, n_channels=64, kernel_size=3, n_blocks=7, fc_size=1024):
    super(Discriminator, self).__init__()

    self.conv1 = ConvolutionalBlock(in_channels=in_channels, out_channels=n_channels, kernel_size=kernel_size,
                                    BatchNorm=False, activation='leakyrelu')

    in_channels = n_channels
    conv_blocks = []
    for i in range(1, n_blocks+1):

      if i % 2 == 0:
        out_channels = in_channels * 2
        stride = 1
      else:
        out_channels = in_channels
        stride = 2

      conv_blocks.append(ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                           BatchNorm=True, activation='LeakyReLU'))
      in_channels = out_channels
    self.conv_blocks = nn.Sequential(*conv_blocks)

    self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((6, 6))

    self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)

    self.leakyrelu = nn.LeakyReLU(0.2)

    self.fc2 = nn.Linear(fc_size, 1)

  def forward(self, hr_img):
    N = hr_img.shape[0]

    output = self.conv1(hr_img)
    output = self.conv_blocks(output)
    output = self.adaptive_avg_pool(output)
    output = self.fc1(output.view(N, -1))
    output = self.leakyrelu(output)
    logit = self.fc2(output)

    return logit

###Truncated VGG19

In [None]:
class TruncatedVGG19(nn.Module):
  def __init__(self, i, j):
    # In this class we are looking for truncated vgg19 with (i-1) maxpool layers
    # and j convolutional layers after the (i-1)th maxpool layer

    super(TruncatedVGG19, self).__init__()

    vgg19 = torchvision.models.vgg19(pretrained=True)

    truncate_at = 0
    conv_count = 0
    pool_count = 0

    for layer in vgg19.features.children():
      truncate_at += 1

      # Count the number of maxpool layers and the convolutional layers after each maxpool
      if isinstance(layer, nn.Conv2d):
        conv_count += 1
      if isinstance(layer, nn.MaxPool2d):
        pool_count += 1
        conv_count = 0

      # Break after reaching jth convolutional layer after (i-1)th maxpool
      if pool_count == i - 1 and conv_count == j:
        break

    assert pool_count == i - 1 and conv_count == j, f'One or both of i = {i} and j = {j} are not valid choices for VGG19!'

    # Include ReLU after the Convolutional layer
    self.new_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])

  def forward(self, x):
    output = self.new_vgg19(x)

    return output

##Utils

In [None]:
train_folders = r'/content/drive/MyDrive/ImageNet100/train.X'

test_folders = [r'/content/drive/MyDrive/SR test dataset/BSD100',
                r'/content/drive/MyDrive/SR test dataset/Set14',
                r'/content/drive/MyDrive/SR test dataset/Set5']

output_folder = r'/content/drive/MyDrive/Projects/Super Resolution'

min_size = 100

In [None]:
def create_data_lists(train_folders, test_folders, min_size, output_folder):

  print("\nCreating data lists... this may take some time.\n")

  train_images = []
  for data_path in glob.glob(train_folders + '/*'):
    train_images.append(glob.glob(data_path + '/*'))

  train_images = list(flatten(train_images))

#  for img_path in train_images:
#    img = Image.open(img_path, mode='r')
#    if img.width >= min_size and img.height >= min_size:
#      train_images.append(img_path)

  random.seed(42)
  random.shuffle(train_images)

  print(f"There are {len(train_images)} images in the training data.\n")
  with open(os.path.join(output_folder, 'train_images.json'), 'w') as j:
    json.dump(train_images, j)


  for folder in test_folders:
    test_images = []
    test_name = folder.split('/')[-1]
    for path in glob.glob(folder + '/*'):
      img = Image.open(path, mode='r')
      if img.width >= min_size and img.height >= min_size:
        test_images.append(path)
    print(f'There are {len(test_images)} images in {test_name} dataset')
    with open(os.path.join(output_folder, test_name + '_test_images.json'), 'w') as j:
      json.dump(test_images, j)

  print(f"\nJSONS containing lists of Train and Test images have been saved to {output_folder}.\n")

In [None]:
#create_data_lists(train_folders, test_folders, min_size, output_folder)

In [None]:
rgb_weights = torch.FloatTensor([65.481, 128.553, 24.966]).to(device)
imagenet_mean = torch.FloatTensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2)
imagenet_std = torch.FloatTensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2)
imagenet_mean_cuda = torch.FloatTensor([0.485, 0.456, 0.406]).to(device).unsqueeze(0).unsqueeze(2).unsqueeze(3)
imagenet_std_cuda = torch.FloatTensor([0.229, 0.224, 0.225]).to(device).unsqueeze(0).unsqueeze(2).unsqueeze(3)

In [None]:
def convert_image(img, source, target):

  assert source in {'pil', '[0, 1]', '[-1, 1]'}, f'Cannot convert from source format: {source}'
  assert target in {'pil', '[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet_norm',
                    'y_channels'}, f'Cannot convert to target format: {target}'

  # Convert from source to [0, 1]
  if source == 'pil':
    img = FT.to_tensor(img)
  elif source == '[0, 1]':
    pass
  elif source == '[-1, 1]':
    img = (img + 1.0) /2

  # Convert from source to target
  if target == 'pil':
    img = FT.to_pil_image(img)
  elif target == '[0, 255]':
    img = 255.0 * img
  elif target == '[0, 1]':
    pass
  elif target == '[-1, 1]':
    img = 2.0 * img - 1.0

  elif target == 'imagenet_norm':
    if img.ndimension() == 3:
      img = (img - imagenet_mean) / imagenet_std
    elif img.ndimension() == 4:
      img = (img - imagenet_mean_cuda) / imagenet_std_cuda

  # y_channels is for converting the image from RGB to YCbCr format for finding Peak Signal-to-Noise Ratio (PSNR)
  # and Structural Similarity Index Measure (SSIM). This is not used for training.
  elif target == 'y_channels':
    img = torch.matmul(255.0 * img.permute(0, 2, 3, 1)[:, 4:-4, 4:-4, :], rgb_weights) / 255.0 + 16

  return img

In [None]:
class ImageTransform():
  def __init__(self, split, crop_size, scaling_factor, LR_img_type, HR_img_type):

    self.split = split.lower()
    self.crop_size = crop_size
    self.scaling_factor = scaling_factor
    self.LR_img_type = LR_img_type
    self.HR_img_type = HR_img_type

    assert split in {'train', 'test'}

  def __call__(self, img):

    if self.split == 'train':
      # Take a random fixed-size crop of the image, which will serve as the high-resolution (HR) image
      left = random.randint(1, img.width - self.crop_size)
      top = random.randint(1, img.height - self.crop_size)
      right = left + self.crop_size
      bottom = top + self.crop_size
    else:
      # Take the largest possible center-crop such that its dimensions are divisible by the scaling factor
      x = img.width % self.scaling_factor
      y = img.height % self.scaling_factor
      left = x // 2
      top = y // 2
      right = left + img.width - x
      bottom = top + img.height - y

    HR_img = img.crop((left, top, right, bottom))

    # Dowsample the High Resolution crop using Bicubic downsampling to obtain Low Resolution version of the image
    LR_img = HR_img.resize((int(HR_img.width / self.scaling_factor), int(HR_img.height / self.scaling_factor)),
                           Image.BICUBIC)

    assert HR_img.width == LR_img.width * self.scaling_factor and HR_img.height == LR_img.height * self.scaling_factor


    LR_img = convert_image(LR_img, source='pil', target=self.LR_img_type)
    HR_img = convert_image(HR_img, source='pil', target=self.HR_img_type)

    return LR_img, HR_img

In [None]:
class Metric:
  def __init__(self, len_loader):
    self.values = []
    self.epochs = []
    self.len_loader = len_loader

  def reset(self):
    self.val = 0

  def update(self, value, epoch):
    self.values.append(value)
    self.epochs.append(epoch)

  def add(self, val, n, epoch):
    self.val += val
    if n == self.len_loader:
      self.avg = self.val / n
      self.update(value=self.avg, epoch=epoch)

##Dataset

In [None]:
class SRDataset(Dataset):
  def __init__(self, data_folder, split, crop_size, scaling_factor, LR_img_type, HR_img_type, test_data_name=None):

    self.data_folder = data_folder
    self.split = split.lower()
    self.crop_size = crop_size
    self.scaling_factor = scaling_factor
    self.LR_img_type = LR_img_type
    self.HR_img_type = HR_img_type
    self.test_data_name = test_data_name

    assert self.split in {'train', 'test'}
    if self.split == 'test' and self.test_data_name is None:
      raise ValueError('Provide the name of the test dataset!')
    assert LR_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet_norm'}
    assert HR_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet_norm'}

    if self.split == 'train':
      assert self.crop_size % self.scaling_factor == 0, 'Crop dimensions are not perfectly divisible by the scaling factor!'

    if self.split == 'train':
      with open(os.path.join(data_folder, 'train_images.json'), 'r') as j:
        self.images = json.load(j)
    else:
      with open(os.path.join(data_folder, self.test_data_name + '_test_images.json'), 'r') as j:
        self.images = json.load(j)

    self.transform = ImageTransform(split=self.split,
                                    crop_size=self.crop_size,
                                    scaling_factor=self.scaling_factor,
                                    LR_img_type=self.LR_img_type,
                                    HR_img_type=self.HR_img_type)

  def __getitem__(self, idx):

    img = Image.open(self.images[idx], mode='r')
    img = img.convert('RGB')
    if img.width <= 96 or img.height <= 96:
      print(self.images[idx], img.width, img.height)
    LR_img, HR_img = self.transform(img)

    return LR_img, HR_img

  def __len__(self):
    return len(self.images)

##Training

###SRResNet

In [None]:
import time
import torch.backends.cudnn as cudnn
import torch
from torch import nn

####Data parameters

In [None]:
data_folder = output_folder
crop_size = 96
scaling_factor = 4

####Model Parameters

In [None]:
large_kernel_size = 9
small_kernel_size = 3
n_channels = 64
n_blocks = 16

####Learning parameters

In [None]:
checkpoint = None
batch_size = 16
start_epoch = 0
iteration = 1e6
num_workers = 4
print_every = 500
lr = 1e-4

In [None]:
cudnn.benchmark = True

In [None]:
train_dataset = SRDataset(data_folder,
                          split='train',
                          crop_size=crop_size,
                          scaling_factor=scaling_factor,
                          LR_img_type='imagenet_norm',
                          HR_img_type='[-1, 1]')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                          pin_memory=True)



In [None]:
len(train_loader)

8125

#### Training SRResNet

In [None]:
epochs = int(1e6 // len(train_loader))
epochs

123

In [None]:
if checkpoint is None:
  model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                   n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
  # Initialize the optimizer
  optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                               lr=lr)

else:
  checkpoint = torch.load(checkpoint)
  start_epoch = checkpoint['epoch'] + 1
  model = checkpoint['model']
  optimizer = checkpoint['optimizer']

criterion = nn.MSELoss().to(device)

In [None]:
def train_SRResNet(train_loader, model, criterion, optimizer, epoch, print_every=100):

  model.to(device)
  model.train()

  train_loss = 0

  for batch, (LR_imgs, HR_imgs) in enumerate(train_loader):

    LR_imgs = LR_imgs.to(device)
    HR_imgs = HR_imgs.to(device)

    # 1. Forward pass
    SR_imgs = model(LR_imgs)

    # 2. Calculate loss
    loss = criterion(SR_imgs, HR_imgs)

    # 3. Optimizer zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()

    train_loss += loss

    if (batch + 1) % print_every == 0:
      print(f'Epoch: {epoch}| Batch: {batch+1}| loss: {loss:.4f}')

  train_loss /= len(train_loader)

  del LR_imgs, HR_imgs, SR_imgs

  return train_loss

In [None]:
epoch = 0

PATH_checkpoint_srresnet = r'/content/drive/MyDrive/Projects/Super Resolution'
name = '/checkpoint_SRResNet.pth'
PATH_checkpoint_srresnet += name

torch.save({'epoch': epoch,
            'model': model,
            'checkpoint': checkpoint},
           PATH_checkpoint_srresnet)


train_losses = []
epoch_list = []

for epoch in range(start_epoch, epochs):
  train_loss = train_SRResNet(train_loader=train_loader,
                     model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     epoch=epoch,
                     print_every=print_every)

  train_losses.append(train_loss)
  epoch_list.append(epoch)

  torch.save({'epoch': epoch,
              'model': model,
              'optimizer': optimizer},
             PATH_checkpoint_srresnet)

"\ntrain_losses = []\nepoch_list = []\n\nfor epoch in range(start_epoch, epochs):\n  train_loss = train_SRResNet(train_loader=train_loader,\n                     model=model,\n                     criterion=criterion,\n                     optimizer=optimizer,\n                     epoch=epoch,\n                     print_every=print_every)\n\n  train_losses.append(train_loss)\n  epoch_list.append(epoch)\n\n  torch.save({'epoch': epoch,\n              'model': model,\n              'optimizer': optimizer},\n             'checkpoint_srresnet.pth.tar')"

In [None]:
PATH_srresnet = r'/content/drive/MyDrive/Projects/Super Resolution'
name = '/SRResNet.pth'
PATH_srresnet += name

torch.save(model, PATH_srresnet)

###SRGAN

####Generator Parameters

In [None]:
large_kernel_size_g = 9
small_kernel_size_g = 3
n_channels_g = 64
n_blocks_g = 16
srresnet_checkpoint = PATH_srresnet

####Discriminator Parameters

In [None]:
kernel_size_d = 3
n_channels_d = 64
n_blocks_d = 7
fc_size_d = 1024

####Learning Parameters

In [None]:
checkpoint_g = None
batch_size = 16
start_epoch_g = 0
iteration_g = 2e5
num_workers = 4
vgg19_i = 5
vgg19_j = 4
beta = 1e-3
print_every = 500
lr_g = 1e-4

In [None]:
cudnn.benchmark = True

In [None]:
if checkpoint_g is None:
  # Generator
  generator = Generator(large_kernel_size=large_kernel_size_g,
                        small_kernel_size=small_kernel_size_g,
                        n_channels=n_channels_g,
                        n_blocks=n_blocks_g)

  # Initialize with generator pre-trained SRResNet
  generator.init_with_srresnet(srresnet_checkpoint)

  # Initialize generator's optimizer
  optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr_g)

  # Discriminator
  discriminator = Discriminator(n_channels=n_channels_d,
                                kernel_size=kernel_size_d,
                                n_blocks=n_blocks_d,
                                fc_size=fc_size_d)

  # Initialize discriminator's optimizer
  optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr_g)

else:
  checkpoint_g = torch.load(checkpoint_g)
  start_epoch_g = checkpoint_g['epoch'] + 1
  generator = checkpoint_g['generator']
  discriminator = checkpoint_g['discriminator']
  optimizer_g = checkpoint_g['optimizer_g']
  optimizer_d = checkpoint_g['optimizer_d']
  print(f'\nLoaded checkpoint from epoch {start_epoch_g}')


Loaded weights from pre-trained SRResNet.



In [None]:
# Truncated VGG19 to calculate loss in vgg space
truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
truncated_vgg19.eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 76.5MB/s]


TruncatedVGG19(
  (new_vgg19): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=

In [None]:
# Loss functions
content_loss_criterion = nn.MSELoss()
adversarial_loss_criterion = nn.BCEWithLogitsLoss()

In [None]:
train_dataset_g = SRDataset(data_folder,
                            split='train',
                            crop_size=crop_size,
                            scaling_factor=scaling_factor,
                            LR_img_type='imagenet_norm',
                            HR_img_type='imagenet_norm')

train_dataloader_g = DataLoader(train_dataset_g, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                pin_memory=True)

#### Training SRGAN

In [None]:
epochs_g = int(iteration_g // len(train_dataloader_g)) + 1
epochs_g

25

In [None]:
def train_SRGAN(train_loader_g, generator, discriminator, truncated_vgg19, content_loss_criterion,
                adversarial_loss_criterion, optimizer_g, optimizer_d, beta, epoch, print_every=100):

  # Move models to device
  generator.to(device)
  discriminator.to(device)

  # Put models in train mode
  generator.train()
  discriminator.train()

  for batch, (LR_imgs, HR_imgs) in enumerate(train_loader_g):

    # Move to default device
    LR_imgs, HR_imgs = LR_imgs.to(device), HR_imgs.to(device)

    # GENERATOR Update
    SR_imgs = generator(LR_imgs)
    # SR_imgs are output by SRResNet trained to output in [-1, 1] convert them to
    # imagenet_norm for content loss in VGG space
    SR_imgs = convert_image(SR_imgs, source='[-1, 1]', target='imagenet_norm')

    SR_imgs_vgg_space = truncated_vgg19(SR_imgs)
    HR_imgs_vgg_space = truncated_vgg19(HR_imgs).detach() # by detaching, we turn off the gradients since they are constants

    # Pass SR_imgs to discriminator
    SR_imgs_discriminated = discriminator(SR_imgs)

    content_loss = content_loss_criterion(SR_imgs_vgg_space, HR_imgs_vgg_space)
    adversarial_loss_g = adversarial_loss_criterion(SR_imgs_discriminated, torch.ones_like(SR_imgs_discriminated))
    perceptual_loss = content_loss + beta * adversarial_loss_g

    # Optimizer zero grad
    optimizer_g.zero_grad()

    # Loss backward
    perceptual_loss.backward()

    # Optimizer step
    optimizer_g.step()


    # DISCRIMINATOR Update
    HR_imgs_discriminated = discriminator(HR_imgs)
    SR_imgs_discriminated = discriminator(SR_imgs.detach()) # By detaching SR_imgs before passing to the discriminator
    # we ensure that the backpropagation is stopped at the discriminator and the gradient does not flow to the Generator

    adversarial_loss_d = adversarial_loss_criterion(SR_imgs_discriminated, torch.zeros_like(SR_imgs_discriminated)) + \
    adversarial_loss_criterion(HR_imgs_discriminated, torch.ones_like(HR_imgs_discriminated))
    # In case of 1st loss, when passed SR_imgs should be driven down to zeros
    # and for 2nd loss, HR_imgs should be driven up to ones

    # Optimizer zero grad
    optimizer_d.zero_grad()

    # Loss backward
    adversarial_loss_d.backward()

    # Optimizer step
    optimizer_d.step()

    if (batch + 1) % print_every == 0:
      print(f'Epoch: {epoch} | Batch: {batch+1} | Content Loss = {content_loss} | Adversarial Loss Generator = {adversarial_loss_g}\
      Adversaria Loss Discriminator = {adversarial_loss_d}')

  del LR_imgs, HR_imgs, SR_imgs, HR_imgs_vgg_space, SR_imgs_vgg_space, HR_discriminated, SR_discriminated

In [None]:
for epoch in range(start_epoch_g, epochs_g):

  train_SRGAN(train_loader_g=train_dataloader_g,
              generator=generator,
              discriminator=discriminator,
              truncated_vgg19=truncated_vgg19,
              content_loss_criterion=content_loss_criterion,
              adversarial_loss_criterion=adversarial_loss_criterion,
              optimizer_g=optimizer_g,
              optimizer_d=optimizer_d,
              beta=beta,
              epoch=epoch)

  '''
  torch.save({'epoch': epoch,
              'generator': generator,
              'discriminator': discriminator,
              'optimizer_g': optimizer_g,
              'optimizer_d': optimizer_d},
             PATH_checkpoint_srgan)'''

In [None]:
PATH_srgan = r'/content/drive/MyDrive/Projects/Super Resolution'
name = '/SRGAN.pth'
PATH_srgan += name

torch.save(generator, PATH_srgan)