In [1]:
%load_ext autoreload

%autoreload 2

# PytorchModuleStorage
### Easy to use API to store forward/backward features
*Francesco Saverio Zuppichini*

## Quick Start

You have a model, e.g. `vgg19` and you want to store the features in the third layer given an input `x`. 

![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19.png)

First, we need a model. We will load `vgg19` from `torchvision.models`. Then, we create a random input `x`

In [2]:
import torch

from torchvision.models import vgg19
from PytorchStorage import ForwardModuleStorage

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = vgg19(False).to(device).eval()

Then, we define a `ForwardModuleStorage` instance by passing the model and the list of layer we are interested on.

In [10]:
storage = ForwardModuleStorage(cnn, [cnn.features[3]])

Finally, we can pass a input to the `storage`.

In [11]:
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage(x) # pass the input to the storage
storage[cnn.features[3]][0] # the features can be accessed by passing the layer as a key

tensor([[[[0.0000, 0.0096, 0.0000,  ..., 0.0000, 0.0779, 0.0000],
          [0.0838, 0.0567, 0.0973,  ..., 0.0000, 0.1429, 0.0132],
          [0.0417, 0.0249, 0.0000,  ..., 0.0000, 0.0653, 0.0000],
          ...,
          [0.1135, 0.0429, 0.0000,  ..., 0.0000, 0.0187, 0.0000],
          [0.0000, 0.0715, 0.0000,  ..., 0.0140, 0.0000, 0.0000],
          [0.0569, 0.0000, 0.0228,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0998, 0.0073,  ..., 0.0000, 0.0725, 0.0000],
          [0.0538, 0.1496, 0.1861,  ..., 0.1608, 0.2325, 0.0000],
          [0.2112, 0.1708, 0.4880,  ..., 0.1965, 0.2087, 0.1108],
          ...,
          [0.0504, 0.0474, 0.1651,  ..., 0.3195, 0.1704, 0.1532],
          [0.2454, 0.2351, 0.2507,  ..., 0.1891, 0.3085, 0.0966],
          [0.0627, 0.1082, 0.1874,  ..., 0.1319, 0.3948, 0.1490]],

         [[0.0936, 0.1198, 0.1036,  ..., 0.2526, 0.1110, 0.0000],
          [0.1442, 0.0190, 0.1689,  ..., 0.2353, 0.0020, 0.0406],
          [0.0000, 0.1516, 0.0460,  ..., 0

The storage keeps an internal `state` (`storage.state`) where we can use the layers as key to access the stored value.

### Hook to a list of layers
You can pass a list of layers and then access the stored outputs

In [20]:
storage = ForwardModuleStorage(cnn, [cnn.features[3], cnn.features[5]])
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage(x) # pass the input to the storage
print(storage[cnn.features[3]][0].shape)
print(storage[cnn.features[5]][0].shape)

torch.Size([1, 64, 224, 224])
torch.Size([1, 128, 112, 112])


### Multiple Inputs

You can also pass multiple inputs, they will be stored using the call order

![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19-1.png)

In [22]:
storage = ForwardModuleStorage(cnn, [cnn.features[3]])
x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
y = torch.rand(1,3,224,224).to(device) # random input, this can be an image
storage([x, y]) # pass the inputs to the storage
print(storage[cnn.features[3]][0].shape) # x
print(storage[cnn.features[3]][1].shape) # y

torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])


### Different inputs for different layers
Image we want to run `x` on a set of layers and `y` on an other, this can be done by specify a dictionary of `{ NAME: [layers...], ...}
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PytorchModuleStorage/master/images/vgg-19-2.png)

In [23]:
storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
storage(x, 'style') # we run x only on the 'style' layers
storage(y, 'content') # we run y only on the 'content' layers


print(storage['style']) 
print(storage['style'][cnn.features[5]])

MutipleKeysDict([(Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), tensor([[[[0.0968, 0.0824, 0.0756,  ..., 0.0599, 0.0000, 0.0959],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0049, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0302],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0142],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0026],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0008]],

         [[0.1660, 0.0000, 0.0743,  ..., 0.0000, 0.0091, 0.0000],
          [0.0217, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0015, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0705, 0.0371],
          [0.0300, 0.0

## Backward
You can also store gradients by using `BackwardModuleStorage`

In [3]:
from PytorchStorage import BackwardModuleStorage

In [6]:
import torch.nn as nn
# we don't need the module, just the layers
storage = BackwardModuleStorage([cnn.features[3]])
x = torch.rand(1,3,224,224).requires_grad_(True).to(device) # random input, this can be an image
loss = nn.CrossEntropyLoss()
# 1 is the ground truth
output = loss(cnn(x), torch.tensor([1]))
storage(output)
# then we can use the layer to get the gradient out from it
storage[cnn.features[3]]

[(tensor([[[[ 0.0000e+00,  4.8176e-06, -5.2397e-07,  ..., -1.6433e-05,
              0.0000e+00,  2.1016e-05],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00],
            [-6.5110e-05,  0.0000e+00,  2.1711e-05,  ...,  0.0000e+00,
              0.0000e+00,  8.7276e-06],
            ...,
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00, -6.9027e-05],
            [-3.0660e-05,  0.0000e+00, -3.6719e-05,  ...,  9.3895e-06,
              0.0000e+00, -3.5631e-05],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00]],
  
           [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00],
            [ 7.7562e-06,  0.0000e+00,  1.2513e-05,  ...,  7.4844e-05,
              0.0000e+00,  5.3793e-06],
            [ 1.7280e-05,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000