# Creator: Ivan Bardarov <br> (University of Strathclyde, March 2019)
## This module extracts the gram matrix from an input photo

In [1]:
import sys
sys.path.append('../../../')
from fastai.imports import *
from fastai.dataset import *

class SaveFeatures():
    """
    Registers a hook to a nn.Module and saves the activations in a variable
    
    Attributes
    ----------
    features : ndarray
        The activations from the layer

    Methods
    -------
    hook_fn(module, input, output)
        The callback function that is registered
    close()
        Remove the hook
    """
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def close(self): self.hook.remove()

def gram(input):
    """
    Calculates the gram matrix from a multidimensional array

    Parameters
    ----------
    input : float.Tensor
        The activations 

    Returns
    -------
    float.Tensor
        the gram matrix

    """
    b,c,h,w = input.size()
    x = input.view(b*c, -1)
    return torch.mm(x, x.t())/input.numel()*1e6

def get_gram(filename):
    """
    Processes an image and returns the gram matrix for it
    
    Parameters
    ----------
    filename : str, Path
        The path to the input iamge

    Returns
    -------
    ndarray
        the gram matrix

    """
    img = open_image(filename)
    img_tfms = val_tfms(img)
    m_vgg(VV(img_tfms[None]))
    img_act = act.features.clone()
    return to_np(gram(img_act).view(-1))

# Setup the pretrained model
m_vgg = to_gpu(vgg16(True)).eval()
set_trainable(m_vgg, False)
sz=288
_, val_tfms = tfms_from_model(vgg16, sz)

blocks = [i-1 for i,o in enumerate(children(m_vgg)) if isinstance(o,nn.MaxPool2d)]
act = SaveFeatures(children(m_vgg)[blocks[0]])

In [5]:
type(children(m_vgg)[0])

torch.nn.modules.conv.Conv2d

In [10]:
blocks = [i-1 for i,o in enumerate(m_vgg) if isinstance(o,nn.MaxPool2d)]


In [11]:
blocks

[5, 12, 22, 32, 42]