<a href="https://colab.research.google.com/github/adisav17/Deep-Semantic-Role-Labeling-with-Auxilary-tasks/blob/main/test_ensemble_srl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1


In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score
import numpy as np
import math
import random
import itertools
import pandas as pd 
import csv


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:

def map_pos_tags(tag):
    pos_groups = {
        "NN": ["NNS", "NNP", "NNPS"],
        "VB": ["VBD", "VBG", "VBN", "VBP", "VBZ"],
        "CC": ["CC"],
        "DT": ["DT"],
        "JJ": ["JJ", "JJR", "JJS"],
        "IN": ["IN"],
        "PRP": ["PRP", "PRP$"],

        #'other', 'JJ 5', 'IN 6', 'PRP ', 'DT 4', 'NN 1', 'VB 2', 'CC 3'
    }
    
    for key, value in pos_groups.items():
        if tag in value:
            return key
    return "other"

In [5]:
def load_and_preprocess_data(file_path):
    sentences = []
    predicate_indices = []
    labels = []
    pos_tags = []
    bio_tags = []
    directed_distances = []
    skipped_count = 0
    taken_count = 0 

    with open(file_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
        current_sentence = []
        current_predicate_index = None
        current_labels = []
        current_pos_tags = []
        current_bio_tags = []
        current_directed_distances = []
        has_arg1 = False

        for row in reader:
            if row:
                #print(row)
                word, pos, bio, word_idx, sentence_idx = row[:5]
                pos = map_pos_tags(pos)

                if len(row)>5:
                    if row[5] == "PRED":
                           current_predicate_index = int(word_idx)

                current_sentence.append(word)
                current_pos_tags.append(pos)
                current_bio_tags.append(bio)
                label = 1 if "ARG1" in row else 0
                current_labels.append(label)
        
                if current_predicate_index is not None:
                   #print("here in pred idx not none")
                   current_directed_distances.append(int(word_idx) - current_predicate_index)

                if "ARG1" in row:
                    has_arg1 = True
                    #print("here in has arg1", has_arg1)
            else:
                if current_sentence and has_arg1 and current_predicate_index is not None:
                   # print("here in else of if row") 
                    sentences.append(" ".join(current_sentence))
                    predicate_indices.append(current_predicate_index)
                    labels.append(current_labels)
                    pos_tags.append(current_pos_tags)
                    bio_tags.append(current_bio_tags)
                    for i in range(current_predicate_index ,0, -1):
                      current_directed_distances.insert(0,i - current_predicate_index - 1)

                    directed_distances.append(current_directed_distances)
                  
                    taken_count+=1
                else:
                    #print("Skipped sentence:", current_sentence)
                    skipped_count+=1
                    #break

                current_sentence = []
                current_predicate_index = None
                current_labels = []
                current_pos_tags = []
                current_bio_tags = []
                current_directed_distances = []
                has_arg1 = False

    return sentences, predicate_indices, labels, pos_tags, bio_tags, directed_distances, skipped_count, taken_count


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
%cd /content/drive/My Drive/nlp_srl

/content/drive/My Drive/nlp_srl


In [8]:
def load_and_preprocess_data_multiple_files(file_paths):
    sentences = []
    predicate_indices = []
    labels = []
    pos_tags = []
    bio_tags = []
    directed_distances = []
    skipped_counts = []
    taken_counts = []

    for file_path in file_paths:
        result = load_and_preprocess_data(file_path)
        sentences.extend(result[0])
        predicate_indices.extend(result[1])
        labels.extend(result[2])
        pos_tags.extend(result[3])
        bio_tags.extend(result[4])
        directed_distances.extend(result[5])
        skipped_counts.append(result[6])
        taken_counts.append(result[7])

    return sentences, predicate_indices, labels, pos_tags, bio_tags, directed_distances, skipped_counts, taken_counts


In [9]:
file_paths = ['partitive_group_nombank.clean.train', 'partitive_group_nombank.clean.test', 'partitive_group_nombank.clean.dev']

In [10]:
def tokenize_loaded_data(sentences, predicate_indices, labels, pos_tags, bio_tags, directed_distances):
    # Create unique dictionaries for pos_tags and bio_tags
    pos_tag_dict = {tag: idx  for idx, tag in enumerate(set(tag for tags in pos_tags for tag in tags))}
    bio_tag_dict = {tag: idx  for idx, tag in enumerate(set(tag for tags in bio_tags for tag in tags))}

    # Convert pos_tags and bio_tags to integers using the dictionaries
    pos_tags = [[pos_tag_dict[tag] for tag in tags] for tags in pos_tags]
    bio_tags = [[bio_tag_dict[tag] for tag in tags] for tags in bio_tags]

    return sentences, predicate_indices, labels, pos_tags, bio_tags, directed_distances


In [11]:
def find_num_tags(tags):
  
  tags_flat = [tag for tag_list in tags for tag in tag_list]
  #print(tags_flat[2])
  #print(len(tags_flat))
  print(set(tags_flat))
  return len(set(tags_flat)) + 1

In [12]:
def pad_and_align_tags(sentence, labels, pos_tags, bio_tags, directed_distances, max_length, tokenizer):
    tokenized_sentence = tokenizer.tokenize(sentence)
    words = sentence.split()

    def align_tags(tags):
        aligned_tags = [-100] * len(tokenized_sentence)
        for word, tag, idx in zip(words, tags, range(len(words))):
            subwords = tokenizer.tokenize(word)
            subword_idx = tokenized_sentence.index(subwords[0], idx)
            aligned_tags[subword_idx] = tag
        return aligned_tags

    aligned_labels = align_tags(labels)
    aligned_pos_tags = align_tags(pos_tags)
    aligned_bio_tags = align_tags(bio_tags)
    aligned_directed_distances = align_tags(directed_distances)

    def pad_tags(tags):
        padded_tags = [-100] + tags[:max_length - 2] + [-100]
        padded_tags = padded_tags + [-100] * (max_length - len(padded_tags))
        return padded_tags

    padded_labels = pad_tags(aligned_labels)
    padded_pos_tags = pad_tags(aligned_pos_tags)
    padded_bio_tags = pad_tags(aligned_bio_tags)
    padded_directed_distances = pad_tags(aligned_directed_distances)

    return padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances

In [13]:
import torch
from torch.nn.utils.rnn import pad_sequence

class SRLfeatDataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, tokenizer, max_length):
        self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances, _, _ = load_and_preprocess_data_multiple_files(file_paths)
        self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances = tokenize_loaded_data(self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances)
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.padded_input_ids = []
        self.padded_labels = []
        self.padded_pos_tags = []
        self.padded_bio_tags = []
        self.padded_directed_distances = []

        for sentence, labels, pos_tags, bio_tags, directed_distances in zip(self.sentences, self.labels, self.pos_tags, self.bio_tags, self.directed_distances):
            input_ids, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = self.process_sentence(sentence, labels, pos_tags, bio_tags, directed_distances)
            self.padded_input_ids.append(input_ids)
            self.padded_labels.append(padded_labels)
            self.padded_pos_tags.append(padded_pos_tags)
            self.padded_bio_tags.append(padded_bio_tags)
            self.padded_directed_distances.append(padded_directed_distances)

    def process_sentence(self, sentence, labels, pos_tags, bio_tags, directed_distances):
        tokenized_sentence = self.tokenizer.tokenize(sentence)
        encoded_sentence = self.tokenizer.encode(sentence, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True)
        input_ids = torch.tensor(encoded_sentence, dtype=torch.long)

        padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = pad_and_align_tags(sentence, labels, pos_tags, bio_tags, directed_distances, self.max_length, self.tokenizer)

        padded_labels = torch.tensor(padded_labels, dtype=torch.long)
        padded_pos_tags = torch.tensor(padded_pos_tags, dtype=torch.long)
        padded_bio_tags = torch.tensor(padded_bio_tags, dtype=torch.long)
        padded_directed_distances = torch.tensor(padded_directed_distances, dtype=torch.float).unsqueeze(-1)

        return input_ids, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances

    def __getitem__(self, index):
        input_ids = self.padded_input_ids[index]
        predicate_idx = self.predicate_indices[index]
        padded_labels = self.padded_labels[index]
        padded_pos_tags = self.padded_pos_tags[index]
        padded_bio_tags = self.padded_bio_tags[index]
        padded_directed_distances = self.padded_directed_distances[index]

        return input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances

    def __len__(self):
        return len(self.sentences)


