![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/legal-nlp/05.4.BertForTokenClassification_TrainEval.ipynb)

# BERT FOR TOKEN CLASSIFICATION - Training/Test Split and Evaluation
Using Hugging Face and importing it to Legal NLP for scalability.

This is a transformer-based approach, which usually returns much bigger models (10x) compared to NerModel, but it can improve the performance over NerModel.

In this notebook we don't save the model, we just train and get metrics on test set. Please see next notebook to check how we finally train with all data and save the model in Spark NLP format.

# Installation

In [None]:
! pip -q install seqeval

In [None]:
! pip install transformers==4.8.1
! pip install pyspark==3.1.2
! pip install spark-nlp
! pip install spark-nlp-display

# Setting name of the project

In [None]:
PROJECT_NAME = 'legal_obligations'

# Imports

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertConfig

from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

import sparknlp
from pyspark.sql import functions as F

from sparknlp.training import CoNLL
from google.colab import files

import pandas as pd
import numpy as np
from tqdm import tqdm, trange

import transformers
from transformers import BertForTokenClassification, TFBertForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup

from sklearn.metrics import classification_report

## Setting up Torch

In [None]:
torch.__version__

'1.13.1+cu116'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

torch.cuda.get_device_name(0)

'Tesla T4'

# Check that files are available

In [None]:
! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/conll_noO.conll

In [None]:
!head -n 20 conll_noO.conll

head: cannot open 'conll_noO.conll' for reading: No such file or directory


# Creating folders for logs and checkpoints

In [None]:
!mkdir {PROJECT_NAME}

In [None]:
!mkdir {PROJECT_NAME}/logs

# Starting a Spark Session for SparkNLP

In [None]:
spark = sparknlp.start()

In [None]:
spark

# Convert JSL conlls in dataframe format

In [None]:
def get_conll_df(pth):
  data = CoNLL().readDataset(spark, pth)
  data = data.withColumn("sentence_idx", F.monotonically_increasing_id())

  df = data.select('sentence_idx', F.explode(F.arrays_zip('token.result','label.result','pos.result')).alias("cols")) \
  .select('sentence_idx',
          F.expr("cols['0']").alias("word"),
          F.expr("cols['1']").alias("tag"),
          F.expr("cols['2']").alias("pos")).toPandas()
  return df

data_df = get_conll_df('./conll_noO.conll')

In [None]:
train_idx, test_idx = train_test_split(data_df['sentence_idx'].unique(), shuffle=True, random_state=42, train_size=0.85, test_size=0.15)

In [None]:
len(train_idx)

4268

In [None]:
len(train_idx)

4268

In [None]:
train_data_df = data_df[data_df['sentence_idx'].isin(train_idx)]
test_data_df = data_df[data_df['sentence_idx'].isin(test_idx)]

In [None]:
train_data_df

