In [1]:
import torch
import numpy as np
from functools import partial, reduce
from collections import OrderedDict
from pprint import pprint

from torchvision.models import vgg16


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = vgg16(False).to(device).eval()

In [3]:
# content_x = loader(content)
# storage = StoreFeatures(cnn, [cnn[0], cnn[5]])
# storage(content_x)
# storage.clear()
# print(storage._state)

In [4]:
class ModuleStorage():
    def __init__(self, where2layers):
        self.where2layers = where2layers
        self.where = list(self.names)[0]
        self._state = OrderedDict({ k : {} for k in self.names})
        self.unsubcribe = []
        
    @property
    def names(self):
        return self.where2layers.keys()
    
    @property
    def layers(self):
        """
        Flat all the layers in the same array
        """
        return reduce(lambda a, b: a + b, self.where2layers.values())
    
    def register_hooks(self, how='forward'):
        """
        Loop in all the layers and register a hook. There is ONLY one hook per layer to improve
        performance.
        """
        for layer in self.layers:
            # create a hash of a layer as an identifier, this is unique
            name = f"{type(layer).__name__.lower()}-{hash(layer)}"
            if how == 'forward':
                self.unsubcribe.append(layer.register_forward_hook(partial(self.hook, name=name)))
            elif how == 'backward':
                self.unsubcribe.append(layer.register_backward_hook(partial(self.hook, name=name)))
            else:
                raise ValueError("type must be 'forward' or 'backward'")
            print(f"[INFO] {how} hook registered to {layer}")
        
    def hook(self, m, i, o, name):
        print(f"{m} called")
#         store only the outputs from the correct layers defined in self.where2layers
        if m in self.where2layers[self.where]: self._state[self.where][name] = o
    
    def clear(self):
        print('[INFO] clear')
        [un.remove() for un in self.unsubcribe]

    def __call__(self, where=None):
        if where not in self.keys(): raise(f"we cannot find any layers with key {where}")
        if self.where is not None: self.where = where
        
    def __repr__(self):
        return str({k: [{i : e.shape for i, e in v.items()}] for k, v in self._state.items()})    

    def __getitem__(self, key):
        return self._state[key]
    
    def keys(self):
        return self._state.keys()

In [5]:
class ForwardModuleStorage(ModuleStorage):
    def __init__(self, module, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.module = module
        self.register_hooks(how='forward')
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        self.module(x)
        
storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
storage(torch.rand(1,3,100,100).to(device), 'style')
pprint(storage['style'])

storage.clear()

[INFO] forward hook registered to Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[INFO] forward hook registered to Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[INFO] forward hook registered to Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
{'conv2d--9223363251159186116': tensor([[[[0.1500, 0.1389, 0.1168,  ..., 0.0892, 0.0238, 0.0000],
          [0.1522, 0.0509, 0.0076,  ..., 0.0000, 0.0000, 0.0000],
          [0.1138, 0.0466, 0.0090,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1152, 0.0789, 0.0202,  ..., 0.0000, 0.0000, 0.0000],
          [0.1061, 0.0801, 0.0125,  ..., 0.0000, 0.0000, 0.0000],
          [0.0636, 0.0546, 0.0000,  ..., 0.0208, 0.0094, 0.0000]],

         [[0.0000, 0.0000, 0

In [6]:
class BackwardModuleStorage(ModuleStorage):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_hooks(how='backward')
        
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        x.backward()

storage = BackwardModuleStorage({'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
x = cnn(torch.rand(1,3,100,100).requires_grad_(True).to(device)).sum() 
storage(x, 'style')
pprint(storage['style'])

storage.clear()

[INFO] backward hook registered to Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[INFO] backward hook registered to Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[INFO] backward hook registered to Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) called
{'conv2d--9223363251159186116': (tensor([[[[ 7.8778e-04, -1.1037e-03,  1.3144e-03,  ..., -1.5999e-03,
            1.3931e-03,  0.0000e+00],
          [-1.0359e-03,  0.0000e+00,  1.0080e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-2.1427e-03, -8.6761e-04, -4.3194e-03,  ...,  3.8091e-03,
            0.0000e+00,  0.0000e+00],
          ...,
          [-3.1418e-03, -2.0311e-03,  0.0000e+00,  ..., -2.0851e-03,
            5.1013e-04,  0.0000e

In [7]:
storage.clear()

[INFO] clear
