## BERT and TOKENIZER 

In [27]:
from transformers import AutoModel,AutoTokenizer

auto_model = AutoModel.from_pretrained("bert-base-cased",output_hidden_states=True)



print(f"\nmodel class is      : {type(auto_model)}")


tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
print(f"\nmodel class is      : {type(tokenizer)}")


torch.save(auto_model,"hw2/stud/saved/bert.pth")



Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



model class is      : <class 'transformers.models.bert.modeling_bert.BertModel'>

model class is      : <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>


## Semantic Role Labelling Dataset

In [33]:
import os
import json
import logging
import torch
from torch.utils.data import DataLoader,Dataset
import random
from typing import Dict

class SRL(Dataset):
 
    def __init__(self,language,tokenizer,path,args_roles = None,pos_list = None,predicate_dis = None) -> None:
        #train
        #self.path_root = 'data'
        #inference 
        self.path_root = 'hw2/stud/data'
        #self.path_root = 'stud/data'
        self.load_data(language,path)
        if args_roles is None :
            self.args_roles,self.list_broken_id = self.list_arg_roles()
            self.args_roles.append("UNK")
        else : 
            self.args_roles = args_roles
            _,self.list_broken_id = self.list_arg_roles()
        

        if pos_list is None :
            self.pos_list,_ = self.list_pos()
            self.pos_list.append("Nothing")
            self.pos_list.append("UNK")
        else : 
            self.pos_list = pos_list
        


        if predicate_dis is None :
            self.predicate_dis,_ = self.list_predicate_roles()
            self.predicate_dis.append("Nothing")
            self.predicate_dis.append("UNK")
        else : 
            self.predicate_dis = predicate_dis
        
        
        


        self.tokenizer = tokenizer

    def load_data(self,language,mode):
        
        mode = mode+".json"
        path = os.path.join(self.path_root,language,mode)
        data_file = open(path)
       
        data_ = json.load(data_file)

        list_data = []

        for data in data_:
            list_data.append(data_[data])
        

        self.data = list_data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, id : int):

        flag = False
        if id in self.list_broken_id :
            flag = True
            while flag == True:

                rand_id = random.randint(0, len(self.data)-1)
                
                if rand_id in self.list_broken_id :
                    pass
                else :
                    flag = False
                    id = rand_id        


        data = self.pre_processing(self.data[id])
        data = self.processig(data)
        return data
        
    def pre_processing(self, data:dict):
        data_list = []
        for role in data["roles"]:
            dictionary = dict()
            dictionary["words"] = data["words"]
            dictionary["role"] = data["roles"][role]
            dictionary["pre_idx"] = role
            dictionary["pos_tags"] = data["pos_tags"]
            dictionary["predicate_meaning"] = data["predicates"]
            data_list.append(dictionary)    
        return data_list
    
    def processig(self,data_list:list):
        
        for dictionary in data_list:

            #dictionary["words"] = data["words"]
            dictionary["gt_arg_identification"] = self.arg_id(dictionary["role"])
            dictionary["gt_arg_classification"] = self.arg_class(dictionary["role"])
            dictionary["pos_idx"] = self.pos_idx(dictionary["pos_tags"])
            dictionary["predicate_meaning_idx"] = self.predicate_meaning_idx(dictionary["predicate_meaning"])
        
        return data_list
   
    def list_arg_roles(self):
        list_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : roles = element["roles"]
            except : flag = False
            if flag :
                for e in roles:
                    sentence = element["roles"][e]

                    for word in sentence:
                        
                        list_roles.append(word)
                list_roles = list(set(list_roles))
            else : 
                list_broken_id.append(i)
        return list_roles,list_broken_id

    def list_predicate_roles(self):
        list_predicate_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : predicates = element["predicates"]
            except : flag = False
            if flag :
                for pre in predicates:
                    list_predicate_roles.append(pre)
                list_predicate_roles = list(set(list_predicate_roles))
            else : 
                list_broken_id.append(i)
        return list_predicate_roles,list_broken_id

    def list_pos(self):
        list_pos = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : pos = element["pos_tags"]
            except : flag = False
            if flag :
                for e in pos:
                    list_pos.append(e)
                list_pos = list(set(list_pos))
            else : 
                list_broken_id.append(i)
        return list_pos,list_broken_id
  
    def arg_class(self,role:list):
        list_idxs = []
        for element in role:
            try : list_idxs.append(self.args_roles.index(element))
            except : list_idxs.append(self.args_roles.index("UNK"))
        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def arg_id(self,role:dict):
        list_idxs = []
        for element in role:
            if element == "_":
                list_idxs.append(0)
            else :
                list_idxs.append(1)

        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def pos_idx(self,pos_tags:dict):
        list_idxs = []
        list_idxs.append(self.pos_list.index("Nothing"))

        for element in pos_tags:
            try :list_idxs.append(self.pos_list.index(element))
            except :list_idxs.append(self.pos_list.index("UNK"))
        
        list_idxs.append(self.pos_list.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64)
    
    def predicate_meaning_idx(self,predicate_meaning_tags:dict):
        list_idxs = []
        list_idxs.append(self.predicate_dis.index("Nothing"))

        for element in predicate_meaning_tags:
            try : list_idxs.append(self.predicate_dis.index(element))
            except : list_idxs.append(self.predicate_dis.index("UNK"))
            
        
        list_idxs.append(self.predicate_dis.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64)
   
    def role_gen(self,sentence):

        base = ["_"]*len(sentence["predicates"])
        roles_dict = dict()
        counter = 0
        for i,item in enumerate(sentence["predicates"]):

            if item != "_":
                base = ["_"]*len(sentence["predicates"])
                sentence["roles"] = 10
                roles_dict[str(i)] = base
                counter += 1
        
        if counter == 0:
            sentence["roles"] = { }
            flag = False
            
                
        else :
            sentence["roles"] = roles_dict
            flag = True

        return sentence,flag
        
    def prepare_batch(self,sentence):

        sentence,flag = self.role_gen(sentence)
        
        if flag :

            data = self.pre_processing(sentence)
            data = self.processig(data)
            data = [data]
            
            
            input = dict() 
            gt = dict()
            batch_sentence = [] 
            
            for period in data:
                for sentence in period :

                    
                
                    #print(len(sentence[0]["words"]))
                    pre_idx = int(sentence["pre_idx"])
                    

                    predicate = sentence["words"][pre_idx]

                    text = " ".join(sentence["words"])
                    tokens: list[str] = text.split()
                    predicate: list[str] = predicate.split()

                    #text = sentence[0]["words"]
                    
                    t = (tokens,predicate)

                    batch_sentence.append(t)
                
                
            
        
        

            batch_output = self.tokenizer.batch_encode_plus(batch_sentence,padding=True,is_split_into_words=True, truncation=True,return_offsets_mapping=True, return_tensors="pt")
            


            for period in data:

                list_positional_predicate_encoding = []
                list_predicate_index = [] 
                list_pos_index = [] 
                list_arg_gt = []
                list_predicate_meaning_index = []
                list_meaning_predicate_encoding = []

                for sentence in period:
                    #positional_encoding
                    #+2 per il CLS iniziale ad SEP finale
                    sentence_words_lenght =  len(sentence["words"])
                    positional_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
                    #+1 per il CLS iniziale
                    pre_idx = int(sentence["pre_idx"])
                    positional_predicate_encoding[:,pre_idx+1] = 1
                    list_positional_predicate_encoding.append(positional_predicate_encoding)
                    #print("positional_prefix_encoding",positional_predicate_encoding)
                    list_predicate_index.append(pre_idx)

                    meaning_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
                    pre_idx = int(sentence["pre_idx"])
                    #rather then set the flag 0,1 set with class verb
                    meaning_predicate_encoding[:,pre_idx+1] = sentence["predicate_meaning_idx"][pre_idx+1]
                    list_meaning_predicate_encoding.append(meaning_predicate_encoding)
                    

                    pos = torch.unsqueeze(sentence["pos_idx"],dim = 0)
                    list_pos_index.append(pos)
                    predicate_meaning_idxs = torch.unsqueeze(sentence["predicate_meaning_idx"],dim = 0)
                    list_predicate_meaning_index.append(predicate_meaning_idxs)


                    arg_gt = torch.unsqueeze(sentence["gt_arg_classification"],dim = 0)
                    list_arg_gt.append(arg_gt)


            list_arg_gt = torch.cat(list_arg_gt,dim = 0)
            list_pos_index = torch.cat(list_pos_index,dim = 0)
            list_predicate_meaning_index = torch.cat(list_predicate_meaning_index,dim = 0)
            list_positional_predicate_encoding = torch.cat(list_positional_predicate_encoding,dim = 0)
            list_predicate_meaning_index_bis = torch.cat(list_meaning_predicate_encoding,dim = 0)
            gt["arg_gt"] = list_arg_gt
            input["predicate_index"] = list_predicate_index
            input["pos_index"] = list_pos_index.long()
            input["predicate_meaning_idx"] = list_predicate_meaning_index.long()
            input["predicate_meaning_idx_bis"] = list_predicate_meaning_index_bis.long()
            offset = batch_output.pop("offset_mapping")
            input["BERT_input"] = batch_output
            input["positional_encoding"] = list_positional_predicate_encoding.long()
            input["offset_mapping"] = offset
            input["gt"] = gt
        
        else :
            input = sentence






        return input,flag
    
