- The base for this code was copied from [Transformers_for_Negation_and_Speculation](https://github.com/adityak6798/Transformers-For-Negation-and-Speculation/blob/master/Transformers_for_Negation_and_Speculation.ipynb)

In [None]:
!pip install transformers
!pip install knockknock==0.1.7
!pip install keras_preprocessing
!pip install datasets
!pip install sentencepiece

In [2]:
# Imports
import os, re, torch, html, tempfile, copy, json, math, shutil, tarfile, tempfile, sys, random, pickle
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, ReLU
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from keras_preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
from transformers import AutoModelForTokenClassification,AutoTokenizer
from datasets import load_dataset

import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
from scipy import stats
import csv
from datetime import datetime

In [3]:
#GLOBAL VARIABLEs
MAX_LEN = 256
bs = 16
EPOCHS = 32
PATIENCE = 8
INITIAL_LEARNING_RATE = 1e-5

SCOPE_METHOD = 'augment' # Options: augment, replace
F1_METHOD = 'average' # Options: average, first_token
TASK = 'negation'
SUBTASK = 'scope_resolution'

In [4]:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()

In [5]:
class Cues:
    def __init__(self, data):
        self.sentences = data[0]
        self.cues = data[1]
        self.num_sentences = len(data[0])
class Scopes:
    def __init__(self, data):
        self.sentences = data[0]
        self.cues = data[1]
        self.scopes = data[2]
        self.num_sentences = len(data[0])

In [6]:
class Data:
    def append(self, new_data):
      for d in new_data:
        self.cue_data.sentences += d.cue_data.sentences
        self.cue_data.cues += d.cue_data.cues
        self.cue_data.num_sentences += d.cue_data.num_sentences
        self.scope_data.sentences += d.scope_data.sentences
        self.scope_data.cues += d.scope_data.cues
        self.scope_data.scopes += d.scope_data.scopes
        self.scope_data.num_sentences += d.scope_data.num_sentences


    def __init__(self, data, dataset_name = 'sfu', frac_no_cue_sents = 1.0):
        self.dataset_name = dataset_name
        '''
        file: The path of the data file.
        dataset_name: The name of the dataset to be preprocessed. Values supported: sfu, bioscope, starsem.
        frac_no_cue_sents: The fraction of sentences to be included in the data object which have no negation/speculation cues.
        '''

        cue_sentence = []
        cue_cues = []
        no_cue_data = []
        scope_cues = []
        scope_scopes = []
        scope_sentence = []
        sentence = []
        in_scope = []
        in_cue = []
        word_num = 0
        c_idx = []
        s_idx = []
        for line in data:
          sentence = []
          for token in line['tokens']:
            sentence.append(token['text'])
          if not line['spans']:
            no_cue_data.append([sentence,[3]*len(sentence)])
          elif len(line['tokens'])>0:
            scopes = []
            cues = []
            for span in line['spans']:
              if span['label'] == 'SCOPE':
                scopes.append(span)
              elif span['label'] == 'CUE':
                cues.append(span)
            cue_sentence.append(sentence)
            scope_sentence.append(sentence)
            cue_cues.append([3]*len(sentence))
            scope_cues.append([3]*len(sentence))
            scope_scopes.append([0]*len(sentence))
            if len(cues) == 1:
              cue = cues[0]
              c_idx = np.arange(cue['token_start'],cue['token_end']+1)
              if len(c_idx) == 1:
                for c in c_idx:
                  cue_cues[-1][c] = 1
                  scope_cues[-1][c] = 1
              elif len(c_idx) >1:
                for c in c_idx:
                  cue_cues[-1][c] = 2
                  scope_cues[-1][c] = 2
            elif len(cues)>1:
              for cue in cues:
                c_idx = np.arange(cue['token_start'],cue['token_end']+1)
                for c in c_idx:
                  cue_cues[-1][c] = 2
                  scope_cues[-1][c] = 2
            for scope in scopes:
              s_idx = np.arange(scope['token_start'],scope['token_end']+1)
              for s in s_idx:
                scope_scopes[-1][s] = 1
            for idx, c in enumerate(cue_cues[-1]):
                  if c == 1 and scope_scopes[-1][idx] == 1:
                    cue_cues[-1][idx] = 0
                    scope_cues[-1][idx] = 0
        cue_only_samples = random.sample(no_cue_data, k=int(frac_no_cue_sents*len(no_cue_data)))
        cue_only_sents = [i[0] for i in cue_only_samples]
        cue_only_cues = [i[1] for i in cue_only_samples]

        self.cue_data = Cues((cue_sentence+cue_only_sents, cue_cues+cue_only_cues))
        self.scope_data = Scopes((scope_sentence, scope_cues, scope_scopes))

    def get_scope_dataloader(self, val_size = 0.15, test_size=0.15, other_datasets = [], combine = True):
        '''
        This function returns the dataloader for the cue detection.
        val_size: The size of the validation dataset (Fraction between 0 to 1)
        test_size: The size of the test dataset (Fraction between 0 to 1)
        other_datasets: Other datasets to use to get one combined train dataloader
        Returns: train_dataloader, list of validation dataloaders, list of test dataloaders
        '''
        method = SCOPE_METHOD
        do_lower_case = True
        if 'uncased' not in SCOPE_MODEL and 'cased' in SCOPE_MODEL:
            do_lower_case = False

        self.tokenizer = AutoTokenizer.from_pretrained(SCOPE_MODEL, do_lower_case=do_lower_case)

        def preprocess_data(obj):
            dl_sents = obj.scope_data.sentences
            dl_cues = obj.scope_data.cues
            dl_scopes = obj.scope_data.scopes

            sentences = [" ".join([s for s in sent]) for sent in dl_sents]
            mytexts = []
            mylabels = []
            mycues = []
            mymasks = []
            if do_lower_case == True:
                sentences_clean = [sent.lower() for sent in sentences]
            else:
                sentences_clean = sentences

            for sent, tags, cues in zip(sentences_clean,dl_scopes, dl_cues):
                new_tags = []
                new_text = []
                new_cues = []
                new_masks = []
                for word, tag, cue in zip(sent.split(),tags,cues):
                    sub_words = self.tokenizer.tokenize(word)
                    for count, sub_word in enumerate(sub_words):
                        mask = 1
                        if count > 0:
                            mask = 0
                        new_masks.append(mask)
                        new_tags.append(tag)
                        new_cues.append(cue)
                        new_text.append(sub_word)
                mymasks.append(new_masks)
                mytexts.append(new_text)
                mylabels.append(new_tags)
                mycues.append(new_cues)
            final_sentences = []
            final_labels = []
            final_masks = []
            if method == 'replace':
                for sent,cues in zip(mytexts, mycues):
                    temp_sent = []
                    for token,cue in zip(sent,cues):
                        if cue==3:
                            temp_sent.append(token)
                        else:
                            temp_sent.append(f'[unused{cue+1}]')
                    final_sentences.append(temp_sent)
                final_labels = mylabels
                final_masks = mymasks
            elif method == 'augment':
                for sent,cues,labels,masks in zip(mytexts, mycues, mylabels, mymasks):
                    temp_sent = []
                    temp_label = []
                    temp_masks = []
                    first_part = 0
                    for token,cue,label,mask in zip(sent,cues,labels,masks):
                        if cue!=3:
                            if first_part == 0:
                                first_part = 1
                                temp_sent.append(f'[unused{cue+1}]')
                                temp_masks.append(1)
                                temp_label.append(label)
                                temp_sent.append(token)
                                temp_masks.append(0)
                                temp_label.append(label)
                                continue
                            temp_sent.append(f'[unused{cue+1}]')
                            temp_masks.append(0)
                            temp_label.append(label)
                        else:
                            first_part = 0
                        temp_masks.append(mask)
                        temp_sent.append(token)
                        temp_label.append(label)
                    final_sentences.append(temp_sent)
                    final_labels.append(temp_label)
                    final_masks.append(temp_masks)
                    if (len(temp_sent) > MAX_LEN):
                      print("***WARNING*** \ninput over MAX LENGTH")
            else:
                raise ValueError("Supported methods for scope detection are:\nreplace\naugment")
            input_ids = pad_sequences([[self.tokenizer.convert_tokens_to_ids(word) for word in txt] for txt in final_sentences],
                                      maxlen=MAX_LEN, dtype="long", truncating="post", padding="post").tolist()

            tags = pad_sequences(final_labels,
                                maxlen=MAX_LEN, value=0, padding="post",
                                dtype="long", truncating="post").tolist()

            final_masks = pad_sequences(final_masks,
                                maxlen=MAX_LEN, value=0, padding="post",
                                dtype="long", truncating="post").tolist()

            attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

            return [input_ids, tags, attention_masks, final_masks]

        inputs = []
        tags = []
        masks = []
        mymasks = []
        ret_val = preprocess_data(self)
        inputs.append(ret_val[0])
        tags.append(ret_val[1])
        masks.append(ret_val[2])
        mymasks.append(ret_val[3])

        for idx, arg in enumerate(other_datasets):
            ret_val = preprocess_data(arg)
            if(combine):
                inputs[0]+=ret_val[0]
                tags[0]+=ret_val[1]
                masks[0]+=ret_val[2]
                mymasks[0]+=ret_val[3]
            else:
                inputs.append(ret_val[0])
                tags.append(ret_val[1])
                masks.append(ret_val[2])
                mymasks.append(ret_val[3])


        inputs = [torch.LongTensor(i) for i in inputs]
        tags = [torch.LongTensor(i) for i in tags]
        masks = [torch.LongTensor(i) for i in masks]
        mymasks = [torch.LongTensor(i) for i in mymasks]
        dataloaders = []
        for i,j,k,l in zip(inputs, tags, masks, mymasks):
            data = TensorDataset(i, k, j, l)
            sampler = RandomSampler(data)
            dataloaders.append(DataLoader(data, sampler=sampler, batch_size=bs))

        return dataloaders[0] if combine else dataloaders

    def get_test_dataloader(self, val_size = 0.15, test_size=0.15, other_datasets = [], combine = True):
        '''
        This function returns the dataloader for the cue detection.
        val_size: The size of the validation dataset (Fraction between 0 to 1)
        test_size: The size of the test dataset (Fraction between 0 to 1)
        other_datasets: Other datasets to use to get one combined train dataloader
        Returns: train_dataloader, list of validation dataloaders, list of test dataloaders
        '''
        method = SCOPE_METHOD
        do_lower_case = True
        if 'uncased' not in SCOPE_MODEL and 'cased' in SCOPE_MODEL:
            print("cased model")
            do_lower_case = False
        self.tokenizer = AutoTokenizer.from_pretrained(SCOPE_MODEL, do_lower_case=do_lower_case)

        def preprocess_data(obj):
            dl_sents = obj.scope_data.sentences
            dl_cues = obj.scope_data.cues
            dl_scopes = obj.scope_data.scopes

            sentences = [" ".join([s for s in sent]) for sent in dl_sents]
            mytexts = []
            mylabels = []
            mycues = []
            mymasks = []
            if do_lower_case == True:
                sentences_clean = [sent.lower() for sent in sentences]
            else:
                sentences_clean = sentences

            for sent, tags, cues in zip(sentences_clean,dl_scopes, dl_cues):
                new_tags = []
                new_text = []
                new_cues = []
                new_masks = []
                for word, tag, cue in zip(sent.split(),tags,cues):
                    sub_words = self.tokenizer.tokenize(word)
                    for count, sub_word in enumerate(sub_words):
                        mask = 1
                        if count > 0:
                            mask = 0
                        new_masks.append(mask)
                        new_tags.append(tag)
                        new_cues.append(cue)
                        new_text.append(sub_word)
                mymasks.append(new_masks)
                mytexts.append(new_text)
                mylabels.append(new_tags)
                mycues.append(new_cues)
            final_sentences = []
            final_labels = []
            final_masks = []
            if method == 'replace':
                for sent,cues in zip(mytexts, mycues):
                    temp_sent = []
                    for token,cue in zip(sent,cues):
                        if cue==3:
                            temp_sent.append(token)
                        else:
                            temp_sent.append(f'[unused{cue+1}]')
                    final_sentences.append(temp_sent)
                final_labels = mylabels
                final_masks = mymasks
            elif method == 'augment':
                for sent,cues,labels,masks in zip(mytexts, mycues, mylabels, mymasks):
                    temp_sent = []
                    temp_label = []
                    temp_masks = []
                    first_part = 0
                    for token,cue,label,mask in zip(sent,cues,labels,masks):
                        if cue!=3:
                            if first_part == 0:
                                first_part = 1
                                temp_sent.append(f'[unused{cue+1}]')
                                temp_masks.append(1)
                                temp_label.append(label)
                                temp_sent.append(token)
                                temp_masks.append(0)
                                temp_label.append(label)
                                continue
                            temp_sent.append(f'[unused{cue+1}]')
                            temp_masks.append(0)
                            temp_label.append(label)
                        else:
                            first_part = 0
                        temp_masks.append(mask)
                        temp_sent.append(token)
                        temp_label.append(label)
                    final_sentences.append(temp_sent)
                    final_labels.append(temp_label)
                    final_masks.append(temp_masks)
                    if (len(temp_sent) > MAX_LEN):
                      print("***WARNING*** \ninput over MAX LENGTH")
            else:
                raise ValueError("Supported methods for scope detection are:\nreplace\naugment")
            input_ids = pad_sequences([[self.tokenizer.convert_tokens_to_ids(word) for word in txt] for txt in final_sentences],
                                      maxlen=MAX_LEN, dtype="long", truncating="post", padding="post").tolist()

            tags = pad_sequences(final_labels,
                                maxlen=MAX_LEN, value=0, padding="post",
                                dtype="long", truncating="post").tolist()

            final_masks = pad_sequences(final_masks,
                                maxlen=MAX_LEN, value=0, padding="post",
                                dtype="long", truncating="post").tolist()

            attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

            return [input_ids, tags, attention_masks, final_masks]

        inputs = []
        tags = []
        masks = []
        mymasks = []
        ret_val = preprocess_data(self)
        inputs.append(ret_val[0])
        tags.append(ret_val[1])
        masks.append(ret_val[2])
        mymasks.append(ret_val[3])

        for idx, arg in enumerate(other_datasets):
            ret_val = preprocess_data(arg)
            if(combine):
                inputs[0]+=ret_val[0]
                tags[0]+=ret_val[1]
                masks[0]+=ret_val[2]
                mymasks[0]+=ret_val[3]
            else:
                inputs.append(ret_val[0])
                tags.append(ret_val[1])
                masks.append(ret_val[2])
                mymasks.append(ret_val[3])


        inputs = [torch.LongTensor(i) for i in inputs]
        tags = [torch.LongTensor(i) for i in tags]
        masks = [torch.LongTensor(i) for i in masks]
        mymasks = [torch.LongTensor(i) for i in mymasks]
        dataloaders = []
        for i,k,l in zip(inputs, masks, mymasks):
            data = TensorDataset(i, k, l)
            sampler = RandomSampler(data)
            dataloaders.append(DataLoader(data, sampler=sampler, batch_size=bs))

        return (dataloaders[0], tags) if combine else (dataloaders, tags)



In [7]:
def f1_cues(y_true, y_pred):
    '''Needs flattened cues'''
    tp = sum([1 for i,j in zip(y_true, y_pred) if (i==j and i!=3)])
    fp = sum([1 for i,j in zip(y_true, y_pred) if (j!=3 and i==3)])
    fn = sum([1 for i,j in zip(y_true, y_pred) if (i!=3 and j==3)])
    if tp==0:
        prec = 0.0001
        rec = 0.0001
    else:
        prec = tp/(tp+fp)
        rec = tp/(tp+fn)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    print(f"F1 Score: {2*prec*rec/(prec+rec)}")
    return prec, rec, 2*prec*rec/(prec+rec)


def f1_scope(y_true, y_pred, level = 'token'): #This is for gold cue annotation scope, thus the precision is always 1.
    if level == 'token':
        print(f1_score([i for i in j for j in y_true], [i for i in j for j in y_pred]))
    elif level == 'scope':
        tp = 0
        fn = 0
        fp = 0
        for y_t, y_p in zip(y_true, y_pred):
            if y_t == y_p:
                tp+=1
            else:
                fn+=1
        prec = 1
        rec = tp/(tp+fn)
        print(f"Precision: {prec}")
        print(f"Recall: {rec}")
        print(f"F1 Score: {2*prec*rec/(prec+rec)}")

def report_per_class_accuracy(y_true, y_pred):
    labels = list(np.unique(y_true))
    lab = list(np.unique(y_pred))
    labels = list(np.unique(labels+lab))
    n_labels = len(labels)
    data = pd.DataFrame(columns = labels, index = labels, data = np.zeros((n_labels, n_labels)))
    for i,j in zip(y_true, y_pred):
        data.at[i,j]+=1
    print(data)

def flat_accuracy(preds, labels, input_mask = None):
    pred_flat = [i for j in preds for i in j]
    labels_flat = [i for j in labels for i in j]
    return sum([1 if i==j else 0 for i,j in zip(pred_flat,labels_flat)]) / len(labels_flat)


def flat_accuracy_positive_cues(preds, labels, input_mask = None):
    pred_flat = [i for i,j in zip([i for j in preds for i in j],[i for j in labels for i in j]) if (j!=4 and j!=3)]
    labels_flat = [i for i in [i for j in labels for i in j] if (i!=4 and i!=3)]
    if len(labels_flat) != 0:
        return sum([1 if i==j else 0 for i,j in zip(pred_flat,labels_flat)]) / len(labels_flat)
    else:
        return None

def scope_accuracy(preds, labels):
    correct_count = 0
    count = 0
    for i,j in zip(preds, labels):
        if i==j:
            correct_count+=1
        count+=1
    return correct_count/count


In [8]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = 0

    def __call__(self, score, model):
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(score, model)
        elif score < self.best_score:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(score, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation F1 increased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [9]:
class ScopeModel:
    def __init__(self, full_finetuning = True, train = False, pretrained_model_path = 'Scope_Resolution_Augment.pickle', device = 'cuda', learning_rate = 3e-5):
        self.model_name = SCOPE_MODEL
        self.task = TASK
        self.num_labels = 2
        if train == True:
            self.model = AutoModelForTokenClassification.from_pretrained(SCOPE_MODEL, num_labels=self.num_labels)
        else:
            self.model = torch.load(pretrained_model_path)
        self.device = torch.device(device)
        if device=='cuda':
            self.model.cuda()
        else:
            self.model.cpu()

        if full_finetuning:
            param_optimizer = list(self.model.named_parameters())
            no_decay = ['bias', 'gamma', 'beta']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                 'weight_decay_rate': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                 'weight_decay_rate': 0.0}
            ]
        else:
            param_optimizer = list(self.model.classifier.named_parameters())
            optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
        self.optimizer = Adam(optimizer_grouped_parameters, lr=learning_rate)

    def train(self, train_dataloader, valid_dataloaders, train_dl_name, val_dl_name, epochs = 5, max_grad_norm = 1.0, patience = 3):
        self.train_dl_name = train_dl_name
        return_dict = {"Task": f"{self.task} Scope Resolution",
                       "Model": self.model_name,
                       "Train Dataset": train_dl_name,
                       "Val Dataset": val_dl_name,
                       "Best Precision": 0,
                       "Best Recall": 0,
                       "Best F1": 0}
        train_loss = []
        valid_loss = []
        early_stopping = EarlyStopping(patience=patience, verbose=True)
        loss_fn = CrossEntropyLoss()
        for _ in tqdm(range(epochs), desc="Epoch"):
            self.model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(self.device) for t in batch)
                b_input_ids, b_input_mask, b_labels, b_mymasks = batch
                logits = self.model(b_input_ids,
                             attention_mask=b_input_mask)[0]
                active_loss = b_input_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss] #2 is num_labels
                active_labels = b_labels.view(-1)[active_loss]
                loss = loss_fn(active_logits, active_labels)
                loss.backward()
                tr_loss += loss.item()
                train_loss.append(loss.item())
                if step%100 == 0:
                    print(f"Batch {step}, loss {loss.item()}")
                nb_tr_examples += b_input_ids.size(0)
                nb_tr_steps += 1
                torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_grad_norm)
                self.optimizer.step()
                self.model.zero_grad()
            print("Train loss: {}".format(tr_loss/nb_tr_steps))

            self.model.eval()

            eval_loss, eval_accuracy, eval_scope_accuracy = 0, 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            predictions , true_labels, ip_mask = [], [], []
            loss_fn = CrossEntropyLoss()
            for valid_dataloader in valid_dataloaders:
                for batch in valid_dataloader:
                    batch = tuple(t.to(self.device) for t in batch)
                    b_input_ids, b_input_mask, b_labels, b_mymasks = batch

                    with torch.no_grad():
                        logits = self.model(b_input_ids,
                                      attention_mask=b_input_mask)[0]
                        active_loss = b_input_mask.view(-1) == 1
                        active_logits = logits.view(-1, self.num_labels)[active_loss]
                        active_labels = b_labels.view(-1)[active_loss]
                        tmp_eval_loss = loss_fn(active_logits, active_labels)

                    logits = logits.detach().cpu().numpy()
                    label_ids = b_labels.to('cpu').numpy()
                    b_input_ids = b_input_ids.to('cpu').numpy()

                    mymasks = b_mymasks.to('cpu').numpy()

                    if F1_METHOD == 'first_token':

                        logits = [list(p) for p in np.argmax(logits, axis=2)]
                        actual_logits = []
                        actual_label_ids = []
                        for l,lid,m in zip(logits, label_ids, mymasks):
                            actual_logits.append([i for i,j in zip(l,m) if j==1])
                            actual_label_ids.append([i for i,j in zip(lid, m) if j==1])

                        logits = actual_logits
                        label_ids = actual_label_ids

                        predictions.append(logits)
                        true_labels.append(label_ids)
                    elif F1_METHOD == 'average':

                        logits = [list(p) for p in logits]

                        actual_logits = []
                        actual_label_ids = []

                        for l,lid,m,b_ii in zip(logits, label_ids, mymasks, b_input_ids):

                            actual_label_ids.append([i for i,j in zip(lid, m) if j==1])
                            my_logits = []
                            curr_preds = []
                            in_split = 0
                            for i,j,k in zip(l,m, b_ii):
                                '''if k == 0:
                                    break'''
                                if j==1:
                                    if in_split == 1:
                                        if len(my_logits)>0:
                                            curr_preds.append(my_logits[-1])
                                        mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                                        if len(my_logits)>0:
                                            my_logits[-1] = mode_pred
                                        else:
                                            my_logits.append(mode_pred)
                                        curr_preds = []
                                        in_split = 0
                                    my_logits.append(np.argmax(i))
                                if j==0:
                                    curr_preds.append(i)
                                    in_split = 1
                            if in_split == 1:
                                if len(my_logits)>0:
                                    curr_preds.append(my_logits[-1])
                                mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                                if len(my_logits)>0:
                                    my_logits[-1] = mode_pred
                                else:
                                    my_logits.append(mode_pred)
                            actual_logits.append(my_logits)

                        predictions.append(actual_logits)
                        true_labels.append(actual_label_ids)

                    tmp_eval_accuracy = flat_accuracy(actual_logits, actual_label_ids)
                    tmp_eval_scope_accuracy = scope_accuracy(actual_logits, actual_label_ids)
                    eval_scope_accuracy += tmp_eval_scope_accuracy
                    valid_loss.append(tmp_eval_loss.mean().item())

                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += len(b_input_ids)
                    nb_eval_steps += 1
                eval_loss = eval_loss/nb_eval_steps
            print("Validation loss: {}".format(eval_loss))
            print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
            print("Validation Accuracy Scope Level: {}".format(eval_scope_accuracy/nb_eval_steps))
            f1_scope([j for i in true_labels for j in i], [j for i in predictions for j in i], level='scope')
            labels_flat = [l_ii for l in true_labels for l_i in l for l_ii in l_i]
            pred_flat = [p_ii for p in predictions for p_i in p for p_ii in p_i]
            classification_dict = classification_report(labels_flat, pred_flat, output_dict= True)
            p = classification_dict["1"]["precision"]
            r = classification_dict["1"]["recall"]
            f1 = classification_dict["1"]["f1-score"]
            if f1>return_dict['Best F1']:
                return_dict['Best F1'] = f1
                return_dict['Best Precision'] = p
                return_dict['Best Recall'] = r
            print("F1-Score Token: {}".format(f1))
            print(classification_report(labels_flat, pred_flat))
            print(f'F1: {f1}')
            early_stopping(f1, self.model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

        self.model.load_state_dict(torch.load('checkpoint.pt'))
        plt.xlabel("Iteration")
        plt.ylabel("Train Loss")
        plt.plot([i for i in range(len(train_loss))], train_loss)
        plt.figure()
        plt.xlabel("Iteration")
        plt.ylabel("Validation Loss")
        plt.plot([i for i in range(len(valid_loss))], valid_loss)
        return return_dict

    def evaluate(self, test_dataloader, test_dl_name):
        return_dict = {"Task": f"{self.task} Scope Resolution",
                       "Model": self.model_name,
                       "Train Dataset": self.train_dl_name,
                       "Test Dataset": test_dl_name,
                       "Precision": 0,
                       "Recall": 0,
                       "F1": 0}
        self.model.eval()
        eval_loss, eval_accuracy, eval_scope_accuracy = 0, 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predictions , true_labels, ip_mask = [], [], []
        loss_fn = CrossEntropyLoss()
        for batch in test_dataloader:
            batch = tuple(t.to(self.device) for t in batch)
            b_input_ids, b_input_mask, b_labels, b_mymasks = batch

            with torch.no_grad():
                logits = self.model(b_input_ids,
                               attention_mask=b_input_mask)[0]
                active_loss = b_input_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss] #5 is num_labels
                active_labels = b_labels.view(-1)[active_loss]
                tmp_eval_loss = loss_fn(active_logits, active_labels)

            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            b_input_ids = b_input_ids.to('cpu').numpy()

            mymasks = b_mymasks.to('cpu').numpy()

            if F1_METHOD == 'first_token':

                logits = [list(p) for p in np.argmax(logits, axis=2)]
                actual_logits = []
                actual_label_ids = []
                for l,lid,m in zip(logits, label_ids, mymasks):
                    actual_logits.append([i for i,j in zip(l,m) if j==1])
                    actual_label_ids.append([i for i,j in zip(lid, m) if j==1])

                logits = actual_logits
                label_ids = actual_label_ids

                predictions.append(logits)
                true_labels.append(label_ids)

            elif F1_METHOD == 'average':

                logits = [list(p) for p in logits]

                actual_logits = []
                actual_label_ids = []

                for l,lid,m,b_ii in zip(logits, label_ids, mymasks, b_input_ids):

                    actual_label_ids.append([i for i,j in zip(lid, m) if j==1])
                    my_logits = []
                    curr_preds = []
                    in_split = 0
                    for i,j,k in zip(l,m,b_ii):
                        '''if k == 0:
                            break'''
                        if j==1:
                            if in_split == 1:
                                if len(my_logits)>0:
                                    curr_preds.append(my_logits[-1])
                                mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                                if len(my_logits)>0:
                                    my_logits[-1] = mode_pred
                                else:
                                    my_logits.append(mode_pred)
                                curr_preds = []
                                in_split = 0
                            my_logits.append(np.argmax(i))
                        if j==0:
                            curr_preds.append(i)
                            in_split = 1
                    if in_split == 1:
                        if len(my_logits)>0:
                            curr_preds.append(my_logits[-1])
                        mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                        if len(my_logits)>0:
                            my_logits[-1] = mode_pred
                        else:
                            my_logits.append(mode_pred)
                    actual_logits.append(my_logits)

                predictions.append(actual_logits)
                true_labels.append(actual_label_ids)

            tmp_eval_accuracy = flat_accuracy(actual_logits, actual_label_ids)
            tmp_eval_scope_accuracy = scope_accuracy(actual_logits, actual_label_ids)
            eval_scope_accuracy += tmp_eval_scope_accuracy

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += len(b_input_ids)
            nb_eval_steps += 1
        eval_loss = eval_loss/nb_eval_steps
        print("Validation loss: {}".format(eval_loss))
        print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
        print("Validation Accuracy Scope Level: {}".format(eval_scope_accuracy/nb_eval_steps))
        f1_scope([j for i in true_labels for j in i], [j for i in predictions for j in i], level='scope')
        labels_flat = [l_ii for l in true_labels for l_i in l for l_ii in l_i]
        pred_flat = [p_ii for p in predictions for p_i in p for p_ii in p_i]
        classification_dict = classification_report(labels_flat, pred_flat, output_dict= True)
        p = classification_dict["1"]["precision"]
        r = classification_dict["1"]["recall"]
        f1 = classification_dict["1"]["f1-score"]
        return_dict['Precision'] = p
        return_dict['Recall'] = r
        return_dict['F1'] = f1
        return_dict['Mean Pred Scope Length'] = np.sum(pred_flat)/len(pred_flat)
        return_dict['Mean Real Scope Length'] = np.sum(labels_flat)/len(labels_flat)
        print("Classification Report:")
        print(classification_report(labels_flat, pred_flat))
        return return_dict

    def predict(self, dataloader):
        self.model.eval()
        predictions, ip_mask = [], []
        for batch in dataloader:
            batch = tuple(t.to(self.device) for t in batch)
            b_input_ids, b_input_mask, b_mymasks = batch

            with torch.no_grad():
                logits = self.model(b_input_ids, attention_mask=b_input_mask)[0]
            logits = logits.detach().cpu().numpy()
            mymasks = b_mymasks.to('cpu').numpy()

            if F1_METHOD == 'first_token':

                logits = [list(p) for p in np.argmax(logits, axis=2)]
                actual_logits = []
                for l,lid,m in zip(logits, label_ids, mymasks):
                    actual_logits.append([i for i,j in zip(l,m) if j==1])

                logits = actual_logits
                label_ids = actual_label_ids

                predictions.append(logits)
                true_labels.append(label_ids)

            elif F1_METHOD == 'average':

                logits = [list(p) for p in logits]

                actual_logits = []

                for l,m in zip(logits, mymasks):

                    my_logits = []
                    curr_preds = []
                    in_split = 0
                    for i,j in zip(l,m):

                        if j==1:
                            if in_split == 1:
                                if len(my_logits)>0:
                                    curr_preds.append(my_logits[-1])
                                mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                                if len(my_logits)>0:
                                    my_logits[-1] = mode_pred
                                else:
                                    my_logits.append(mode_pred)
                                curr_preds = []
                                in_split = 0
                            my_logits.append(np.argmax(i))
                        if j==0:
                            curr_preds.append(i)
                            in_split = 1
                    if in_split == 1:
                        if len(my_logits)>0:
                            curr_preds.append(my_logits[-1])
                        mode_pred = np.argmax(np.average(np.array(curr_preds), axis=0), axis=0)
                        if len(my_logits)>0:
                            my_logits[-1] = mode_pred
                        else:
                            my_logits.append(mode_pred)
                    actual_logits.append(my_logits)

                predictions.append(actual_logits)
        return predictions

