In [None]:
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F
from easy_transformer import EasyTransformer
from torch.utils.data import Dataset, DataLoader
import random
import spacy
nlp = spacy.load('en_core_web_sm')

In [None]:
class CustomGPT2ForSequenceClassification(EasyTransformer):
    def __init__(self, config):
        super().__init__(config)
        self.unembed = None
        self.classification_head1 = torch.nn.Linear(config.d_model * config.n_ctx, num_labels)        
        
    def forward(self, input_ids):
       
        embed = self.embed(tokens=input_ids)
        embed = embed.squeeze(1)
        #print('embed',embed.shape)
        pos_embed = self.pos_embed(input_ids)
        #print('pos_embed',pos_embed.shape)
        residual = embed + pos_embed
        #print('residual', residual.shape)
        for block in self.blocks:
            normalized_resid_pre = block.ln1(residual)
            #print('normalized_resid_pre', normalized_resid_pre.shape)
            attn_out = block.attn(normalized_resid_pre)
            #print('attn_out', attn_out.shape)
            resid_mid = residual + attn_out
            #print('resid_mid', resid_mid.shape)

            normalized_resid_mid = block.ln2(resid_mid)
            #print('normalized_resid_mid', normalized_resid_mid.shape)
            mlp_out = block.mlp(normalized_resid_mid)
            #print('mlp_out', mlp_out.shape)
            resid_post = resid_mid + mlp_out
            #print('resid_post', resid_post.shape)
        normalized_resid_final = self.ln_final(resid_post)
        #print('normalized_resid_final', normalized_resid_final.shape)
        normalized_resid_final = normalized_resid_final.view(normalized_resid_final.shape[0], -1)
        #print('normalized_resid_final', normalized_resid_final.shape)
        logits = self.classification_head1(normalized_resid_final)
        return logits
        


In [None]:
dataset_name = 'stsb'

In [None]:
validation_dataset = load_dataset('glue', dataset_name, split='validation')


In [None]:
num_samples = 100
subset_indices = torch.randperm(len(validation_dataset)).tolist()[:num_samples]
validation_dataset = validation_dataset.select(subset_indices)
len(validation_dataset)


In [None]:
validation_dataset[43], len(validation_dataset)

In [None]:
c = []
for each in validation_dataset:
    c.append(validation_dataset['label'])
np.unique(c)

In [None]:
def tokenize(datapoint, max_length = 1024, token_to_add = 50256):
    sep_place = [0]
    sentence1 = datapoint['sentence1']
    sentence1_tokens = reference_gpt2.to_tokens(sentence1, prepend_bos = False)
    
    sep_2 = sentence1_tokens.size(1)
    sep_place.append(sep_2+1)
    sentence2 = datapoint['sentence2']
    sentence2_tokens = reference_gpt2.to_tokens(sentence2, prepend_bos = False)
    
    token_to_add = torch.tensor([50264], dtype=torch.long)
    token_to_add = token_to_add.unsqueeze(0) 
    sentence1_tokens = torch.cat((sentence1_tokens, token_to_add), dim=1)
    concatenated_tokens = torch.cat((sentence1_tokens, sentence2_tokens), dim=1)
    
    recovered_tokens = reference_gpt2.to_str_tokens(concatenated_tokens)
    
    noun = []
    pnoun = []
    verb = []
    subj = []
    obj = []
    neg = []
    
    for i, each in enumerate(recovered_tokens[:concatenated_tokens.size(1)]):
        doc = nlp(each)
        for token in doc:
            
            if token.pos_ in ["NOUN", "PROPN"]: #noun or proper noun
                noun.append(i)
            #if token.dep_ == "nsubj": #subject
            #    subj.append(i)
            #if token.dep_ == "neg": # negation
            #    neg.append(i)
            #if token.pos_ == "VERB": #verb
            #    verb.append(i)
            #if token.dep_ in ["dobj", "iobj"]: #object
            #    obj.append(i)
            if token.pos_ in ["PRON"]: # pronoun
                pnoun.append(i)
    labels = torch.tensor(datapoint['label'])
    real_length = concatenated_tokens.size(1)
    remaining_length = max_length - concatenated_tokens.size(1)
    while remaining_length > 0:
        concatenated_tokens = torch.cat((concatenated_tokens, torch.tensor([[token_to_add]])), dim=1)
        remaining_length -= 1
    return concatenated_tokens, labels, real_length, sep_place, noun, pnoun, verb, subj, obj, neg


In [None]:
def register_hooks(module):
    def hook(module, input, output):
        print("Output shape:", output.shape)  
    # Register the hook to the module
    module.register_forward_hook(hook)


