In [None]:
#------# Import libraries and datasets #------#

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import datasets as dts
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import nltk
import re
import gc
import random
import spacy
%matplotlib inline

from nltk.corpus import stopwords
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from wordcloud import WordCloud,STOPWORDS
from nltk.stem.snowball import SnowballStemmer
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MultiLabelBinarizer
from skmultilearn.problem_transform import BinaryRelevance
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multioutput import ClassifierChain
from imblearn.under_sampling import TomekLinks
from imblearn.under_sampling import RandomUnderSampler


from transformers import BertModel,AutoModel
from sklearn.metrics import f1_score
from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score, f1_score

In [None]:
dataset = dts.load_dataset('lex_glue','unfair_tos')

In [None]:
train_dataset = pd.DataFrame.from_dict(dataset["train"])
val_dataset = pd.DataFrame.from_dict(dataset["validation"])
test_dataset = pd.DataFrame.from_dict(dataset["test"])

stop_words = list(stopwords.words('english'))

In [None]:
definitions = {"Limitation of liability": "this clause stipulates that the duty to pay damages is limited or excluded, for certain kind of losses, under certain conditions. "
               , "Unilateral termination": "this clause gives provider the right to suspend and/or terminate the service and/or the contract, and sometimes details the circumstances under which the provider claims to have a right to do so."
               , "Unilateral change": "this clause specifies the conditions under which the service provider could amend and modify the terms of service and/or the service itself."
               , "Content removal": "this clause gives the provider a right to modify/delete user’s content, including in-app purchases, and sometimes specifies the conditions under which the service provider may do so."
               , "Contract by using": "this clause stipulates that the consumer is bound by the terms of use of a specific service, simply by using the service, without even being required to mark that he or she has read and accepted them."
               , "Choice of law": "this clause specifies what law will govern the contract, meaning also what law will be applied in potential adjudication of a dispute arising under the contract."
               , "Jurisdiction": "this selection clause requires or allows the parties to resolve their disputes through an arbitration process, before the case could go to court."
               , "Arbitration": "this forum selection clause requires or allows the parties to resolve their disputes through an arbitration process, before the case could go to court however, such a clause may or may not specify that arbitration should occur within a specific jurisdiction. "}
label_to_def = {
    0 : "Limitation of liability",
    1 : "Unilateral termination",
    2 : "Unilateral change",
    3 : "Content removal",
    4 : "Contract by using",
    5 : "Choice of law",
    6 : "Jurisdiction",
    7 : "Arbitration"
}

entail_con = ["entails  that"]