In [14]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [15]:
max_length = 128

In [16]:
srl_feat_dataset = SRLfeatDataset(file_paths, tokenizer, max_length)

In [17]:
from torch.utils.data import random_split

# Calculate the number of samples for the train and validation sets
total_samples = len(srl_feat_dataset)
train_samples = int(total_samples * 0.8)
val_samples = total_samples - train_samples

# Split the dataset into train and validation sets
train_dataset, val_dataset = random_split(srl_feat_dataset, [train_samples, val_samples])


In [18]:
batch_size = 4  # Choose a batch size according to your needs

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [19]:
class SRLauxDataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, tokenizer, max_length):
        self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances, _, _ = load_and_preprocess_data_multiple_files(file_paths)
        self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances = tokenize_loaded_data(self.sentences, self.predicate_indices, self.labels, self.pos_tags, self.bio_tags, self.directed_distances)
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.padded_input_ids = []
        self.padded_labels = []
        self.padded_pos_tags = []
        self.padded_bio_tags = []
        self.padded_directed_distances = []

        for sentence, labels, pos_tags, bio_tags, directed_distances in zip(self.sentences, self.labels, self.pos_tags, self.bio_tags, self.directed_distances):
            input_ids, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = self.process_sentence(sentence, labels, pos_tags, bio_tags, directed_distances)
            self.padded_input_ids.append(input_ids)
            self.padded_labels.append(padded_labels)
            self.padded_pos_tags.append(padded_pos_tags)
            self.padded_bio_tags.append(padded_bio_tags)
            self.padded_directed_distances.append(padded_directed_distances)

    def process_sentence(self, sentence, labels, pos_tags, bio_tags, directed_distances):
        tokenized_sentence = self.tokenizer.tokenize(sentence)
        encoded_sentence = self.tokenizer.encode(sentence, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True)
        input_ids = torch.tensor(encoded_sentence, dtype=torch.long)

        padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = pad_and_align_tags(sentence, labels, pos_tags, bio_tags, directed_distances, self.max_length, self.tokenizer)

        padded_labels = torch.tensor(padded_labels, dtype=torch.long)
        padded_pos_tags = torch.tensor(padded_pos_tags, dtype=torch.long)
        padded_bio_tags = torch.tensor(padded_bio_tags, dtype=torch.long)
        padded_directed_distances = torch.tensor(padded_directed_distances, dtype=torch.long)

        return input_ids, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances

    def __getitem__(self, index):
        input_ids = self.padded_input_ids[index]
        predicate_idx = self.predicate_indices[index]
        padded_labels = self.padded_labels[index]
        padded_pos_tags = self.padded_pos_tags[index]
        padded_bio_tags = self.padded_bio_tags[index]
        padded_directed_distances = self.padded_directed_distances[index]

        return input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances

    def __len__(self):
        return len(self.sentences)


