<a href="https://colab.research.google.com/github/RaiAnant/MangaChroma/blob/master/fasterai/loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Fri Aug 16 11:51:44 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    17W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Imports

In [0]:
from fastai import *
from fastai.core import *
from fastai.torch_core import *
from fastai.callbacks  import hook_outputs
import torchvision.models as models

## Loss classes

In [0]:
# m = models.vgg16_bn(True)

In [0]:
# ??models.vgg16_bn

In [0]:
# ??models.vgg

In [0]:
# ??models.vgg.make_layers

In [0]:
# m = models.vgg16_bn(True).features.cuda().eval()
# blocks = [i-1 for i,o in enumerate(children(m)) if isinstance(o,nn.MaxPool2d)]
# l = [m[i] for i in blocks[2:5]]
# l

[ReLU(inplace), ReLU(inplace), ReLU(inplace)]

In [0]:
class FeatureLoss(nn.Module):
  """Feature loss required for pretraining Generator"""
  def __init__(self, layer_wts=[20,70,10]):
    super.__init__()
    # Feature extractor
    # .features gets the layers of the model with pretrained weights 
    self.m_feat = models.vgg16_bn(True).features.cuda().eval()
    requires_grad(self.m_feat, False)
    # Selecting the activation layers just before downsampling occurs
    blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
    layers_ids = blocks[2:5]
    self.loss_features = [self.m_feat[i] for i in layer_ids]
    # Takes the outputs of these specific layers which 
    self.hooks = hook_outputs(self.loss_features, detach=False)
    self.wgts = layer_wgts
    self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))]
    self.base_loss = F.l1_loss
  
  def _make_features(self, x, clone=False):
    self.m_feat(x)
    # .clone() copies the tensor while still keeping the copy attached to the graph
    return [(o.clone() if clone else o) for o in self.hooks.stored]

  def forward(self, input, target):
    out_feat = self._make_features(target, clone=True) # Extract features for the target image (Colored)
    in_feat = self._make_features(input) # Extract features for the tfmed image
    self.feat_losses = [self.base_loss(input,target)]
    self.feat_losses += [self.base_loss(f_in, f_out)*w
                         for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
    
    self.metrics = dict(zip(self.metric_names, self.feat_losses))
    return sum(self.feat_losses)

  def __del__(self): self.hooks.remove()

In [0]:
class WassFeatureLoss(nn.Module):
  """Feature loss paired with wasserstein loss"""
  def __init__(self, layer_wgts=[5,15,2], wass_wgts=[3.0,0.7,0.01]):
    super.__init__()
    self.m_feat = models.vgg16_bn(True).features.cuda().eval()
    requires_grad(self.m_feat, False)
    blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
    layers_id = blocks[2:5]
    self.loss_features = [self.m_feat[i] for i in layer_ids]
    self.wgts = layer_wgts
    self.wass_wgts = wass_wgts
    self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'wass_{i}' for i in range(len(layer_ids))]
    self.base_loss = F.l1_loss

  def _make_features(self, x, clone=False):
    self.m_feat(x)
    return [(o.clone() if clone else o) for o in self.hooks.stored]

  def _calc_2_moments(self, tensor):
    # ?? Comeback when used
    chans = tensor.shape[1]
    tensor = tensor.view(1, chans, -1)
    n = tensor.shape[2]
    mu = tensor.mean(2)
    # Centering the values of the tensor to later find out the covariance
    tensor = (tensor - mu[:,:,None]).squeeze(0)
    # Error handling
    if n == 0: return None, None 
    # Cov with itself is variance
    # FeatureLoss paper pg. 8. Efficient way to calculate Gram matrix
    # might have to divide by n*chans as stated in the paper
    cov = torch.mm(tensor, tensor.t()) / float(n)
    return mu, cov

  def _get_style_vals(self, tensor):
    mean, cov = self._calc_2_moments(tensor)
    if mean is None:
      return None, None, None
    eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
    eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
    root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
    tr_cov = eigvals.clamp(min=0).sum()
    return mean, tr_cov, root_cov


  def forward(self, input, target):
    out_feat = self._make_features(target, clone=True)
    in_feat = self._make_features(input)
    self.feat_losses = [self.base_loss(input,target)]
    self.feat_losses += [self.base_loss(f_in, f_out)*w
                         for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
    
    # Uptil now only the feature reconstruction loss is implemented 
    # W-GAN will only have a part in the style reconstruction loss
    # Style of the target image is extracted 
    styles = [self._get_style_vals(i) for i in out_feat]

    if styles[0][0] is not None:
      self.feat_losses += [self._single_wass_loss(f_pred, f_targ)*w
                           for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)]

    self.metrics = dict(zip(self.metric_names, self.feat_losses))
    return sum(self.feat_losses)


In [33]:
torch.randn(1,2) * torch.randn(2,1)

tensor([[ 0.0517, -0.7658],
        [ 0.0377, -0.5578]])

In [35]:
torch.randn(4,5,3).sum()

tensor(-4.1681)