# Basic usage of modified transformers library (modification have been done only for BERT)

## Install the changed library (only first once)

1. move into the modified library
> $ cd transformers

1. install using pip (I recommend editable mode (-e option).)
> $ pip install -e .

1. please install pytorch (https://pytorch.org/) \
 e.g., 
 > $ pip install torch torchvision

## Usage

In [None]:
import torch
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer, BartModel, BartTokenizer, BartConfig, T5Model, T5Tokenizer
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Load pretrained model/tokenizer
model_name = "t5-large" #"facebook/bart-large" #"bert-large-uncased-whole-word-masking"
#tokenizer = BertTokenizer.from_pretrained(model_name)
#model = BertModel.from_pretrained(model_name).to(device)
# tokenizer = RobertaTokenizer.from_pretrained(model_name)
# model = RobertaModel.from_pretrained(model_name).to(device)
# tokenizer = BartTokenizer.from_pretrained(model_name)
# model = BartModel.from_pretrained(model_name).to(device)
tokenizer = T5Tokenizer.from_pretrained(model_name)
tokenizer.mask_token = "<extra_id_0>"
model = T5Model.from_pretrained(model_name).to(device)
model.eval()

# Tokenize text and convert to ids. 
# input_ids = torch.tensor([tokenizer.encode("A 41 year old person age is [MASK] than a 42 year old person .", add_special_tokens=True)]).to(device)
input_ids = torch.tensor([tokenizer.encode(f"A 41 year old person age is {tokenizer.mask_token} than a 42 year old person .", add_special_tokens=True)]).to(device)
#input_ids = torch.tensor([tokenizer.encode("The size of a feather is usually much <mask> than the size of a nail .", add_special_tokens=True)]).to(device)
tokenized_text = tokenizer.convert_ids_to_tokens(input_ids[0])
print(tokenized_text)

In [None]:
# If you want to extract vector norms (i.e., ||f(x)||, ||αf(x)||, ||Σαf(x)||), please set the argument "output_norms" to True.
with torch.no_grad():
    #last_hidden_state, pooler_output, hidden_states, attentions, norms
#     decoder_features, decoder_states, decoder_attns, encoder_last_state, encoder_states, encoder_attns, norms = model(input_ids=input_ids.repeat(2, 1), output_hidden_states=True, output_attentions=True, output_norms=True)    
#     decoder_features, decoder_states, decoder_attns, decoder_self_norms, decoder_cross_norms, encoder_last_state, encoder_states, encoder_attns, encoder_norms = model(input_ids=input_ids, output_hidden_states=True, output_attentions=True, output_norms=True)  # BART
    decoder_features, decoder_cache, decoder_states, decoder_attns, decoder_self_norms, decoder_cross_norms, encoder_last_state, encoder_states, encoder_attns, encoder_norms = model(input_ids=input_ids, decoder_input_ids=input_ids, output_hidden_states=True, output_attentions=True, output_norms=True)  # T5

In [None]:
norms = encoder_norms

In [None]:
num_heads = model.config.num_attention_heads
num_layers = model.config.num_hidden_layers
#fig, axes = plt.subplots(num_layers, num_heads)
#fig = plt.figure()
#fig.suptitle(f"BERT-Large-WWM Attention Norms")
fig = plt.figure() #plt.figure(figsize=(16, 12)) 
fig.set_figheight(120)
fig.set_figwidth(160)
#plt.rcParams["figure.figsize"] = (15 * num_layers, 15 * num_heads)
# num_layers = 6
for layer in tqdm(range(num_layers)):
    for head in range(num_heads):
        plt.subplot(num_heads, num_layers, num_layers * head + layer + 1)
        real_layer = layer
#         real_layer = layer if layer < 3 else layer + 18
#         if (layer == 0) or (layer == 23): #or ((layer in layer2head) and (head in layer2head[layer])):
#         afx_norm = norms[real_layer][0][1]  # bart encoder-self
        afx_norm = norms[real_layer][1]  # bart decoder-self, bart cross, t5 all
        norm = afx_norm[0][head].cpu().numpy()
        df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
        sns.heatmap(df,cmap="Reds",square=True)
        plt.gcf().subplots_adjust(bottom=0.2)
        
plt.savefig(f"t5-encoder-self-attention-fixed.png", bbox_inches="tight")
plt.close(fig)

In [None]:
import time
time.sleep(30)

In [None]:
norms = decoder_cross_norms
num_heads = model.config.num_attention_heads
num_layers = model.config.num_hidden_layers
#fig, axes = plt.subplots(num_layers, num_heads)
#fig = plt.figure()
#fig.suptitle(f"BERT-Large-WWM Attention Norms")
fig = plt.figure() #plt.figure(figsize=(16, 12)) 
fig.set_figheight(120)
fig.set_figwidth(160)
#plt.rcParams["figure.figsize"] = (15 * num_layers, 15 * num_heads)
# num_layers = 6
for layer in tqdm(range(num_layers)):
    for head in range(num_heads):
        plt.subplot(num_heads, num_layers, num_layers * head + layer + 1)
        real_layer = layer
#         real_layer = layer if layer < 3 else layer + 18
#         if (layer == 0) or (layer == 23): #or ((layer in layer2head) and (head in layer2head[layer])):
#         afx_norm = norms[real_layer][0][1]  # bart encoder-self
        afx_norm = norms[real_layer][1]  # bart decoder-self, bart cross, t5 all
        norm = afx_norm[0][head].cpu().numpy()
        df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
        sns.heatmap(df,cmap="Reds",square=True)
        plt.gcf().subplots_adjust(bottom=0.2)
        
#     time.sleep(30)

plt.savefig(f"t5-cross-attention-fixed-2.png", bbox_inches="tight")
plt.close(fig)

In [None]:
time.sleep(30)

In [None]:
norms = decoder_self_norms
num_heads = model.config.num_attention_heads
num_layers = model.config.num_hidden_layers
#fig, axes = plt.subplots(num_layers, num_heads)
#fig = plt.figure()
#fig.suptitle(f"BERT-Large-WWM Attention Norms")
fig = plt.figure() #plt.figure(figsize=(16, 12)) 
fig.set_figheight(120)
fig.set_figwidth(160)
#plt.rcParams["figure.figsize"] = (15 * num_layers, 15 * num_heads)
# num_layers = 6
for layer in tqdm(range(num_layers)):
    for head in range(num_heads):
        plt.subplot(num_heads, num_layers, num_layers * head + layer + 1)
        real_layer = layer
#         real_layer = layer if layer < 3 else layer + 18
#         if (layer == 0) or (layer == 23): #or ((layer in layer2head) and (head in layer2head[layer])):
#         afx_norm = norms[real_layer][0][1]  # bart encoder-self
        afx_norm = norms[real_layer][1]  # bart decoder-self, bart cross, t5 all
        norm = afx_norm[0][head].cpu().numpy()
        df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
        sns.heatmap(df,cmap="Reds",square=True)
        plt.gcf().subplots_adjust(bottom=0.2)
        
#     time.sleep(30)
        
plt.savefig(f"t5-decoder-self-attention-fixed.png", bbox_inches="tight")
plt.close(fig)

END of oLMpics visualizations
------------------------------------------------------------------------

Code below is from the original kobayashi notebook

### Attention weight visualization (Head-level visualization)

In [None]:
plt.figure() 
attention = attentions[layer-1][0][head-1].cpu().numpy()
df = pd.DataFrame(attention,columns=tokenized_text,index=tokenized_text)
sns.heatmap(df,cmap="Reds",square=True)
plt.gcf().subplots_adjust(bottom=0.2)

### Norm of the weighted vectors (||αf(x)||) visualization (Head-level visualization)

In [None]:
from tqdm.auto import tqdm

In [None]:
layer2head = {2: [2, 3], 3: [12], 5: [2], 6: [6], 12: [3], 13: [4], 15: [8], 16: [0, 8], 17: [2, 5, 10]}

In [None]:
num_heads = model.config.num_attention_heads
num_layers = model.config.num_hidden_layers
#fig, axes = plt.subplots(num_layers, num_heads)
#fig = plt.figure()
#fig.suptitle(f"BERT-Large-WWM Attention Norms")
fig = plt.figure() #plt.figure(figsize=(16, 12)) 
fig.set_figheight(120)
fig.set_figwidth(160)
#plt.rcParams["figure.figsize"] = (15 * num_layers, 15 * num_heads)
for layer in tqdm(range(num_layers)):
    for head in range(num_heads):
        plt.subplot(num_heads, num_layers, num_layers * head + layer + 1)
        if (layer == 0) or (layer == 23): #or ((layer in layer2head) and (head in layer2head[layer])):
            afx_norm = norms[layer][1]
            norm = afx_norm[0][head].cpu().numpy()
            df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
            sns.heatmap(df,cmap="Reds",square=True)
            plt.gcf().subplots_adjust(bottom=0.2)
        
plt.savefig(f"{model_name}-attention-important_all.png", bbox_inches="tight")
plt.close(fig)

In [None]:
layer = 1
head = 1
plt.figure() 
afx_norm = norms[layer-1][0][1]  # BART
norm = afx_norm[0][head-1].cpu().numpy()
df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
sns.heatmap(df,cmap="Reds",square=True)
plt.gcf().subplots_adjust(bottom=0.2)

In [None]:
layer = 1
head = 1
plt.figure() 
afx_norm = norms[layer-1][1]
norm = afx_norm[0][head-1].cpu().numpy()
df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
sns.heatmap(df,cmap="Reds",square=True)
plt.gcf().subplots_adjust(bottom=0.2)

### Attention weight visualization (Layer-level visualization)

In [None]:
plt.figure() 
attention = attentions[layer-1][0].mean(0).cpu().numpy()
df = pd.DataFrame(attention,columns=tokenized_text,index=tokenized_text)
sns.heatmap(df,cmap="Reds",square=True)
plt.gcf().subplots_adjust(bottom=0.2)

### Norm of the summed weighted vectors (||Σ αf(x)||) visualization (Layer-level visualization)

In [None]:
plt.figure() 
summed_afx_norm = norms[layer-1][2]
norm = summed_afx_norm[0].cpu().numpy()
df = pd.DataFrame(norm,columns=tokenized_text,index=tokenized_text)
sns.heatmap(df,cmap="Reds",square=True)
plt.gcf().subplots_adjust(bottom=0.2)