In [21]:
## dataset defined here
max_length = 128
srl_aux_feat_dataset = SRLauxDataset(file_paths, tokenizer, max_length)

In [22]:
from torch.utils.data import Subset

# Calculate the number of samples for the train and validation sets
train_samples = 9100
val_samples_first = 500
val_samples_last = 1500

# Create the train and validation sets using list slicing
train_dataset = Subset(srl_aux_feat_dataset, range(0, train_samples))
val_dataset = Subset(srl_aux_feat_dataset, list(range(0, val_samples_first)) + list(range(-val_samples_last, 0)))


In [None]:
input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = next(iter(train_loader))

In [None]:
padded_pos_tags[0]

tensor([-100,    1,    3,    3,    4,    5,    3, -100,    3, -100,    3,    3,
           5, -100,    5,    3,    0,    4,    6,    3, -100,    3,    3,    5,
           5,    3,    3, -100, -100,    4,    5,    5,    3, -100,    3,    3,
           3,    4,    5,    5,    2, -100,    5,    3, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100])

In [23]:
class SRLfeatModel(nn.Module):
    
    def __init__(self, bert_model, pos_tag_embedding_dim, bio_tag_embedding_dim, num_pos_tags, num_bio_tags, linear_output_dim, num_labels, dropout_rate):
        super(SRLfeatModel, self).__init__()
              
        self.bert = bert_model 
        self.pos_tag_embedder = nn.Embedding(num_embeddings=num_pos_tags + 1, embedding_dim=pos_tag_embedding_dim)
        self.bio_tag_embedder = nn.Embedding(num_embeddings=num_bio_tags + 1, embedding_dim=bio_tag_embedding_dim)
        
        self.linear1 = nn.Linear(self.bert.config.hidden_size, linear_output_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.linear2 = nn.Linear(linear_output_dim + pos_tag_embedding_dim + bio_tag_embedding_dim + 1, num_labels)

    def forward(self, input_ids, pos_tags, bio_tags, directed_distances, padded_labels):
        bert_output = self.bert(input_ids)
        sequence_output = bert_output.last_hidden_state

        transformed_bert_output = self.linear1(sequence_output)
        transformed_bert_output = self.dropout(transformed_bert_output)
        
        # Mask and replace -100 values in pos_tags and bio_tags with 0
        pos_tags = pos_tags.masked_fill(pos_tags == -100, 8)
        bio_tags = bio_tags.masked_fill(bio_tags == -100, 17)
        
        #print("bio_tag_embeddings:", bio_tag_embeddings)
        bio_tags = bio_tags.to(device).long()
        pos_tags = pos_tags.to(device).long()
        print("bio_tags in forward dtype:", bio_tags.dtype)
        print("pos_tags in forward dtype:", pos_tags.dtype)
   

        

        pos_tag_embeddings = self.pos_tag_embedder(pos_tags)
        bio_tag_embeddings = self.bio_tag_embedder(bio_tags)
        print(transformed_bert_output.shape)
        print(pos_tag_embeddings.shape)
        print(bio_tag_embeddings.shape)
        print(directed_distances.shape)
        bio_tag_embeddings = bio_tag_embeddings.squeeze(2)
        directed_distances = directed_distances.unsqueeze(-1)
        print(transformed_bert_output.shape)
        print(pos_tag_embeddings.shape)
        print(bio_tag_embeddings.shape)
        print(directed_distances.shape)


        concatenated_features = torch.cat([transformed_bert_output, pos_tag_embeddings, bio_tag_embeddings, directed_distances], dim=-1)
        
        logits = self.linear2(concatenated_features)

        mask = (padded_labels != -100)
        masked_labels = torch.masked_select(padded_labels, mask)
        masked_logits = torch.masked_select(logits, mask.unsqueeze(-1)).view(-1, logits.shape[-1])

        return masked_logits, masked_labels


In [25]:
class SRLauxModel(nn.Module):
    def __init__(self, bert_model, lstm_hidden_size, dropout_rate, layers_to_use=[1, 2, 3], num_pos_tags=9, num_bio_tags=18):
        super(SRLauxModel, self).__init__()
        self.bert = bert_model
        self.layers_to_use = layers_to_use
        self.layer_weights = nn.Parameter(torch.rand(len(layers_to_use), dtype=torch.float))
        self.softmax = nn.Softmax(dim=0)

        self.auxiliary_pos = nn.Linear(in_features=self.bert.config.hidden_size, out_features=num_pos_tags)
        self.auxiliary_bio = nn.Linear(in_features=self.bert.config.hidden_size, out_features=num_bio_tags)
        self.auxiliary_directed_distance = nn.Linear(in_features=self.bert.config.hidden_size, out_features=1)

        self.downstream = nn.Sequential(
            nn.LSTM(input_size=self.bert.config.hidden_size,
                    hidden_size=lstm_hidden_size,
                    num_layers=1,
                    batch_first=True,
                    bidirectional=True),
            nn.Dropout(dropout_rate),
            nn.Linear(in_features=lstm_hidden_size * 2, out_features=1)
        )

    def forward(self, input_ids, predicate_idx, labels=None):
        input_embeddings = self.bert.embeddings(input_ids)
        
        # Create predicate indicator embedding
        predicate_mask = torch.zeros_like(input_ids).scatter_(1, predicate_idx.view(-1, 1), 1)
        predicate_indicator = self.bert.embeddings.token_type_embeddings(predicate_mask.to(input_ids.device))

        # Add predicate indicator to input embeddings
        input_embeddings = input_embeddings + predicate_indicator

        bert_output = self.bert(inputs_embeds=input_embeddings, output_hidden_states=True)
        all_layer_outputs = bert_output.hidden_states

        selected_layer_outputs = [all_layer_outputs[i] for i in self.layers_to_use]
        weighted_outputs = [self.softmax(self.layer_weights)[i] * output for i, output in enumerate(selected_layer_outputs)]
        weighted_average = torch.stack(weighted_outputs).sum(dim=0)

        pos_logits = self.auxiliary_pos(weighted_average)
        bio_logits = self.auxiliary_bio(weighted_average)
        directed_distance_logits = self.auxiliary_directed_distance(weighted_average)

        downstream_output, _ = self.downstream[0](weighted_average)
        downstream_output = self.downstream[1](downstream_output)
        logits = self.downstream[2](downstream_output)

        if labels is not None:
            labels_mask = (labels != -100)
            labels = labels[labels_mask]
            logits = logits[labels_mask]

       # logits = logits.squeeze(-1)     

        return logits, labels, pos_logits, bio_logits, directed_distance_logits


In [24]:


class SRLindModel(nn.Module):
    def __init__(self, bert_model, lstm_hidden_size, dropout_rate, predicate_emb_dim):
        super(SRLindModel, self).__init__()
        self.bert = bert_model

        self.predicate_embedding = nn.Embedding(2, predicate_emb_dim)

        self.downstream = nn.Sequential(
            nn.LSTM(input_size=self.bert.config.hidden_size + predicate_emb_dim,
                    hidden_size=lstm_hidden_size,
                    num_layers=1,
                    batch_first=True,
                    bidirectional=True),
            nn.Dropout(dropout_rate),
            nn.Linear(in_features=lstm_hidden_size * 2, out_features=1)
        )

    def forward(self, input_ids, predicate_idx, labels=None, attention_mask=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = bert_output.last_hidden_state

        # Create predicate indicator embedding
        predicate_indicator = torch.zeros_like(input_ids)
        predicate_indicator.scatter_(1, predicate_idx.unsqueeze(1), 1)
        predicate_emb = self.predicate_embedding(predicate_indicator)
        
        # Concatenate predicate indicator embeddings with the sequence output
        sequence_output = torch.cat((sequence_output, predicate_emb), dim=-1)

        downstream_output, _ = self.downstream[0](sequence_output)
        downstream_output = self.downstream[1](downstream_output)
        logits = self.downstream[2](downstream_output)

        if labels is not None:
            labels_mask = (labels != -100)
            labels = labels[labels_mask]
            logits = logits[labels_mask]

        return logits, labels


In [26]:
bert_model_1 = BertModel.from_pretrained("bert-base-uncased")

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [27]:
bert_model_2 = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
num_bio_tags = 19
num_pos_tags = 10
num_labels = 1
linear_output_dim = 256
pos_tag_embedding_dim = 7
bio_tag_embedding_dim = 16
dropout_rate = 0.2
num_epochs = 100
learning_rate = 3e-5
clip_grad_value = 1.0
custom_weight_value = 27.0
lstm_hidden_size = 35
predicate_emb_dim = 40



In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
srl_aux_model  = SRLauxModel(bert_model_2, lstm_hidden_size, dropout_rate, layers_to_use=[1, 2, 3], num_pos_tags=9, num_bio_tags=18).to(device)

In [30]:
num_pos_tags

10

In [None]:
srl_feat_model = SRLfeatModel(bert_model_1, pos_tag_embedding_dim, bio_tag_embedding_dim, num_pos_tags, num_bio_tags, linear_output_dim, num_labels, dropout_rate).to(device)

In [31]:
import torch.nn as nn

class SRLEnsembleModel(nn.Module):
    def __init__(self, srl_aux_model, srl_feat_model, num_labels, weight_init=None):
        super(SRLEnsembleModel, self).__init__()
        self.srl_aux_model = srl_aux_model
        self.srl_aux_model = srl_feat_model
        self.weight = nn.Parameter(torch.Tensor(1, num_labels))
        
        if weight_init is None:
            nn.init.uniform_(self.weight)
        else:
            self.weight.data = weight_init

    def forward(self, input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances):
        # SRLindModel
        logits1, labels1,  pos_logits, bio_logits, directed_distance_logits = self.srl_aux_model(input_ids, predicate_idx, labels=padded_labels)

        # SRLfeatModel
        logits2, labels2 = self.srl_feat_model(input_ids, padded_pos_tags, padded_bio_tags, padded_directed_distances, padded_labels)

        # Ensemble logits using weighted average
        ensemble_logits = self.weight[0] * logits1 + (1 - self.weight[0]) * logits2

        return ensemble_logits, labels1,  pos_logits, bio_logits, directed_distance_logits  # labels1 and labels2 are the same


In [32]:
def validate_on_train(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch

            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags,  padded_directed_distances  = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device), padded_pos_tags.to(device),padded_bio_tags.to(device).long(),  padded_directed_distances.to(device)
            
            #print("bio_tags:", bio_tags)
            print("bio_tags dtype:", padded_bio_tags.dtype)
            print("pos_tags dtype:", padded_pos_tags.dtype)
            logits, labels = model( input_ids, predicate_idx, padded_pos_tags, padded_bio_tags, padded_directed_distances, padded_labels)
        

            loss = criterion(logits, labels.float().unsqueeze(1))
             
            total_loss += loss.item()

            mask = labels.ne(-100)
            valid_labels = labels[mask].cpu().numpy()
            valid_logits = logits[mask].cpu().numpy().squeeze()

            all_labels.extend(valid_labels)
            all_logits.extend(valid_logits)

    average_loss = total_loss / len(dataloader)
    all_labels = np.array(all_labels)
    all_logits = np.array(all_logits)

    all_probs = 1 / (1 + np.exp(-all_logits))

    precision, recall, thresholds = precision_recall_curve(all_labels, all_probs)

    f_scores = np.where((precision + recall) != 0.0, (2 * precision * recall) / (precision + recall + 1e-10), 0)

    best_threshold = thresholds[np.argmax(f_scores)]

    preds = (all_probs > best_threshold).astype(int)
    accuracy = accuracy_score(all_labels, preds)
    best_f_score = f1_score(all_labels, preds)

    return average_loss, accuracy, best_f_score, best_threshold


In [33]:
def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags,  padded_directed_distances  = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device), padded_pos_tags.to(device),padded_bio_tags.to(device).long(),  padded_directed_distances.to(device)
            #print("bio_tags:", bio_tags)
            print("bio_tags dtype:", padded_bio_tags.dtype)
            print("pos_tags dtype:", padded_pos_tags.dtype)
            logits, labels = model( input_ids, predicate_idx, padded_pos_tags, padded_bio_tags, padded_directed_distances, padded_labels)
        

            loss = criterion(logits, labels.float().unsqueeze(1))
             
            total_loss += loss.item()

            mask = labels.ne(-100)
            valid_labels = labels[mask].cpu().numpy()
            valid_logits = logits[mask].cpu().numpy().squeeze()

            all_labels.extend(valid_labels)
            all_logits.extend(valid_logits)

    average_loss = total_loss / len(dataloader)
    all_labels = np.array(all_labels)
    all_logits = np.array(all_logits)

    all_probs = 1 / (1 + np.exp(-all_logits))

    precision, recall, thresholds = precision_recall_curve(all_labels, all_probs)

    f_scores = np.where((precision + recall) != 0.0, (2 * precision * recall) / (precision + recall + 1e-10), 0)

    best_threshold = thresholds[np.argmax(f_scores)]

    preds = (all_probs > best_threshold).astype(int)
    accuracy = accuracy_score(all_labels, preds)
    best_f_score = f1_score(all_labels, preds)

    return average_loss, accuracy, best_f_score, best_threshold


In [34]:
def train_model(model, train_dataset, train_dataloader, val_dataloader, criterion, optimizer, num_epochs, clip_grad_value=1, weighting_method='none', custom_value=20, patience=15):

  if weighting_method != 'none':
      negative_count = 27.0 #sum([1 for label_seq in train_dataset.labels for label in label_seq if label == 0])
      positive_count = 1.0 #sum([1 for label_seq in train_dataset.labels for label in label_seq if label == 1])

      if weighting_method == 'direct':
          pos_weight = torch.tensor([negative_count / positive_count], device=device)
      elif weighting_method == 'log':
          pos_weight = torch.tensor([np.log(negative_count / positive_count)], device=device)
      elif weighting_method == 'custom':
          pos_weight = torch.tensor([custom_value], device=device)
      else:
          raise ValueError("Invalid weighting_method value. It must be 'none', 'direct', 'log', or 'custom'.")
  else:
      pos_weight = torch.tensor(1.0, device=device)



  train_accuracies = []
  val_accuracies = []
  train_f_scores = []
  val_f_scores = []
  avg_train_loss_per_epoch = []
  avg_val_loss_per_epoch = []

  # Early stopping initialization
  best_val_accuracy = float('-inf')
  patience_counter = 0

  for epoch in range(num_epochs):
      print(f"Epoch {epoch+1}/{num_epochs}")
      total_train_loss = 0
      num_train_batches = 0

      for i, batch in enumerate(train_dataloader):
          model.train()
          #input_ids, predicate_idx, padded_labels = batch
          #input_ids, predicate_idx, padded_labels = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device)

          #logits, labels = model(input_ids, predicate_idx, padded_labels)
          input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch
          input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags,  padded_directed_distances  = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device), padded_pos_tags.to(device),padded_bio_tags.to(device).long(),  padded_directed_distances.to(device)
          
          print("bio_tags dtype:", padded_bio_tags.dtype)
          print("pos_tags dtype:", padded_pos_tags.dtype)
          logits, labels = model( input_ids, predicate_idx, padded_pos_tags, padded_bio_tags, padded_directed_distances, padded_labels)

          criterion.pos_weight = pos_weight

          loss = criterion(logits, labels.float().unsqueeze(1))
          loss.backward()

          torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_value)

          optimizer.step()
          optimizer.zero_grad()

          total_train_loss += loss.item()
          num_train_batches += 1


      avg_train_loss_per_epoch.append(total_train_loss / num_train_batches)

      val_loss, val_accuracy, val_f_score, val_threshold = validate(model, val_dataloader, criterion)
      avg_val_loss_per_epoch.append(val_loss)
      print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Best F-score: {val_f_score}, Best Threshold: {val_threshold}")

      if (epoch + 1) % 10 == 0:
          train_loss, train_accuracy, train_f_score, train_threshold = validate_on_train(model, train_dataloader, criterion)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(val_accuracy)
          train_f_scores.append(train_f_score)
          val_f_scores.append(val_f_score)
          print(f"Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Best F-score: {train_f_score}, Best Threshold: {train_threshold}")

      # Early stopping
      if val_accuracy > best_val_accuracy:
          best_val_accuracy = val_accuracy
          patience_counter = 0
      else:
          patience_counter += 1

      if patience_counter >= patience:
          print(f"Early stopping triggered after {epoch + 1} epochs due to no improvement in validation accuracy")
          num_missing_values = num_epochs - epoch - 1
          train_accuracies.extend([None] * num_missing_values)
          val_accuracies.extend([None] * num_missing_values)
          train_f_scores.extend([None] * num_missing_values)
          val_f_scores.extend([None] * num_missing_values)

          #return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores
          return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores
          

  return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores


In [35]:
def validate_on_train(model, dataloader):
    main_task_criterion = nn.BCEWithLogitsLoss()
    pos_tag_criterion = nn.CrossEntropyLoss()
    bio_tag_criterion = nn.CrossEntropyLoss()
    directed_distance_criterion = nn.MSELoss()

    model.eval()
    total_loss = 0
    total_aux_pos_loss = 0
    total_aux_bio_loss = 0
    total_aux_directed_distance_loss = 0
    
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch
            input_ids, predicate_idx, padded_labels = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device)
            padded_pos_tags, padded_bio_tags, padded_directed_distances = padded_pos_tags.to(device), padded_bio_tags.to(device), padded_directed_distances.to(device)

            logits, labels, pos_logits, bio_logits, directed_distance_logits = model(input_ids, predicate_idx, padded_labels)

            loss = main_task_criterion(logits, labels.float().unsqueeze(1))
            aux_pos_loss = pos_tag_criterion(pos_logits.view(-1, pos_logits.shape[-1]), padded_pos_tags.view(-1))
            aux_bio_loss = bio_tag_criterion(bio_logits.view(-1, bio_logits.shape[-1]), padded_bio_tags.view(-1))
            aux_directed_distance_loss = directed_distance_criterion(directed_distance_logits.squeeze(-1), padded_directed_distances.float())

            total_loss += loss.item()
            total_aux_pos_loss += aux_pos_loss.item()
            total_aux_bio_loss += aux_bio_loss.item()
            total_aux_directed_distance_loss += aux_directed_distance_loss.item()

            mask = labels.ne(-100)
            valid_labels = labels[mask].cpu().numpy()
            valid_logits = logits[mask].cpu().numpy().squeeze()

            all_labels.extend(valid_labels)
            all_logits.extend(valid_logits)

        # Masking the -100 labels for auxiliary tasks
        aux_mask = padded_labels.ne(-100).view(-1)

        average_loss = total_loss / len(dataloader)
        average_aux_pos_loss = (total_aux_pos_loss / len(dataloader)) * aux_mask.float().mean()
        average_aux_bio_loss = (total_aux_bio_loss / len(dataloader)) * aux_mask.float().mean()
        average_aux_directed_distance_loss = (total_aux_directed_distance_loss / len(dataloader)) * aux_mask.float().mean()
    
    all_labels = np.array(all_labels)
    all_logits = np.array(all_logits)

    all_probs = 1 / (1 + np.exp(-all_logits))

    precision, recall, thresholds = precision_recall_curve(all_labels, all_probs)

    f_scores = np.where((precision + recall) != 0.0, (2 * precision * recall) / (precision + recall + 1e-10), 0)

    best_threshold = thresholds[np.argmax(f_scores)]

    preds = (all_probs > best_threshold).astype(int)
    accuracy = accuracy_score(all_labels, preds)
    best_f_score = f1_score(all_labels, preds)

    return average_loss, accuracy, best_f_score, best_threshold, average_aux_pos_loss, average_aux_bio_loss, average_aux_directed_distance_loss