In [None]:
attention_scores_list = []
def register_attention_hooks(module):
    if isinstance(module, EasyTransformer):
        for i, block in enumerate(module.blocks):
            attention_module = block.attn.hook_attn
            def hook(module, input, output):
                attention_scores = output[0]
                attention_scores_list.append(attention_scores)
            attention_module.register_forward_hook(hook)


In [None]:
def get_max(attn_scores, tosee_list, name, directory, length, layer, head):
    temp = []
    if len(tosee_list) == 1:
        num = tosee_list[0] 
        attn_map = attn_scores[num, :].cpu().detach().numpy()
        attn_map[num] = 0
        np.save(directory + name + '_' + str(layer) + '_' + str(head) + '.npy', attn_map)
    
    else:
        for num in tosee_list:
            attn_map = attn_scores[num, :].cpu().detach().numpy()
            attn_map[num] = 0
            temp.append(np.sum(attn_map) / length)
        max_value = max(temp)
        max_value_index = temp.index(max_value)
        attn_map = attn_scores[max_value_index, :].cpu().detach().numpy()
        attn_map[max_value_index] = 0
        np.save(directory + name + '_' + str(layer) + '_' + str(head) + '.npy', attn_map)
        

In [None]:
num_labels = 2
model_path = '../trained_models/easy_transformer_gpt2small_' + dataset_name + '.pth' 
  
for i, point in enumerate(validation_dataset): 
    temp = []
    print('data point ',i)
    reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
    config = reference_gpt2.cfg
    model = CustomGPT2ForSequenceClassification(config)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict, strict=False)
    device = torch.device("mps")
    
    model.to(device)
    model.eval()
    reference_gpt2.to(device)
    reference_gpt2.eval()
    
    tokens, label, length, seperators, noun, pnoun, verb, subj, obj, neg = tokenize(point)
    
    attention_scores_list = []
    register_attention_hooks(model)
    outputs = model(tokens)
    
    for attention in range(0, 12):
        for head in range(0, 12):
            directory = '../gpt2_small/verb_subject/' + dataset_name + '/' + str(i) + '/'
            if len(noun) > 0 and len(pnoun) > 0: #or len(noun) > 0 or len(pnoun) > 0 or len(obj) > 0 or len(noun) > 0:
                if os.path.exists(directory):
                    pass
                else:
                    os.mkdir(directory)
                
                get_max(attention_scores_list[attention][head,:,:], noun, 'noun', directory, length, attention, head)
                get_max(attention_scores_list[attention][head,:,:], pnoun, 'pnoun', directory, length, attention, head)
                
                #get_max(attention_scores_list[attention][head,:,:], verb, 'verb', directory, length, attention, head)
                #get_max(attention_scores_list[attention][head,:,:], subj, 'subj',  directory, length, attention, head)
                #get_max(attention_scores_list[attention][head,:,:], obj, 'obj', directory, length, attention, head)
                #get_max(attention_scores_list[attention][head,:,:], neg, 'neg', directory, length, attention, head)


In [None]:
dataset_name = 'stsb'

In [None]:
main_path = '../gpt2_small/verb_subject/' + dataset_name + '/'

In [None]:
all_points = os.listdir(main_path)
if '.DS_Store' in all_points:
    all_points.remove('.DS_Store')


In [None]:
relations = ['noun', 'pnoun', 'verb', 'subj', 'obj', 'neg']

In [None]:
relation_dict = {}
for relation in relations:
    print(relation)
    new_array = empty_array = np.empty((12, 12))
    for layer in range(0, 12):
        temp = []
        start = 0
        for head in range(0, 12):
            for folder in all_points:
                files = os.listdir(main_path + folder + '/')
                if '.DS_Store' in files:
                    files.remove('.DS_Store')
                check_file = relation + '_' + str(layer) + '_' + str(head) + '.npy'
                if check_file in files:
                    temp.append(np.load(main_path + folder + '/' + check_file).sum())  
            mean = np.mean(np.array(temp))
            new_array[layer, head] = mean
    relation_dict[relation] = new_array
        
         

In [None]:
relation_dict['noun'].shape

In [None]:
for key in relation_dict.keys():
    uni = np.unique(relation_dict[key])
    if np.isnan(relation_dict[key]).any() != True:
        plt.imshow(relation_dict[key], cmap='YlGn') 
        plt.colorbar()
        plt.savefig('../gpt2_small/verb_subject/' + dataset_name + '_' + key + '_finetuned' + '.png')

        plt.show()
        print(key)