# Environment

In [1]:
# Third-party
import multiprocessing

from pathlib import Path
from PIL import Image

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_style('darkgrid')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision.transforms as T
from torchvision.models import vgg19

# First-party
from data import *
from utils import *
from models.mesrgan import *

# Settings

In [2]:
MODEL_NAME       = "MESRGAN_T2_ALL_DATA"
TRAIN_STAGE1     = False # TODO: Set to False
TRAIN_STAGE2     = True # TODO: Set to False
TRAIN_DATA       = ["data/DIV2K_train_HR", "data/Flickr2K", "data/OutdoorSceneTrain_v2"]
VAL_DATA         = "./data/DIV2k_BEST_PICTURES"
DEVICE           = torch.device('cuda:1')
N_WORKERS        = 8
STAGE1_ITERATIONS= 2e5 
STAGE2_ITERATIONS= 2e5   
MONITOR_INTERVAL = 1000
EXPANSION_FACTOR = 2

# Data

In [3]:
data_mean  = [0.4439, 0.4517, 0.4054]
data_std   = [0.2738, 0.2607, 0.2856]

train_env = {}
train_env["HR_size"] = 128
train_env["LR_size"] = 128 // 4
train_env["transform"] = T.Compose([
            T.RandomCrop((train_env["HR_size"], train_env["HR_size"])),
            T.RandomHorizontalFlip(p=0.5),
            RandomRotationsTransform([-90, 0, 90]),
            T.ToTensor(),
            T.Normalize(mean=data_mean, std=data_std)
])
train_env["dataset"]    = ImageDataset(TRAIN_DATA, train_env["transform"])
train_env["dataloader"] = DataLoader(train_env["dataset"], batch_size=16, shuffle=True, num_workers=N_WORKERS) 

val_env = {}
val_env["HR_size"] = 1024
val_env["LR_size"] = 1024 // 4
val_env["transform"] = T.Compose([
            T.CenterCrop((val_env["HR_size"], val_env["HR_size"])),
            T.ToTensor(),
            T.Normalize(mean=data_mean, std=data_std)
])
val_env["dataset"]    = ImageDataset(VAL_DATA, val_env["transform"])
val_env["dataloader"] = DataLoader(val_env["dataset"], batch_size=3, shuffle=False, num_workers=3) 

# Training

In [4]:
generator = Generator(t=EXPANSION_FACTOR)
discriminator = Discriminator().to(DEVICE)

# Content loss (i.e. L1) for visual similarity
content_loss = torch.nn.L1Loss().to(DEVICE)

# Perceptual loss (i.e. VGG) to perserve/reconstruct objects/features of original HR image
vgg19_model = vgg19(pretrained=True).to(DEVICE)
vgg_19_54_model = vgg19_model.features[:35] # "54 indicates features obtained by the 4th convolution before the 5th maxpooling layer"
vgg_19_54_model.eval()

_perceptual_loss = torch.nn.MSELoss(reduction='mean').to(DEVICE)
def perceptual_loss(hr_imgs, sr_imgs):
    hr_features = vgg_19_54_model(hr_imgs)
    sr_features = vgg_19_54_model(sr_imgs)
    return _perceptual_loss(hr_features, sr_features)

# Adversarial loss for plausibility of image.
def rel_disc(x1, x2):
    return x1 - x2.mean()
bce_loss = torch.nn.BCEWithLogitsLoss(reduction='mean').to(DEVICE)

def adversarial_loss(x_r, x_f):
    ones = torch.ones((x_r.shape)).to(DEVICE)
    zeros = torch.zeros((x_r.shape)).to(DEVICE)
    loss = bce_loss(x_r - x_f.mean(), zeros) + bce_loss(x_f - x_r.mean(), ones)
    return loss

# Discriminator loss used to train the discriminator.
def discriminator_loss(x_r, x_f):
    ones = torch.ones((x_r.shape)).to(DEVICE)
    zeros = torch.zeros((x_r.shape)).to(DEVICE)
    loss = bce_loss(x_r - x_f.mean(), ones) + bce_loss(x_f - x_r.mean(), zeros)
    return loss