In [36]:
def validate(model, dataloader):
    main_task_criterion = nn.BCEWithLogitsLoss()
    pos_tag_criterion = nn.CrossEntropyLoss()
    bio_tag_criterion = nn.CrossEntropyLoss()
    directed_distance_criterion = nn.MSELoss()

    model.eval()
    total_loss = 0
    total_aux_pos_loss = 0
    total_aux_bio_loss = 0
    total_aux_directed_distance_loss = 0
    
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch
            input_ids, predicate_idx, padded_labels = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device)
            padded_pos_tags, padded_bio_tags, padded_directed_distances = padded_pos_tags.to(device), padded_bio_tags.to(device), padded_directed_distances.to(device)

            logits, labels, pos_logits, bio_logits, directed_distance_logits = model(input_ids, predicate_idx, padded_labels)

            loss = main_task_criterion(logits, labels.float().unsqueeze(1))
            aux_pos_loss = pos_tag_criterion(pos_logits.view(-1, pos_logits.shape[-1]), padded_pos_tags.view(-1))
            aux_bio_loss = bio_tag_criterion(bio_logits.view(-1, bio_logits.shape[-1]), padded_bio_tags.view(-1))
            aux_directed_distance_loss = directed_distance_criterion(directed_distance_logits.squeeze(-1), padded_directed_distances.float())

            total_loss += loss.item()
            total_aux_pos_loss += aux_pos_loss.item()
            total_aux_bio_loss += aux_bio_loss.item()
            total_aux_directed_distance_loss += aux_directed_distance_loss.item()

            mask = labels.ne(-100)
            valid_labels = labels[mask].cpu().numpy()
            valid_logits = logits[mask].cpu().numpy().squeeze()

            all_labels.extend(valid_labels)
            all_logits.extend(valid_logits)

        # Masking the -100 labels for auxiliary tasks
        aux_mask = padded_labels.ne(-100).view(-1)

        average_loss = total_loss / len(dataloader)
        average_aux_pos_loss = (total_aux_pos_loss / len(dataloader)) * aux_mask.float().mean()
        average_aux_bio_loss = (total_aux_bio_loss / len(dataloader)) * aux_mask.float().mean()
        average_aux_directed_distance_loss = (total_aux_directed_distance_loss / len(dataloader)) * aux_mask.float().mean()
    
    all_labels = np.array(all_labels)
    all_logits = np.array(all_logits)

    all_probs = 1 / (1 + np.exp(-all_logits))

    precision, recall, thresholds = precision_recall_curve(all_labels, all_probs)

    f_scores = np.where((precision + recall) != 0.0, (2 * precision * recall) / (precision + recall + 1e-10), 0)

    best_threshold = thresholds[np.argmax(f_scores)]

    preds = (all_probs > best_threshold).astype(int)
    accuracy = accuracy_score(all_labels, preds)
    best_f_score = f1_score(all_labels, preds)

    return average_loss, accuracy, best_f_score, best_threshold, average_aux_pos_loss, average_aux_bio_loss, average_aux_directed_distance_loss


