In [1]:
from typing import Tuple, List, Dict

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

from pytorch_probing import Interceptor

In [2]:
class TransformDecoderBlock(torch.nn.Module):
    """
    Block of a Transform Decoder.
    """

    def __init__(self, embed_dim:int, n_head:int, dropout_rate:float=0.0):
        super().__init__()

        self.attention = torch.nn.MultiheadAttention(embed_dim, n_head, bias=False, batch_first=True)
        self.dropout_attention = torch.nn.Dropout(dropout_rate)
        self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
        self.linear1 = torch.nn.Linear(embed_dim, 4*embed_dim)
        self.dropout_linear1 = torch.nn.Dropout(dropout_rate)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4*embed_dim, embed_dim)
        self.dropout_linear2 = torch.nn.Dropout(dropout_rate)
        self.layer_norm2 = torch.nn.LayerNorm(embed_dim)

    def forward(self, x:torch.Tensor) -> torch.Tensor:

        #Masked Multi-Head Attention
        attention_mask = torch.nn.Transformer.generate_square_subsequent_mask(x.shape[1])
        y1, _ = self.attention(x, x, x, is_causal=True, need_weights=False, attn_mask=attention_mask)
        y1 = self.dropout_attention(y1)
        
        #Add & Norm
        y1 = x+y1
        y1 = self.layer_norm1(y1)
        
        #Feed Forward
        y2 = self.dropout_linear1(self.linear1(y1))
        y2 = self.relu(y2)
        y2 = self.dropout_linear2(self.linear2(y2))
        
        #Add & Norm
        result = y1+y2
        result = self.layer_norm2(result)

        return result

In [3]:
class ExampleModel(torch.nn.Module):
    def __init__(self, embed_dim, vocab_size, n_layers, n_head):
        super().__init__()

        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)

        blocks = [TransformDecoderBlock(embed_dim, n_head, 0.1) for _ in range(n_layers)]
        self.decoder = torch.nn.Sequential(*blocks)

        self.linear = torch.nn.Linear(embed_dim, vocab_size)

        self.embedding_output = None

    def forward(self, inputs):
        x = self.embedding(inputs)

        self.embedding_output = x.detach().clone()

        x = self.decoder(x)
        x = self.linear(x)

        return x
    
    def my_method(self):
        return "A"

In [4]:
embed_dim = 8
vocab_size = 10
n_layers = 2
n_head = 2

example_model = ExampleModel(embed_dim, vocab_size, n_layers, n_head)

In [5]:
example_model

