In [1]:
import torch
import numpy as np
from functools import partial, reduce
from collections import OrderedDict
from pprint import pprint

from torchvision.models import vgg16


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = vgg16(False).to(device).eval()

In [3]:
# content_x = loader(content)
# storage = StoreFeatures(cnn, [cnn[0], cnn[5]])
# storage(content_x)
# storage.clear()
# print(storage._state)

In [13]:
class MutipleKeysDict(OrderedDict):
    """
    Allow to get values from multiple keys. Example:
    
    ```python
    d = MutipleKeysDict({ 'a' : 1, 'b' : 2, 'c' : 3})
    d[['a', 'b']]
    # out [1,2]
    ```
    """
    def __getitem__(self, keys):
        
        if type(keys) is list:
            res = [dict.__getitem__(self, key) for key in keys]
        else: res = super().__getitem__(keys)
        return res
    
d = MutipleKeysDict({ 'a' : 1, 'b' : 2, 'c' : 3})
d[['a', 'b']]

[1, 2]

In [5]:
class ModuleStorage():
    def __init__(self, where2layers, debug=False):
        self.where2layers = where2layers
        self.where = list(self.names)[0]
        self._state = MutipleKeysDict({ k : MutipleKeysDict() for k in self.names})
        self.unsubcribe = []
        self.debug = debug
        
    @property
    def names(self):
        return self.where2layers.keys()
    
    @property
    def layers(self):
        """
        Flat all the layers in the same array
        """
        return reduce(lambda a, b: a + b, self.where2layers.values())
    
    def register_hooks(self, how='forward'):
        """
        Loop in all the layers and register a hook. There is ONLY one hook per layer to improve
        performance.
        """
        for layer in self.layers:
            # create a hash of a layer as an identifier, this is unique
#             name = f"{type(layer).__name__.lower()}-{hash(layer)}"
            if how == 'forward':
                self.unsubcribe.append(layer.register_forward_hook(partial(self.hook, name=layer)))
            elif how == 'backward':
                self.unsubcribe.append(layer.register_backward_hook(partial(self.hook, name=layer)))
            else:
                raise ValueError("type must be 'forward' or 'backward'")
            if self.debug: print(f"[INFO] {how} hook registered to {layer}")
        
    def hook(self, m, i, o, name):
        if self.debug: print(f"{m} called")
#         store only the outputs from the correct layers defined in self.where2layers
        if m in self.where2layers[self.where]: self._state[self.where][name] = o
    
    def clear(self):
        if self.debug: print('[INFO] clear')
        [un.remove() for un in self.unsubcribe]

    def __call__(self, where=None):
        if where not in self.keys(): raise(f"we cannot find any layers with key {where}")
        if self.where is not None: self.where = where
        
    def __repr__(self):
        return str({k: [{i : e.shape for i, e in v.items()}] for k, v in self._state.items()})    

    def __getitem__(self, key):
        return self._state[key]
    
    def keys(self):
        return self._state.keys()

In [17]:
class ForwardModuleStorage(ModuleStorage):
    def __init__(self, module, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.module = module
        self.register_hooks(how='forward')
        
        
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        self.module(x)
#         if type(x) != list: x = [x]
#         [self.module(_x) for _x in x]
        
storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5],  cnn.features[9]], 'content' : [cnn.features[5], cnn.features[10]]})
storage(torch.rand(1,3,100,100).to(device), 'style')
storage(torch.rand(1,3,100,100).to(device), 'content')

pprint(storage['style'])

storage.clear()

MutipleKeysDict([(Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                  tensor([[[[0.0000, 0.0456, 0.0284,  ..., 0.0000, 0.0167, 0.0819],
          [0.0000, 0.0911, 0.0458,  ..., 0.1335, 0.1861, 0.1201],
          [0.0000, 0.0549, 0.0175,  ..., 0.1046, 0.1009, 0.1384],
          ...,
          [0.0000, 0.0946, 0.0626,  ..., 0.0850, 0.1443, 0.1094],
          [0.0000, 0.0679, 0.0662,  ..., 0.0249, 0.0723, 0.0224],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0209, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0233, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0465, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0815, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0543, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0134],
          [0.0290, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0167, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    

In [18]:
storage = ForwardModuleStorage(cnn, [cnn.features[5], cnn.features[9]])
storage([torch.rand(1,3,100,100).to(device), torch.rand(1,3,100,100).to(device)])

pprint(storage['style'])

storage.clear()

AttributeError: 'list' object has no attribute 'keys'

In [11]:
storage['style'][[cnn.features[5], cnn.features[9]]]

[tensor([[[[0.0000e+00, 6.5853e-02, 0.0000e+00,  ..., 4.4428e-02,
            4.2158e-02, 1.1559e-01],
           [0.0000e+00, 1.4792e-01, 9.2415e-02,  ..., 3.6840e-02,
            1.0875e-01, 8.4957e-02],
           [0.0000e+00, 5.5843e-03, 3.7081e-02,  ..., 0.0000e+00,
            1.0339e-01, 6.5520e-02],
           ...,
           [0.0000e+00, 1.1574e-01, 4.7353e-02,  ..., 1.1194e-01,
            3.6669e-02, 9.9927e-02],
           [0.0000e+00, 1.1178e-01, 3.4352e-02,  ..., 0.0000e+00,
            5.5424e-02, 1.2428e-01],
           [0.0000e+00, 3.6883e-04, 2.9656e-02,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00]],
 
          [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [1.9318e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           ...,
           [2.1758e-02, 0.0000e+00, 0.

In [None]:
class BackwardModuleStorage(ModuleStorage):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_hooks(how='backward')
        
    def __call__(self, x, *args, **kwargs):
        super().__call__(*args, **kwargs)
        x.backward()

storage = BackwardModuleStorage({'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
x = cnn(torch.rand(1,3,100,100).requires_grad_(True).to(device)).sum() 
storage(x, 'style')
pprint(storage['style'])

storage.clear()

In [None]:
storage.clear()