In [37]:

def train_model_aux(model, train_dataset, train_dataloader, val_dataloader, optimizer, num_epochs, task_weights, clip_grad_value=1, weighting_method='none', custom_value=20, patience=15):

    main_task_criterion = nn.BCEWithLogitsLoss()
    pos_tag_criterion = nn.CrossEntropyLoss()
    bio_tag_criterion = nn.CrossEntropyLoss()

    # For regression
    directed_distance_criterion = nn.MSELoss()
    
    if weighting_method != 'none':
        negative_count = 24 #sum([1 for label_seq in train_dataset.labels for label in label_seq if label == 0])
        positive_count = 1 #sum([1 for label_seq in train_dataset.labels for label in label_seq if label == 1])

        if weighting_method == 'direct':
            pos_weight = torch.tensor([negative_count / positive_count], device=device)
        elif weighting_method == 'log':
            pos_weight = torch.tensor([np.log(negative_count / positive_count)], device=device)
        elif weighting_method == 'custom':
            pos_weight = torch.tensor([custom_value], device=device)
        else:
            raise ValueError("Invalid weighting_method value. It must be 'none', 'direct', 'log', or 'custom'.")
    else:
        pos_weight = torch.tensor(1.0, device=device)

    train_accuracies = []
    val_accuracies = []
    train_f_scores = []
    val_f_scores = []
    avg_train_loss_per_epoch = []
    avg_val_loss_per_epoch = []

    # Early stopping initialization
    best_val_accuracy = float('-inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        total_train_loss = 0
        num_train_batches = 0

        for i, batch in enumerate(train_dataloader):
            model.train()
            input_ids, predicate_idx, padded_labels, padded_pos_tags, padded_bio_tags, padded_directed_distances = batch
            input_ids, predicate_idx, padded_labels = input_ids.to(device), predicate_idx.to(device), padded_labels.to(device)
            padded_pos_tags, padded_bio_tags, padded_directed_distances = padded_pos_tags.to(device), padded_bio_tags.to(device), padded_directed_distances.to(device)

            logits, labels, pos_logits, bio_logits, directed_distance_logits = model(input_ids, predicate_idx, padded_labels)

            main_task_criterion.pos_weight = pos_weight
            main_task_loss = main_task_criterion(logits, labels.float().unsqueeze(1))

            pos_tag_loss = pos_tag_criterion(pos_logits.view(-1, pos_logits.shape[-1]), padded_pos_tags.view(-1))
            bio_tag_loss = bio_tag_criterion(bio_logits.view(-1, bio_logits.shape[-1]), padded_bio_tags.view(-1))
            padded_directed_distances.unsqueeze(-1)
        #    print("directed_distance_logits",directed_distance_logits.shape)
       #     print("padded_directed_distances",padded_directed_distances.shape)
            directed_distance_loss = directed_distance_criterion(directed_distance_logits.float().view(-1, directed_distance_logits.shape[-1]), padded_directed_distances.float().view(-1,directed_distance_logits.shape[-1]))
         #   print("after")
          #  print(directed_distance_logits.float().view(-1, directed_distance_logits.shape[-1]).shape)
          #  print(padded_directed_distances.float().view(-1).shape)

    

           # total_loss = main_task_loss + pos_tag_loss + bio_tag_loss + directed_distance_loss
            task_weights = task_weights.float()
            total_loss = task_weights[0]*main_task_loss + task_weights[1]*pos_tag_loss + task_weights[2]*bio_tag_loss + task_weights[3]*directed_distance_loss


            total_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_value)

            optimizer.step()
            optimizer.zero_grad()

            total_train_loss += total_loss.item()
            num_train_batches += 1

        avg_train_loss_per_epoch.append(total_train_loss / num_train_batches)

        val_loss, val_accuracy, val_f_score, val_threshold, average_aux_pos_loss, average_aux_bio_loss, average_aux_directed_distance_loss = validate(model, val_dataloader)
        avg_val_loss_per_epoch.append(val_loss)

        print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Best F-score: {val_f_score}, Best Threshold: {val_threshold}, POS loss: {average_aux_pos_loss}, BIO loss; {average_aux_bio_loss}, Dir dist loss : {average_aux_directed_distance_loss}")

        if (epoch + 1) % 10 == 0:
            average_loss, train_accuracy, train_f_score, train_threshold, average_aux_pos_loss, average_aux_bio_loss, average_aux_directed_distance_loss= validate_on_train(model, train_dataloader)
            train_accuracies.append(train_accuracy)
            val_accuracies.append(val_accuracy)
            train_f_scores.append(train_f_score)
            val_f_scores.append(val_f_score)

            #torch.save(model.state_dict(), f"{model_path}/{file_name+str(epoch)+'.pth'}")
            #print("model saved")

            print(f"Train Loss: {val_loss}, Train Accuracy: {val_accuracy}, Best F-score: {train_f_score}, Best Threshold: {train_threshold}, POS loss: {average_aux_pos_loss}, BIO loss; {average_aux_bio_loss}, Dir dist loss : {average_aux_directed_distance_loss}")

         # Early stopping
        if val_accuracy > best_val_accuracy:
             best_val_accuracy = val_accuracy
             patience_counter = 0
        else:
            patience_counter += 1 

        if patience_counter >= patience:
          print(f"Early stopping triggered after {epoch + 1} epochs due to no improvement in validation accuracy")
          num_missing_values = num_epochs - epoch - 1
          train_accuracies.extend([None] * num_missing_values)
          val_accuracies.extend([None] * num_missing_values)
          train_f_scores.extend([None] * num_missing_values)
          val_f_scores.extend([None] * num_missing_values)

          #return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores
          return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores      



    return avg_train_loss_per_epoch, avg_val_loss_per_epoch, train_accuracies, val_accuracies, train_f_scores, val_f_scores        