# here we define our collate function
def collate_fn(batch) -> Dict[str, torch.Tensor]:
    #print(batch)
    input = dict() 
    batch_sentence = [] 
    #print(len(batch))
    for period in batch:
        for sentence in period :
        
            #print(len(sentence[0]["words"]))
            pre_idx = int(sentence["pre_idx"])
            

            predicate = sentence["words"][pre_idx]

            text = " ".join(sentence["words"])
            tokens: list[str] = text.split()
            predicate: list[str] = predicate.split()

            #text = sentence[0]["words"]
            
            t = (tokens,predicate)

            batch_sentence.append(t)
            #print(batch_sentence)

    batch_output = tokenizer.batch_encode_plus(batch_sentence,padding=True,is_split_into_words=True, truncation=True,return_offsets_mapping=True, return_tensors="pt")
    #print(batch_output.keys())


    gt = dict()
    
    


    for period in batch:

        list_positional_predicate_encoding = []
        list_arg_gt = []
        list_predicate_index = [] 
        list_pos_index = [] 
        list_predicate_meaning_index = []
        list_meaning_predicate_encoding = []

        for sentence in period:
            #positional_encoding
            #+2 per il CLS iniziale ad SEP finale
            sentence_words_lenght =  len(sentence["words"])
            positional_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
            #+1 per il CLS iniziale
            pre_idx = int(sentence["pre_idx"])
            positional_predicate_encoding[:,pre_idx+1] = 1
            list_positional_predicate_encoding.append(positional_predicate_encoding)
            #print("positional_prefix_encoding",positional_predicate_encoding)
            list_predicate_index.append(pre_idx)


            meaning_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
            pre_idx = int(sentence["pre_idx"])
            #rather then set the flag 0,1 set with class verb
            meaning_predicate_encoding[:,pre_idx+1] = sentence["predicate_meaning_idx"][pre_idx+1]
            list_meaning_predicate_encoding.append(meaning_predicate_encoding)




            pos = torch.unsqueeze(sentence["pos_idx"],dim = 0)
            list_pos_index.append(pos)
            predicate_meaning_idxs = torch.unsqueeze(sentence["predicate_meaning_idx"],dim = 0)
            list_predicate_meaning_index.append(predicate_meaning_idxs)






            #note CLS and SEP are discharder after Bi-LSTM, the Classifier takes in input only wokrds hidden state embedding
            arg_gt = torch.unsqueeze(sentence["gt_arg_classification"],dim = 0)
            list_arg_gt.append(arg_gt)
        

    list_arg_gt = torch.cat(list_arg_gt,dim = 0)
    list_pos_index = torch.cat(list_pos_index,dim = 0)
    list_predicate_meaning_index = torch.cat(list_predicate_meaning_index,dim = 0)
    list_predicate_meaning_index_bis = torch.cat(list_meaning_predicate_encoding,dim = 0)
    list_positional_predicate_encoding = torch.cat(list_positional_predicate_encoding,dim = 0)
    gt["arg_gt"] = list_arg_gt
    input["predicate_index"] = list_predicate_index
    input["pos_index"] = list_pos_index.long()
    input["predicate_meaning_idx"] = list_predicate_meaning_index.long()
    input["predicate_meaning_idx_bis"] = list_predicate_meaning_index_bis.long()
    offset = batch_output.pop("offset_mapping")
    input["BERT_input"] = batch_output
    input["positional_encoding"] = list_positional_predicate_encoding.long()
    input["offset_mapping"] = offset
    input["gt"] = gt

   


    return input



In [35]:
train_dataset = SRL("EN",tokenizer,"train")
#same mapping should be used in both the dataset
dev_dataset = SRL("EN",tokenizer,"dev",train_dataset.args_roles,train_dataset.pos_list,train_dataset.predicate_dis)



In [52]:
#print(train_dataset.args_roles,train_dataset.pos_list,train_dataset.predicate_dis)



import json

data = {
    'args_roles' : train_dataset.args_roles,
    "pos_list" : train_dataset.pos_list,
    "predicate_dis" : train_dataset.predicate_dis,
}





with open('hw2/stud/saved/vocabulary.json', 'w') as outfile:
    json.dump(data, outfile)


with open('hw2/stud/saved/vocabulary.json') as json_file:
    data = json.load(json_file)

print(data)
pos_list = data['pos_list']
args_roles = data['args_roles']
predicate_dis = data['predicate_dis']
print(pos_list)
   