Load data from [huggingface](https://huggingface.co/datasets/rcds/MultiLegalNeg)

In [10]:
de_dataset = load_dataset("rcds/MultiLegalNeg", "de")
fr_dataset = load_dataset("rcds/MultiLegalNeg", "fr")
it_dataset = load_dataset("rcds/MultiLegalNeg", "it")
ch_dataset = load_dataset("rcds/MultiLegalNeg", "swiss")
fr_dalloux = load_dataset("rcds/MultiLegalNeg", "fr_dalloux")
en_sfu = load_dataset("rcds/MultiLegalNeg", "en_sfu")
en_bioscope = load_dataset("rcds/MultiLegalNeg", "en_bioscope")
en_sherlock = load_dataset("rcds/MultiLegalNeg", "en_sherlock")

Downloading builder script:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds___multi_legal_neg/de/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec)
INFO:datasets.builder:Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds___multi_legal_neg/de/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec)
Downloading and preparing dataset multi_legal_neg/de to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/de/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec...
INFO:datasets.builder:Downloading and preparing dataset multi_legal_neg/de to /root/.cache/

Downloading data:   0%|          | 0.00/164k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/de_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/13787decc95c218af83991ab761eeb23834b86ec905644a65d4b87f7e4ba5fba
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/de_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/13787decc95c218af83991ab761eeb23834b86ec905644a65d4b87f7e4ba5fba
creating metadata file for /root/.cache/huggingface/datasets/downloads/13787decc95c218af83991ab761eeb23834b86ec905644a65d4b87f7e4ba5fba
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/13787decc95c218af83991ab761eeb23834b86ec905644a65d4b87f7e4ba5fba
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/d

