Training the SRGAN

In [8]:
# builtin 
import glob
import random
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")

# all imports
import torch 
import numpy as np 
import torch.nn as nn
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
from sklearn.model_selection import train_test_split



# our modules
from src.config import cfg, root_path
from src.utils import MeanSTDFinder
from src.data_loaders import SuperResolutionDataLoader
from src.models.srgan import Generator, Discriminator, VggFeatureExtractor


# create path for models checkpoint
Path(root_path).joinpath("saved_models/srgan").mkdir(exist_ok=True, parents=True)
Path(root_path).joinpath("saved_models/srgan/images").mkdir(exist_ok=True, parents=True)



In [9]:
# get the images dataset path 
images_pth = cfg.dataset.images_dir

train_paths, test_paths = train_test_split(
    sorted(glob.glob(images_pth + "/*.*")),
    test_size=0.2,
    random_state=42,
)

# get the mean and std of the dataset 
# mean_std = MeanSTDFinder(images_dir=images_pth)()
mean_std = {'mean': [0.2903465 , 0.31224626, 0.29810828],
 'std': [0.1457739 , 0.13011318, 0.12317199]}

In [10]:
# load the dataloaders
train_dataloader = DataLoader(
    SuperResolutionDataLoader(train_paths,**mean_std),
    batch_size=cfg.train.batch_size,
    shuffle=True,
    num_workers=cfg.train.n_cpu,
)
test_dataloader = DataLoader(
    SuperResolutionDataLoader(test_paths,**mean_std),
    batch_size=int(cfg.train.batch_size * 0.75),
    shuffle=True,
    num_workers=cfg.train.n_cpu,
)


In [11]:
from math import exp
from torch.autograd import Variable

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

In [12]:
########## Define the Model Parameters ##########
generator = Generator()
discriminator = Discriminator()

feature_extractor = VggFeatureExtractor()
feature_extractor.eval()

gan_loss = torch.nn.BCEWithLogitsLoss()
content_loss = torch.nn.L1Loss()


# Transfer all to the device
generator = generator.to(cfg.device.device)
discriminator = discriminator.to(cfg.device.device)
feature_extractor = feature_extractor.to(cfg.device.device)
gan_loss = gan_loss.to(cfg.device.device)
content_loss = content_loss.to(cfg.device.device)



In [13]:
# define the optimizers for generator and discriminator

optimizer_G = torch.optim.Adam(
    generator.parameters(),
    lr=cfg.train.learning_rate,
    betas=(cfg.train.b1, cfg.train.b2),
)
optimizer_D = torch.optim.Adam(
    discriminator.parameters(),
    lr=cfg.train.learning_rate,
    betas=(cfg.train.b1, cfg.train.b2),
)

In [14]:
import math

# train losses
train_gen_loss, train_disc_loss, train_counter = [], [], []
# test losses
test_gen_loss, test_disc_loss = [], []
test_ssim = []
test_psnr = []