Unnamed: 0,sentence_idx,word,tag,pos
0,0,Exhibit,O,NN
1,0,10.6,O,NN
2,0,memorandum,B-DOC,NN
3,0,Between,O,NN
4,0,(hereinafter,B-PARTY,NN
...,...,...,...,...
98359,8589937102,.,O,NN
98360,8589937102,Language,O,NN
98361,8589937102,and,O,NN
98362,8589937102,propietary,B-ROLE,NN


In [None]:
test_data_df

Unnamed: 0,sentence_idx,word,tag,pos
93,8,ARTICLE,O,NN
94,8,IV,O,NN
95,8,DUTIES,O,NN
96,8,AS,O,NN
97,8,WATER,B-PARTY,NN
...,...,...,...,...
98334,8589937099,determined,O,NN
98335,8589937099,to,O,NN
98336,8589937099,be,O,NN
98337,8589937099,void,O,NN


## Checking the DF looks good

In [None]:
train_data_df.head(25)

Unnamed: 0,sentence_idx,word,tag,pos
0,0,Exhibit,O,NN
1,0,10.6,O,NN
2,0,memorandum,B-DOC,NN
3,0,Between,O,NN
4,0,(hereinafter,B-PARTY,NN
5,0,collectively,I-PARTY,NN
6,0,called,I-PARTY,NN
7,0,"""Parties""",I-PARTY,NN
8,0,and,I-PARTY,NN
9,0,individually,I-PARTY,NN


In [None]:
test_data_df.head(25)

Unnamed: 0,sentence_idx,word,tag,pos
93,8,ARTICLE,O,NN
94,8,IV,O,NN
95,8,DUTIES,O,NN
96,8,AS,O,NN
97,8,WATER,B-PARTY,NN
98,8,"NOW,",I-PARTY,NN
99,8,INC.,I-PARTY,NN
100,8,9,O,NN
126,12,6.6,O,NN
127,12,PRODUCT,B-DOC,NN


In [None]:
print (train_data_df.shape)

(83262, 4)


In [None]:
print (test_data_df.shape)

(15102, 4)


In [None]:
train_data_df['tag'].value_counts()

O            60545
I-PARTY      10349
B-PARTY       4894
I-DOC         2730
B-DOC         1689
B-DATE        1527
B-LAW          659
B-ROLE         282
B-LOC          221
B-ORDINAL      132
B-PERCENT      116
B-PERSON        86
I-EFFDATE       17
B-EFFDATE       15
Name: tag, dtype: int64

In [None]:
test_data_df['tag'].value_counts()

O            11125
I-PARTY       1748
B-PARTY        910
I-DOC          473
B-DOC          318
B-DATE         269
B-LAW          123
B-LOC           38
B-ROLE          34
B-PERCENT       20
B-ORDINAL       19
I-EFFDATE       16
B-PERSON         6
B-EFFDATE        3
Name: tag, dtype: int64

# First, train / fine-tune a model on the dataset

## Iterating function to feed the model with sentences
Converting conll sentence annotations to tuples (word, pos, tag)

In [None]:
## convert conll file to sentences

class SentenceGetter(object):
    
    def __init__(self, dataset):
        self.n_sent = 1
        self.dataset = dataset
        self.empty = False
        agg_func = lambda s: [(w,p, t) for w,p, t in zip(s["word"].values.tolist(),
                                                       s['pos'].values.tolist(),
                                                        s["tag"].values.tolist())]
        self.grouped = self.dataset.groupby("sentence_idx").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

train_getter = SentenceGetter(train_data_df)
test_getter = SentenceGetter(test_data_df)

## Getting sentences and labels
- Sentences: concatenation of first element of tuple (word)
- Labels: concatenation of second element of tuple (label)

In [None]:
# Sentences 
train_sentences = [[word[0] for word in sentence] for sentence in train_getter.sentences]
print("Example of train sentence:")
print (train_sentences[5])

test_sentences = [[word[0] for word in sentence] for sentence in test_getter.sentences]
print("Example of test sentence:")
print (test_sentences[5])

# Labels
train_labels = [[s[2] for s in sentence] for sentence in train_getter.sentences]
print("Example of train sentence:")
print(train_labels[5])

test_labels = [[s[2] for s in sentence] for sentence in test_getter.sentences]
print("Example of test sentence:")
print(test_labels[5])

Example of train sentence:
['3.2', '__________', '("Professional")', 'Default', '7']
Example of test sentence:
['on', 'which', 'commercial', 'banks', 'in', 'Dallas', ',']
Example of train sentence:
['O', 'B-PARTY', 'I-PARTY', 'O', 'O']
Example of test sentence:
['O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']


## Converting tags to numeric values with a dict

In [None]:
tag_values = list(set(train_data_df["tag"].values))
tag_values.append("PAD")
tag2idx = {t: i for i, t in enumerate(tag_values)}

In [None]:
print(tag_values[:10])
print(tag2idx)

['B-PERCENT', 'B-PARTY', 'B-DOC', 'B-EFFDATE', 'O', 'I-PARTY', 'B-PERSON', 'B-ROLE', 'B-LAW', 'B-DATE']
{'B-PERCENT': 0, 'B-PARTY': 1, 'B-DOC': 2, 'B-EFFDATE': 3, 'O': 4, 'I-PARTY': 5, 'B-PERSON': 6, 'B-ROLE': 7, 'B-LAW': 8, 'B-DATE': 9, 'I-DOC': 10, 'B-LOC': 11, 'I-EFFDATE': 12, 'B-ORDINAL': 13, 'PAD': 14}


## Model metadata

### Bulding on top of biobert

In [None]:
MODEL_TO_TRAIN = 'zlucia/custom-legalbert'

### Hyperparam settings

In [None]:
# Defining some key variables that will be used later on in the training
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 2e-05

## Instantiating the proper tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(MODEL_TO_TRAIN, do_lower_case=False)

Downloading:   0%|          | 0.00/251k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/307 [00:00<?, ?B/s]

### Tokenize and extend the labels in case a word is split

In [None]:
def tokenize_and_preserve_labels(sentence, text_labels):
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)

    return tokenized_sentence, labels

## Tokenize and get tokens and labels

In [None]:
train_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(train_sentences, train_labels)
]

