In [246]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from collections import Counter
import os
import numpy as np
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [247]:
from transformers import BertModel

In [248]:
class ProcessFiles():

    def __init__(self):
        pass
    
    
    def open_process(self, fileurl, args1_length, split_to, merge_type):
        with open(fileurl, "r") as f:
            lines = [i.strip() for i in f.readlines()]
        lines = [self._process_single_line(i, args1_length, split_to, merge_type) for i in lines]
        
        sentences = [i[0] for i in lines]
        labels = [i[1] for i in lines]

        return sentences, labels

    def _process_single_line(self, line, args1_length, split_to, merge_type):

        normal_split = line.strip('.').split(" ")
        before_comp_after = line.strip(".").split(f" {split_to} ")
        
        before_comp = before_comp_after[0].split(" ")
        after_comp = before_comp_after[1].split(" ")
        
        component_a = before_comp[:-args1_length]
        component_arg1 = before_comp[-args1_length:]
        component_compl = split_to
        component_arg2 = after_comp # as the sentence ends with the arguments

        labels_space = [0 for _ in component_a] + [1 for _ in component_arg1] + [2] + [3 for _ in component_arg2]  # 2: split to
        labels_space.append(0) # 온점

        return line, labels_space # str, list

    

In [249]:

class Attention_dataset(Dataset):
    def __init__(self, sentences, labels_space, tokenizer, merge_type, max_length = 32):
        self.sentences = sentences
        self.labels_space = labels_space
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.merge_type = merge_type
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, index):

        sentence = self.sentences[index]
        label_space = self.labels_space[index]
        
        inputs = self.tokenizer(sentence, padding = "max_length", max_length = self.max_length, truncation = True, return_tensors = "pt")
        label_tokenize, label_tokenize_expand = self._get_labels_tokenize(sentence, label_space)
        mask = self._make_mask(label_tokenize)

        inputs['attention_mask'] = torch.tensor(mask, requires_grad=False)
        inputs = {key: value.requires_grad_(False) for key, value in inputs.items()}
        label_tokenize_expand = torch.tensor(label_tokenize_expand, requires_grad=False)

        return inputs, label_tokenize_expand
    
    
    def _get_labels_tokenize(self, sentence, label_space):
        tokens = self.tokenizer.tokenize(sentence)
        
        #[a ##a ##a b]: [0,1,1,0], [0,1,2,2], [0,1,2,3], [0,0,0,1]
        reduction_index = [1 if token.startswith("##") else 0 for token in tokens]
        cum_reduction_index = np.array([sum(reduction_index[:i+1]) for i, _ in enumerate(tokens)])
        temp = np.array([i for i, _ in enumerate(tokens)])
        index = temp - cum_reduction_index

        # [typea,typea,typea, typeb]
        label_tokenize = [label_space[i] for i in index]
        label_tokenize_expand = [0] + label_tokenize + [0] + [4] * (self.max_length - (len(label_tokenize) + 2))

        return label_tokenize, label_tokenize_expand
    
    def _make_mask(self, label_tokenize):
        lst = []
        if self.merge_type == "front":
            for idx_row, label_row in enumerate([0] + label_tokenize + [0]):
                if (label_row == 1) or (label_row == 2):
                    row_mask = [0] + [1 if (label_col == 1) or (label_col == 2) else 0  for _, label_col in enumerate(label_tokenize)] + [0] + [0] * (self.max_length - (len(label_tokenize) + 2))
                else:
                    row_mask = [1] + [1                                                 for _, label_col in enumerate(label_tokenize)] + [1] + [0] * (self.max_length - (len(label_tokenize) + 2))
                
                lst.append(row_mask)
        
        if self.merge_type == "back":
            for idx_row, label_row in enumerate([0] + label_tokenize + [0]):
                if (label_row == 2) or (label_row == 3):
                    row_mask = [0] + [1 if (label_col == 2) or (label_col == 3) else 0  for _, label_col in enumerate(label_tokenize)] + [0] + [0] * (self.max_length - (len(label_tokenize) + 2))
                else:
                    row_mask = [1] + [1                                                 for _, label_col in enumerate(label_tokenize)] + [1] + [0] * (self.max_length - (len(label_tokenize) + 2))
                
                lst.append(row_mask)
        
        if self.merge_type == "none":
            for idx_row, label_row in enumerate([0] + label_tokenize + [0]):
                if (label_row == 2):
                    row_mask = [0] + [1 if (label_col == 1) or (label_col == 2) or (label_col == 3) else 0  for _, label_col in enumerate(label_tokenize)] + [0] + [0] * (self.max_length - (len(label_tokenize) + 2))
                else:
                    row_mask = [1] + [1 for _, label_col in enumerate(label_tokenize)] + [1] + [0] * (self.max_length - (len(label_tokenize) + 2))

                # row_mask = [1] + [1 for _, label_col in enumerate(label_tokenize)] + [1] + [0] * (self.max_length - (len(label_tokenize) + 2))

                lst.append(row_mask)

        for _ in range(self.max_length - len(lst)):
            lst.append([0] * self.max_length)
        

        custom_mask = lst    
        return custom_mask  