Downloading data:   0%|          | 0.00/55.3k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/de_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/bf409f10a85751b357f153399f9c51eb8e3d169808793ff76629e6edf3ea4128
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/de_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/bf409f10a85751b357f153399f9c51eb8e3d169808793ff76629e6edf3ea4128
creating metadata file for /root/.cache/huggingface/datasets/downloads/bf409f10a85751b357f153399f9c51eb8e3d169808793ff76629e6edf3ea4128
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/bf409f10a85751b357f153399f9c51eb8e3d169808793ff76629e6edf3ea4128
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/datas

Downloading data:   0%|          | 0.00/29.8k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/de_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/163cc0c56784d9211952c3fceddac1811893bce4441465a9aae3c923cc26674d
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/de_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/163cc0c56784d9211952c3fceddac1811893bce4441465a9aae3c923cc26674d
creating metadata file for /root/.cache/huggingface/datasets/downloads/163cc0c56784d9211952c3fceddac1811893bce4441465a9aae3c923cc26674d
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/163cc0c56784d9211952c3fceddac1811893bce4441465a9aae3c923cc26674d
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
Gene

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/de/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/de/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds___mul

Downloading data:   0%|          | 0.00/216k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/fr_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/51cab29516fd75d229a2b3685f2f8c3c719dcaa2176662f5548f84d383554b50
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/fr_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/51cab29516fd75d229a2b3685f2f8c3c719dcaa2176662f5548f84d383554b50
creating metadata file for /root/.cache/huggingface/datasets/downloads/51cab29516fd75d229a2b3685f2f8c3c719dcaa2176662f5548f84d383554b50
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/51cab29516fd75d229a2b3685f2f8c3c719dcaa2176662f5548f84d383554b50
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/d