test_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(test_sentences, test_labels)
]

train_tokenized_texts_tokens = [token_label_pair[0] for token_label_pair in train_tokenized_texts_and_labels]
test_tokenized_texts_tokens = [token_label_pair[0] for token_label_pair in test_tokenized_texts_and_labels]

train_tokenized_texts_labels = [token_label_pair[1] for token_label_pair in train_tokenized_texts_and_labels]
test_tokenized_texts_labels = [token_label_pair[1] for token_label_pair in test_tokenized_texts_and_labels]

In [None]:
print(train_tokenized_texts_tokens[5])
print(train_tokenized_texts_labels[5])

['3', '[UNK]', '2', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '7']
['O', 'O', 'O', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'B-PARTY', 'I-PARTY', 'I-PARTY', 'I-PARTY', 'I-PARTY', 'I-PARTY', 'O', 'O']


In [None]:
print(test_tokenized_texts_tokens[5])
print(test_tokenized_texts_labels[5])

['on', 'which', 'commercial', 'banks', 'in', '[UNK]', '[UNK]']
['O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']


## Converting tokens to id && padding sentences to have fixed length

In [None]:
train_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in train_tokenized_texts_tokens],
                          maxlen=MAX_LEN, dtype="long", value=0.0,
                          truncating="post", padding="post")

test_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in test_tokenized_texts_tokens],
                          maxlen=MAX_LEN, dtype="long", value=0.0,
                          truncating="post", padding="post")

train_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in train_tokenized_texts_labels],
                     maxlen=MAX_LEN, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")

test_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in test_tokenized_texts_labels],
                     maxlen=MAX_LEN, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")

In [None]:
print(train_input_ids[5])
print(test_input_ids[5])
print(train_tags[5])
print(test_tags[5])

