## Perception loss

In [31]:
from fastai.vision.all import *
from torcheval.metrics.functional import peak_signal_noise_ratio
import fastai; fastai.__version__

'2.7.15'

In [33]:
# define gram matrix calculation which enables the use of gram (style) loss
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [34]:
# set VGG16 model for inference during loss calculation
vgg_m = vgg16_bn(True).features.eval()
vgg_m = vgg_m.requires_grad_(False)

In [35]:
# check which blocks are MaxPool layers -> enables grabbing feature maps prior to dimension reduction
blocks = [i-1 for i,o in enumerate(vgg_m.children()) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

([5, 12, 22, 32, 42],
 [ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True)])

In [None]:
# set hyper parameters for model and loss
wd, y_range, base_loss, ref_loss, bbone = 1e-3, (-3, 3), F.l1_loss, F.mse_loss, resnet34

In [22]:
# create feature loss class with weighting factors for feature and gram contributions
class FeatureLoss(Module):
    def __init__(self, m_feat, layer_ids, L1_wgt, feature_wgts, gram_wgts):
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.feature_wgts = feature_wgts
        self.gram_wgts = gram_wgts
        self.L1_wgt = L1_wgt
        self.metric_names = ['MSE',] +['PSNR',] +['l1',] +[f'feat_{i}' for i in range(len(layer_ids))
        ] + [f'gram_{i}' for i in range(len(layer_ids))]         
                            
    # feature generator
    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target, reduction='mean'):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [ref_loss(input,target,reduction=reduction)]                     # MSE as metric
        self.feat_losses += [peak_signal_noise_ratio(input,target)]                         # PSNR as metric
        self.feat_losses += [base_loss(input,target,reduction=reduction)]                   # F1 loss 
        self.feat_losses += [base_loss(f_in, f_out,reduction=reduction)*w                   # Feature loss 
                             for f_in, f_out, w in zip(in_feat, out_feat, self.feature_wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out),reduction=reduction)*w*20000    # gram loss
                             for f_in, f_out, w in zip(in_feat, out_feat, self.gram_wgts)]
        if reduction=='none':
            self.feat_losses = [f.mean(dim=[1,2,3]) for f in self.feat_losses[:4]] + [f.mean(dim=[1,2]) for f in self.feat_losses[4:]]
        for n,l in zip(self.metric_names, self.feat_losses): setattr(self, n, l)
        return sum(self.feat_losses[2:])           # 2: to exclude MSE and PSNR from loss calculation, but to have them as metrics

    def __del__(self): self.hooks.remove()