In [5]:
def stage1_training_loop(n_iterations):
    torch.cuda.empty_cache()
    
    writer = SummaryWriter('runs/' + MODEL_NAME + '_stage1')
    
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4)
    #g_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=2e5, gamma=0.5)
    g_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=1e5, gamma=0.5) # Deviates from the paper but 2e5 is too much for us
    
    n_mini_batches_processed = 0
    while n_mini_batches_processed < n_iterations:
        
        for hr_imgs in train_env["dataloader"]:
            g_optimizer.zero_grad()

            hr_imgs = hr_imgs.to(DEVICE)
            lr_imgs = F.interpolate(hr_imgs, size=(train_env["LR_size"], train_env["LR_size"]), mode='bicubic', align_corners=False)

            sr_imgs = generator(lr_imgs)

            loss = content_loss(hr_imgs, sr_imgs)
            loss.backward()
            g_optimizer.step()
            g_scheduler.step()
            
            n_mini_batches_processed += 1        
            if n_mini_batches_processed % MONITOR_INTERVAL == 0:
                
                print(f"{n_mini_batches_processed} mini-batches done. Content loss: {loss}")
                writer.add_scalar('Content Loss', loss, n_mini_batches_processed)
                
                with torch.no_grad():
                    hr_imgs = next(iter(val_env['dataloader']))
                    hr_imgs = hr_imgs.to(DEVICE)
                    lr_imgs = F.interpolate(hr_imgs, size=(val_env["LR_size"], val_env["LR_size"]), mode='bicubic', align_corners=False)
                    sr_imgs = generator(lr_imgs)
                    
                sr_imgs = denormalize(sr_imgs.cpu().detach())
                
                writer.add_image('SR Butterfly', sr_imgs[0], n_mini_batches_processed)
                writer.add_image('SR Food', sr_imgs[1], n_mini_batches_processed)
                writer.add_image('SR House', sr_imgs[2], n_mini_batches_processed)
                
        if n_mini_batches_processed % 20000 == 0:
            PATH = "trained_models/" + MODEL_NAME + "_stage1_generator.trch"
            torch.save(generator.state_dict(), PATH)
    writer.close()

if TRAIN_STAGE1:
    #stage1_training_loop(n_iterations=2e5) # The paper doesn't specify an amount
    stage1_training_loop(n_iterations=STAGE1_ITERATIONS)

In [6]:
PATH = "trained_models/" + MODEL_NAME + "_stage1_generator.trch"

if TRAIN_STAGE1:
    torch.save(generator.state_dict(), PATH)
else:
    generator.load_state_dict(torch.load(PATH))
    generator.to(DEVICE)

In [None]:
def stage2_training_loop(n_iterations):
    torch.cuda.empty_cache()
    writer = SummaryWriter('runs/' + MODEL_NAME + '_stage2')
    
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
    g_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=5e4, gamma=0.5)
    
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
    
    n_mini_batches_processed = 0
    while n_mini_batches_processed < n_iterations:
        for hr_imgs in train_env["dataloader"]:

            hr_imgs = hr_imgs.to(DEVICE)
            lr_imgs = F.interpolate(hr_imgs, size=(train_env["LR_size"], train_env["LR_size"]), mode='bicubic', align_corners=False)
            
            
            # Generator
            g_optimizer.zero_grad()
            sr_imgs = generator(lr_imgs)

            p_loss = perceptual_loss(hr_imgs, sr_imgs)
            hr_d = discriminator(hr_imgs)
            sr_d = discriminator(sr_imgs)
            a_loss = adversarial_loss(hr_d, sr_d)
            c_loss = content_loss(hr_imgs, sr_imgs)

            lmbd = 5.0e-3
            eta  = 1.0e-2
            g_loss = p_loss + lmbd * a_loss + eta * c_loss

            g_loss.backward()
            
            g_optimizer.step()
            g_scheduler.step()

            # Discriminator
            d_optimizer.zero_grad()
            
            hr_d = discriminator(hr_imgs)
            sr_d = discriminator(sr_imgs.detach())
            d_loss = discriminator_loss(hr_d, sr_d)

            d_loss.backward()
            d_optimizer.step()

            n_mini_batches_processed += 1        
            if n_mini_batches_processed % MONITOR_INTERVAL == 0:
                print(f'Mini-Batch {n_mini_batches_processed} done. '
                      f'Discriminator loss: {d_loss:.5f}. '
                      f'Generator loss: {g_loss:.2f} '
                      f'(perc: {p_loss:.2f}, adv: {a_loss:.2f}, content: {c_loss:.2f})'
                     )
                writer.add_scalar('Content Loss', c_loss, n_mini_batches_processed)
                writer.add_scalar('Perceptual Loss', p_loss, n_mini_batches_processed)
                writer.add_scalar('Adverserial Loss', a_loss, n_mini_batches_processed)
                writer.add_scalar('Discriminator Loss', d_loss, n_mini_batches_processed)
                
                # Add images
                hr_imgs = next(iter(val_env['dataloader']))
                hr_imgs = hr_imgs.to(DEVICE)
                lr_imgs = F.interpolate(hr_imgs, size=(val_env["LR_size"], val_env["LR_size"]), mode='bicubic', align_corners=False)

                with torch.no_grad():
                    sr_imgs = generator(lr_imgs)
                sr_imgs = denormalize(sr_imgs.cpu().detach())
                
                writer.add_image('SR Butterfly', sr_imgs[0], n_mini_batches_processed)
                writer.add_image('SR Food', sr_imgs[1], n_mini_batches_processed)
                writer.add_image('SR House', sr_imgs[2], n_mini_batches_processed)
                
    torch.save(generator.state_dict(), 'trained_models/' + MODEL_NAME + '_stage2.trch')            
    writer.close()
    