[149   1 110   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1
   1 352   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   

## Now that sentences are padded, I need to prevent attention from seeing pads (id=0)

In [None]:
train_attention_masks = [[float(i != 0.0) for i in ii] for ii in train_input_ids]
test_attention_masks = [[float(i != 0.0) for i in ii] for ii in test_input_ids]

In [None]:
print(train_attention_masks[5])
print(test_attention_masks[5])

[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

### Double checking that pairing input-mask is in place

In [None]:
for i,m in zip(train_input_ids[5], train_attention_masks[5]):
  print(f"Token id: {i}\Token mask: {m}")

Token id: 149\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 110\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 352\Token mask: 1.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\To

In [None]:
for i,m in zip(test_input_ids[5], test_attention_masks[5]):
  print(f"Token id: {i}\Token mask: {m}")

Token id: 18\Token mask: 1.0
Token id: 32\Token mask: 1.0
Token id: 1423\Token mask: 1.0
Token id: 2946\Token mask: 1.0
Token id: 11\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 1\Token mask: 1.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0\Token mask: 0.0
Token id: 0

## Arrays to tensors transformation

In [None]:
tr_inputs = torch.tensor(train_input_ids)
val_inputs = torch.tensor(test_input_ids)
tr_tags = torch.tensor(train_tags)
val_tags = torch.tensor(test_tags)
tr_masks = torch.tensor(train_attention_masks)
val_masks = torch.tensor(test_attention_masks)

In [None]:
print(tr_inputs[5])
print(tr_tags[5])
print(tr_masks[5])

tensor([149,   1, 110,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1, 352,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  

In [None]:
print(val_inputs[5])
print(val_tags[5])
print(val_masks[5])

tensor([  18,   32, 1423, 2946,   11,    1,    1,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,   

### Checking sizes match

#### Training

In [None]:
len([x for x in tr_inputs[5] if x != 0]) # How many NO_PADs we have?

20

In [None]:
len([x for x in tr_tags[5] if x != 7])

256

In [None]:
len([x for x in tr_masks[5] if x != 0])

20

#### Test

In [None]:
len([x for x in val_inputs[5] if x != 0]) # How many NO_PADs we have?

7

In [None]:
len([x for x in val_tags[5] if x != 7])

256

In [None]:
len([x for x in val_masks[5] if x != 0])

7

## Creating the DataLoaders to feed the batches during training

In [None]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=TRAIN_BATCH_SIZE)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=TRAIN_BATCH_SIZE)

# Loading the transformer model

In [None]:
transformers.__version__

'4.8.1'

In [None]:
model = BertForTokenClassification.from_pretrained(
    MODEL_TO_TRAIN,
    num_labels=len(tag2idx),
    output_attentions = False,
    output_hidden_states = False
)
model.to(device)

Downloading:   0%|          | 0.00/747 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at zlucia/custom-legalbert were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized fr

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

## Setting up the optimizer.
We want to optimize weight values, so we add a decay.
We can get all the weights from `model_named_parameters()`
But we need to remove `bias`, `gamma` and `beta` which are Layer Normalization parameters we don't want to touch.

Activate `FULL_TINETUNING` to modify weights in all the layers.

In [None]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)


## Setting up the scheduler
It will manage Optimizer and Learning Rate changes. We use warmup

In [None]:
epochs = 15
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)


Now, let's train

In [None]:
## Store the average loss after each epoch so we can plot them.
loss_values, validation_loss_values = [], []

for EPOCH in trange(epochs, desc="Epoch"):
    # Put the model into training mode.
    model.train()
    # Reset the total loss for this epoch.
    total_loss = 0

    # Training loop
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # Always clear any previously calculated gradients before performing a backward pass.
        model.zero_grad()
        # forward pass
        # This will return the loss (rather than the model output)
        # because we have provided the `labels`.
        outputs = model(b_input_ids, token_type_ids=None,
                        attention_mask=b_input_mask, labels=b_labels)
        # get the loss
        loss = outputs[0]
        # Perform a backward pass to calculate the gradients.
        loss.backward()
        # track train loss
        total_loss += loss.item()
        # Clip the norm of the gradient
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        # Update the learning rate.
        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)
    tr_loss = f"Average train loss: {str(avg_train_loss)}\n"

    # Saving partial models (this creates the folder too)
    if step > epochs - 5:

      tokenizer.save_pretrained(f'{PROJECT_NAME}/{str(EPOCH)}/tokenizer/')
      model.save_pretrained(save_directory=f'{PROJECT_NAME}/{str(EPOCH)}/',
                            save_config=True, state_dict=model.state_dict)
      # Saving checkpoint in case it crashes, to restore work
      torch.save({
          'epoch': EPOCH,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': avg_train_loss,
          }, f'{PROJECT_NAME}/{str(EPOCH)}/checkpoint.pth')
    else:
      print("Skipping saving the model. Too early")

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    # Put the model into evaluation mode
    model.eval()
    # Reset the validation loss for this epoch.
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # Telling the model not to compute or store gradients,
        # saving memory and speeding up validation
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            # This will return the logits rather than the loss because we have not provided labels.
            outputs = model(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        # Move logits and labels to CPU
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        eval_loss += outputs[0].mean().item()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)

    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)

    val_loss = f"Validation loss: {str(eval_loss)}\n"
    
    # Saving losses log
    with open(f'{PROJECT_NAME}/logs/epoch_' + str(EPOCH) + '_loss.log', 'a') as f:
      f.write(tr_loss)
      f.write(val_loss)

    # Calculating metrics
    pred_tags = [tag_values[p_i] for p, l in zip(predictions, true_labels)
                                 for p_i, l_i in zip(p, l) if tag_values[l_i] != "PAD"]
    valid_tags = [tag_values[l_i] for l in true_labels
                                  for l_i in l if tag_values[l_i] != "PAD"]
    
    report = classification_report(valid_tags, pred_tags)
    
    # Saving metrics
    with open(f'{PROJECT_NAME}/logs/epoch_' + str(EPOCH) + '_metrics.log', 'a') as f:
      f.write(report)

    # Printing also to stdout
    print(tr_loss)
    print(val_loss)
    print(report)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:   7%|▋         | 1/15 [03:05<43:22, 185.89s/it]

Average train loss: 0.8051821073489402