Downloading data:   0%|          | 0.00/71.6k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/fr_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/800581e2bd810e76ed533ff07fb419344a8fc879488d7d2607431960b8225a76
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/fr_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/800581e2bd810e76ed533ff07fb419344a8fc879488d7d2607431960b8225a76
creating metadata file for /root/.cache/huggingface/datasets/downloads/800581e2bd810e76ed533ff07fb419344a8fc879488d7d2607431960b8225a76
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/800581e2bd810e76ed533ff07fb419344a8fc879488d7d2607431960b8225a76
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/datas

Downloading data:   0%|          | 0.00/41.4k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/fr_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/6db0da4e8ea0c83a2e2cbf4af53256137a7dc827350e5ecc31576d620f6a23d0
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/fr_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/6db0da4e8ea0c83a2e2cbf4af53256137a7dc827350e5ecc31576d620f6a23d0
creating metadata file for /root/.cache/huggingface/datasets/downloads/6db0da4e8ea0c83a2e2cbf4af53256137a7dc827350e5ecc31576d620f6a23d0
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/6db0da4e8ea0c83a2e2cbf4af53256137a7dc827350e5ecc31576d620f6a23d0
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
Gene

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/fr/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/fr/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds___mul

Downloading data:   0%|          | 0.00/187k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/it_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/e0a2945adb33ae3857bdefe72ebf2ca9f3040226702e506c6c60db118251986c
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/it_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/e0a2945adb33ae3857bdefe72ebf2ca9f3040226702e506c6c60db118251986c
creating metadata file for /root/.cache/huggingface/datasets/downloads/e0a2945adb33ae3857bdefe72ebf2ca9f3040226702e506c6c60db118251986c
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/e0a2945adb33ae3857bdefe72ebf2ca9f3040226702e506c6c60db118251986c
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/d

