In [1]:
from fastai import *
from fastai.vision import *
from fastai.vision.models.unet import DynamicUnet
from fastai.vision.learner import cnn_config
from fastai.callbacks import *
import sys
sys.path.append('../../')
sys.path.append('../../model')
from model.losses import *
from model.bpho.resnet import *
from model.bpho.unet import *
from model.metrics import psnr, ssim
from data.load_fluo import *

## set device

In [2]:
gpu_id = 3
num_cores = 4
torch.cuda.set_device(gpu_id)

## adapting feature loss to wnresnet structure

In [3]:
critic_pth = Path('/home/alaa/Dropbox/BPHO Staff/USF/Mitotracker/models')
critic_sf = load_learner(path=critic_pth/'baselines',file='mitotracker_PSSR-SF.pkl')

In [15]:
encoder = critic_sf.model.model[0].eval()
flattened_encoder = flatten_model(encoder)

In [16]:
# inspect the types of layers to find grid-changing layers
layer_types = {type(layer) for layer in flattened_encoder}; layer_types
# grid-changing layers include:
# 1. conv2d with stride=2
# 2. all pooling layers
# add controls in find_layers() function to find these layers

{torch.nn.modules.activation.ReLU,
 torch.nn.modules.conv.Conv2d,
 torch.nn.modules.pooling.AdaptiveAvgPool2d,
 torch.nn.modules.pooling.AvgPool2d,
 torch.nn.modules.pooling.MaxPool2d}

## feature loss definition

In [None]:
def flatten_model(model):
    """Using children() method, flatten a complex model."""
    flattened = []

    def get_children(block):
        for child in list(block.children()):
            grand_children = list(child.children())
            if len(grand_children):
                get_children(child)
            else: flattened.append(child)
    
    get_children(model)
    return flattened

In [None]:
def find_layers(flattened_model):
    """Find the layers previous to the grid-changing layers in a flattened model."""
    
    def is_grid_changing(layer):
        """add controls here"""
        if 'pooling' in str(type(layer)): return True
        if isinstance(layer, torch.nn.modules.conv.Conv2d) and layer.stride==(2,2):
            return True
    
    loss_features = []
    for i, layer in enumerate(flattened_model[1:]):
        if is_grid_changing(layer):
            loss_features.append(flattened_model[i]) 
            # append the layer previous to the grid-changing ones
            # want to see the grid-changing ones? add the index by 1
            # loss_features.append(flattened_model[i+1]) 
    return loss_features

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_wgts):
        super().__init__()
        self.__name__ = 'feat_loss'
        self.m_feat = m_feat
        self.loss_features = find_layers(flatten_model(self.m_feat))
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [
            f'feat_{i}' for i in range(len(self.loss_features))
              ] + [f'gram_{i}' for i in range(len(self.loss_features))]
    
    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, pred, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(pred)
        self.feat_losses = [base_loss(pred,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)

    def __del__(self): self.hooks.remove()

# instantiate feature loss

In [None]:
encoder = critic_sf.model.model[0].eval()

#### how many feature maps are in the feature loss function?

In [17]:
feature_maps = find_layers(flatten_model(encoder))

In [19]:
len(feature_maps)

8

###  tweak layer_wgts accordingly

In [None]:
feat_loss = FeatureLoss(m_feat=encoder, layer_wgts=[1/8 for _ in range(8)])

### training with feature loss

In [None]:
learn = wnres_unet_learner(data, arch, in_c=n_frames, wnres_args=wnres_args,
                           loss_func=feat_loss, 
                           metrics=metrics, model_dir=model_pth, callback_fns=[LossMetrics], wd=wd)