In [1]:
%load_ext autoreload

%autoreload 2

# PytorchModuleStorage
### Easy to use API to store forward/backward features
*Francesco Saverio Zuppichini*

## Quick Start

You have a model, e.g. `vgg19` and you want to store the features in the third layer given an input `x`. 

![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19.png)

First, we need a model. We will load `vgg19` from `torchvision.models`. Then, we create a random input `x`

In [2]:
import torch

from torchvision.models import vgg19
from PytorchStorage import ForwardModuleStorage

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = vgg19(False).to(device).eval()

Then, we define a `ForwardModuleStorage` instance by passing the model and the list of layer we are interested on.

In [3]:
storage = ForwardModuleStorage(cnn, [cnn.features[3]])

Finally, we can pass a input to the `storage`.

In [4]:
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage(x) # pass the input to the storage
storage[cnn.features[3]][0] # the features can be accessed by passing the layer as a key

tensor([[[[0.0815, 0.0000, 0.0136,  ..., 0.0435, 0.0058, 0.0584],
          [0.1270, 0.0873, 0.0800,  ..., 0.0910, 0.0808, 0.0875],
          [0.0172, 0.0095, 0.1667,  ..., 0.2503, 0.0938, 0.1044],
          ...,
          [0.0000, 0.0181, 0.0950,  ..., 0.1760, 0.0261, 0.0092],
          [0.0533, 0.0043, 0.0625,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0776, 0.1942, 0.2467,  ..., 0.1669, 0.0778, 0.0969],
          [0.1714, 0.1516, 0.3037,  ..., 0.1950, 0.0428, 0.0892],
          [0.1219, 0.2611, 0.2902,  ..., 0.1964, 0.2083, 0.2422],
          ...,
          [0.1813, 0.1193, 0.2079,  ..., 0.3328, 0.4176, 0.2015],
          [0.0870, 0.2522, 0.1454,  ..., 0.2726, 0.1916, 0.2314],
          [0.0250, 0.1256, 0.1301,  ..., 0.1425, 0.1691, 0.0775]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.1044],
          [0.0000, 0.0202, 0.0000,  ..., 0.0000, 0.0873, 0.0908],
          [0.0000, 0.0000, 0.0000,  ..., 0

The storage keeps an internal `state` (`storage.state`) where we can use the layers as key to access the stored value.

### Hook to a list of layers
You can pass a list of layers and then access the stored outputs

In [5]:
storage = ForwardModuleStorage(cnn, [cnn.features[3], cnn.features[5]])
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage(x) # pass the input to the storage
print(storage[cnn.features[3]][0].shape)
print(storage[cnn.features[5]][0].shape)

torch.Size([1, 64, 224, 224])
torch.Size([1, 128, 112, 112])


### Multiple Inputs

You can also pass multiple inputs, they will be stored using the call order

![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19-1.png)

In [6]:
storage = ForwardModuleStorage(cnn, [cnn.features[3]])
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
y = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage([x, y]) # pass the inputs to the storage
print(storage[cnn.features[3]][0].shape) # x
print(storage[cnn.features[3]][1].shape) # y

torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])


### Different inputs for different layers
Image we want to run `x` on a set of layers and `y` on an other, this can be done by specify a dictionary of `{ NAME: [layers...], ...}
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19-2.png)

In [7]:
storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
storage(x, 'style') # we run x only on the 'style' layers
storage(y, 'content') # we run y only on the 'content' layers


print(storage['style']) 
print(storage['style'][cnn.features[5]])

MutipleKeysDict([(Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0383, 0.0042,  ..., 0.0852, 0.0246, 0.1101],
          [0.0000, 0.0000, 0.1106,  ..., 0.0000, 0.0107, 0.0487],
          ...,
          [0.0085, 0.0809, 0.0000,  ..., 0.0000, 0.0012, 0.0018],
          [0.0000, 0.0817, 0.1753,  ..., 0.0000, 0.0000, 0.0701],
          [0.0000, 0.1445, 0.1105,  ..., 0.2428, 0.0418, 0.0803]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0400, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0

## Backward
You can also store gradients by using `BackwardModuleStorage`

In [8]:
from PytorchStorage import BackwardModuleStorage

In [9]:
import torch.nn as nn
# we don't need the module, just the layers
storage = BackwardModuleStorage([cnn.features[3]])
x = torch.rand(1,3,224,224).requires_grad_(True).to(device) # random input, this can be an image
loss = nn.CrossEntropyLoss()
# 1 is the ground truth
output = loss(cnn(x), torch.tensor([1]))
storage(output)
# then we can use the layer to get the gradient out from it
storage[cnn.features[3]]

[(tensor([[[[ 1.6662e-05,  0.0000e+00,  9.1222e-06,  ...,  1.2165e-07,
              0.0000e+00,  0.0000e+00],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  1.8770e-05],
            [ 4.9425e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00],
            ...,
            [ 7.3107e-05,  0.0000e+00,  0.0000e+00,  ..., -2.6335e-05,
              0.0000e+00,  2.1168e-05],
            [ 1.0214e-07,  0.0000e+00,  8.3543e-06,  ...,  0.0000e+00,
              8.6060e-06,  0.0000e+00],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00]],
  
           [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  2.9192e-05,
              0.0000e+00,  0.0000e+00],
            [ 0.0000e+00, -1.3629e-05,  0.0000e+00,  ...,  0.0000e+00,
             -8.7888e-06,  0.0000e+00],
            [ 0.0000e+00,  0.0000e+00, -3.7738e-05,  ...,  0.0000e+00,
             -3.6711