## Imports

In [None]:
!pip install transformers

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import transformers
from transformers import GPT2Config, GPT2Model
from transformers.models.gpt2 import modeling_gpt2

## Define GPT2

In [None]:
# Initializing a model (with random weights) from the configuration
model = GPT2Model.from_pretrained("gpt2")

# Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

In [None]:
block_list = []
for module in model.modules():
    print('------------------')
    print(f'{type(module)}:\n{module}')
    print('------------------')
    if isinstance(module, modeling_gpt2.GPT2Block):
        block_list.append(module)
print(f'Number of blocks: {len(block_list)}')

In [None]:
print(help(model.register_forward_hook))

## Define Hooks

In [None]:
class SaveOutput:
    def __init__(self):
        self.module_out = None
        self.module_in = None

    def __call__(self, module, module_in, module_out):
        """Forward pytorch hook"""
        self.module_out = module_out
        self.module_in = module_in

In [None]:
class BlockHook:
    def __init__(self, model:modeling_gpt2.GPT2Model, block_num:int, debug:bool=False):
        self.block = [module for module in model.modules() if isinstance(module, modeling_gpt2.GPT2Block)][block_num]
        self.hook = SaveOutput()
        self.debug = debug
        self.block.register_forward_hook(self.hook)
        self.wte, _ = [module for module in model.modules() if isinstance(module, torch.nn.modules.sparse.Embedding)]
        self.ln_f = [module for module in model.modules() if isinstance(module, torch.nn.modules.normalization.LayerNorm)][-1]

    def get_logits(self):
        h = self.hook.module_out[0]
        print(h.shape)
        h = self.ln_f(h)
        batch, seq_len, param = h.shape
        h_flat = torch.reshape(h, (batch*seq_len, param))
        emb_weight = self.wte.weight
        if self.debug:
            print(type(emb_weight))
            print(emb_weight.shape)
            print(h.shape)
        logits = torch.einsum("bcl,tl->bct", h, emb_weight)
        return logits

    def get_best_interests(self, num:int=15):
        logits = self.get_logits()
        prob = logits[0, -1, :].detach().numpy()
        prob_ind = zip(prob, np.arange(len(prob)))
        my_key = lambda x: x[0]
        prob_ind_sorted = sorted(prob_ind, key=my_key, reverse=True)
        interests = prob_ind_sorted[:num]
        return interests

## Register hook



In [None]:
hook_list = [BlockHook(model, i) for i in range(12)]
my_hook = hook_list[0]
print(my_hook.hook.module_out)

In [None]:
model.eval()
token = tokenizer("I love eating", return_tensors="pt")
res = model(**token)

In [None]:
for hook in hook_list:
    interests = hook.get_best_interests()
    print('------------------')
    for logit, interest in interests:
        print(f"{logit:.5f} :: {tokenizer.decode(interest)}")
    print('------------------')

## Plot