if TRAIN_STAGE2:
    #stage2_training_loop(2e5) # The paper doesn't specify an amount
    stage2_training_loop(STAGE2_ITERATIONS)

Mini-Batch 1000 done. Discriminator loss: 0.00037. Generator loss: 2.05 (perc: 1.93, adv: 24.67, content: 0.13)
Mini-Batch 2000 done. Discriminator loss: 0.00897. Generator loss: 2.89 (perc: 2.78, adv: 21.39, content: 0.19)
Mini-Batch 3000 done. Discriminator loss: 0.00018. Generator loss: 2.71 (perc: 2.56, adv: 29.43, content: 0.19)
Mini-Batch 4000 done. Discriminator loss: 0.02699. Generator loss: 3.42 (perc: 3.31, adv: 21.64, content: 0.21)
Mini-Batch 5000 done. Discriminator loss: 0.02797. Generator loss: 3.25 (perc: 3.16, adv: 17.01, content: 0.19)
Mini-Batch 6000 done. Discriminator loss: 0.53281. Generator loss: 2.64 (perc: 2.59, adv: 9.50, content: 0.21)
Mini-Batch 7000 done. Discriminator loss: 0.00062. Generator loss: 3.29 (perc: 3.19, adv: 19.08, content: 0.19)
Mini-Batch 8000 done. Discriminator loss: 0.00145. Generator loss: 3.87 (perc: 3.78, adv: 17.29, content: 0.20)
Mini-Batch 9000 done. Discriminator loss: 0.00000. Generator loss: 3.09 (perc: 2.88, adv: 41.28, content:

In [None]:
G_PATH = "trained_models/" + MODEL_NAME + "_stage2_generator.trch"
D_PATH = "trained_models/" + MODEL_NAME + "_stage2_discriminator.trch"

if TRAIN_STAGE2:
    torch.save(generator.state_dict(), G_PATH)
    torch.save(discriminator.state_dict(), D_PATH)
else:
    generator.load_state_dict(torch.load(G_PATH))
    discriminator.load_state_dict(torch.load(D_PATH))

# Evaluation

In [None]:
with torch.no_grad():
    hr_imgs = next(iter(train_env['dataloader']))[:4,]
    hr_imgs = hr_imgs.to(DEVICE)
    lr_imgs = F.interpolate(hr_imgs, size=(train_env["LR_size"], train_env["LR_size"]), mode='bicubic', align_corners=False)
    sr_imgs = generator(lr_imgs)
    
    plot_images(denormalize(lr_imgs.cpu().detach()))
    plot_images(denormalize(sr_imgs.cpu().detach()))
    plot_images(denormalize(hr_imgs.cpu().detach()))

In [None]:
with torch.no_grad():
    hr_imgs = next(iter(val_env['dataloader']))
    hr_imgs = hr_imgs.to(DEVICE)
    lr_imgs = F.interpolate(hr_imgs, size=(val_env["LR_size"], val_env["LR_size"]), mode='bicubic', align_corners=False)
    sr_imgs = generator(lr_imgs)
    
    plot_images(denormalize(lr_imgs.cpu().detach()))
    plot_images(denormalize(sr_imgs.cpu().detach()))
    plot_images(denormalize(hr_imgs.cpu().detach()))