### Importing the required libraries

In [1]:
# Python libraries
import pandas as pd
import numpy as np
from numpy import linalg as la

# Transformers libraries
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GPT2Model
import torch.nn.functional as F

### Definition of the model and the hook function

In [2]:
model_id = 'openai-community/gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, output_attentions=True)
gpt2_model = model.transformer

# Function to be called by the hook
output_list, module_list = [], []

def hook_fn(module, input, output):
    output_list.append(output)
    module_list.append(module)

# Attaching hook to all layers
for layer in model.modules():
    layer.register_forward_hook(hook_fn)


### Choosing the prompt and running gpt2

In [3]:
# About Walt Whitman (Tokens: 308)
prompt = "Explain this poem: O Captain! my Captain! our fearful trip is done,\nThe ship has weather’d every rack, the prize we sought is won,\nThe port is near, the bells I hear, the people all exulting,\nWhile follow eyes the steady keel, the vessel grim and daring;\nBut O heart! heart! heart!\nO the bleeding drops of red,\nWhere on the deck my Captain lies,\nFallen cold and dead.\n\nO Captain! my Captain! rise up and hear the bells;\nRise up—for you the flag is flung—for you the bugle trills,\nFor you bouquets and ribbon’d wreaths—for you the shores a-crowding,\nFor you they call, the swaying mass, their eager faces turning;\nHere Captain! dear father!\nThis arm beneath your head!\nIt is some dream that on the deck,\nYou’ve fallen cold and dead.\n\nMy Captain does not answer, his lips are pale and still,\nMy father does not feel my arm, he has no pulse nor will,\nThe ship is anchor’d safe and sound, its voyage closed and done,\nFrom fearful trip the victor ship comes in with object won;\nExult O shores, and ring O bells!\nBut I with mournful tread,\nWalk the deck my Captain lies,\nFallen cold and dead."

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

with torch.no_grad():
    outputs = model(input_ids)
    
for i, module in enumerate(module_list):
    try:
        print(i, module, '   output_shape: ', output_list[i].shape)
    except:
        try:
            print(i, module, '   output_shape: ', len(output_list[i]))
        except:
            print(i, module)

attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
attn_output shape:  torch.Size([1, 308, 768])
0 Embedding(50257, 768)    output_shape:  torch.Size([1, 308, 768])
1 Embedding(1024, 768)    output_shape:  torch.Size([1, 308, 768])
2 Dropout(p=0.1, inplace=False)    output_shape:  torch.Size([1, 308, 768])
3 LayerNorm((768,), eps=1e-05, elementwise_affine=True)    output_shape:  torch.Size([1, 308, 768])
4 Conv1D()    output_shape:  torch.Size([1, 308, 2304])
5 Dropout(p=0.1, inplace=False)    output_shape:  torch.Size([1, 12, 308, 308])
6

### Storing output list and module list

#### Word2Vec conversion

Remember that the passages are:
    
    1. Convert word to tokens and tokens to vectors
    2. Apply the positional encoding matrix
    3. Sum the positional encoding to the vectorial representation of the prompt

In [4]:
Token2Vec          = output_list[0][0]
PositionalEncoding = output_list[1][0]
PositionPlusVect   = output_list[2][0]

torch.save(Token2Vec, "output_Captain/word2vec/Token2Vec.pt")
torch.save(PositionalEncoding, "output_Captain/word2vec/PositionalEncoding.pt")
torch.save(PositionPlusVect, "output_Captain/word2vec/PositionPlusVect.pt")

#### Storing evolution of the decoder blocks

In [5]:
index_list = np.array([3, 4, 5, 7, 9, 11, 13, 15])
'''
Decoder_01_FirstNormalization  = output_list[3]
Decoder_01_QKV_representation  = output_list[4]
Decoder_01_AttentionHeads      = output_list[5]
Decoder_01_AttentionProj       = output_list[7]
Decoder_01_SecondNormalization = output_list[9]
Decoder_01_FirstLayerNN        = output_list[11]
Decoder_01_SecondLayerNN       = output_list[13]  'Delta space'
Decoder_01_final_output        = output_list[15]  'Residual + Delta space'
'''
module_name = ["FirstNormalization", "QKV_representation", "AttentionHeads", "AttentionProj", "SecondNormalization", "FirstLayerNN", "SecondLayerNN", "Decoder_Final_Output"]
# Create Decoder_mask and flatten it
Decoder_mask = np.concatenate([index_list + i*13 for i in range(12)])

# Assuming output_list is already defined, we can proceed
# Extract elements for Decoder_list
Decoder_list = [output_list[mask] for mask in Decoder_mask]

PositionalEmbedding = output_list[2]

Decoder_list = [output_list[mask] for mask in Decoder_mask]

for i in range(12):
    for j in range(8):
        print("Decoder ", i + 1 , " ", module_name[j])
        torch.save(Decoder_list[j+i*8], "output_Captain/decoder/decoder_"+str(i+1) +"/"+module_name[j]+".pt")