In [None]:
def convert_to_entailment(dataset,defs,lab2def,ent_con,remove_unseen):
    new_dataset = {"text":[],"labels":[],"str_labels":[]}
    num_text = len(dataset["text"])
    nlp = spacy.load("en_core_web_sm")
    max_len = 0
    
    iteri = 0
    print ("garmi garmi : ",num_text)
    for i in range(0,num_text):
        if remove_unseen :
            if len(dataset["labels"][i]) > 0:
                iteri +=1
                print (iteri,"/",num_text)
                old_string = dataset["text"][i]
                random_number = random.randint(0, len(ent_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + ent_con[random_number] + " " + definitions[label_to_def[j]]
                    new_dataset["text"].append(new_string)
                    new_dataset["str_labels"].append(str(dataset["labels"][i]))
                    if j in dataset["labels"][i]:
                        new_dataset["labels"].append([0,1])
                    else:
                        new_dataset["labels"].append([1,0])

                    doc2 = nlp(new_string)
                    if len(doc2) > max_len :
                        max_len = len(doc2)
        else:
            iteri +=1
            print (iteri,"/",num_text)
            print (i,"/",num_text)
            old_string = dataset["text"][i]
            random_number = random.randint(0, len(ent_con)) - 1
            doc = nlp(old_string)
            old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
            for j in range(0,8):
                new_string = old_string + " " + ent_con[random_number] + " " + definitions[label_to_def[j]]
                new_dataset["text"].append(new_string)
                new_dataset["str_labels"].append(str(dataset["labels"][i]))
                if j in dataset["labels"][i]:
                    new_dataset["labels"].append([0,1])
                else:
                    new_dataset["labels"].append([1,0])

                doc2 = nlp(new_string)
                if len(doc2) > max_len :
                    max_len = len(doc2)
            
    
    return new_dataset,max_len

train_dataset_ent, train_max = convert_to_entailment(train_dataset,definitions,label_to_def,entail_con,True)
val_dataset_ent, val_max = convert_to_entailment(val_dataset,definitions,label_to_def,entail_con,False)
test_dataset_ent, test_max = convert_to_entailment(test_dataset,definitions,label_to_def,entail_con,False)

print (train_max, val_max, test_max)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset,num_classes,tokenizer):
        
        self.dataset = dataset
        self.texts = self.dataset["text"]
        self.labels = self.dataset["labels"]
        self.num_classes = num_classes
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        
        # Tokenize the text
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
    
        multi_label = torch.tensor(label,dtype=torch.float)
        #multi_label = torch.zeros(self.num_classes, dtype=torch.float32)
        #multi_label[label] = 1
        
        return {'input_ids':input_ids, 'attention_mask':attention_mask, 'multi_label':multi_label}
    

def list_it(curr):
    return [curr]

def delist_it(curr):
    new_list = []
    for i in curr:
        new_list.append(i[0][0])
    return new_list

def str2list(curr):
    if curr=="[]":
        return []
    else:
        return [int(x) for x in curr[1:-1].split(',')]

batch_size = 16
num_classes = 2
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


train_custom = CustomDataset(train_dataset_ent, num_classes,tokenizer)
train_dataloader = DataLoader(train_custom, batch_size=batch_size, shuffle=True)
valid_custom = CustomDataset(val_dataset_ent, num_classes,tokenizer)
val_dataloader = DataLoader(valid_custom, batch_size=batch_size, shuffle=True)

In [None]:
pd.DataFrame.from_dict(train_dataset_ent)["labels"].value_counts().plot.pie(autopct='%1.1f%%')

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
        #self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, num_classes)
        #self.sig = torch.nn.GELU()
        
    def forward(self, input_ids, attention_mask):
        _ , pooled_output = self.bert(input_ids=input_ids, attention_mask =attention_mask,return_dict=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        #prediction = self.sig(logits)
        
        return logits

learning_rate = 3e-5
base_model = BERTClassifier(num_classes)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(base_model.parameters(), lr=learning_rate)
print (base_model)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)


def train(base_model,train_dataloader,optimizer,loss_function):
    # Training loop
    num_epochs = 20
    valid_interval = 20  # Perform validation and save model every 10 iterations
    iteration = 0
    max_f1 = 0
    stop_criterion = 2000000

    running_loss = []
    for epoch in range(num_epochs):
        base_model.train()  # Set the model to training mode
        for curr_batch in train_dataloader:
            
            if iteration > stop_criterion:
                break
            
            input_ids = curr_batch['input_ids'].to(device)
            attention_mask = curr_batch['attention_mask'].to(device)
            targets = curr_batch['multi_label'].to(device)



            outputs = base_model(input_ids,attention_mask)
            loss = loss_function(outputs.to(device), targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #print (iteration)
            running_loss.append(loss.item())
            if len(running_loss) > 20:
                running_loss.pop(0)
            print (f"Epoch : {epoch} ,Iteration : {iteration}, training loss: {loss:.4f} , running loss:{sum(running_loss)/len(running_loss)}",find_metrics(targets,outputs))
            
            # freeing up excess memory
            del loss, outputs
            gc.collect()
            torch.cuda.empty_cache()
            
            
            # Validation and model saving
            if iteration % valid_interval == 0:
                base_model.eval()  # Set the model to evaluation mode

                with torch.no_grad():
                    total_loss = []
                    f1_micro = []
                    f1_macro = []
                    f1_avg = []
                    for val_batch in val_dataloader:
                        val_input_ids = val_batch['input_ids'].to(device)
                        val_attention_mask = val_batch['attention_mask'].to(device)
                        val_targets = val_batch['multi_label'].to(device)


                        outputs = base_model(val_input_ids,val_attention_mask)
                        loss = loss_function(outputs.to(device), val_targets)
                        
                        total_loss.append(loss.item())
                        val_out = find_metrics(val_targets,outputs)
                        f1_micro.append(val_out[0])
                        f1_macro.append(val_out[1])
                        f1_avg.append(val_out[2])
                        
                        # emptying memory
                        del val_out, loss, outputs
                        gc.collect()
                        torch.cuda.empty_cache()
                        
                    avg_acc = sum(f1_avg)/len(f1_avg)
                    avg_f1mic = sum(f1_micro)/len(f1_micro)
                    avg_f1mac = sum(f1_macro)/len(f1_macro)
                    avg_loss = sum(total_loss)/len(total_loss)
                    
                    print (f"Validation loss : {sum(total_loss)/len(total_loss)} ", ' ,acc : ',avg_acc," ,f1-micro : ",avg_f1mic," ,f1-macro : ",avg_f1mac)
                    if avg_f1mac > max_f1 :
                        max_f1 = avg_f1mic
                        torch.save(base_model.state_dict(),f"model_trained_ent/model_{iteration}.pth")
                    
                    del total_loss, f1_micro, f1_macro, f1_avg

                base_model.train()  # Set the model back to training mode
            
            iteration += 1
    return base_model, train_dataloader, optimizer, loss_function

base_model, train_dataloader, optimizer, loss_function = train(base_model,train_dataloader,optimizer,loss_function)

In [None]:
import torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
def find_metrics1(targets,prediction):
    #final_pred = ((torch.sigmoid(prediction) >= 0.5) * 1.0) 
    final_pred = prediction
    np_tar = targets.cpu().detach().numpy()
    np_pred = final_pred.cpu().detach().numpy()
    
    avg_f1_mic = f1_score(np_tar.flatten(), np_pred.flatten(), average='micro',zero_division=0)
    avg_f1_mac = f1_score(np_tar, np_pred, average='macro',zero_division=1)
    avg_acc = accuracy_score(np_tar, np_pred)
    del np_tar
    del np_pred
    del final_pred
    return avg_f1_mic, avg_f1_mac, avg_acc

def test_accuracies(base_model,test_data,tokenizer_private):
    nlp = spacy.load("en_core_web_sm")
    total_acc = []
    
    global_pred = []
    global_tar = []
    for i in range(0,len(test_data)):
        
        old_str = test_data["text"][i]
        doc = nlp(old_str)
        old_str = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
        prediction = []
        for j in range(0,8):
            new_string = old_str[:-1] + " " + definitions[label_to_def[j]]  
            inputs = tokenizer_private.encode_plus(
                new_string,
                add_special_tokens=True,
                max_length=128,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            outputs = torch.sigmoid(base_model(inputs["input_ids"].to(device),inputs["attention_mask"].to(device))) 
            #print (outputs)
            if outputs[0][1] > outputs[0][0] :
                prediction.append(1)
            else:
                prediction.append(0)
        
        targets = torch.zeros(8, dtype=torch.float32)
        targets[test_data["labels"][i]] = 1
        
        global_pred.append(prediction)
        global_tar.append(targets.tolist())
        print ("prediction : ",prediction)
        print ("targets : ",targets)
        print ("metrics : ",find_metrics1(torch.tensor(targets),torch.tensor(prediction)))
        total_acc.append(list(find_metrics1(torch.tensor(targets),torch.tensor(prediction) )))
        
    global_pred = np.array(global_pred)
    global_tar = np.array(global_tar)
    incorrect_predictions = np.where(global_pred != global_tar)

    # Compute label frequencies for incorrect predictions
    incorrect_label_frequencies = {}
    for label_idx, instance_idx in zip(*incorrect_predictions):
        if label_idx not in incorrect_label_frequencies:
            incorrect_label_frequencies[label_idx] = 1
        else:
            incorrect_label_frequencies[label_idx] += 1

    # Create a table to display the incorrect predictions and frequencies
    table_data = {"Incorrect Label": list(incorrect_label_frequencies.keys()),
                  "Frequency": list(incorrect_label_frequencies.values())}
    df = pd.DataFrame(table_data)
    print ("incorrect frequencies table: \n",df)
    print ("test statistics : ",np.sum(np.array(total_acc),axis=0)/len(np.array(total_acc)))
    return global_pred,global_tar
test_all_pred, test_all_tar = test_accuracies(base_model,test_dataset,tokenizer)

In [None]:
def infreq(test_all_pred,test_all_tar):
    for i in range(0,test_all_pred.shape[0]):
        print (test_all_pred[i])
        print (test_all_tar[i])
infreq(test_all_pred,test_all_tar)

In [None]:
config = {
    "train_1_lab" : [0,1,4,5,6,7],
    "test_1_lab" : [2,3],
    "train_2_lab" : [0,1,4,5],
    "test_2_lab" : [2,3,6,7]
}

In [None]:
train_new_show, test_new_show, val_new_show = {}, {}, {}

def compare_list(list1,list2):
    for i in range(0,len(list1)):
        if list1[i] in list2:
            return True
    return False
    

def convert_to_hidden(train_data,config,train_new_show,val_new_show):
    
    nlp = spacy.load("en_core_web_sm")
    
    train_new_show["1_lab"] = {"text":[],"labels":[]}
    train_new_show["2_lab"] = {"text":[],"labels":[]}
    val_new_show["1_lab"] = {"text":[],"labels":[]}
    val_new_show["2_lab"] = {"text":[],"labels":[]}
    
    
    
    for i in range(0,len(train_data["text"])):
        print (i,"|",len(train_data["text"]))
        # 1 lab
        if compare_list(train_data["labels"][i],config["train_1_lab"]):
            if len(train_dataset["labels"][i]) > 0:
                old_string = train_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    train_new_show["1_lab"]["text"].append(new_string)
                    #train_new_show["1_lab"]["str_labels"].append(str(train_dataset["labels"][i]))
                    if j in train_dataset["labels"][i]:
                        train_new_show["1_lab"]["labels"].append([0,1])
                    else:
                        train_new_show["1_lab"]["labels"].append([1,0])

        else :
            if len(train_dataset["labels"][i]) > 0:
                old_string = train_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    val_new_show["1_lab"]["text"].append(new_string)
                    #val_new_show["1_lab"]["str_labels"].append(str(train_dataset["labels"][i]))
                    if j in train_dataset["labels"][i]:
                        val_new_show["1_lab"]["labels"].append([0,1])
                    else:
                        val_new_show["1_lab"]["labels"].append([1,0])

        # 1 lab
        if compare_list(train_data["labels"][i],config["train_1_lab"]):
            if len(train_dataset["labels"][i]) > 0:
                old_string = train_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    train_new_show["2_lab"]["text"].append(new_string)
                    #train_new_show["2_lab"]["str_labels"].append(str(train_dataset["labels"][i]))
                    if j in train_dataset["labels"][i]:
                        train_new_show["2_lab"]["labels"].append([0,1])
                    else:
                        train_new_show["2_lab"]["labels"].append([1,0])

        else :
            if len(train_dataset["labels"][i]) > 0:
                old_string = train_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    val_new_show["2_lab"]["text"].append(new_string)
                    #val_new_show["2_lab"]["str_labels"].append(str(train_dataset["labels"][i]))
                    if j in train_dataset["labels"][i]:
                        val_new_show["2_lab"]["labels"].append([0,1])
                    else:
                        val_new_show["2_lab"]["labels"].append([1,0])
    
    # add validation data from validation
    
    
    
    
    return train_new_show,val_new_show

train_new_show, val_new_show = convert_to_hidden(train_dataset,config,train_new_show,val_new_show)
train_new_show, val_new_show = convert_to_hidden(val_dataset,config,train_new_show,val_new_show)

test_new_show = {"1_lab_seen":{"text":[],"labels":[]},"2_lab_unseen":{"text":[],"labels":[]}}
def calculate_train_hidden_met(test_dataset,config,test_new_show):
    nlp = spacy.load("en_core_web_sm")
    
    for i in range(0,len(test_dataset["text"])):
        if compare_list(test_dataset["labels"][i],config["train_1_lab"]):
            if len(test_dataset["labels"][i]) > 0:
                old_string = test_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    test_new_show["1_lab_seen"]["text"].append(new_string)
                    if j in test_dataset["labels"][i]:
                        test_new_show["1_lab_seen"]["labels"].append([0,1])
                    else:
                        test_new_show["1_lab_seen"]["labels"].append([1,0])

        else :
            if len(test_dataset["labels"][i]) > 0:
                old_string = test_dataset["text"][i]
                random_number = random.randint(0, len(entail_con)) - 1
                doc = nlp(old_string)
                old_string = (" ").join([" ".join([token.text for token in sent if not token.is_punct]) for sent in doc.sents])[:-1]
                for j in range(0,8):
                    new_string = old_string + " " + entail_con[random_number] + " " + definitions[label_to_def[j]]
                    test_new_show["2_lab_unseen"]["text"].append(new_string)

                    if j in test_dataset["labels"][i]:
                        test_new_show["2_lab_unseen"]["labels"].append([0,1])
                    else:
                        test_new_show["2_lab_unseen"]["labels"].append([1,0])
    
    return test_new_show

test_new_show = calculate_train_hidden_met(test_dataset,config,test_new_show)

In [None]:
# train 1 lab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model1 = BERTClassifier(2)
base_model1.to(device)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(base_model1.parameters(), lr=learning_rate)

class CustomDataset(Dataset):
    def __init__(self, dataset,num_classes,tokenizer):
        
        self.dataset = dataset
        self.texts = self.dataset["text"]
        self.labels = self.dataset["labels"]
        self.num_classes = num_classes
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        
        # Tokenize the text
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
    
        multi_label = torch.tensor(label,dtype=torch.float)
        #multi_label = torch.zeros(self.num_classes, dtype=torch.float32)
        #multi_label[label] = 1
        
        return {'input_ids':input_ids, 'attention_mask':attention_mask, 'multi_label':multi_label}
    
def list_it(curr):
    return [curr]

def delist_it(curr):
    new_list = []
    for i in curr:
        new_list.append(i[0][0])
    return new_list

def str2list(curr):
    if curr=="[]":
        return []
    else:
        return [int(x) for x in curr[1:-1].split(',')]

batch_size = 16
num_classes = 2
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


train_custom1 = CustomDataset(train_new_show["1_lab"], num_classes,tokenizer)
train_dataloader1 = DataLoader(train_custom1, batch_size=batch_size, shuffle=True)
valid_custom1 = CustomDataset(val_new_show["1_lab"], num_classes,tokenizer)
val_dataloader1 = DataLoader(valid_custom1, batch_size=batch_size, shuffle=True)

def train(base_model,train_dataloader,val_dataloader,optimizer,loss_function):
    # Training loop
    num_epochs = 20
    valid_interval = 20  # Perform validation and save model every 10 iterations
    iteration = 0
    max_f1 = 0
    stop_criterion = 2000000

    running_loss = []
    for epoch in range(num_epochs):
        base_model.train()  # Set the model to training mode
        for curr_batch in train_dataloader:
            
            if iteration > stop_criterion:
                break
            
            input_ids = curr_batch['input_ids'].to(device)
            attention_mask = curr_batch['attention_mask'].to(device)
            targets = curr_batch['multi_label'].to(device)



            outputs = base_model(input_ids,attention_mask)
            loss = loss_function(outputs.to(device), targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #print (iteration)
            running_loss.append(loss.item())
            if len(running_loss) > 20:
                running_loss.pop(0)
            print (f"Epoch : {epoch} ,Iteration : {iteration}, training loss: {loss:.4f} , running loss:{sum(running_loss)/len(running_loss)}",find_metrics(targets,outputs))
            
            # freeing up excess memory
            del loss, outputs
            gc.collect()
            torch.cuda.empty_cache()
            
            
            # Validation and model saving
            if iteration % valid_interval == 0:
                base_model.eval()  # Set the model to evaluation mode

                with torch.no_grad():
                    total_loss = []
                    f1_micro = []
                    f1_macro = []
                    f1_avg = []
                    for val_batch in val_dataloader:
                        val_input_ids = val_batch['input_ids'].to(device)
                        val_attention_mask = val_batch['attention_mask'].to(device)
                        val_targets = val_batch['multi_label'].to(device)


                        outputs = base_model(val_input_ids,val_attention_mask)
                        loss = loss_function(outputs.to(device), val_targets)
                        
                        total_loss.append(loss.item())
                        val_out = find_metrics(val_targets,outputs)
                        f1_micro.append(val_out[0])
                        f1_macro.append(val_out[1])
                        f1_avg.append(val_out[2])
                        
                        # emptying memory
                        del val_out, loss, outputs
                        gc.collect()
                        torch.cuda.empty_cache()
                        
                    avg_acc = sum(f1_avg)/len(f1_avg)
                    avg_f1mic = sum(f1_micro)/len(f1_micro)
                    avg_f1mac = sum(f1_macro)/len(f1_macro)
                    avg_loss = sum(total_loss)/len(total_loss)
                    
                    print (f"Validation loss : {sum(total_loss)/len(total_loss)} ", ' ,acc : ',avg_acc," ,f1-micro : ",avg_f1mic," ,f1-macro : ",avg_f1mac)
                    if avg_f1mic > max_f1 :
                        max_f1 = avg_f1mic
                        torch.save(base_model.state_dict(),f"models_1_lab/model_{iteration}.pth")
                    
                    del total_loss, f1_micro, f1_macro, f1_avg

                base_model.train()  # Set the model back to training mode
            
            iteration += 1
    return base_model, train_dataloader, optimizer, loss_function

base_model1, train_dataloader1, optimizer, loss_function = train(base_model1,train_dataloader1,val_dataloader1,optimizer,loss_function)