{'args_roles': ['result', 'goal', 'co-agent', 'purpose', 'stimulus', 'co-patient', 'attribute', 'time', 'destination', 'agent', 'asset', '_', 'location', 'cause', 'product', 'experiencer', 'topic', 'source', 'extent', 'theme', 'beneficiary', 'patient', 'value', 'co-theme', 'recipient', 'instrument', 'material', 'UNK'], 'pos_list': ['ADV', 'VERB', 'NOUN', 'PRON', 'INTJ', 'ADJ', 'ADP', 'PUNCT', 'DET', 'PART', 'NUM', 'SCONJ', 'CCONJ', 'AUX', 'X', 'PROPN', 'SYM', 'Nothing', 'UNK'], 'predicate_dis': ['VISIT', 'DISCARD', 'WAIT', 'JUSTIFY_EXCUSE', 'HURT_HARM_ACHE', 'RECORD', 'RELY', 'OPPOSE_REBEL_DISSENT', 'LIBERATE_ALLOW_AFFORD', 'INCREASE_ENLARGE_MULTIPLY', 'OBTAIN', 'HIRE', 'FOLLOW-IN-SPACE', 'APPEAR', 'SPEED-UP', 'BUY', 'TAKE-A-SERVICE_RENT', 'HIT', 'VERIFY', 'ALLY_ASSOCIATE_MARRY', 'SORT_CLASSIFY_ARRANGE', 'REPEAT', 'IMAGINE', 'NAME', 'RESIST', 'ATTEND', '_', 'RESULT_CONSEQUENCE', 'REPAIR_REMEDY', 'DESTROY', 'LEARN', 'EXTEND', 'DISCUSS', 'LEAVE_DEPART_RUN-AWAY', 'TRY', 'CONSIDER', 'ADJUS

## Model

In [36]:


embeddings = dict()

embeddings["predicate_flag_embedding_output_dim"] = 32
#defined in initial exploration of the dataset
embeddings["pos_embedding_input_dim"] = 0
embeddings["pos_embedding_output_dim"] = 100
#-------------------------------------------------
embeddings["predicate_embedding_input_dim"] = 0
embeddings["predicate_embedding_output_dim"] = False
#defined in initial exploration of the dataset
n_classes = 0



bilstm = dict()
bilstm["n_layers"] = 2
bilstm["output_dim"] = 50
dropouts = [0.4,0.3,0.3]

language_portable = True
predicate_meaning = True
pos = True

cfg = dict()
cfg["embeddings"] = embeddings
cfg["n_classes"] = n_classes
cfg["bilstm"] = bilstm
cfg["language_portable"] = language_portable
cfg["dropouts"] = dropouts

In [37]:
#from mmcv import Config
from hw2.stud.arg import Arg_Classifier, Arg_Classifier_from_paper



#cfg = Config.fromfile('/home/francesco/Desktop/nlp2022-hw2-main/hw2/stud/configs/model.py')

cfg["embeddings"]["pos_embedding_input_dim"] = len(train_dataset.pos_list)
cfg["embeddings"]["predicate_embedding_input_dim"] = len(train_dataset.predicate_dis)
cfg["n_classes"] = len(train_dataset.args_roles)


model = Arg_Classifier("EN",cfg).cuda()
#model = Arg_Classifier_from_paper("EN",cfg).cuda()
print(model)

automodel = auto_model.cuda()

Arg_Classifier(
  (bi_lstm_portable): LSTM(132, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (embedding_predicate_flag): Embedding(2, 32, max_norm=True)
  (embedding_predicate): Embedding(305, False, max_norm=True)
  (embedding_pos): Embedding(19, 100, max_norm=True)
  (bi_lstm): LSTM(900, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (dropout_language_constraint): Dropout(p=0.6, inplace=False)
  (dropout_in_classifier): Dropout(p=0.4, inplace=False)
  (Relu): ReLU()
  (Sigmoid): Sigmoid()
  (linear0): Linear(in_features=300, out_features=700, bias=True)
  (linear1): Linear(in_features=700, out_features=140, bias=True)
  (linear2): Linear(in_features=140, out_features=28, bias=True)
)


In [38]:

def metrics(gold,pred):
    true_positives, false_positives, false_negatives = 0, 0, 0
    null_tag = "_"

    for i,r_g in  enumerate(gold):
        r_p = pred[i]

        if r_g != null_tag and r_p != null_tag:
            true_positives += 1
        elif r_g != null_tag and r_p == null_tag:
            false_negatives += 1
        elif r_g == null_tag and r_p != null_tag:
            false_positives += 1

    a = true_positives + false_positives
    b = true_positives + false_negatives
    if a == 0 and b == 0 :        
        argument_identification = {
            "true_positives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        } 

    else : 
        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)
        f1 = 2 * (precision * recall) / (precision + recall)
        argument_identification = {
            "true_positives": true_positives,
            "false_positives": false_positives,
            "false_negatives": false_negatives,
            "precision": precision,
            "recall": recall,
            "f1": f1,
        }




    true_positives, false_positives, false_negatives = 0, 0, 0
    for i,r_g in  enumerate(gold):
        r_p = pred[i]

        if r_g != null_tag and r_p != null_tag:
            if r_g == r_p:
                true_positives += 1
            else:
                false_positives += 1
                false_negatives += 1
        elif r_g != null_tag and r_p == null_tag:
                false_negatives += 1
        elif r_g == null_tag and r_p != null_tag:
                false_positives += 1


    a = true_positives + false_positives
    b = true_positives + false_negatives
    if a == 0 and b == 0 :
        argument_classification = {
            "true_positives": 0,
            "false_positives": 0,
            "false_negatives": 0,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        } 

    else : 
        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)
        f1 = 2 * (precision * recall) / (precision + recall)
        argument_classification = {
            "true_positives": true_positives,
            "false_positives": false_positives,
            "false_negatives": false_negatives,
            "precision": precision,
            "recall": recall,
            "f1": f1,
        }


    return argument_identification,argument_classification




"""
from [1,1,8,8,8,8,8,8,8,8,8,8,8,8,2,2,2,8,8,8......]
to from [agent,agent,_,_,_........,]
"""
def mapping_args(g,p,mapping):
    
    
    gt = [mapping[elem] for elem in g]
    predictions = [mapping[elem] for elem in p]


    return gt,predictions



## Training Argument Identification and Classification

In [40]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)



#optimizer = torch.optim.Adam(model.parameters(),lr = 0.000005)
optimizer = torch.optim.Adam(model.parameters())


scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()

EPOCHS = 200
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "hw2/stud/saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1)
        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    scheduler.step()

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    print("Epochs n.", epoch)
    print("F1 train:",f1)
    print("F1 avg train:",f1_avg)
    
    avg_train_loss = total_loss/counter
    writer.add_scalar("Loss/train", avg_train_loss, epoch)


    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Train_EN/identification", identification_result["f1"], epoch)
    writer.add_scalar("Train_EN/classification", classification_result["f1"], epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------
    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
    

    writer.add_scalar("Loss/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Eval_EN/identification", identification_result["f1"], epoch)
    writer.add_scalar("Eval_EN/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        print("SAVED :",PATH)
        torch.save(model.state_dict(),PATH)
    

Epochs n. 0
F1 train: [0.08643815 0.21511018 0.         0.         0.         0.
 0.         0.         0.         0.74044153 0.         0.98816612
 0.         0.         0.         0.         0.46160962 0.
 0.         0.52845967 0.         0.32482172 0.         0.
 0.28207307 0.         0.        ]
F1 avg train: 0.9601526783280693
identification {'true_positives': 16033, 'false_positives': 1728, 'false_negatives': 8855, 'precision': 0.9027081808456731, 'recall': 0.6442060430729669, 'f1': 0.7518581912823279}
classification_result {'true_positives': 10839, 'false_positives': 6922, 'false_negatives': 14049, 'precision': 0.6102696920218457, 'recall': 0.4355110896817743, 'f1': 0.5082885882435696}
EPOCHS : 0
F1 eval : [0.         0.2697201  0.         0.         0.         0.
 0.         0.         0.         0.75091697 0.         0.98994539
 0.         0.         0.         0.         0.50916497 0.
 0.         0.54124189 0.         0.27832512 0.         0.
 0.35744681 0.         0.        

## Language Transfert Learning 

### Fine-Tuning over Structural Information over English dataset


#### Loading English Pretrained Model 

In [16]:
#note that parameter EN is only used for tracking on which dataset was/is trained
model = Arg_Classifier("EN",cfg).cuda()
model.load_state_dict(torch.load(PATH))
model.eval()

Arg_Classifier(
  (bi_lstm_portable): LSTM(132, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (embedding_predicate_flag): Embedding(2, 32, max_norm=True)
  (embedding_predicate): Embedding(304, False, max_norm=True)
  (embedding_pos): Embedding(18, 100, max_norm=True)
  (bi_lstm): LSTM(900, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (dropout_language_constraint): Dropout(p=0.6, inplace=False)
  (dropout_in_classifier): Dropout(p=0.4, inplace=False)
  (Relu): ReLU()
  (Sigmoid): Sigmoid()
  (linear0): Linear(in_features=300, out_features=675, bias=True)
  (linear1): Linear(in_features=675, out_features=135, bias=True)
  (linear2): Linear(in_features=135, out_features=27, bias=True)
)

#### Language constained training

In [17]:
model.set_language_constrains()

#### Fine-Tuning

In [18]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)
_id =  _id +"Language constained training"



optimizer = torch.optim.Adam(model.parameters(),lr = 0.00000005)
scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()

EPOCHS = 1
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1).cpu()
        b,n = sample_batched["gt"]["arg_gt"].size()
        print(predicted.size())
        print(b,n)

        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    print("Epochs n.", epoch)
    print("F1 train:",f1_score(g, p, average=None))
    scheduler.step()
    avg_train_loss = total_loss/counter
    writer.add_scalar("Loss/train", avg_train_loss, epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------

    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 eval :",f1_avg)
    

    writer.add_scalar("Loss/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Eval_EN/identification", identification_result["f1"], epoch)
    writer.add_scalar("Eval_EN/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        torch.save(model.state_dict(),PATH)
    



    





    


torch.Size([28])
2 14
torch.Size([320])
5 64
torch.Size([150])
5 30
torch.Size([329])
7 47
torch.Size([45])
1 45
torch.Size([34])
2 17
torch.Size([192])
3 64
torch.Size([69])
3 23
torch.Size([30])
2 15
torch.Size([68])
4 17
torch.Size([58])
1 58
torch.Size([17])
1 17
torch.Size([48])
2 24
torch.Size([116])
4 29
torch.Size([44])
2 22
torch.Size([105])
3 35
torch.Size([20])
1 20
torch.Size([99])
3 33
torch.Size([32])
2 16
torch.Size([45])
3 15
torch.Size([36])
2 18
torch.Size([84])
2 42
torch.Size([50])
2 25
torch.Size([46])
2 23
torch.Size([432])
6 72
torch.Size([250])
5 50
torch.Size([176])
4 44
torch.Size([78])
3 26
torch.Size([192])
4 48
torch.Size([72])
2 36
torch.Size([57])
3 19
torch.Size([190])
5 38
torch.Size([102])
3 34
torch.Size([26])
2 13
torch.Size([120])
3 40
torch.Size([108])
2 54
torch.Size([76])
2 38
torch.Size([180])
4 45
torch.Size([54])
2 27
torch.Size([66])
2 33
torch.Size([40])
2 20
torch.Size([102])
3 34
torch.Size([106])
2 53
torch.Size([26])
1 26
torch.Size([18]

KeyboardInterrupt: 

### Spanish

#### New Spanish Dataset
bert-base-multilingual-cased


In [14]:
from transformers import BertTokenizer, BertModel
import os
import json
import logging
import torch
from torch.utils.data import DataLoader,Dataset
import random
from typing import Dict



auto_model = AutoModel.from_pretrained("bert-base-multilingual-cased",output_hidden_states=True)
print(f"\nmodel class is      : {type(auto_model)}")

tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
print(f"\nmodel class is      : {type(tokenizer)}")



class SRL(Dataset):
 
    def __init__(self,language,path,args_roles = None,pos_list = None) -> None:

        self.path_root = 'data'
        self.load_data(language,path)
        if args_roles is None :
            self.args_roles,self.list_broken_id = self.list_arg_roles()
        else : 
            self.args_roles = args_roles
            _,self.list_broken_id = self.list_arg_roles()
        

        if pos_list is None :
            self.pos_list,_ = self.list_pos()
            self.pos_list.append("Nothing")
        else : 
            self.pos_list = pos_list
        


        self.predicate_dis,_ = self.list_predicate_roles()
        
        
        
        print(self.pos_list)

        self.predicate_dis.append("Nothing")

    def load_data(self,language,mode):
        
        mode = mode+".json"
        path = os.path.join(self.path_root,language,mode)
        data_file = open(path)
       
        data_ = json.load(data_file)

        list_data = []

        for data in data_:
            list_data.append(data_[data])
        

        self.data = list_data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, id : int):

        flag = False
        if id in self.list_broken_id :
            flag = True
            while flag == True:

                rand_id = random.randint(0, len(self.data)-1)
                
                if rand_id in self.list_broken_id :
                    pass
                else :
                    flag = False
                    id = rand_id        


        data = self.pre_processing(self.data[id])
        data = self.processig(data)
        return data
        
    def pre_processing(self, data:dict):
        data_list = []
        for role in data["roles"]:
            dictionary = dict()
            dictionary["words"] = data["words"]
            dictionary["role"] = data["roles"][role]
            dictionary["pre_idx"] = role
            dictionary["pos_tags"] = data["pos_tags"]
            dictionary["predicate_meaning"] = data["predicates"]
            data_list.append(dictionary)    
        return data_list
    
    def processig(self,data_list:list):
        
        for dictionary in data_list:

            #dictionary["words"] = data["words"]
            dictionary["gt_arg_identification"] = self.arg_id(dictionary["role"])
            dictionary["gt_arg_classification"] = self.arg_class(dictionary["role"])
            dictionary["pos_idx"] = self.pos_idx(dictionary["pos_tags"])
            dictionary["predicate_meaning_idx"] = self.predicate_meaning_idx(dictionary["predicate_meaning"])
        
        return data_list
   
    def list_arg_roles(self):
        list_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : roles = element["roles"]
            except : flag = False
            if flag :
                for e in roles:
                    sentence = element["roles"][e]

                    for word in sentence:
                        
                        list_roles.append(word)
                list_roles = list(set(list_roles))
            else : 
                list_broken_id.append(i)
        return list_roles,list_broken_id

    def list_predicate_roles(self):
        list_predicate_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : predicates = element["predicates"]
            except : flag = False
            if flag :
                for pre in predicates:
                    list_predicate_roles.append(pre)
                list_predicate_roles = list(set(list_predicate_roles))
            else : 
                list_broken_id.append(i)
        return list_predicate_roles,list_broken_id

    def list_pos(self):
        list_pos = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : pos = element["pos_tags"]
            except : flag = False
            if flag :
                for e in pos:
                    list_pos.append(e)
                list_pos = list(set(list_pos))
            else : 
                list_broken_id.append(i)
        return list_pos,list_broken_id
  
    def arg_class(self,role:list):
        list_idxs = []
        for element in role:
            list_idxs.append(self.args_roles.index(element))
        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def arg_id(self,role:dict):
        list_idxs = []
        for element in role:
            if element == "_":
                list_idxs.append(0)
            else :
                list_idxs.append(1)

        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def pos_idx(self,pos_tags:dict):
        list_idxs = []
        list_idxs.append(self.pos_list.index("Nothing"))

        for element in pos_tags:
            list_idxs.append(self.pos_list.index(element))
        
        list_idxs.append(self.pos_list.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64)
    
    def predicate_meaning_idx(self,predicate_meaning_tags:dict):
        list_idxs = []
        list_idxs.append(self.predicate_dis.index("Nothing"))

        for element in predicate_meaning_tags:
            list_idxs.append(self.predicate_dis.index(element))
        
        list_idxs.append(self.predicate_dis.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64) 

    
# here we define our collate function
def collate_fn(batch) -> Dict[str, torch.Tensor]:
    #print(batch)
    input = dict() 
    batch_sentence = [] 
    #print(len(batch))
    for period in batch:
        for sentence in period :
        
            #print(len(sentence[0]["words"]))
            pre_idx = int(sentence["pre_idx"])
            

            predicate = sentence["words"][pre_idx]

            text = " ".join(sentence["words"])
            tokens: list[str] = text.split()
            predicate: list[str] = predicate.split()

            #text = sentence[0]["words"]
            
            t = (tokens,predicate)

            batch_sentence.append(t)
            #print(batch_sentence)

    batch_output = tokenizer.batch_encode_plus(batch_sentence,padding=True,is_split_into_words=True, truncation=True,return_offsets_mapping=True, return_tensors="pt")
    #print(batch_output.keys())


    gt = dict()
    
    


    for period in batch:

        list_positional_predicate_encoding = []
        list_arg_gt = []
        list_predicate_index = [] 
        list_pos_index = [] 
        list_predicate_meaning_index = []

        for sentence in period:
            #positional_encoding
            #+2 per il CLS iniziale ad SEP finale
            sentence_words_lenght =  len(sentence["words"])
            positional_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
            #+1 per il CLS iniziale
            pre_idx = int(sentence["pre_idx"])
            positional_predicate_encoding[:,pre_idx+1] = 1
            list_positional_predicate_encoding.append(positional_predicate_encoding)
            #print("positional_prefix_encoding",positional_predicate_encoding)
            list_predicate_index.append(pre_idx)




            pos = torch.unsqueeze(sentence["pos_idx"],dim = 0)
            list_pos_index.append(pos)
            predicate_meaning_idxs = torch.unsqueeze(sentence["predicate_meaning_idx"],dim = 0)
            list_predicate_meaning_index.append(predicate_meaning_idxs)






            #note CLS and SEP are discharder after Bi-LSTM, the Classifier takes in input only wokrds hidden state embedding
            arg_gt = torch.unsqueeze(sentence["gt_arg_classification"],dim = 0)
            list_arg_gt.append(arg_gt)
        

    list_arg_gt = torch.cat(list_arg_gt,dim = 0)
    list_pos_index = torch.cat(list_pos_index,dim = 0)
    list_predicate_meaning_index = torch.cat(list_predicate_meaning_index,dim = 0)
    list_positional_predicate_encoding = torch.cat(list_positional_predicate_encoding,dim = 0)
    gt["arg_gt"] = list_arg_gt
    input["predicate_index"] = list_predicate_index
    input["pos_index"] = list_pos_index.long()
    input["predicate_meaning_idx"] = list_predicate_meaning_index.long()
    offset = batch_output.pop("offset_mapping")
    input["BERT_input"] = batch_output
    input["positional_encoding"] = list_positional_predicate_encoding.long()
    input["offset_mapping"] = offset
    input["gt"] = gt

   


    return input
















Downloading: 100%|██████████| 625/625 [00:00<00:00, 137kB/s]
Downloading: 100%|██████████| 714M/714M [00:09<00:00, 73.2MB/s] 
Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



model class is      : <class 'transformers.models.bert.modeling_bert.BertModel'>


Downloading: 100%|██████████| 29.0/29.0 [00:00<00:00, 8.39kB/s]
Downloading: 100%|██████████| 996k/996k [00:01<00:00, 810kB/s] 
Downloading: 100%|██████████| 1.96M/1.96M [00:01<00:00, 1.00MB/s]



model class is      : <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>


In [15]:

#train_dataset = SRL("EN","train")

#note here we are directly loading args_roles mapping as computed before the the dasetet where we have perfomerd 
#EN and ES dataset should have the same consistency in generating 
train_dataset = SRL("ES","train",train_dataset.args_roles,train_dataset.pos_list)
#same mapping should be used in both the dataset
dev_dataset = SRL("ES","dev",train_dataset.args_roles,train_dataset.pos_list)

['SCONJ', 'ADV', 'PUNCT', 'NOUN', 'DET', 'NUM', 'PART', 'ADP', 'CCONJ', 'SYM', 'INTJ', 'PRON', 'X', 'ADJ', 'AUX', 'PROPN', 'VERB', 'Nothing']
['SCONJ', 'ADV', 'PUNCT', 'NOUN', 'DET', 'NUM', 'PART', 'ADP', 'CCONJ', 'SYM', 'INTJ', 'PRON', 'X', 'ADJ', 'AUX', 'PROPN', 'VERB', 'Nothing']


#### English-Spanish attempt

In [17]:
#note that parameter EN is only used for tracking on which dataset was/is trained, and activate loading of the pretrained head
#load the fine-tuned model over english
PATH = "/media/mv/Volume/Download/TEST_EXP/nlp2022-hw2-main-master/saved/model_2022_12_19_15_25_3.pth"
model = Arg_Classifier("ES",cfg)
model.load_state_dict(torch.load(PATH))
model.train().cuda()

Arg_Classifier(
  (bi_lstm_portable): LSTM(132, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (embedding_predicate_flag): Embedding(2, 32, max_norm=True)
  (embedding_predicate): Embedding(304, False, max_norm=True)
  (embedding_pos): Embedding(18, 100, max_norm=True)
  (bi_lstm): LSTM(900, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (dropout_language_constraint): Dropout(p=0.6, inplace=False)
  (dropout_in_classifier): Dropout(p=0.4, inplace=False)
  (Relu): ReLU()
  (Sigmoid): Sigmoid()
  (linear0): Linear(in_features=300, out_features=675, bias=True)
  (linear1): Linear(in_features=675, out_features=135, bias=True)
  (linear2): Linear(in_features=135, out_features=27, bias=True)
)

In [18]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)




#optimizer = torch.optim.Adam(model.parameters(),lr = 0.000005)
optimizer = torch.optim.Adam(model.parameters())

scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()
auto_model.cuda()

EPOCHS = 200
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1)
        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    scheduler.step()

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    print("Epochs n.", epoch)
    print("F1 train:",f1)
    print("F1 avg train:",f1_avg)
    
    avg_train_loss = total_loss/counter
    writer.add_scalar("EN_Loss_ES/train", avg_train_loss, epoch)


    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("EN_Train_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("EN_Train_ES/classification", classification_result["f1"], epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------
    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
    

    writer.add_scalar("EN_Loss_ES/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("EN_Eval_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("EN_Eval_ES/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        print("SAVED :",PATH)
        torch.save(model.state_dict(),PATH)
    

Epochs n. 0
F1 train: [0.         0.06857143 0.         0.         0.         0.
 0.         0.         0.         0.         0.43322476 0.
 0.         0.9834875  0.         0.         0.         0.04166667
 0.         0.         0.44676806 0.         0.         0.
 0.04545455 0.        ]
F1 avg train: 0.9492945394245353
identification {'true_positives': 932, 'false_positives': 224, 'false_negatives': 1222, 'precision': 0.8062283737024222, 'recall': 0.43268337975858867, 'f1': 0.5631419939577039}
classification_result {'true_positives': 517, 'false_positives': 639, 'false_negatives': 1637, 'precision': 0.4472318339100346, 'recall': 0.24001857010213556, 'f1': 0.31238670694864046}
EPOCHS : 0
F1 eval : [0.         0.01474926 0.         0.         0.         0.
 0.         0.         0.         0.51171393 0.         0.
 0.98792859 0.         0.         0.         0.         0.
 0.         0.56926316 0.         0.         0.         0.
 0.03058104 0.        ]
F1 avg eval : 0.9593153965026872

#### Compare without Transfert Learning 

In [19]:
model = Arg_Classifier("ES",cfg).cuda()

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)+"WT"




#optimizer = torch.optim.Adam(model.parameters(),lr = 0.000005)
optimizer = torch.optim.Adam(model.parameters())

scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()
auto_model.cuda()

EPOCHS = 200
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1)
        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    scheduler.step()

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    print("Epochs n.", epoch)
    print("F1 train:",f1)
    print("F1 avg train:",f1_avg)
    
    avg_train_loss = total_loss/counter
    writer.add_scalar("Loss_ES/train", avg_train_loss, epoch)


    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Train_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("Train_ES/classification", classification_result["f1"], epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------
    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
    

    writer.add_scalar("Loss_ES/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Eval_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("Eval_ES/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        print("SAVED :",PATH)
        torch.save(model.state_dict(),PATH)
    

Epochs n. 0
F1 train: [0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.17606602 0.
 0.         0.97629743 0.         0.         0.         0.
 0.         0.         0.020558   0.         0.         0.
 0.         0.        ]
F1 avg train: 0.9326034676010221
identification {'true_positives': 157, 'false_positives': 106, 'false_negatives': 1994, 'precision': 0.596958174904943, 'recall': 0.07298930729893073, 'f1': 0.1300745650372825}
classification_result {'true_positives': 71, 'false_positives': 192, 'false_negatives': 2080, 'precision': 0.26996197718631176, 'recall': 0.03300790330079033, 'f1': 0.0588235294117647}
EPOCHS : 0
F1 eval : [0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.4334126  0.         0.
 0.98445385 0.         0.         0.         0.         0.
 0.         0.12718204 0.         0.         0.         0.
 0.         0.        ]
F1 avg eval : 0.9491122608744599
identific

### French

#### New French Dataset
bert-base-multilingual-cased


In [24]:
from transformers import BertTokenizer, BertModel
import os
import json
import logging
import torch
from torch.utils.data import DataLoader,Dataset
import random
from typing import Dict






class SRL(Dataset):
 
    def __init__(self,language,path,args_roles = None) -> None:

        self.path_root = 'data'
        self.load_data(language,path)
        if args_roles is None :
            self.args_roles,self.list_broken_id = self.list_arg_roles()
        else : 
            self.args_roles = args_roles
            _,self.list_broken_id = self.list_arg_roles()

        self.pos_list,_ = self.list_pos()
        print(self.args_roles)
        self.predicate_dis,_ = self.list_predicate_roles()
        self.pos_list.append("Nothing")
        self.predicate_dis.append("Nothing")

    def load_data(self,language,mode):
        
        mode = mode+".json"
        path = os.path.join(self.path_root,language,mode)
        data_file = open(path)
       
        data_ = json.load(data_file)

        list_data = []

        for data in data_:
            list_data.append(data_[data])
        

        self.data = list_data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, id : int):

        flag = False
        if id in self.list_broken_id :
            flag = True
            while flag == True:

                rand_id = random.randint(0, len(self.data)-1)
                
                if rand_id in self.list_broken_id :
                    pass
                else :
                    flag = False
                    id = rand_id        


        data = self.pre_processing(self.data[id])
        data = self.processig(data)
        return data
        
    def pre_processing(self, data:dict):
        data_list = []
        for role in data["roles"]:
            dictionary = dict()
            dictionary["words"] = data["words"]
            dictionary["role"] = data["roles"][role]
            dictionary["pre_idx"] = role
            dictionary["pos_tags"] = data["pos_tags"]
            dictionary["predicate_meaning"] = data["predicates"]
            data_list.append(dictionary)    
        return data_list
    
    def processig(self,data_list:list):
        
        for dictionary in data_list:

            #dictionary["words"] = data["words"]
            dictionary["gt_arg_identification"] = self.arg_id(dictionary["role"])
            dictionary["gt_arg_classification"] = self.arg_class(dictionary["role"])
            dictionary["pos_idx"] = self.pos_idx(dictionary["pos_tags"])
            dictionary["predicate_meaning_idx"] = self.predicate_meaning_idx(dictionary["predicate_meaning"])
        
        return data_list
   
    def list_arg_roles(self):
        list_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : roles = element["roles"]
            except : flag = False
            if flag :
                for e in roles:
                    sentence = element["roles"][e]

                    for word in sentence:
                        
                        list_roles.append(word)
                list_roles = list(set(list_roles))
            else : 
                list_broken_id.append(i)
        return list_roles,list_broken_id

    def list_predicate_roles(self):
        list_predicate_roles = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : predicates = element["predicates"]
            except : flag = False
            if flag :
                for pre in predicates:
                    list_predicate_roles.append(pre)
                list_predicate_roles = list(set(list_predicate_roles))
            else : 
                list_broken_id.append(i)
        return list_predicate_roles,list_broken_id

    def list_pos(self):
        list_pos = []
        list_broken_id = []
        for i,element in enumerate(self.data):
            flag = True
            try : pos = element["pos_tags"]
            except : flag = False
            if flag :
                for e in pos:
                    list_pos.append(e)
                list_pos = list(set(list_pos))
            else : 
                list_broken_id.append(i)
        return list_pos,list_broken_id
  
    def arg_class(self,role:list):
        list_idxs = []
        for element in role:
            list_idxs.append(self.args_roles.index(element))
        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def arg_id(self,role:dict):
        list_idxs = []
        for element in role:
            if element == "_":
                list_idxs.append(0)
            else :
                list_idxs.append(1)

        

        return torch.tensor(list_idxs, dtype=torch.int64)

    def pos_idx(self,pos_tags:dict):
        list_idxs = []
        list_idxs.append(self.pos_list.index("Nothing"))

        for element in pos_tags:
            list_idxs.append(self.pos_list.index(element))
        
        list_idxs.append(self.pos_list.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64)
    
    def predicate_meaning_idx(self,predicate_meaning_tags:dict):
        list_idxs = []
        list_idxs.append(self.predicate_dis.index("Nothing"))

        for element in predicate_meaning_tags:
            list_idxs.append(self.predicate_dis.index(element))
        
        list_idxs.append(self.predicate_dis.index("Nothing"))
        return torch.tensor(list_idxs, dtype=torch.int64) 

    
# here we define our collate function
def collate_fn(batch) -> Dict[str, torch.Tensor]:
    #print(batch)
    input = dict() 
    batch_sentence = [] 
    #print(len(batch))
    for period in batch:
        for sentence in period :
        
            #print(len(sentence[0]["words"]))
            pre_idx = int(sentence["pre_idx"])
            

            predicate = sentence["words"][pre_idx]

            text = " ".join(sentence["words"])
            tokens: list[str] = text.split()
            predicate: list[str] = predicate.split()

            #text = sentence[0]["words"]
            
            t = (tokens,predicate)

            batch_sentence.append(t)
            #print(batch_sentence)

    batch_output = tokenizer.batch_encode_plus(batch_sentence,padding=True,is_split_into_words=True, truncation=True,return_offsets_mapping=True, return_tensors="pt")
    #print(batch_output.keys())


    gt = dict()
    
    


    for period in batch:

        list_positional_predicate_encoding = []
        list_arg_gt = []
        list_predicate_index = [] 
        list_pos_index = [] 
        list_predicate_meaning_index = []

        for sentence in period:
            #positional_encoding
            #+2 per il CLS iniziale ad SEP finale
            sentence_words_lenght =  len(sentence["words"])
            positional_predicate_encoding = torch.zeros(1,sentence_words_lenght+2)
            #+1 per il CLS iniziale
            pre_idx = int(sentence["pre_idx"])
            positional_predicate_encoding[:,pre_idx+1] = 1
            list_positional_predicate_encoding.append(positional_predicate_encoding)
            #print("positional_prefix_encoding",positional_predicate_encoding)
            list_predicate_index.append(pre_idx)




            pos = torch.unsqueeze(sentence["pos_idx"],dim = 0)
            list_pos_index.append(pos)
            predicate_meaning_idxs = torch.unsqueeze(sentence["predicate_meaning_idx"],dim = 0)
            list_predicate_meaning_index.append(predicate_meaning_idxs)






            #note CLS and SEP are discharder after Bi-LSTM, the Classifier takes in input only wokrds hidden state embedding
            arg_gt = torch.unsqueeze(sentence["gt_arg_classification"],dim = 0)
            list_arg_gt.append(arg_gt)
        

    list_arg_gt = torch.cat(list_arg_gt,dim = 0)
    list_pos_index = torch.cat(list_pos_index,dim = 0)
    list_predicate_meaning_index = torch.cat(list_predicate_meaning_index,dim = 0)
    list_positional_predicate_encoding = torch.cat(list_positional_predicate_encoding,dim = 0)
    gt["arg_gt"] = list_arg_gt
    input["predicate_index"] = list_predicate_index
    input["pos_index"] = list_pos_index.long()
    input["predicate_meaning_idx"] = list_predicate_meaning_index.long()
    offset = batch_output.pop("offset_mapping")
    input["BERT_input"] = batch_output
    input["positional_encoding"] = list_positional_predicate_encoding.long()
    input["offset_mapping"] = offset
    input["gt"] = gt

   


    return
     



OSError: Can't load the configuration of 'bert-base-multili ngual-cased'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'bert-base-multili ngual-cased' is the correct path to a directory containing a config.json file

In [25]:

#train_dataset = SRL("EN","train")

#note here we are directly loading args_roles mapping as computed before the the dasetet where we have perfomerd 
#EN and ES dataset should have the same consistency in generating 
train_dataset = SRL("FR","train",train_dataset.args_roles,train_dataset.pos_list)
#same mapping should be used in both the dataset
dev_dataset = SRL("FR","dev",train_dataset.args_roles,train_dataset.pos_list)

TypeError: __init__() takes from 3 to 4 positional arguments but 5 were given

#### English-French attempt

In [None]:
#note that parameter EN is only used for tracking on which dataset was/is trained, and activate loading of the pretrained head
#load the fine-tuned model over english
PATH = "saved/model_2022_12_18_21_33_45.pth"
model = Arg_Classifier("ES",cfg)
model.load_state_dict(torch.load(PATH))
model.train().cuda()

Arg_Classifier(
  (bi_lstm_portable): LSTM(132, 50, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (embedding_predicate_flag): Embedding(2, 32, max_norm=True)
  (embedding_predicate): Embedding(304, False, max_norm=True)
  (embedding_pos): Embedding(18, 100, max_norm=True)
  (bi_lstm): LSTM(900, 50, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (dropout_language_constraint): Dropout(p=0.6, inplace=False)
  (dropout_in_classifier): Dropout(p=0.4, inplace=False)
  (Relu): ReLU()
  (Sigmoid): Sigmoid()
  (linear0): Linear(in_features=300, out_features=675, bias=True)
  (linear1): Linear(in_features=675, out_features=135, bias=True)
  (linear2): Linear(in_features=135, out_features=27, bias=True)
)

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)




#optimizer = torch.optim.Adam(model.parameters(),lr = 0.000005)
optimizer = torch.optim.Adam(model.parameters())

scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()
auto_model.cuda()

EPOCHS = 200
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1)
        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    scheduler.step()

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    print("Epochs n.", epoch)
    print("F1 train:",f1)
    print("F1 avg train:",f1_avg)
    
    avg_train_loss = total_loss/counter
    writer.add_scalar("EN_Loss_ES/train", avg_train_loss, epoch)


    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("EN_Train_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("EN_Train_ES/classification", classification_result["f1"], epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------
    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
    

    writer.add_scalar("EN_Loss_ES/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("EN_Eval_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("EN_Eval_ES/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        print("SAVED :",PATH)
        torch.save(model.state_dict(),PATH)
    

Epochs n. 0
F1 train: [0.         0.41519926 0.00668896 0.98159286 0.         0.38427948
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
F1 avg train: 0.9458893102533754
identification {'true_positives': 657, 'false_positives': 127, 'false_negatives': 1495, 'precision': 0.8380102040816326, 'recall': 0.30529739776951675, 'f1': 0.4475476839237057}
classification_result {'true_positives': 401, 'false_positives': 383, 'false_negatives': 1751, 'precision': 0.5114795918367347, 'recall': 0.18633828996282528, 'f1': 0.2731607629427793}
EPOCHS : 0
F1 eval : [0.         0.50540098 0.         0.98476308 0.25635359 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
F1 avg eval : 0.951819105330271
i

#### Compare without Transfert Learning 

In [None]:
model = Arg_Classifier("FR",cfg).cuda()

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from sklearn.metrics import f1_score,confusion_matrix
from torch.optim.lr_scheduler import ExponentialLR
from datetime import datetime



currentDateAndTime = datetime.now()
_id = str(currentDateAndTime.year)+"_"+str(currentDateAndTime.month)+"_"+str(currentDateAndTime.day)+"_"+str(currentDateAndTime.hour)+"_"+str(currentDateAndTime.minute)+"_"+str(currentDateAndTime.second)




#optimizer = torch.optim.Adam(model.parameters(),lr = 0.000005)
optimizer = torch.optim.Adam(model.parameters())

scheduler = ExponentialLR(optimizer, gamma=0.9)


logSotfMax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()

dataloader_train = DataLoader(train_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

dataloader_dev = DataLoader(dev_dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)

mapping = dataloader_train.dataset.args_roles

auto_model.eval()
auto_model.cuda()

EPOCHS = 200
patience_counter = 0
patience = 5
max_val_loss = 9999
f1_score_max = 0
output_path = "saved"
model_name = "model_"+_id+".pth"
PATH = os.path.join(output_path,model_name)


for epoch in range(EPOCHS):

    #TRAINING
    p = []
    g = []
    model.train()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_train):
        #print(sample_batched)

        
        optimizer.zero_grad()
       
        #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
        #------------------FILTERING SUB-WORDS----------------------
        subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
        word_emebedding = []
        for i in range(n):
            subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
            flag = subtoken_mask[0,i,0]
            if flag :
                continue
            else :
                word_emebedding.append(subwords_embedding)
        word_emebedding = torch.cat(word_emebedding,dim = 1)
        #-------------------------FORWARD/BACKWARD----------------------------------
        x = model.forward(subwords_embeddings = output_hidden_states_sum,
            perdicate_positional_encoding = sample_batched["positional_encoding"],
            predicate_index = sample_batched["predicate_index"],
            pos_index_encoding = sample_batched["pos_index"],
            predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])        
        b,n = sample_batched["gt"]["arg_gt"].size()
        loss = nll_loss(logSotfMax(x),gt)
        total_loss = total_loss + loss
        #print(loss)
        loss.backward()
        optimizer.step()
    

        counter += 1 
            

        #-------------------------RESULT STORING----------------------------------
        predicted = torch.argmax(x, dim=1)
        p += predicted.tolist()
        g += gt.tolist()
    

    #-------------------------RESULTS----------------------------------
    scheduler.step()

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    print("Epochs n.", epoch)
    print("F1 train:",f1)
    print("F1 avg train:",f1_avg)
    
    avg_train_loss = total_loss/counter
    writer.add_scalar("Loss_ES/train", avg_train_loss, epoch)


    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Train_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("Train_ES/classification", classification_result["f1"], epoch)



    #EVALUATION
    p = []
    g = []
    model.eval()
    total_loss = 0
    counter = 0
    for i_batch, sample_batched in enumerate(dataloader_dev):
    
      #----------------------PREPARE INPUT/OUTPUT-------------------------------
        input_bert = sample_batched["BERT_input"]
        input_bert['input_ids'] = input_bert['input_ids'].cuda()
        input_bert['token_type_ids'] = input_bert['token_type_ids'].cuda()
        input_bert['attention_mask'] = input_bert['attention_mask'].cuda()
        sample_batched["positional_encoding"] = sample_batched["positional_encoding"].cuda()
        sample_batched["pos_index"] = sample_batched["pos_index"].cuda()
        sample_batched["predicate_meaning_idx"] = sample_batched["predicate_meaning_idx"].cuda()
        #prepare gt
        gt = torch.flatten(sample_batched["gt"]["arg_gt"]).cuda()
        offset = sample_batched["offset_mapping"]
        #-----------------BERT EMBEDDING---------------------------
        with torch.no_grad():
            output = auto_model(**input_bert)
            output_hidden_states_sum = torch.stack(output.hidden_states[-4:], dim=0).sum(dim=0)
            b,n,h = output_hidden_states_sum.size()
    
            #------------------FILTERING SUB-WORDS----------------------
            subtoken_mask = torch.unsqueeze(offset[:,:, 0] != 0,dim =-1)
            word_emebedding = []
            for i in range(n): 
                subwords_embedding = torch.unsqueeze(output_hidden_states_sum[:,i,:],dim = 1)
                flag = subtoken_mask[0,i,0]
                if flag :
                    continue
                else :
                    word_emebedding.append(subwords_embedding)
            word_emebedding = torch.cat(word_emebedding,dim = 1)
            #-------------------------FORWARD----------------------------------
            x = model.forward(subwords_embeddings = output_hidden_states_sum,
                        perdicate_positional_encoding = sample_batched["positional_encoding"],
                        predicate_index = sample_batched["predicate_index"],
                        pos_index_encoding = sample_batched["pos_index"],
                        predicate_meaning_encoding = sample_batched["predicate_meaning_idx"])   


            b,n = sample_batched["gt"]["arg_gt"].size()
            loss = nll_loss(logSotfMax(x),gt)
            total_loss = total_loss + loss
            #-------------------------RESULT STORING----------------------------------
            predicted = torch.argmax(x, dim=1)
            p += predicted.tolist()
            g += gt.tolist()
            counter += 1 
    
    #-------------------------RESULTS----------------------------------
    avg_eval_loss = total_loss/counter

    if avg_eval_loss < max_val_loss:
        max_val_loss = avg_eval_loss
    else :
        patience_counter += 1
    

    f1 = f1_score(g, p, average=None)
    f1_avg = f1_score(g, p, average="weighted")

    if patience_counter >= patience :  


        print("Early stopping at epoch : ",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
        break
    else :
        print("EPOCHS :",epoch)
        print("F1 eval :",f1)
        print("F1 avg eval :",f1_avg)
    

    writer.add_scalar("Loss_ES/validation", avg_eval_loss, epoch)

    g,p = mapping_args(g,p,mapping)

    identification_result,classification_result = metrics(g,p)
    print("identification",identification_result)
    print("classification_result",classification_result)

    writer.add_scalar("Eval_ES/identification", identification_result["f1"], epoch)
    writer.add_scalar("Eval_ES/classification", classification_result["f1"], epoch)




    if f1_avg > f1_score_max:
        f1_score_max = f1_avg
        print("SAVED :",PATH)
        torch.save(model.state_dict(),PATH)
    

Epochs n. 0
F1 train: [0.         0.19946809 0.         0.9768764  0.         0.0295421
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
F1 avg train: 0.9336015019908596
identification {'true_positives': 172, 'false_positives': 71, 'false_negatives': 1987, 'precision': 0.7078189300411523, 'recall': 0.07966651227420102, 'f1': 0.14321398834304747}
classification_result {'true_positives': 85, 'false_positives': 158, 'false_negatives': 2074, 'precision': 0.3497942386831276, 'recall': 0.03937007874015748, 'f1': 0.070774354704413}
EPOCHS : 0
F1 eval : [0.         0.23416618 0.         0.9777718  0.00265252 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
F1 avg eval : 0.9389493973815146
ide