Downloading data:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/it_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/a02e3ae8dd6a327106998468365dfd4ac5c351193bf8833ce5905d0658c1d08f
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/it_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/a02e3ae8dd6a327106998468365dfd4ac5c351193bf8833ce5905d0658c1d08f
creating metadata file for /root/.cache/huggingface/datasets/downloads/a02e3ae8dd6a327106998468365dfd4ac5c351193bf8833ce5905d0658c1d08f
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/a02e3ae8dd6a327106998468365dfd4ac5c351193bf8833ce5905d0658c1d08f
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co/datas

Downloading data:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/it_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/c295aa329eaa009f88850c61b4828946a602dc9d047b2a635b568bc825b50c7e
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/it_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/c295aa329eaa009f88850c61b4828946a602dc9d047b2a635b568bc825b50c7e
creating metadata file for /root/.cache/huggingface/datasets/downloads/c295aa329eaa009f88850c61b4828946a602dc9d047b2a635b568bc825b50c7e
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/c295aa329eaa009f88850c61b4828946a602dc9d047b2a635b568bc825b50c7e
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
Gene

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/it/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/it/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds___mul

Downloading data:   0%|          | 0.00/28.3k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/swiss_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/3ff7aa0b1a02318fd83ac07bec802e750e6c7cdc44bba1b7535111de01705590
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/swiss_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/3ff7aa0b1a02318fd83ac07bec802e750e6c7cdc44bba1b7535111de01705590
creating metadata file for /root/.cache/huggingface/datasets/downloads/3ff7aa0b1a02318fd83ac07bec802e750e6c7cdc44bba1b7535111de01705590
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/3ff7aa0b1a02318fd83ac07bec802e750e6c7cdc44bba1b7535111de01705590
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingfac

Downloading data:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/swiss_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/f067ca98f0a6df9f214aff67616b0472da5049aa64a399a60d6310df07f10ec0
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/swiss_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/f067ca98f0a6df9f214aff67616b0472da5049aa64a399a60d6310df07f10ec0
creating metadata file for /root/.cache/huggingface/datasets/downloads/f067ca98f0a6df9f214aff67616b0472da5049aa64a399a60d6310df07f10ec0
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/f067ca98f0a6df9f214aff67616b0472da5049aa64a399a60d6310df07f10ec0
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.co

Downloading data:   0%|          | 0.00/5.85k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/swiss_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/2fa2b6b979356286642e0e84f762955e5b3d1edc15921615334e9bbe6c1a6f53
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/swiss_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/2fa2b6b979356286642e0e84f762955e5b3d1edc15921615334e9bbe6c1a6f53
creating metadata file for /root/.cache/huggingface/datasets/downloads/2fa2b6b979356286642e0e84f762955e5b3d1edc15921615334e9bbe6c1a6f53
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/2fa2b6b979356286642e0e84f762955e5b3d1edc15921615334e9bbe6c1a6f53
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 mi

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/swiss/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/swiss/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rcds

Downloading data:   0%|          | 0.00/449k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/fr_dalloux_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/4702f2ef8d2b4709e614c24578c92a52007d0fce965626f1e5ee032763a39a54
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/fr_dalloux_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/4702f2ef8d2b4709e614c24578c92a52007d0fce965626f1e5ee032763a39a54
creating metadata file for /root/.cache/huggingface/datasets/downloads/4702f2ef8d2b4709e614c24578c92a52007d0fce965626f1e5ee032763a39a54
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/4702f2ef8d2b4709e614c24578c92a52007d0fce965626f1e5ee032763a39a54
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://

Downloading data:   0%|          | 0.00/144k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/fr_dalloux_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/0874929d2efb0d9a1e857788fbc2e0ac0cb95d7b90b8acf75115dbd4496ecea7
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/fr_dalloux_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/0874929d2efb0d9a1e857788fbc2e0ac0cb95d7b90b8acf75115dbd4496ecea7
creating metadata file for /root/.cache/huggingface/datasets/downloads/0874929d2efb0d9a1e857788fbc2e0ac0cb95d7b90b8acf75115dbd4496ecea7
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/0874929d2efb0d9a1e857788fbc2e0ac0cb95d7b90b8acf75115dbd4496ecea7
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://hugg