In [38]:

batch_size = 8  # Choose a batch size according to your needs

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [39]:
num_epochs = 100
lstm_hidden_size= 50
dropout_rate = 0.3
learning_rate = 3e-5
clip_grad_value = 1.5
custom_weight_value = 27.0
layers_to_use = [1,2,3]
task_weights = torch.tensor([1.0, 0.1,0.1,0.06]).to(device)

In [40]:
model = SRLEnsembleModel(srl_aux_model, srl_feat_model, num_labels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

In [41]:
train_model(model = model, train_dataset = train_dataset, train_dataloader=train_dataloader , val_dataloader = val_dataloader,
                optimizer = optimizer, num_epochs = num_epochs, task_weights = task_weights, clip_grad_value=clip_grad_value, weighting_method='custom', custom_value=custom_weight_value , patience=25)

Epoch 1/100
Validation Loss: 0.32934448981285097, Validation Accuracy: 0.9383330492585649, Best F-score: 0.3996017258546299, Best Threshold: 0.8257491588592529, POS loss: 0.17547602951526642, BIO loss; 0.18892383575439453, Dir dist loss : 736.897216796875
Epoch 2/100
Validation Loss: 0.3291298671364784, Validation Accuracy: 0.9512527697289926, Best F-score: 0.449576597382602, Best Threshold: 0.8933901190757751, POS loss: 0.11921053379774094, BIO loss; 0.14621210098266602, Dir dist loss : 231.6236114501953
Epoch 3/100
Validation Loss: 0.2841283204257488, Validation Accuracy: 0.9602352139083007, Best F-score: 0.5009625668449197, Best Threshold: 0.9294103980064392, POS loss: 0.09342692792415619, BIO loss; 0.12603749334812164, Dir dist loss : 33.29819107055664
Epoch 4/100
Validation Loss: 0.18534785366058348, Validation Accuracy: 0.9651781148798364, Best F-score: 0.5493050959629385, Best Threshold: 0.8705502152442932, POS loss: 0.08021482825279236, BIO loss; 0.11479443311691284, Dir dist l

([275.4905132166321,
  119.07815915461374,
  28.565370393344093,
  5.8059071405072205,
  4.161684576886819,
  3.6236191428084785,
  3.20186154017968,
  2.9376959513957974,
  2.689681495639687,
  2.4743520652839295,
  2.2473141227591853,
  2.038776628070745,
  1.904306429158405,
  1.7402170220216464,
  1.6106767959251136,
  1.4650635750725316,
  1.3271743994212737,
  1.2078684006665419,
  1.1197761288976626,
  1.036059954169293,
  0.9516950527006377,
  0.8734609679055444,
  0.8146412822504664,
  0.7310982451615413,
  0.7126488216443934,
  0.6778602480875272,
  0.644144073580449,
  0.6001734368909013,
  0.5506792187016385,
  0.5077046431103469,
  0.49400528635574353,
  0.482317535167393,
  0.44183789452567357,
  0.40860982701143084,
  0.40343561627769514,
  0.39008871812916807,
  0.37408122018974777,
  0.3574363535797135,
  0.327671445215445,
  0.3188030057699667,
  0.3065150296872053,
  0.2932271952290736,
  0.3000912800683314,
  0.2658453937325306,
  0.26865465722500176,
  0.2468323868