class Collator():
    def __init__(self, max_length):
        self.max_length = max_length
        pass

    def __call__(self, samples):
        # sample = samples[8] # a single input
        data = [sample[0] for sample in samples]
        labels = [sample[1] for sample in samples]
        
        data, labels = {"input_ids": torch.stack([sample["input_ids"].view(-1).contiguous() for sample in data]).to(device), 
            "attention_mask": torch.stack([sample["attention_mask"]for sample in data]).to(device)}, torch.stack(labels).to(device)

        return data, labels

In [250]:
(1 == 2) or (1 == 0)

False

In [251]:
# statistical analysis of attention score
#calculate args1, args2, and, others 4*4 matrix, with group getting attention: sum, group sending attetnion: average


            


class Postprocessor():
    def __init__(self, max_length, device = device):
        self.max_length = max_length
        self.device = device
        pass

    def process(self, whole_attention, labels_tokenize): # whole attention: layers, batch, x, y

        output_lst = []
        num_batch = labels_tokenize.size()[0]
        
        for index in range(num_batch): #per each sentence in the batch 
            
            label_tokenize = labels_tokenize[index].cpu().detach().tolist() # seq_len
            counter = Counter(label_tokenize)
            # print(duplicated_index, duplicated_average_index, num_labels, sep="\n")
            # print("\n\n")

            matrix_A = torch.zeros((self.max_length, len(counter)), requires_grad=False).to(device) # for sum 
            
            matrix_n = self.max_length
            matrix_m = len(counter)

            for column in range(matrix_m): # for each column
                for row, label in zip(range(matrix_n), label_tokenize):
                    if label == column:
                        matrix_A[row][column] = 1
            
            # return torch.einsum("sr, hrc, ca -> hsa", matrix_A.T, whole_attention[0][index, ...], matrix_A), matrix_A, whole_attention[0][index, ...]
        
            x = torch.stack([torch.einsum("sr, hrc, ca -> hsa", matrix_A.T, layer_attention[index, ...], matrix_A) for layer_attention in whole_attention])
            output_lst.append(x)

        return torch.stack(output_lst) # batch, layers, head, seq_len, seq_len
            



In [252]:

def main(*, fileurl, args1_length, split_to, bertModel, bertTokenizer, max_length, merge_type):

    bertModel.eval()
    bertModel.to(device)

    processfiles = ProcessFiles()
    collator = Collator(max_length)
    postprocessor = Postprocessor(max_length)

    sentences, labels_space = processfiles.open_process(fileurl, args1_length, split_to, merge_type)
    dataset = Attention_dataset(sentences, labels_space, bertTokenizer, merge_type, max_length)
    dataloader = DataLoader(dataset, batch_size = 128, shuffle = False, collate_fn=collator)


    for data in dataloader: # single batch to process whole sentence
        
        inputs = data[0]
        labels = data[1]

        out = bertModel(**inputs, output_attentions = True)
        out = postprocessor.process(out.attentions, labels)
        # return inputs, out
        out = torch.round(out.mean(dim = (0,1,2)), decimals=3)
        out = list(out[1:4, 1:4].reshape(-1).contiguous().cpu().detach().numpy())

    type = fileurl.split("/")[-1].split(".")[0]

    return type, out

    
        

