### SRResNet

In [1]:
# HELPER FUNCTIONS TO BE USED DURING TRAINING
import DataLoader
from DataLoader import MabulaDataset
from albumentations import Flip, Rotate, RandomCrop
from albumentations.pytorch import ToTensor
from torch.utils.data import Dataset, DataLoader
from SRResNetBlock import *
import torch
import torch.optim as optim
from Utility import *

##########################################################
############### Model creation ###########################

def create_model(dim=320, scale_factor=2, batch_size=1, n_res_blocks=8, GPU=True):   
    t = transformationMatrix(dim, scale_factor, batch_size, GPU=GPU)[0,0]
    net = SRResNet(scale_factor=scale_factor, n_res_blocks=n_res_blocks)
    discriminator = SRResNetDiscriminator()
    
    # move models to GPU, if available
    if torch.cuda.is_available() and GPU:
        net.to(torch.device("cuda:0"))
        discriminator.to(torch.device("cuda:0"))
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')
    return net, t

##########################################################
############### Training loop ############################

def training_loop(net, optimizer, dataloaders, t, n_epochs=1000, GPU=True):
    train_loader, test_loader = dataloaders
    
    print_every= 10
    test_every = 20
    
    # keep track of losses over time
    losses = []
    meanPSNR_lst = []
    meanSSIM_lst = []
    for epoch in range(n_epochs+1):      
        for batch in train_loader:
            X_hr_true = batch['image']
            if torch.cuda.is_available() and GPU:
                X_hr_true = X_hr_true.to("cuda:0")
            X_lr = torch.matmul(t, X_hr_true)
            # =========================================
            #            TRAIN Network

            optimizer.zero_grad()
            X_hr_fake = net(X_lr)
            MSE_loss = nn.MSELoss()(X_hr_fake, X_hr_true)
            MSE_loss.backward()
            optimizer.step()
        
        has_tested = False
        if epoch % test_every == 0:
            print("#### Testing ", epoch, " ####")
            net.eval()
            PSNR_lst = []
            SSIM_lst = []
            for batch in test_loader:
                X_hr_true = batch['image']
                if torch.cuda.is_available() and GPU:
                    X_hr_true = X_hr_true.to("cuda:0")
                X_lr = torch.matmul(t, X_hr_true)
                with torch.no_grad():
                    X_hr_fake = net(X_lr)
                    MSE_loss_test = nn.MSELoss()(X_hr_fake, X_hr_true)
                X_hr_fake = X_hr_fake.cpu()
                X_hr_true =  X_hr_true.cpu()
                PSNR, SSIM = calculate_scores(X_hr_fake, X_hr_true)
                PSNR_lst.append(PSNR)
                SSIM_lst.append(SSIM)
            meanPSNR = np.mean(PSNR_lst)
            meanSSIM = np.mean(SSIM_lst)
            meanPSNR_lst.append(meanPSNR)
            meanSSIM_lst.append(meanSSIM)
            net.train()
            has_tested = True
        
        # Logging info
        if epoch % print_every == 0 and has_tested:
            print("---- Epoch nr: ", epoch, " ----")
            print("Train loss: ", str(MSE_loss))
            print("Last Test loss: ", str(MSE_loss_test))
            print("Last PSNR: ", meanPSNR)
            print("Last SSIM: ", meanSSIM)
            losses.append((MSE_loss.item(), MSE_loss_test.item())) 
             
    print("Done")
    return losses, meanPSNR_lst, meanSSIM_lst

##########################################################
############### Training Multiple models #################

def trainModelEnsemble(name="name", n_epochs=100, dim=320, scale_factors = [2, 4, 8], batch_size=16, n_res_blocks=8, GPU=True):
    # Records losses of the trained models
    loss_dict = {}
    # Trains a model for each stride:
    for _, scale_factor in enumerate(scale_factors):
        # Define data loaders:
        train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True, num_workers=0)
        dataloaders = [train_loader, test_loader]

        # Create model:
        net, t = create_model(dim=dim, scale_factor=scale_factor, batch_size=batch_size, GPU=GPU, n_res_blocks=n_res_blocks)

        # Define optimizers:
        optimizer = optim.Adam(net.parameters(), lr, [beta1, beta2])

        print("\n --- Training with parameters: ---")
        print("scale factor: ", scale_factor)
        print("epochs: ", n_epochs)
        print("batch size: ", batch_size)
            
        # Train model:
        losses, meanPSNR_lst, meanSSIM_lst = training_loop(net, optimizer, dataloaders, t, n_epochs=n_epochs, GPU=GPU)
        loss_dict[name + "_" + str(scale_factor)] = (losses, meanPSNR_lst, meanSSIM_lst)
            
        # Save trained model:
        path = "checkpoints/" + name + "_" + str(scale_factor)
        saveModels(net, net, path=path)

        # Delete the trained models
        del net
        print("Passed training of model: ", name+str(scale_factor))
        print("\n")

    return loss_dict

In [2]:
GPU = True
dim = 320
batch_size = 4
n_res_blocks = 6
n_epochs = 600
scale_factors = [2, 4, 8]
# Parameters of optimizer:
lr = 0.0001
beta1 = 0.5
beta2 = 0.99

In [3]:
# Choose augmentations:
transforms=[Flip(), Rotate(), ToTensor()]
test_transforms=[ToTensor()]
# Create dataset:
train_data = MabulaDataset(file_path="/Data/OCTA/Train", transforms=transforms)
test_data = MabulaDataset(file_path="/Data/OCTA/Test", transforms=test_transforms)
# Create dataloader:
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True, num_workers=0)
# Clamp loaders
dataloaders = [train_loader, test_loader]

In [4]:
# Train SRResNet
#loss_dict = trainModelEnsemble(name="SRResNet",
#                               n_epochs=n_epochs,
#                               dim=dim,
#                               batch_size=batch_size,
#                               n_res_blocks=n_res_blocks,
#                               scale_factors=scale_factors,
#                               GPU=GPU)