Downloading data:   0%|          | 0.00/78.9k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/fr_dalloux_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/dd5fca5e8c0b84e4c86d953efdb04c161ac9b472eaced0fe96f03715bf535886
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/fr_dalloux_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/dd5fca5e8c0b84e4c86d953efdb04c161ac9b472eaced0fe96f03715bf535886
creating metadata file for /root/.cache/huggingface/datasets/downloads/dd5fca5e8c0b84e4c86d953efdb04c161ac9b472eaced0fe96f03715bf535886
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/dd5fca5e8c0b84e4c86d953efdb04c161ac9b472eaced0fe96f03715bf535886
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation t

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/fr_dalloux/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/fr_dalloux/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/dat

Downloading data:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_sfu_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/e793ac3b83b2eb7cdfa69fef2343144341d9fe21746b82053a7a9c601a808156
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_sfu_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/e793ac3b83b2eb7cdfa69fef2343144341d9fe21746b82053a7a9c601a808156
creating metadata file for /root/.cache/huggingface/datasets/downloads/e793ac3b83b2eb7cdfa69fef2343144341d9fe21746b82053a7a9c601a808156
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/e793ac3b83b2eb7cdfa69fef2343144341d9fe21746b82053a7a9c601a808156
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingf

Downloading data:   0%|          | 0.00/469k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_sfu_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/af75bd9fd5b758a60851f4709a51fd494abf9fc16d2be34748488889444c748f
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_sfu_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/af75bd9fd5b758a60851f4709a51fd494abf9fc16d2be34748488889444c748f
creating metadata file for /root/.cache/huggingface/datasets/downloads/af75bd9fd5b758a60851f4709a51fd494abf9fc16d2be34748488889444c748f
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/af75bd9fd5b758a60851f4709a51fd494abf9fc16d2be34748488889444c748f
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://huggingface.

Downloading data:   0%|          | 0.00/240k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_sfu_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/ac3399aa3319a3aff9b9885ea220c9f6f20f920ae2807cc2c936dcec484eb4cf
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_sfu_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/ac3399aa3319a3aff9b9885ea220c9f6f20f920ae2807cc2c936dcec484eb4cf
creating metadata file for /root/.cache/huggingface/datasets/downloads/ac3399aa3319a3aff9b9885ea220c9f6f20f920ae2807cc2c936dcec484eb4cf
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/ac3399aa3319a3aff9b9885ea220c9f6f20f920ae2807cc2c936dcec484eb4cf
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_sfu/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_sfu/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/datasets/rc

Downloading data:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_bioscope_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/f507ba609457daf88e9068a6ef62fd47a8b46fc9c21c6c6c841be4763f65e2db
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_bioscope_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/f507ba609457daf88e9068a6ef62fd47a8b46fc9c21c6c6c841be4763f65e2db
creating metadata file for /root/.cache/huggingface/datasets/downloads/f507ba609457daf88e9068a6ef62fd47a8b46fc9c21c6c6c841be4763f65e2db
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/f507ba609457daf88e9068a6ef62fd47a8b46fc9c21c6c6c841be4763f65e2db
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https:

Downloading data:   0%|          | 0.00/521k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_bioscope_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/ea94b00747a39b0fff249773b6528975bc512a9e3f035cb30f312cacb4e86ed6
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_bioscope_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/ea94b00747a39b0fff249773b6528975bc512a9e3f035cb30f312cacb4e86ed6
creating metadata file for /root/.cache/huggingface/datasets/downloads/ea94b00747a39b0fff249773b6528975bc512a9e3f035cb30f312cacb4e86ed6
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/ea94b00747a39b0fff249773b6528975bc512a9e3f035cb30f312cacb4e86ed6
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://hu