Validation loss: 0.5784538872539997

              precision    recall  f1-score   support

      B-DATE       0.66      0.64      0.65       275
       B-DOC       0.81      0.33      0.47       356
   B-EFFDATE       0.00      0.00      0.00         8
       B-LAW       0.66      0.39      0.49       187
       B-LOC       0.00      0.00      0.00        46
   B-ORDINAL       0.00      0.00      0.00        21
     B-PARTY       0.63      0.41      0.50      1219
   B-PERCENT       0.00      0.00      0.00        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.14      0.24        44
       I-DOC       0.66      0.40      0.50       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.84      0.75      0.79      3139
           O       0.85      0.96      0.90     11581

    accuracy                           0.83     17439
   macro avg       0.44      0.29      0.32     17439
wei

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  13%|█▎        | 2/15 [06:14<40:40, 187.72s/it]

Average train loss: 0.5140464875235486

Validation loss: 0.49340613558888435

              precision    recall  f1-score   support

      B-DATE       0.59      0.75      0.66       275
       B-DOC       0.79      0.43      0.56       356
   B-EFFDATE       0.00      0.00      0.00         8
       B-LAW       0.61      0.53      0.57       187
       B-LOC       1.00      0.02      0.04        46
   B-ORDINAL       1.00      0.29      0.44        21
     B-PARTY       0.70      0.43      0.53      1219
   B-PERCENT       0.92      0.46      0.62        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       0.90      0.41      0.56        44
       I-DOC       0.73      0.48      0.58       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.83      0.77      0.80      3139
           O       0.87      0.95      0.91     11581

    accuracy                           0.84     17439
   macro avg       0.64      0.39      0.45     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  20%|██        | 3/15 [09:24<37:40, 188.37s/it]

Average train loss: 0.42840198683204933

Validation loss: 0.46881517705818015

              precision    recall  f1-score   support

      B-DATE       0.63      0.77      0.69       275
       B-DOC       0.73      0.49      0.58       356
   B-EFFDATE       0.00      0.00      0.00         8
       B-LAW       0.69      0.53      0.60       187
       B-LOC       0.50      0.04      0.08        46
   B-ORDINAL       1.00      0.43      0.60        21
     B-PARTY       0.70      0.46      0.56      1219
   B-PERCENT       0.72      0.69      0.71        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       0.97      0.73      0.83        44
       I-DOC       0.74      0.55      0.63       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.84      0.78      0.81      3139
           O       0.88      0.95      0.91     11581

    accuracy                           0.85     17439
   macro avg       0.60      0.46      0.50     17439
w

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  27%|██▋       | 4/15 [12:32<34:33, 188.53s/it]

Average train loss: 0.37674027269900734

Validation loss: 0.4438314704845349

              precision    recall  f1-score   support

      B-DATE       0.74      0.75      0.75       275
       B-DOC       0.77      0.50      0.60       356
   B-EFFDATE       1.00      0.25      0.40         8
       B-LAW       0.62      0.64      0.63       187
       B-LOC       0.45      0.11      0.18        46
   B-ORDINAL       1.00      0.52      0.69        21
     B-PARTY       0.66      0.53      0.59      1219
   B-PERCENT       0.75      0.69      0.72        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       0.94      0.75      0.84        44
       I-DOC       0.75      0.61      0.67       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.82      0.78      0.80      3139
           O       0.89      0.94      0.91     11581

    accuracy                           0.85     17439
   macro avg       0.67      0.51      0.56     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  33%|███▎      | 5/15 [15:41<31:25, 188.60s/it]

Average train loss: 0.35289629015015134

Validation loss: 0.4591724102695783

              precision    recall  f1-score   support

      B-DATE       0.73      0.77      0.75       275
       B-DOC       0.81      0.52      0.63       356
   B-EFFDATE       1.00      0.50      0.67         8
       B-LAW       0.62      0.64      0.63       187
       B-LOC       0.39      0.20      0.26        46
   B-ORDINAL       1.00      0.52      0.69        21
     B-PARTY       0.63      0.56      0.59      1219
   B-PERCENT       0.77      0.65      0.71        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.80      0.89        44
       I-DOC       0.76      0.58      0.65       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.78      0.80      0.79      3139
           O       0.90      0.92      0.91     11581

    accuracy                           0.85     17439
   macro avg       0.67      0.53      0.58     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  40%|████      | 6/15 [18:50<28:18, 188.70s/it]

Average train loss: 0.3252535241642105

