In [1]:
import torch
import numpy as np
from functools import partial, reduce
from collections import OrderedDict
from pprint import pprint

from torchvision.models import vgg16

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = vgg16(False).to(device).eval()

In [3]:
class MutipleKeysDict(OrderedDict):
    """
    Allow to get values from multiple keys. Example:
    
    ```python
    d = MutipleKeysDict({ 'a' : 1, 'b' : 2, 'c' : 3})
    d[['a', 'b']]
    # out [1,2]
    ```
    """
    def __getitem__(self, keys):
        
        if type(keys) is list:
            res = [dict.__getitem__(self, key) for key in keys]
        else: res = super().__getitem__(keys)
        return res

In [20]:
class ModuleStorage():
    def __init__(self, where2layers, debug=False):
        self.where2layers = where2layers
        self.where = list(self.names)[0]
        self.state = self._state
        self.unsubcribe = []
        self.debug = debug
    
    @property
    def _state(self):
        return MutipleKeysDict({ 
            k : MutipleKeysDict() if type(self.where2layers) == dict else [] 
            for k in self.names 
        })
    
    @property
    def names(self):
        names = []
        if type(self.where2layers) == dict:
            names = self.where2layers.keys()
        elif type(self.where2layers) is list:
            names = self.where2layers
        return names
    
    @property
    def layers(self):
        """
        Flat all the layers in the same array
        """
        layers = []
        if type(self.where2layers) == dict:
            layers = reduce(lambda a, b: a + b, self.where2layers.values())
        elif type(self.where2layers) is list:
            layers = self.where2layers
        return layers 
    
    def register_hooks(self, how='forward'):
        """
        Loop in all the layers and register a hook. There is ONLY one hook per layer to improve
        performance.
        """
        for layer in self.layers:
            # create a hash of a layer as an identifier, this is unique
#             name = f"{type(layer).__name__.lower()}-{hash(layer)}"
            if how == 'forward':
                self.unsubcribe.append(layer.register_forward_hook(partial(self.hook, name=layer)))
            elif how == 'backward':
                self.unsubcribe.append(layer.register_backward_hook(partial(self.hook, name=layer)))
            else:
                raise ValueError("type must be 'forward' or 'backward'")
            if self.debug: print(f"[INFO] {how} hook registered to {layer}")
        
    def hook(self, m, i, o, name):
        if self.debug: print(f"{m} called")
            
        if type(self.where2layers) == dict:
    #       store only the outputs from the correct layers defined in self.where2layers
            if m in self.where2layers[self.where]: self.state[self.where][name] = o
        if type(self.where2layers) is list:
            self.state[name].append(o) 
            
    def clear(self):
        if self.debug: print('[INFO] clear')
        [un.remove() for un in self.unsubcribe]

    def __call__(self, where=None):
        if where is not None:
            if where not in self.keys(): raise(f"we cannot find any layers with key {where}")
            self.where = where
        
    def __repr__(self):
        items = lambda x: x.items() if type(x) == MutipleKeysDict else enumerate(x)
        return str({k: [{i : e.shape for i, e in items(v)}] for k, v in self.state.items()})    

    def __getitem__(self, key):
        return self.state[key]
    
    def keys(self):
        return self.state.keys()

## Store input for multiple layers

In [22]:
class ForwardModuleStorage(ModuleStorage):
    def __init__(self, module, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.module = module
        self.register_hooks(how='forward')
        
        
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        if type(x) != list: x = [x]
        [self.module(_x) for _x in x]
        
storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
storage(torch.rand(1,3,100,100).to(device), 'style')
storage(torch.rand(1,3,100,100).to(device), 'content')

pprint(storage['style'].keys())

storage.clear()

del storage

odict_keys([Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))])


## Store multiple inputs for same layers

In [25]:
storage = ForwardModuleStorage(cnn, [cnn.features[5], cnn.features[9]])
a = torch.rand(1,3,100,100).to(device)
b = torch.rand(1,3,100,100).to(device)
storage([a, b])

storage.clear()

pprint(storage[cnn.features[5]])

[tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0639, 0.1510,  ..., 0.1072, 0.2296, 0.2635],
          [0.0000, 0.2039, 0.0981,  ..., 0.0552, 0.1525, 0.2407],
          [0.0000, 0.0676, 0.1154,  ..., 

In [7]:
class BackwardModuleStorage(ModuleStorage):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_hooks(how='backward')
        
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        if type(x) != list: x = [x]
        [_x.backward() for _x in x]

storage = BackwardModuleStorage({'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
x = cnn(torch.rand(1,3,100,100).requires_grad_(True).to(device)).sum() 
storage(x, 'style')
pprint(storage['style'])

storage.clear()

MutipleKeysDict([(Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                  (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00

In [8]:
storage.clear()