In [256]:
if __name__ == "__main__":
    # lookup attention score of each tokens
    ## with base model, with custom model, with random model
    checkpoint = "google-bert/bert-base-uncased"
    model_type = "normal"
    checkpoint = "/home/hyohyeongjang/dependency_bert/checkpoint_output_dependency/checkpoint-dependency-high"
    model_type = "dependency"
    bertTokenizer = AutoTokenizer.from_pretrained(checkpoint)
    bertModel = AutoModel.from_pretrained(checkpoint)
    base = "/home/hyohyeongjang/dependency_bert/semantic_dataset/"
    
    fileurl = [base+i for i in sorted(os.listdir(base))]
    category = [("every", "front", "and"),
                ("every", "front", "or"),
                ("less than 5", "front", "and"),
                ("less than 5", "front", "or"),
                ("more than 5", "front", "and"),
                ("more than 5", "front", "or"),
                ("every", "back", "and"),
                ("less than 5", "back", "and"),
                ("more than 5", "back", "and"),
                ("no", "back", "and"),
                ("every", "back", "or"),
                ("less than 5", "back", "or"),
                ("more than 5", "back", "or"),
                ("no", "back", "or"),
                ("no", "front", "and"),
                ("no", "front", "or"),                
                ]
    
    args1_length = [2,2,4,4,4,4,1,1,1,1,1,1,1,1,2,2]
    split_to = [i.split("_")[-2] for i in fileurl]
    model_types = ["normal", "dependency"]
    model_checkpoints = ["google-bert/bert-base-uncased", "/home/hyohyeongjang/dependency_bert/checkpoint_output_dependency/checkpoint-dependency-high"]
    
    for model_type, model_checkpoint in zip(model_types, model_checkpoints):
        for merge_type, m in zip(["front", "back", "none"], ["f", "b", "n"]):    

            outs = []
            files = []
            for f,l,s, c in zip(fileurl,args1_length,split_to, category):
                file, out = main(fileurl=f, 
                    args1_length=l,
                    split_to=s,
                    bertModel=bertModel,
                    bertTokenizer=bertTokenizer,
                    max_length=32,
                    merge_type=merge_type)
                outs.append(out)
                files.append(file)

            category_this = [[model_type] + list(i) for i in category]
            category_this = pd.DataFrame(category_this)
            category_this.columns = ["type", "quant", "frontness", "comple"]            
            
            df = pd.DataFrame(outs)
            df.columns = [f"{m}{i}" for i in range(1,10)]
            df = pd.concat([category_this, df], axis = 1)

            df.to_csv(f"/home/hyohyeongjang/dependency_bert/semantics_result/df_{model_type}_{merge_type}_FullofWhole.csv")
            print(f"done_{model_type}_{merge_type}")
        
        

Some weights of BertModel were not initialized from the model checkpoint at /home/hyohyeongjang/dependency_bert/checkpoint_output_dependency/checkpoint-dependency-high and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


done_normal_front
done_normal_back
done_normal_none
done_dependency_front
done_dependency_back
done_dependency_none


In [22]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8
every_and_n,0.114,0.1,0.099,0.1,0.095,0.125,0.106,0.099,0.111
every_or_n,0.113,0.108,0.097,0.101,0.102,0.124,0.105,0.107,0.11
lessthan5_and_n,0.095,0.073,0.08,0.086,0.078,0.107,0.092,0.076,0.09
lessthan5_or_n,0.093,0.079,0.079,0.086,0.084,0.108,0.091,0.082,0.089
morethan5_and_n,0.092,0.073,0.076,0.083,0.078,0.101,0.09,0.077,0.087
morethan5_or_n,0.091,0.079,0.075,0.082,0.085,0.103,0.089,0.083,0.086
n_and_every,0.105,0.118,0.101,0.1,0.098,0.107,0.095,0.096,0.114
n_and_lessthan5,0.088,0.092,0.087,0.085,0.077,0.087,0.081,0.073,0.093
n_and_morethan5,0.085,0.09,0.082,0.083,0.078,0.084,0.079,0.073,0.089
n_and_no,0.111,0.116,0.096,0.104,0.096,0.105,0.1,0.095,0.109


In [11]:
4

4