ExampleModel(
  (embedding): Embedding(10, 8)
  (decoder): Sequential(
    (0): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=8, out_features=32, bias=True)
      (dropout_linear1): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (linear2): Linear(in_features=32, out_features=8, bias=True)
      (dropout_linear2): Dropout(p=0.1, inplace=False)
      (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, el

In [6]:
path = "decoder/0"

In [7]:
test_batch_size = 2
test_sequence_size = 3

inputs = torch.empty([2, test_sequence_size], dtype=int)
inputs[:][0] = 0
inputs[:][1] = 1


In [8]:
example_model.eval()

ExampleModel(
  (embedding): Embedding(10, 8)
  (decoder): Sequential(
    (0): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=8, out_features=32, bias=True)
      (dropout_linear1): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (linear2): Linear(in_features=32, out_features=8, bias=True)
      (dropout_linear2): Dropout(p=0.1, inplace=False)
      (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, el

In [9]:
original_outputs = example_model(inputs)
original_outputs2 = example_model(inputs)

In [10]:
paths = ["decoder.0", "embedding"]

intercepted_model = Interceptor(example_model, paths, detach=False)

In [11]:
example_model

ExampleModel(
  (embedding): InterceptorLayer(
    (_module): Embedding(10, 8)
  )
  (decoder): Sequential(
    (0): InterceptorLayer(
      (_module): TransformDecoderBlock(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
        )
        (dropout_attention): Dropout(p=0.1, inplace=False)
        (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (linear1): Linear(in_features=8, out_features=32, bias=True)
        (dropout_linear1): Dropout(p=0.1, inplace=False)
        (relu): ReLU()
        (linear2): Linear(in_features=32, out_features=8, bias=True)
        (dropout_linear2): Dropout(p=0.1, inplace=False)
        (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
     

In [12]:
intercepted_outputs = intercepted_model(inputs)

In [13]:
assert_array_almost_equal(original_outputs.detach().numpy(), 
                          original_outputs2.detach().numpy(), 
                          decimal=5)

assert_array_almost_equal(original_outputs.detach().numpy(), 
                          intercepted_outputs.detach().numpy(), 
                          decimal=5)

In [14]:
assert_array_almost_equal(example_model.embedding_output.detach().numpy(), 
                          intercepted_model.outputs["embedding"].detach().numpy(), 
                          decimal=5)


In [15]:
intercepted_model.interceptor_clear()
intercepted_model.outputs   

{'decoder.0': None, 'embedding': None}

In [16]:
intercepted_model.my_method()

'A'

In [17]:
isinstance(intercepted_model, Interceptor)

True

In [18]:
isinstance(intercepted_model, ExampleModel)

False

In [19]:
intercepted_model

Interceptor(
  (_module): ExampleModel(
    (embedding): InterceptorLayer(
      (_module): Embedding(10, 8)
    )
    (decoder): Sequential(
      (0): InterceptorLayer(
        (_module): TransformDecoderBlock(
          (attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
          )
          (dropout_attention): Dropout(p=0.1, inplace=False)
          (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (linear1): Linear(in_features=8, out_features=32, bias=True)
          (dropout_linear1): Dropout(p=0.1, inplace=False)
          (relu): ReLU()
          (linear2): Linear(in_features=32, out_features=8, bias=True)
          (dropout_linear2): Dropout(p=0.1, inplace=False)
          (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        )
      )
      (1): TransformDecoderBlock(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuan

In [20]:
intercepted_model.a = 1

In [21]:
example_model = intercepted_model.reduce()

In [22]:
example_model.a

1

In [23]:
paths = ["decoder.0", "WRONG_PATH", "embedding"]

#Assert raises
try:
    intercepted_model = Interceptor(example_model, paths, detach=False)
except ValueError:
    pass

In [24]:
import pprint

model_string = pprint.pformat(example_model)

assert "Interceptor" not in model_string

In [25]:
paths = ["decoder.0", "embedding"]

intercepted_model = Interceptor(example_model, paths, detach=False)

intercepted_model(inputs)

tensor([[[ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964],
         [ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964],
         [ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964]],

        [[-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657],
         [-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657],
         [-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657]]], grad_fn=<ViewBackward0>)

In [26]:
torch.save(intercepted_model, "intercepted_model.pth")

intercepted_model2 = torch.load("intercepted_model.pth")
intercepted_model2.eval()

Interceptor(
  (_module): ExampleModel(
    (embedding): InterceptorLayer(
      (_module): Embedding(10, 8)
    )
    (decoder): Sequential(
      (0): InterceptorLayer(
        (_module): TransformDecoderBlock(
          (attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
          )
          (dropout_attention): Dropout(p=0.1, inplace=False)
          (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (linear1): Linear(in_features=8, out_features=32, bias=True)
          (dropout_linear1): Dropout(p=0.1, inplace=False)
          (relu): ReLU()
          (linear2): Linear(in_features=32, out_features=8, bias=True)
          (dropout_linear2): Dropout(p=0.1, inplace=False)
          (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        )
      )
      (1): TransformDecoderBlock(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuan

In [27]:
intercepted_model2 = torch.load("intercepted_model.pth")

In [28]:
assert intercepted_model2.outputs["decoder.0"] is None
assert intercepted_model2.outputs["embedding"] is None

In [29]:
import os
os.remove("intercepted_model.pth")

In [30]:
intercepted_model.reduce()

ExampleModel(
  (embedding): Embedding(10, 8)
  (decoder): Sequential(
    (0): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=8, out_features=32, bias=True)
      (dropout_linear1): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (linear2): Linear(in_features=32, out_features=8, bias=True)
      (dropout_linear2): Dropout(p=0.1, inplace=False)
      (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformDecoderBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
      )
      (dropout_attention): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((8,), eps=1e-05, el

In [31]:
with Interceptor(example_model, paths) as intercepted_model:
    intercepted_model(inputs)
    print(intercepted_model)

Interceptor(
  (_module): ExampleModel(
    (embedding): InterceptorLayer(
      (_module): Embedding(10, 8)
    )
    (decoder): Sequential(
      (0): InterceptorLayer(
        (_module): TransformDecoderBlock(
          (attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
          )
          (dropout_attention): Dropout(p=0.1, inplace=False)
          (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (linear1): Linear(in_features=8, out_features=32, bias=True)
          (dropout_linear1): Dropout(p=0.1, inplace=False)
          (relu): ReLU()
          (linear2): Linear(in_features=32, out_features=8, bias=True)
          (dropout_linear2): Dropout(p=0.1, inplace=False)
          (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        )
      )
      (1): TransformDecoderBlock(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuan

In [32]:
intercepted_model(inputs)



tensor([[[ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964],
         [ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964],
         [ 0.5432, -0.0260,  0.3923,  0.0831,  0.7893, -0.7662, -0.0886,
          -0.5506,  0.8688,  0.2964]],

        [[-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657],
         [-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657],
         [-0.5086, -0.3076, -0.2503, -0.7200, -0.3460,  0.3234,  0.7612,
          -0.6162,  0.0605, -0.1657]]], grad_fn=<ViewBackward0>)

In [33]:
intercepted_model.outputs

{'decoder.0': tensor([[[-0.1937, -0.3488,  1.5743,  0.4033, -0.7181, -1.5511, -0.5654,
            1.3994],
          [-0.1937, -0.3488,  1.5743,  0.4033, -0.7181, -1.5511, -0.5654,
            1.3994],
          [-0.1937, -0.3488,  1.5743,  0.4033, -0.7181, -1.5511, -0.5654,
            1.3994]],
 
         [[-0.1782, -0.6812,  1.5887,  0.9179, -0.7413,  0.6664, -1.7629,
            0.1906],
          [-0.1782, -0.6812,  1.5887,  0.9179, -0.7413,  0.6664, -1.7629,
            0.1906],
          [-0.1782, -0.6812,  1.5887,  0.9179, -0.7413,  0.6664, -1.7629,
            0.1906]]]),
 'embedding': tensor([[[-0.2290,  0.0151, -0.0612, -0.2276, -0.7952, -1.4815, -0.1022,
            0.5629],
          [-0.2290,  0.0151, -0.0612, -0.2276, -0.7952, -1.4815, -0.1022,
            0.5629],
          [-0.2290,  0.0151, -0.0612, -0.2276, -0.7952, -1.4815, -0.1022,
            0.5629]],
 
         [[-0.9286, -0.0813,  0.6004, -0.0322, -0.9358, -0.4568, -1.3712,
           -0.7904],
          [-0.9