Downloading data:   0%|          | 0.00/269k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_bioscope_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/73ae168bfeb20fd9921f6e8bfca805ab3480c9bf063d715bab08b360b9c7abd1
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_bioscope_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/73ae168bfeb20fd9921f6e8bfca805ab3480c9bf063d715bab08b360b9c7abd1
creating metadata file for /root/.cache/huggingface/datasets/downloads/73ae168bfeb20fd9921f6e8bfca805ab3480c9bf063d715bab08b360b9c7abd1
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/73ae168bfeb20fd9921f6e8bfca805ab3480c9bf063d715bab08b360b9c7abd1
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_bioscope/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_bioscope/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/rcds--MultiLegalNeg/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec
Generating dataset multi_legal_neg (/root/.cache/huggingface/d

Downloading data:   0%|          | 0.00/399k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_sherlock_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/645cda412e8d2c4f449131551723144dfef05dc3addeceae541920e568bb08c8
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/train/en_sherlock_train.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/645cda412e8d2c4f449131551723144dfef05dc3addeceae541920e568bb08c8
creating metadata file for /root/.cache/huggingface/datasets/downloads/645cda412e8d2c4f449131551723144dfef05dc3addeceae541920e568bb08c8
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/645cda412e8d2c4f449131551723144dfef05dc3addeceae541920e568bb08c8
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https:

Downloading data:   0%|          | 0.00/131k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_sherlock_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/82f10929629e22f8dc0f54eae2c3bad0e1c345ca7b59e1f474fde82c9a1a233b
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/test/en_sherlock_test.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/82f10929629e22f8dc0f54eae2c3bad0e1c345ca7b59e1f474fde82c9a1a233b
creating metadata file for /root/.cache/huggingface/datasets/downloads/82f10929629e22f8dc0f54eae2c3bad0e1c345ca7b59e1f474fde82c9a1a233b
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/82f10929629e22f8dc0f54eae2c3bad0e1c345ca7b59e1f474fde82c9a1a233b
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
https://hu

Downloading data:   0%|          | 0.00/93.8k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_sherlock_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/053ef0aa14a81505a4dda0b7c965a9554f47b72bc7990b6124f56eecc865b472
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/rcds/MultiLegalNeg/resolve/main/data/validation/en_sherlock_validation.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/053ef0aa14a81505a4dda0b7c965a9554f47b72bc7990b6124f56eecc865b472
creating metadata file for /root/.cache/huggingface/datasets/downloads/053ef0aa14a81505a4dda0b7c965a9554f47b72bc7990b6124f56eecc865b472
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/053ef0aa14a81505a4dda0b7c965a9554f47b72bc7990b6124f56eecc865b472
Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split
INFO:datasets.builder:Generating test split


Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_sherlock/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset multi_legal_neg downloaded and prepared to /root/.cache/huggingface/datasets/rcds___multi_legal_neg/en_sherlock/0.0.0/3473d423c8924d16bb6bb82e2c36dfc33a90d70f4c03337bd49a7e395aa152ec. Subsequent calls will reuse this data.


In [11]:
de_train = Data(de_dataset["train"], dataset_name="de_train")
fr_train = Data(fr_dataset["train"], dataset_name="fr_train")
it_train = Data(it_dataset["train"], dataset_name="it_train")
ch_train = Data(ch_dataset["train"], dataset_name="ch_train")
fr_dalloux_train = Data(fr_dalloux["train"], dataset_name="fr_dalloux_train")
en_sfu_train = Data(en_sfu["train"], dataset_name="en_sfu_train")
en_bioscope_train = Data(en_bioscope["train"], dataset_name="en_bioscope_train")
en_sherlock_train =  Data(en_sherlock["train"], dataset_name="en_sherlock_train")

de_val = Data(de_dataset["validation"], dataset_name="de_val")
fr_val = Data(fr_dataset["validation"], dataset_name="fr_val")
it_val = Data(it_dataset["validation"], dataset_name="it_val")
ch_val = Data(ch_dataset["validation"], dataset_name="ch_val")
fr_dalloux_val = Data(fr_dalloux["validation"], dataset_name = "fr_dalloux_val")
en_sfu_val = Data(en_sfu["validation"], dataset_name="en_sfu_val")
en_bioscope_val = Data(en_bioscope["validation"], dataset_name="en_bioscope_val")
en_sherlock_val =  Data(en_sherlock["validation"], dataset_name="en_sherlock_val")

de_test = Data(de_dataset["test"],dataset_name="de_test")
fr_test = Data(fr_dataset["test"], dataset_name="fr_test")
it_test = Data(it_dataset["test"], dataset_name="it_test")
ch_test = Data(ch_dataset["test"], dataset_name="ch_test")
fr_dalloux_test = Data(fr_dalloux["test"], dataset_name="fr_dalloux_test")
en_sfu_test = Data(en_sfu["test"], dataset_name="en_sfu_test")
en_bioscope_test = Data(en_bioscope["test"], dataset_name="en_bioscope_test")
en_sherlock_test =  Data(en_sherlock["test"], dataset_name="en_sherlock_test")

In [12]:
datasets = {
    "de": {"train": de_train,
           "test": de_test,
           "val": de_val},
    "fr": {"train": fr_train,
           "test": fr_test,
           "val": fr_val},
    "it": {"train": it_train,
           "test": it_test,
           "val": it_val},
    "ch": {"train": ch_train,
           "test": ch_test,
           "val": ch_val},
    "dalloux": {"train": fr_dalloux_train,
           "test": fr_dalloux_test,
           "val": fr_dalloux_val},
    "sfu": {"train": en_sfu_train,
           "test": en_sfu_test,
           "val": en_sfu_val},
    "sherlock": {"train": en_sherlock_train,
           "test": en_sherlock_test,
           "val": en_sherlock_val},
    "bioscope": {"train": en_bioscope_train,
           "test": en_bioscope_test,
           "val": en_bioscope_val}
}

## Train models

In [None]:
#store results to google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
file_path = "/content/drive/PATHTOFILE"
experiment_type = "multilingual" #multilingual or zero-shot
train_datasets = ["sfu", "bioscope", "sherlock", "dalloux", "de", "fr", "it", "ch"]
test_datasets = ["fr","it", "de","ch"]
models = ['xlm-roberta-base',
          'distilbert-base-multilingual-cased',
          'bert-base-multilingual-uncased',
          'xlm-roberta-large',
          'cis-lmu/glot500-base',
          "joelito/legal-swiss-roberta-base",
          "joelito/legal-swiss-roberta-large",
          "joelito/legal-xlm-roberta-base",
          "joelito/legal-xlm-roberta-large"]

In [None]:
headers = ["Task", "Model", "Train Dataset", "Test Dataset","Precision", "Recall", "F1", 'Mean Pred Scope Length', 'Mean Real Scope Length', "Seed"]
f = open(file_path, "a")
writer = csv.DictWriter(f, fieldnames = headers)
writer.writeheader()
f.close()

In [None]:
for m in models:
  print(f'training model {m}')
  SCOPE_MODEL = m
  model_results = []

  for seed in range(1,6):
    print(f'ROUND {seed}/5')
    torch.manual_seed(seed)
    t_ds = []
    v_ds = []
    tst_ds = []
    test_dataloaders = {}
    round_results = []

    if experiment_type == "zero_shot":
      for t in train_datasets:
        t_ds.append(datasets[t]["train"])
        t_ds.append(datasets[t]["test"])
        v_ds.append(datasets[t]["val"])
    elif experiment_type == "multilingual":
      for t in train_datasets:
        t_ds.append(datasets[t]["train"])
        v_ds.append(datasets[t]["val"])

    train_dl = t_ds[0].get_scope_dataloader(other_datasets = t_ds[1:], combine = True)
    val_dls = v_ds[0].get_scope_dataloader(other_datasets = v_ds[1:], combine = False)

    test_dls = []
    for tst in test_datasets:
      test_dls.append(datasets[tst]["test"].get_scope_dataloader())
      test_dataloaders[tst] = test_dls[-1]


    model = ScopeModel(full_finetuning=True, train=True, learning_rate = INITIAL_LEARNING_RATE)
    result = model.train(train_dl, val_dls, epochs=EPOCHS, patience=PATIENCE, train_dl_name = ", ".join(train_datasets), val_dl_name = ", ".join(train_datasets))

    for k in test_dataloaders.keys():
      print(f"Evaluate on {k}:")
      round_results.append(model.evaluate(test_dataloaders[k], test_dl_name = k))
      round_results[-1]["Seed"] = seed

    for i in round_results:
      print(f'{i["Test Dataset"]}: {i["F1"]}')

    model_results.append(round_results)
    f = open(file_path, "a")
    writer = csv.DictWriter(f, fieldnames = headers)
    writer.writerows(round_results)
    f.close()

Method to pretty print annotatins and predictions

In [36]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

def print_predictions(test_dataset, model): #model = trained ScopeModel
    real_labels = []
    pred_labels = []
    count_corr = 0
    count_false = 0
    for i in test_dataset:
      tst = Data([i])
      if tst.scope_data.num_sentences>0:
        tst_dl, tst_tags = tst.get_test_dataloader(combine=False)
        preds = model.predict(tst_dl[0])
        if preds[0] == tst.scope_data.scopes:
          count_corr += 1
        else:
          count_false +=1
        txt = ""
        for idx, w in enumerate(tst.scope_data.sentences[0]):
          real_labels.append(tst.scope_data.scopes[0][idx])
          if tst.scope_data.scopes[0][idx] == 1:
            txt += color.UNDERLINE
          txt += w +" "+ color.END

        print(f"Annotation: {txt}")
        txt = ""
        for idx, w in enumerate(tst.scope_data.sentences[0]):
          pred_labels.append(preds[0][0][idx])
          if preds[0][0][idx] == 1:
            txt += color.UNDERLINE
          txt += w +" "+ color.END

        print(f"Prediction: {txt}")
        print("---")
    print(f"F1-score: {classification_report(real_labels, pred_labels, output_dict= True)['1']['f1-score']}")

# **Use pre-trained negation model**
- load model from [huggingface](https://huggingface.co/rcds/neg-xlm-roberta-base)

In [None]:
SCOPE_MODEL = "rcds/neg-xlm-roberta-base"
model = ScopeModel(full_finetuning=True, train=True, learning_rate = INITIAL_LEARNING_RATE)

In [None]:
print_predictions(de_dataset["test"], model)