### SRGAN

In [1]:
# HELPER FUNCTIONS TO BE USED DURING TRAINING
import DataLoader
from DataLoader import MabulaDataset
from albumentations import Flip, Rotate, RandomCrop, Blur
from albumentations.pytorch import ToTensor
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

##########################################################
############### Model creation ###########################

from SRResNetBlock import *
# Creates the model to be used.
def create_model(dim=320, scale_factor=2, batch_size=1, n_res_blocks=8, GPU=True):   
    t = transformationMatrix(dim, scale_factor, batch_size, GPU=GPU)[0,0]
    t_test = transformationMatrix(320, scale_factor, batch_size, GPU=GPU)[0,0]
    net = SRResNet(scale_factor=scale_factor, n_res_blocks=n_res_blocks)
    discriminator = SRResNetDiscriminator()
    #discriminator = Discriminator()
    # move models to GPU, if available
    if torch.cuda.is_available() and GPU:
        net.to(torch.device("cuda:0"))
        discriminator.to(torch.device("cuda:0"))
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')
    return net, discriminator, t, t_test

##########################################################
############### Losses ###################################

def real_mse_loss(D_out):
    # how close is D_out from being "real"?
    l=0
    for each in D_out:
        l += torch.mean((each-1.0)**2)
    return l

def fake_mse_loss(D_out):
    # how close is D_out from being "fake"?
    l=0
    for each in D_out:
        l += torch.mean((each)**2)
    return l


##########################################################
############### Training loop ############################

from Utility import *
def training_loop(models, optimizers, dataloaders,  t, t_test, n_epochs=200, GAN_loss_weight=10**-3, GPU=True):
    generator, discriminator = models
    g_optimizer, d_optimizer = optimizers
    train_loader, test_loader = dataloaders

    print_every= 10
    test_every = 20
    
    # keep track of losses over time
    losses = []
    D_losses = []
    G_losses = []
    meanPSNR_lst = []
    meanSSIM_lst = []
    for epoch in range(1, n_epochs+1):
        iter_train = iter(train_loader)
        for batch in train_loader:
            X_hr_discriminator = iter_train.next()['image']
            X_hr_true = batch['image']
            if torch.cuda.is_available() and GPU:
                X_hr_discriminator = X_hr_discriminator.to("cuda:0")
                X_hr_true = X_hr_true.to("cuda:0")
            X_lr = torch.matmul(t, X_hr_true)
            # ============================================
            #            TRAIN THE DISCRIMINATOR
            d_optimizer.zero_grad()
            # 1. Compute the REAL loss on REAL the images:
            D_X_hr_real = discriminator(X_hr_discriminator)
            D_X_hr_real_loss = real_mse_loss(D_X_hr_real)
            # 2. Generate fake images high resolution images:
            X_hr_fake = generator(X_lr)
            # 3. Compute the FAKE loss for the FAKE image:
            D_X_hr_fake = discriminator(X_hr_fake)
            D_X_hr_fake_loss = fake_mse_loss(D_X_hr_fake)
            # 4. Compute the total loss and perform backprop:
            d_x_loss = D_X_hr_real_loss + D_X_hr_fake_loss
            d_x_loss.backward()
            d_optimizer.step()
            # =========================================
            #            TRAIN THE GENERATOR
            g_optimizer.zero_grad()
            # 1. Generate fake images high resolution images:
            X_hr_fake = generator(X_lr)
            # 2. Compute the REAL loss based on the FAKE image:
            D_X_hr_fake = discriminator(X_hr_fake)
            D_X_hr_real_fake_loss = real_mse_loss(D_X_hr_fake)
            # 3. Compute the MSE loss:
            MSE_loss = nn.MSELoss()(X_hr_fake, X_hr_true)
            # 4. Compute the total loss and perform backprop:
            g_T_real_loss = GAN_loss_weight*D_X_hr_real_fake_loss + MSE_loss
            g_T_real_loss.backward()
            g_optimizer.step()
        
        has_tested = False
        if epoch % test_every == 0:
            print("#### Testing ", epoch, " ####")
            generator.eval()
            PSNR_lst = []
            SSIM_lst = []
            for batch in test_loader:
                X_hr_true = batch['image']
                if torch.cuda.is_available() and GPU:
                    X_hr_true = X_hr_true.to("cuda:0")
                X_lr = torch.matmul(t_test, X_hr_true)
                with torch.no_grad():
                    X_hr_fake = generator(X_lr)
                X_hr_fake = X_hr_fake.cpu()
                X_hr_true =  X_hr_true.cpu()
                PSNR, SSIM = calculate_scores(X_hr_fake, X_hr_true)
                PSNR_lst.append(PSNR)
                SSIM_lst.append(SSIM)
            meanPSNR = np.mean(PSNR_lst)
            meanSSIM = np.mean(SSIM_lst)
            meanPSNR_lst.append(meanPSNR)
            meanSSIM_lst.append(meanSSIM)
            generator.train()
            has_tested = True
        
        # Logging info
        if epoch % print_every == 0 and has_tested:
            print("---- Epoch nr: ", epoch, " ----")
            print("Train loss: ")
            print("Discriminator: ", str(d_x_loss))
            print("Generator: ", str(g_T_real_loss))
            print("Last PSNR: ", meanPSNR)
            print("Last SSIM: ", meanSSIM)
            losses.append((d_x_loss.item(), g_T_real_loss.item()))
            D_losses.append((D_X_hr_real_loss.item(), D_X_hr_fake_loss.item()))
            G_losses.append((D_X_hr_real_fake_loss.item(), MSE_loss.item()))
    print("Done")
    return losses, D_losses, G_losses, meanPSNR_lst, meanSSIM_lst


