# Perceptual Losses Networks
### This script applies to
                - Original Perceptual Loss Network, 
                - Width-Concatenated Perceptual Losses Network,                         
                - Channel-Concatenated Perceptual Losses Network

## Importing Libraries

In [None]:
from fastai.basics import *
from fastai.vision import *
import fastai
from fastai.callbacks import *
from fastai.vision.gan import *
from fastai.vision.learner import cnn_config
import torch.nn as nn
import random
import os
import torchvision
from torchvision.models import vgg16_bn
from torchvision.models import vgg19_bn
from torchvision.models import densenet201
from matplotlib.pyplot import *
from torchvision.utils import save_image

## Defining paths

In [None]:
pathroot = Path('/storage/BEP')
path = pathroot/'data2'
path_low = path/'low_count/'
path_high = path/'high_count/'
path_tests = pathroot/'tests2'
path_predictions = pathroot/'Predictions_50'

## Initialising the DataLoaders Object

In [None]:
#Defining batch size and dimensions of training images
bs,size=8,512 #Size for width concatenation equals 3*512 = 1536

#Progressive resizing: First decreasing image size (which enables higher batch size)
bs,size=bs*2,size//2

#Defining the ResNet34 architecture to be used as the decoding part of the U-Net learner
arch = models.resnet34

#Defining the low-count image list and splitting of a portion of 10 procent to be used as validation
src = ImageImageList.from_folder(path_low).split_by_rand_pct(0.1, seed=42)

In [None]:
#Defining a function that returns the correct path of the high-count image to its corresponding low-count image
def correct_path(path_in, path_high):
    y = str(path_in).split('/')
    return Path(str(path_high) + '/' + y[-2]+ '/'+ y[-1])

#Defining a function that finally creates the dataloader object that will be passed to the U-Net learner
def get_data(bs,size):
    data = (src.label_from_func(lambda x: correct_path(x, path_high))
           .transform(get_transforms(max_zoom=2.,max_rotate=15), size=size, tfm_y=True)
           .databunch(bs=bs, num_workers=0))
    data.c = 3
    return data

In [None]:
#Executing our earlier defined function to create the dataloaders object
data = get_data(bs,size)

## Creating the VGG-19 Loss function

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
loss_m = vgg19_bn(True).features.cuda().eval()
requires_grad(loss_m, False)
blocks = [i-1 for i,o in enumerate(children(loss_m)) if isinstance(o,nn.MaxPool2d)]
base_loss = F.l1_loss

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(loss_m, blocks[2:5], [5,15,2]).cuda()

## Creating the Dynamic U-Net learner

In [None]:
#Set the initial weight decay
wd = 1e-3

#Initiasing the U-Net learner by passing it our data and decoder architecture
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                     blur=True, self_attention=True, norm_type=NormType.Weight)
#Use half-precision
learn.to_fp16();

In [None]:
#Finding the right learn rate by plotting the loss vs learning rate and looking at where the plot is steepest and using 
# the corresponding learn rate
learn.lr_find()
learn.recorder.plot()

In [None]:
#Defining a function that trains the network using fit_one_cycle. 
# Callbacks are used to save the network every trained epoch since the notebook shuts down after 6 hours.
# Note that the save destination of the callbacks is defined for my storage: /storage/BEP.
lr = 1e-4
def do_fit(save_name, no_of_cycles, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(no_of_cycles, lrs, pct_start=pct_start, callbacks=[SaveModelCallback(learn, every='epoch',  
                name=f'/storage/BEP/{save_name}_saved_net')])
    learn.save(f'/storage/BEP/{save_name}')
    learn.show_results(rows=1, imgsize=5)

## Actual Training

In [None]:
#Fit four 10 epochs
do_fit('PLN_Original_1a', 10, slice(lr*10))

In [None]:
#Unfreeze the lowest layers
learn.unfreeze()
do_fit('PLN_Original_1b', 10, slice(1e-5,lr))

In [None]:
# Increase image size to 512x512, decrease batch size to keep GPU memory usage down
# Again setting half-precision and freezing the lowest layers
data = get_data(bs//4,size*2)
learn.to_fp16();
learn.freeze();

In [None]:
do_fit('PLN_Original_2a', 10)

In [None]:
#Unfreezing and using different learn rates for lowest layers and higher learn rate for deepest layers
learn.unfreeze()
do_fit('PLN_Expanded_2b', 10, slice(1e-6,1e-4))

In [None]:
learn.unfreeze()
do_fit('PLN_Expanded_2c', 10, slice(1e-6,1e-4), pct_start=0.3)

## Performing predictions

In [None]:
#Loading the saved network
learn.load(pathroot/'Models/PLN_Original_2c')

In [None]:
#Again creating the dataloaders object to use for predictions, this time using full-precision
data = get_data(bs//4,size*2)
learn.to_fp32();
learn.freeze();

In [None]:
#Function that retrieves the correct name of current image
def name(data,i):
    return str(data.items[i]).split('/')[-1]

def save_predics(test_set):
    #Creating an imagelist to predict from
    src_test = ImageImageList.from_folder(path_tests/test_set).split_none()
    data_test = src_test.label_from_func(lambda x: path_tests/test_set/x.name)
    
    #Making directory
    os.mkdir(path_predictions/test_set)
    
    #Predicting and saving the result
    for i in range(0,len(data_test.items)):
        image = learn.predict(data_test.x[i])[0]
        image.save(path_predictions/test_set/name(data_test,i))

In [None]:
#Executing the save_predics function
save_predics('test_uniform_new_phantom')

In [None]:
#zip the predictions to be able to download from paperspace storage
# Note that the part before .tar is the zip-file name and the second part is which files to zip
!tar cvfz PLN_50_predictions.tar.gz /storage/BEP/Predictions_50/*