In [1]:
#import libs
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from srresnet import srresnet

from torchvision.models import vgg16_bn
import warnings

#turn off pytorch warnings
warnings.filterwarnings('ignore')

In [2]:
#model parameters
batch_size = 8
num_epochs = 1
train_val_split_pct = 0.1
fp_16 = False

#set paths to low resolution and high resolution image folders. Make sure input and target images have same names
low_res_images_path = ""
high_res_images_path = ""

In [None]:
#load images from folders as a list and split inputs into training and validation sets

#high res target images
high_res_images_list = ImageList.from_folder(high_res_images_path)

#low res input images
low_res_images_list = ImageImageList.from_folder(low_res_images_path).split_by_rand_pct(train_val_split_pct)

In [None]:
#create fastai databunch and get training labels
data = (lr_list.label_from_func(lambda x: path_fullres/x.name)).databunch(bs=batch_size)
data.c = 3

In [None]:
#show sample databunch
data.show_batch(ds_type=DatasetType.Valid, rows=batch_size, figsize=(20,20))

In [None]:
#define training model
model = srresnet()

In [None]:
# define feature loss 
# Jeremy Howard's implementation of Perceptual Losses for Real-Time Style Transfer and Super-Resolution
# Justin Johnson, Alexandre Alahi, Li Fei-Fei
# https://arxiv.org/abs/1603.08155v1
# taken from fast.ai course v3 https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres.ipynb

t = data.valid_ds[0][1].data
t = torch.stack([t,t])

def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

base_loss = F.l1_loss
vgg_m = vgg16_bn(True).features.cuda().eval()

requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]

class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()
        
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

In [None]:
#create fast.ai Learner

#weight decay
wd = 1e-3

#create learner object
learn = Learner(data,
                model,
                wd = wd,
                loss_func=feat_loss,
                callback_fns=LossMetrics)

gc.collect();

In [None]:
#Mixed precision training
if fp_16 == True : learn.to_fp16()

In [None]:
#find appropriate learning rate
learn.lr_find()
learn.recorder.plot()

In [None]:
# set appropriate Learning rate from the plot , if you donot understand the plot , 1e-4 usually works from my observations
lr = 1e-4

In [None]:
#fit !
learn.fit(num_epochs, lr)
learn.save("NAME OF MODEL")
learn.show_results(rows=batch_size, imgsize=20)

In [None]:
#export model for inferencing.
#We donot need vgg16 weights and loss functions for inferencing.
#hence we can recreate our learner without feature loss before exporting as this significantly reduces the model size
learn = Learner(data,
                model,
                wd = wd,
                loss_func=base_loss)
#load weights
learn.load("NAME OF MODEL")
learn.export("NAME OF MODEL")