##########################################################
############### Training Multiple models #################

def trainModelEnsemble(name="name", n_epochs=100, dim=320, scale_factors = [2, 4, 8], batch_size=16, n_res_blocks=8, GAN_loss_weight=10**-3, GPU=True, IMPORT_GENERATOR=False):
    # Records losses of the trained models
    loss_dict = {}
    for i, scale_factor in enumerate(scale_factors):
        # Define data loaders:
        train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True, num_workers=0)
        dataloaders = [train_loader, test_loader]

        # Create model:
        generator, discriminator, t, t_test = create_model(dim=dim, scale_factor=scale_factor, batch_size=batch_size, n_res_blocks=n_res_blocks, GPU=GPU)
        if IMPORT_GENERATOR:
            del generator
            path = "checkpoints/SRResNet_" + str(scale_factor) + "_generator"
            generator, _ = readModels(path, gpu=GPU)
        models = [generator, discriminator]
        
        # Define optimizers:
        g_optimizer = optim.Adam(generator.parameters(), lr, [beta1, beta2])
        d_optimizer = optim.Adam(discriminator.parameters(), lr, [beta1, beta2])
        optimizers = [g_optimizer, d_optimizer]

        print("\n --- Training with parameters: ---")
        print("scale factor: ", scale_factor)
        print("epochs: ", n_epochs)
        print("batch size: ", batch_size)
        
        # Train model:
        losses, D_losses, G_losses, meanPSNR_lst, meanSSIM_lst = training_loop(models=models,
                                                                               optimizers=optimizers,
                                                                               dataloaders=dataloaders,
                                                                               n_epochs=n_epochs,
                                                                               t=t,
                                                                               t_test=t_test,
                                                                               GAN_loss_weight=GAN_loss_weight,
                                                                               GPU=GPU)
        loss_dict[name + "_" + str(scale_factor)] = (losses, D_losses, G_losses, meanPSNR_lst, meanSSIM_lst)
            
        # Save trained model:
        path = "checkpoints/" + name + "_" + str(scale_factor)
        saveModels(generator, discriminator, path=path)

        # Delete the trained models
        del generator
        del discriminator
        print("Passed training of model: ", name+str(scale_factor))
        print("\n")

    return loss_dict

In [2]:
# DEFINE PARAMETERS FOR TRAINING
GPU = True
IMPORT_GENERATOR = True
dim = 96
batch_size = 32
n_res_blocks = 6
n_epochs = 400
scale_factors = [2, 4, 8]
GAN_loss_weight=10**-3
# Parameters of optimizer:
lr = 0.0001
beta1 = 0.5
beta2 = 0.99

In [3]:
# Choose augmentations:
transforms=[Flip(), Rotate(), ToTensor(), RandomCrop(height=dim, width=dim, always_apply=True)]
test_transforms=[ToTensor()]
# Create dataset:
train_data = MabulaDataset(file_path="/Data/OCTA/Train", transforms=transforms)
test_data = MabulaDataset(file_path="/Data/OCTA/Test", transforms=test_transforms)
# Create dataloader:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_data, batch_size=12, shuffle=True, num_workers=0)
# Clamp dataloaders:
dataloaders = [train_loader, test_loader]

In [4]:
# Uncomment to Train SRGAN
# Note: Assumes that the SRResNet has been pretrained.
#loss_dict = trainModelEnsemble(name="SRResNetGAN",
#                               n_epochs=n_epochs,
#                               dim=dim,
#                               scale_factors=scale_factors,
#                               batch_size=batch_size,
#                               n_res_blocks=n_res_blocks,
#                               GAN_loss_weight=GAN_loss_weight,
#                               GPU=GPU,
#                               IMPORT_GENERATOR=IMPORT_GENERATOR)