Decoder  1   FirstNormalization
Decoder  1   QKV_representation
Decoder  1   AttentionHeads
Decoder  1   AttentionProj
Decoder  1   SecondNormalization
Decoder  1   FirstLayerNN
Decoder  1   SecondLayerNN
Decoder  1   Decoder_Final_Output
Decoder  2   FirstNormalization
Decoder  2   QKV_representation
Decoder  2   AttentionHeads
Decoder  2   AttentionProj
Decoder  2   SecondNormalization
Decoder  2   FirstLayerNN
Decoder  2   SecondLayerNN
Decoder  2   Decoder_Final_Output
Decoder  3   FirstNormalization
Decoder  3   QKV_representation
Decoder  3   AttentionHeads
Decoder  3   AttentionProj
Decoder  3   SecondNormalization
Decoder  3   FirstLayerNN
Decoder  3   SecondLayerNN
Decoder  3   Decoder_Final_Output
Decoder  4   FirstNormalization
Decoder  4   QKV_representation
Decoder  4   AttentionHeads
Decoder  4   AttentionProj
Decoder  4   SecondNormalization
Decoder  4   FirstLayerNN
Decoder  4   SecondLayerNN
Decoder  4   Decoder_Final_Output
Decoder  5   FirstNormalization
Decoder  5  

Let's extract also the intermediate result of Output Attention + Residual connection

In [6]:
# Extract also Output Attention + Residual connection
SecondLayerNN_list         =  [torch.load(f"output_Captain/decoder/decoder_{i+1}/SecondLayerNN.pt")[0] for i in range(12)]
Decoder_Final_Output_list  =  [torch.load(f"output_Captain/decoder/decoder_{i+1}/Decoder_Final_Output.pt")[0][0] for i in range(12)]

AttentionPlusResidual_list =  [DecOut - SecLayer  for DecOut, SecLayer in zip(Decoder_Final_Output_list, SecondLayerNN_list)]

for i, AttentionPlusResidual in enumerate(AttentionPlusResidual_list):
    torch.save((AttentionPlusResidual.unsqueeze(0),)
, f"output_Captain/decoder/decoder_{i+1}/AttentionPlusResidual.pt")

Push Back the last token to the language domain to study the evolution of the probability distribution during the 12 blocks

In [7]:
# Extract the final normalization layer and the 'inverse matrix'
ln_f = gpt2_model.ln_f
lm_head = model.lm_head

# Let's save the 'projection' on the vocabulary before and after the softmax
for i, (out_attention, out_decoder) in enumerate(zip(AttentionPlusResidual_list, Decoder_Final_Output_list)):
    print(f"Decoder {i+1}")
    
    final_norm = ln_f(out_attention)
    projection = lm_head(final_norm[-1])
    softmax = F.softmax(projection, dim=-1)
    
    top_values, top_indices = torch.topk(softmax, 10, dim=-1)

    print("    Top10 attn: ", top_indices, top_values)

    torch.save(projection, f"output_Captain/last_token_pdf/decoder_{i+1}/attention_projection.pt")
    torch.save(softmax, f"output_Captain/last_token_pdf/decoder_{i+1}/attention_softmax.pt")

    final_norm = ln_f(out_decoder)
    projection = lm_head(final_norm[-1])
    softmax = F.softmax(projection, dim=-1)

    torch.save(projection, f"output_Captain/last_token_pdf/decoder_{i+1}/out_decoder_projection.pt")
    torch.save(softmax, f"output_Captain/last_token_pdf/decoder_{i+1}/out_decoder_softmax.pt")
    
    # Extract the top 10 entries with highest softmax values
    top_values, top_indices = torch.topk(softmax, 10, dim=-1)
    print("     Top10 dec: ", top_indices, top_values)
    

Decoder 1
    Top10 attn:  tensor([ 383,  198,  317,  314, 2448,  554,  632,  843, 1081, 1649]) tensor([0.0799, 0.0761, 0.0451, 0.0286, 0.0252, 0.0192, 0.0190, 0.0171, 0.0164,
        0.0158], grad_fn=<TopkBackward0>)
     Top10 dec:  tensor([ 383,  198,  632,  843,  887, 1081, 1002,  554, 1318, 1649]) tensor([0.1186, 0.1049, 0.0713, 0.0696, 0.0537, 0.0327, 0.0299, 0.0293, 0.0271,
        0.0267], grad_fn=<TopkBackward0>)
Decoder 2
    Top10 attn:  tensor([ 383,  198,  843,  632,  887, 1649, 1081, 1002, 1318,  554]) tensor([0.0966, 0.0772, 0.0644, 0.0564, 0.0532, 0.0361, 0.0316, 0.0299, 0.0293,
        0.0236], grad_fn=<TopkBackward0>)
     Top10 dec:  tensor([ 632,  383,  198, 1002, 1649, 1318,  843,  887,  770, 1081]) tensor([0.0903, 0.0777, 0.0605, 0.0564, 0.0516, 0.0500, 0.0481, 0.0360, 0.0333,
        0.0257], grad_fn=<TopkBackward0>)
Decoder 3
    Top10 attn:  tensor([ 632,  198,  383, 1002, 1649,  843,  887, 1318,  770,  775]) tensor([0.0828, 0.0588, 0.0557, 0.0544, 0.0540, 0.05

In [10]:
print("Finished")

Finished