Validation loss: 0.43623003115256626

              precision    recall  f1-score   support

      B-DATE       0.72      0.77      0.74       275
       B-DOC       0.77      0.56      0.65       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.65      0.65      0.65       187
       B-LOC       0.32      0.28      0.30        46
   B-ORDINAL       1.00      0.52      0.69        21
     B-PARTY       0.72      0.51      0.60      1219
   B-PERCENT       0.77      0.65      0.71        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.80      0.89        44
       I-DOC       0.80      0.62      0.70       512
   I-EFFDATE       0.00      0.00      0.00        19
     I-PARTY       0.90      0.77      0.83      3139
           O       0.88      0.96      0.92     11581

    accuracy                           0.87     17439
   macro avg       0.68      0.55      0.60     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  47%|████▋     | 7/15 [21:59<25:10, 188.82s/it]

Average train loss: 0.3041251828310205

Validation loss: 0.4424445399393638

              precision    recall  f1-score   support

      B-DATE       0.64      0.82      0.72       275
       B-DOC       0.79      0.56      0.66       356
   B-EFFDATE       1.00      0.50      0.67         8
       B-LAW       0.64      0.66      0.65       187
       B-LOC       0.34      0.28      0.31        46
   B-ORDINAL       1.00      0.52      0.69        21
     B-PARTY       0.70      0.54      0.61      1219
   B-PERCENT       0.77      0.65      0.71        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       0.97      0.82      0.89        44
       I-DOC       0.78      0.62      0.69       512
   I-EFFDATE       1.00      0.21      0.35        19
     I-PARTY       0.88      0.78      0.83      3139
           O       0.89      0.95      0.92     11581

    accuracy                           0.87     17439
   macro avg       0.74      0.57      0.62     17439
wei

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  53%|█████▎    | 8/15 [25:08<22:01, 188.77s/it]

Average train loss: 0.28523846142994824

Validation loss: 0.452697000776728

              precision    recall  f1-score   support

      B-DATE       0.68      0.79      0.73       275
       B-DOC       0.82      0.56      0.67       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.69      0.66      0.67       187
       B-LOC       0.35      0.26      0.30        46
   B-ORDINAL       1.00      0.52      0.69        21
     B-PARTY       0.68      0.55      0.61      1219
   B-PERCENT       0.77      0.65      0.71        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.80      0.89        44
       I-DOC       0.80      0.60      0.69       512
   I-EFFDATE       1.00      0.16      0.27        19
     I-PARTY       0.77      0.82      0.79      3139
           O       0.90      0.92      0.91     11581

    accuracy                           0.85     17439
   macro avg       0.75      0.57      0.62     17439
wei

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  60%|██████    | 9/15 [28:16<18:52, 188.78s/it]

Average train loss: 0.26900044328241207

Validation loss: 0.4636924589673678

              precision    recall  f1-score   support

      B-DATE       0.72      0.78      0.75       275
       B-DOC       0.79      0.56      0.66       356
   B-EFFDATE       0.71      0.62      0.67         8
       B-LAW       0.62      0.67      0.65       187
       B-LOC       0.46      0.24      0.31        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.70      0.54      0.61      1219
   B-PERCENT       0.71      0.65      0.68        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.80      0.89        44
       I-DOC       0.79      0.62      0.70       512
   I-EFFDATE       0.88      0.37      0.52        19
     I-PARTY       0.83      0.80      0.81      3139
           O       0.89      0.94      0.92     11581

    accuracy                           0.86     17439
   macro avg       0.72      0.58      0.63     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  67%|██████▋   | 10/15 [31:25<15:43, 188.76s/it]

Average train loss: 0.2588984980734427

Validation loss: 0.4543944566200177

              precision    recall  f1-score   support

      B-DATE       0.73      0.76      0.74       275
       B-DOC       0.81      0.58      0.68       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.64      0.66      0.65       187
       B-LOC       0.47      0.30      0.37        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.55      0.62      1219
   B-PERCENT       0.74      0.65      0.69        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.80      0.89        44
       I-DOC       0.81      0.63      0.71       512
   I-EFFDATE       1.00      0.37      0.54        19
     I-PARTY       0.87      0.78      0.82      3139
           O       0.89      0.96      0.92     11581

    accuracy                           0.87     17439
   macro avg       0.76      0.58      0.65     17439