for epoch in range(cfg.train.n_epochs):

    ############################ Training ####################
    gen_loss = 0
    disc_loss = 0
    train_bar = tqdm(train_dataloader, desc=f"Training")

    for batch_idx, imgs in enumerate(train_bar):

        generator.train()
        discriminator.train()

        low_res_ipt = imgs["lr"].to(cfg.device.device)
        high_res_ipt = imgs["hr"].to(cfg.device.device)
        #################### Generator ######################

        optimizer_G.zero_grad()
        generated_hr = generator(low_res_ipt)
        disc_opt = discriminator(generated_hr)

        # Adverserial loss
        loss_GAN = gan_loss(disc_opt, torch.ones_like(disc_opt))

        # content loss
        generated_features = feature_extractor(generated_hr)
        real_feaures = feature_extractor(high_res_ipt)
        loss_CONTENT = content_loss(generated_features, real_feaures)

        # total loss
        total_loss_generator = loss_CONTENT + 1e-3 * loss_GAN

        # backpropagate
        total_loss_generator.backward()
        optimizer_G.step()
        #################### discriminator ######################

        optimizer_D.zero_grad()

        real_disc_opt = discriminator(high_res_ipt)
        loss_D_real = gan_loss(real_disc_opt, torch.ones_like(real_disc_opt))

        fake_disc_opt = discriminator(generated_hr.detach())
        loss_D_fake = gan_loss(fake_disc_opt, torch.zeros_like(fake_disc_opt))

        # total loss
        total_disc_loss = (loss_D_real + loss_D_fake) / 2

        # backprop
        total_disc_loss.backward()
        optimizer_D.step()

        ################## Accumulate losses ###############

        gen_loss += total_loss_generator.item()
        disc_loss += total_disc_loss.item()

        train_bar.set_postfix(
            gen_loss=gen_loss / (batch_idx + 1), disc_loss=disc_loss / (batch_idx + 1)
        )
    
    train_gen_loss.append(gen_loss / len(train_dataloader))
    train_disc_loss.append(disc_loss / len(train_dataloader))

    ############################ Testing ####################
    gen_loss = 0
    disc_loss = 0
    valid_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}

    test_bar = tqdm(test_dataloader, desc=f"Testing")

    for batch_idx, imgs in enumerate(test_bar):
        generator.eval()
        discriminator.eval()

        # get the inputs
        low_res_ipt = imgs["lr"].to(cfg.device.device)
        high_res_ipt = imgs["hr"].to(cfg.device.device)

        # get the batch size 
        batch_size = low_res_ipt.size(0)
        valid_results['batch_sizes'] += batch_size


        ############# Generator Eval ###############

        generated_hr = generator(low_res_ipt)
        disc_opt = discriminator(generated_hr)

        # calculate the mse 
        batch_mse = ((generated_hr - high_res_ipt) ** 2).data.mean()
        valid_results['mse'] += batch_mse * batch_size


        batch_ssim = ssim(generated_hr, high_res_ipt).item()
        valid_results['ssims'] += batch_ssim * batch_size
        valid_results['psnr'] = 10 * math.log10((high_res_ipt.max()**2) / (valid_results['mse'] / valid_results['batch_sizes']))
        valid_results['ssim'] = valid_results['ssims'] / valid_results['batch_sizes']



        # Adverserial loss
        loss_GAN = gan_loss(disc_opt, torch.ones_like(disc_opt))

        # content loss
        generated_features = feature_extractor(generated_hr)
        real_feaures = feature_extractor(high_res_ipt)
        loss_CONTENT = content_loss(generated_features, real_feaures)

        # total loss
        total_loss_generator = loss_CONTENT + 1e-3 * loss_GAN

        #################### discriminator eval ######################

        real_disc_opt = discriminator(high_res_ipt)
        loss_D_real = gan_loss(real_disc_opt, torch.ones_like(real_disc_opt))

        fake_disc_opt = discriminator(generated_hr.detach())
        loss_D_fake = gan_loss(fake_disc_opt, torch.zeros_like(fake_disc_opt))

        # total loss
        total_disc_loss = (loss_D_real + loss_D_fake) / 2

        ############### Accumulate losses ##########################
        gen_loss += total_loss_generator.item()
        disc_loss += total_disc_loss.item()

        if epoch %  10 == 0:

            if batch_idx % cfg.train.batch_size == 0:

                imgs_lr = nn.functional.interpolate(low_res_ipt, scale_factor=4)
                imgs_hr = make_grid(high_res_ipt, nrow=1, normalize=True)
                gen_hr = make_grid(generated_hr, nrow=1, normalize=True)
                imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
                img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
                save_image(img_grid, f"saved_models/srgan/images/{epoch}_{batch_idx}.png", normalize=False)

        test_bar.set_postfix(
            gen_loss=gen_loss / (batch_idx + 1), disc_loss=disc_loss / (batch_idx + 1),
            ssim = valid_results["ssim"],
            psnr = valid_results["psnr"]
        )
    test_gen_loss.append(gen_loss / len(test_dataloader))
    test_disc_loss.append(disc_loss / len(test_dataloader))
    test_psnr.append(valid_results['psnr'])
    test_ssim.append(valid_results['ssim'])


    torch.save(generator.state_dict(), "saved_models/srgan/generator.pth")
    torch.save(discriminator.state_dict(), "saved_models/srgan/discriminator.pth")


Training:   4%|▎         | 6/166 [00:22<10:04,  3.78s/it, disc_loss=0.633, gen_loss=1.31]


KeyboardInterrupt: 