wei

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  73%|███████▎  | 11/15 [34:34<12:35, 188.85s/it]

Average train loss: 0.25048426527585554

Validation loss: 0.4810361700753371

              precision    recall  f1-score   support

      B-DATE       0.67      0.81      0.74       275
       B-DOC       0.84      0.57      0.68       356
   B-EFFDATE       0.83      0.62      0.71         8
       B-LAW       0.66      0.66      0.66       187
       B-LOC       0.45      0.28      0.35        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.56      0.62      1219
   B-PERCENT       0.74      0.65      0.69        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.84      0.91        44
       I-DOC       0.83      0.59      0.69       512
   I-EFFDATE       1.00      0.05      0.10        19
     I-PARTY       0.81      0.80      0.81      3139
           O       0.90      0.94      0.92     11581

    accuracy                           0.86     17439
   macro avg       0.74      0.56      0.61     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  80%|████████  | 12/15 [37:43<09:26, 188.87s/it]

Average train loss: 0.24285436477234115

Validation loss: 0.48704534644881886

              precision    recall  f1-score   support

      B-DATE       0.71      0.77      0.74       275
       B-DOC       0.80      0.58      0.67       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.69      0.64      0.67       187
       B-LOC       0.45      0.33      0.38        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.55      0.62      1219
   B-PERCENT       0.63      0.65      0.64        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.82      0.90        44
       I-DOC       0.80      0.61      0.69       512
   I-EFFDATE       1.00      0.32      0.48        19
     I-PARTY       0.78      0.82      0.80      3139
           O       0.90      0.93      0.92     11581

    accuracy                           0.86     17439
   macro avg       0.74      0.58      0.64     17439
w

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  87%|████████▋ | 13/15 [40:52<06:17, 188.82s/it]

Average train loss: 0.23669564457082037

Validation loss: 0.47740485953787964

              precision    recall  f1-score   support

      B-DATE       0.69      0.80      0.74       275
       B-DOC       0.79      0.60      0.68       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.66      0.66      0.66       187
       B-LOC       0.42      0.30      0.35        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.55      0.62      1219
   B-PERCENT       0.63      0.65      0.64        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.84      0.91        44
       I-DOC       0.78      0.65      0.71       512
   I-EFFDATE       1.00      0.21      0.35        19
     I-PARTY       0.82      0.80      0.81      3139
           O       0.90      0.94      0.92     11581

    accuracy                           0.86     17439
   macro avg       0.74      0.58      0.63     17439
w

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch:  93%|█████████▎| 14/15 [44:01<03:08, 188.78s/it]

Average train loss: 0.23165874761431965

Validation loss: 0.4857822445531686

              precision    recall  f1-score   support

      B-DATE       0.72      0.81      0.76       275
       B-DOC       0.83      0.58      0.68       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.66      0.67      0.66       187
       B-LOC       0.42      0.30      0.35        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.56      0.62      1219
   B-PERCENT       0.65      0.65      0.65        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.82      0.90        44
       I-DOC       0.80      0.63      0.71       512
   I-EFFDATE       1.00      0.16      0.27        19
     I-PARTY       0.83      0.80      0.82      3139
           O       0.90      0.95      0.92     11581

    accuracy                           0.87     17439
   macro avg       0.75      0.58      0.63     17439
we

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch: 100%|██████████| 15/15 [47:09<00:00, 188.66s/it]

Average train loss: 0.23067494748688455

Validation loss: 0.48736114737888175

              precision    recall  f1-score   support

      B-DATE       0.70      0.81      0.75       275
       B-DOC       0.81      0.59      0.68       356
   B-EFFDATE       1.00      0.62      0.77         8
       B-LAW       0.66      0.65      0.65       187
       B-LOC       0.44      0.33      0.38        46
   B-ORDINAL       0.92      0.52      0.67        21
     B-PARTY       0.71      0.55      0.62      1219
   B-PERCENT       0.61      0.65      0.63        26
    B-PERSON       0.00      0.00      0.00         6
      B-ROLE       1.00      0.84      0.91        44
       I-DOC       0.80      0.63      0.71       512
   I-EFFDATE       1.00      0.16      0.27        19
     I-PARTY       0.81      0.80      0.81      3139
           O       0.90      0.94      0.92     11581

    accuracy                           0.86     17439
   macro avg       0.74      0.58      0.63     17439
w


