Install packages:

In [None]:
!pip install --upgrade sacremoses
!pip install transformers[torch]
!pip install accelerate -U
!pip install --upgrade seqeval
!pip install --upgrade datasets
!pip install --upgrade evaluate
!pip install --upgrade tokenizers

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


---------------------
load the serialized corpus ready for experimenting, convert to huggingface:

In [None]:
import pickle
import datasets

dataset = pickle.load(open('hf-dataset_v2.pkl','rb'))
dataset_hf = datasets.Dataset.from_list(dataset)

unique_tags = set([])
[ unique_tags.update(item['labels']) for item in dataset ]
print(unique_tags)
nr_classes = len(unique_tags)
tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}
label_names = list(tag2id.keys())

{'aggression', 'O'}


In [None]:
dataset_hf

Dataset({
    features: ['tokens', 'labels', 'fname'],
    num_rows: 802
})

auxilliary functions to compute metrics and tokenization

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, EvalPrediction

import numpy as np

from evaluate import load
metric = load("seqeval")

def compute_metrics(p: EvalPrediction):
    predictions = p.predictions
    labels = p.label_ids
    #predictions = np.argmax(predictions, axis=2)
    predictions = predictions.argmax(-1)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    flattened_results = {
        "overall_precision": results["overall_precision"],
        "overall_recall": results["overall_recall"],
        "overall_f1": results["overall_f1"],
        "overall_accuracy": results["overall_accuracy"],
    }
    for k in results.keys():
      if(k not in flattened_results.keys()):
        flattened_results[k+"_f1"]=results[k]["f1"]

    return flattened_results


#Get the values for input_ids, token_type_ids, attention_mask
def tokenize_adjust_labels(all_samples_per_split, tokenizer):
  tokenized_samples = tokenizer.batch_encode_plus(all_samples_per_split["tokens"], is_split_into_words=True, truncation=True, max_length=512) #, padding='max_length', pad_to_max_length=True, max_length=512)

  #tokenized_samples is not a datasets object so this alone won't work with Trainer API, hence map is used
  #so the new keys [input_ids, labels (after adjustment)]
  #can be added to the datasets dict for each train test validation split
  total_adjusted_labels = []
  for k in range(0, len(tokenized_samples["input_ids"])):
    prev_wid = -1
    word_ids_list = tokenized_samples.word_ids(batch_index=k)
    existing_label_ids = all_samples_per_split["labels"][k]
    existing_label_ids = [tag2id[label] for label in existing_label_ids] # line added as we need numeric label ids
    i = -1
    adjusted_label_ids = []

    for wid in word_ids_list:
      if(wid is None):
        adjusted_label_ids.append(-100)
      elif(wid!=prev_wid):
        i = i + 1
        adjusted_label_ids.append(existing_label_ids[i])
        prev_wid = wid
      else:
        adjusted_label_ids.append(existing_label_ids[i])

    total_adjusted_labels.append(adjusted_label_ids)
  tokenized_samples["labels"] = total_adjusted_labels
  return tokenized_samples


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

## HerBERT-large
The cell below contains 10-fold cross validation experiments using HerBERT-large:

In [None]:
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score
from sklearn.model_selection import KFold, train_test_split
from transformers import EvalPrediction, BertConfig, AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from scipy.stats import t
import datasets, torch, gc, transformers, os
import wandb


transformers.logging.set_verbosity_error()
os.environ["WANDB_DISABLED"] = "true"


def compute_metrics_sklearn(p: EvalPrediction):
    predictions = p.predictions.argmax(-1)
    labels = p.label_ids

    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    flat_preds = [item for sublist in true_predictions for item in sublist]
    flat_labels = [item for sublist in true_labels for item in sublist]

    f1 = f1_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    recall = recall_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    precision = precision_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    return {"f1": f1, "recall": recall, "precision": precision}


training_args = TrainingArguments(
    do_eval=True,
    output_dir='./aggression_model',          # output directory
    overwrite_output_dir=True,
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=20,  # batch size per device during training
    per_device_eval_batch_size=20,   # batch size for evaluation
    warmup_steps=80,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=80,
    eval_strategy="steps",     # Evaluation is done at the end of each epoch, alternatives: "steps", "epoch"
    save_strategy="steps",           # Save is done at the end of each epoch, alternatives: "steps", "epoch"
    # save_steps=500                 # Number of updates steps before two checkpoint saves if save_strategy="steps",
    learning_rate=5e-5,
    eval_steps = 20,
    save_total_limit = 1,
    load_best_model_at_end=True,     # Whether or not to load the best model found during training at the end of training.
    metric_for_best_model='eval_f1',
    disable_tqdm=False,
    report_to=None
)


n_splits = 10
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

f1_scores = []
recall_scores = []
precision_scores = []

for fold, (train_index, test_index) in enumerate(kf.split(dataset_hf)):

    model_config = BertConfig.from_pretrained("allegro/herbert-large-cased", num_labels=len(id2tag), id2label=id2tag,  label2id=tag2id)
    tokenizer_obj = AutoTokenizer.from_pretrained("allegro/herbert-large-cased", do_lower_case=False)
    model = BertForTokenClassification.from_pretrained("allegro/herbert-large-cased", config=model_config)


    print(f"Fold {fold+1}/{n_splits}")

    train_val_dataset = dataset_hf.select(train_index)
    test_dataset = dataset_hf.select(test_index)

    train_dataset, val_dataset = train_val_dataset.train_test_split(test_size=0.125, shuffle=True).values()

    tokenized_train_dataset = train_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)
    tokenized_val_dataset = val_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)
    tokenized_test_dataset = test_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)

    data_collator = DataCollatorForTokenClassification(tokenizer_obj, return_tensors='pt') #, padding='max_length')

    # Retrain model on current fold
    trainer = Trainer(
        model=model,  # Use the model architecture from before
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer_obj,
        compute_metrics=compute_metrics_sklearn # use sklearn metrics
    )
    trainer.train()

    # Evaluate on test set
    eval_results = trainer.predict(tokenized_test_dataset)
    metrics = compute_metrics_sklearn(eval_results)

    f1_scores.append(metrics["f1"])
    recall_scores.append(metrics["recall"])
    precision_scores.append(metrics["precision"])

    del model
    torch.cuda.empty_cache()
    gc.collect()

# Calculate mean and confidence interval
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), np.std(a, ddof=1)/np.sqrt(n)
    h = se * t.ppf((1 + confidence)/2., n-1)
    return m, m-h, m+h

f1_mean, f1_lower, f1_upper = mean_confidence_interval(f1_scores)
recall_mean, recall_lower, recall_upper = mean_confidence_interval(recall_scores)
precision_mean, precision_lower, precision_upper = mean_confidence_interval(precision_scores)

print(f"F1: Mean={f1_mean:.4f}, 95% CI=({f1_lower:.4f}, {f1_upper:.4f})")
print(f"Recall: Mean={recall_mean:.4f}, 95% CI=({recall_lower:.4f}, {recall_upper:.4f})")
print(f"Precision: Mean={precision_mean:.4f}, 95% CI=({precision_lower:.4f}, {precision_upper:.4f})")


config.json:   0%|          | 0.00/664 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/907k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/556k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Fold 1/10


Map:   0%|          | 0/630 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/81 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.438171,0.794552,0.859496,0.738733
40,No log,0.320223,0.871057,0.871802,0.870348
60,No log,0.270246,0.858757,0.885433,0.875962
80,0.383500,0.233668,0.897105,0.898577,0.895889
100,0.383500,0.252463,0.898015,0.891086,0.911509
120,0.383500,0.264181,0.891584,0.903094,0.894457
140,0.383500,0.302298,0.90144,0.903987,0.899744
160,0.118400,0.304286,0.902358,0.904068,0.901041
180,0.118400,0.32693,0.907123,0.908774,0.905887
200,0.118400,0.401096,0.90929,0.911884,0.907779


Fold 2/10


Map:   0%|          | 0/630 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/81 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.303772,0.859031,0.904422,0.817979
40,No log,0.22631,0.901015,0.915799,0.900223
60,No log,0.1943,0.911872,0.917649,0.908771
80,0.356800,0.217307,0.902005,0.892082,0.918481
100,0.356800,0.240767,0.914332,0.921956,0.911898
120,0.356800,0.217881,0.911745,0.905106,0.922064
140,0.356800,0.275024,0.921801,0.924262,0.919999
160,0.108400,0.273571,0.918849,0.920639,0.917375
180,0.108400,0.293789,0.924611,0.927835,0.922643
200,0.108400,0.325457,0.92845,0.930749,0.926877


Fold 3/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.417819,0.805479,0.86717,0.751983
40,No log,0.295954,0.861709,0.863656,0.85994
60,No log,0.264784,0.874685,0.890566,0.875342
80,0.350900,0.260889,0.878726,0.891477,0.877342
100,0.350900,0.293935,0.891725,0.89473,0.889549
120,0.350900,0.377876,0.880631,0.892882,0.879333
140,0.350900,0.440095,0.871048,0.886662,0.869932
160,0.105500,0.368962,0.891973,0.891243,0.892753
180,0.105500,0.419944,0.889466,0.892518,0.887223
200,0.105500,0.463342,0.886027,0.884398,0.887875


Fold 4/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.354934,0.828803,0.883465,0.780511
40,No log,0.258278,0.879731,0.893388,0.875887
60,No log,0.238131,0.898344,0.900772,0.896398
80,0.365000,0.254057,0.896527,0.906886,0.894924
100,0.365000,0.2581,0.90844,0.908493,0.908387
120,0.365000,0.331404,0.902472,0.909425,0.900022
140,0.365000,0.374474,0.909516,0.913804,0.907284
160,0.112900,0.346609,0.906006,0.913001,0.904063
180,0.112900,0.373045,0.906056,0.911964,0.90372
200,0.112900,0.445922,0.907093,0.912742,0.904802


Fold 5/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.316405,0.858943,0.904361,0.817869
40,No log,0.240465,0.893519,0.88599,0.903888
60,No log,0.179551,0.916211,0.921074,0.913525
80,0.354800,0.189133,0.920354,0.9147,0.929459
100,0.354800,0.218524,0.925199,0.925505,0.924905
120,0.354800,0.213398,0.926218,0.932656,0.925368
140,0.354800,0.305858,0.92435,0.930376,0.92285
160,0.107100,0.232984,0.927412,0.926515,0.928415
180,0.107100,0.282407,0.927369,0.929262,0.925949
200,0.107100,0.320352,0.9283,0.930428,0.926791


Fold 6/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.354972,0.83703,0.889187,0.790653
40,No log,0.269208,0.87682,0.889728,0.870915
60,No log,0.257428,0.862212,0.89612,0.875836
80,0.369200,0.253916,0.88315,0.903349,0.886523
100,0.369200,0.27157,0.898231,0.904676,0.894846
120,0.369200,0.315196,0.895382,0.907283,0.893663
140,0.369200,0.41874,0.896516,0.904922,0.89314
160,0.125200,0.374969,0.895761,0.90507,0.89254
180,0.125200,0.426732,0.900047,0.907406,0.896919
200,0.125200,0.454708,0.897347,0.903963,0.893879


Fold 7/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.430858,0.802597,0.865148,0.748481
40,No log,0.308015,0.855681,0.865381,0.84992
60,No log,0.249936,0.881896,0.890522,0.878905
80,0.356400,0.25639,0.88179,0.881295,0.882303
100,0.356400,0.276166,0.888683,0.887567,0.889912
120,0.356400,0.344166,0.881275,0.896379,0.884409
140,0.356400,0.411498,0.880431,0.895057,0.882398
160,0.117600,0.35899,0.892852,0.901356,0.891526
180,0.117600,0.396808,0.894702,0.898401,0.892396
200,0.117600,0.493437,0.889987,0.901589,0.891231


Fold 8/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.365188,0.821082,0.878083,0.77103
40,No log,0.274055,0.87508,0.875732,0.874451
60,No log,0.23356,0.878855,0.898176,0.883252
80,0.360600,0.216787,0.901545,0.909804,0.900188
100,0.360600,0.238964,0.903626,0.906746,0.901514
120,0.360600,0.340675,0.883937,0.897706,0.882751
140,0.360600,0.287913,0.907128,0.909568,0.905365
160,0.115800,0.344873,0.887181,0.900214,0.886366
180,0.115800,0.346029,0.907981,0.908967,0.907109
200,0.115800,0.367182,0.915183,0.915865,0.914567


Fold 9/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.421384,0.790094,0.856358,0.733349
40,No log,0.305856,0.863862,0.869136,0.860224
60,No log,0.295965,0.83781,0.875823,0.872741
80,0.361300,0.248741,0.887205,0.890182,0.885096
100,0.361300,0.280028,0.882789,0.893344,0.882287
120,0.361300,0.277956,0.890895,0.898709,0.889643
140,0.361300,0.33422,0.88935,0.899616,0.890224
160,0.121000,0.371562,0.887863,0.898606,0.888936
180,0.121000,0.537028,0.879446,0.89464,0.884598
200,0.121000,0.515239,0.889045,0.897802,0.888219


Fold 10/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.353932,0.838,0.88986,0.791851
40,No log,0.271608,0.886761,0.889257,0.884595
60,No log,0.202353,0.910854,0.915511,0.908451
80,0.365500,0.202087,0.907971,0.919104,0.909802
100,0.365500,0.308822,0.914781,0.921867,0.913763
120,0.365500,0.248317,0.91555,0.922545,0.91461
140,0.365500,0.316006,0.906456,0.911441,0.903741
160,0.107600,0.315049,0.91483,0.921164,0.913298
180,0.107600,0.327622,0.916408,0.916064,0.916767
200,0.107600,0.330978,0.915315,0.919581,0.913211


F1: Mean=0.8967, 95% CI=(0.8857, 0.9077)
Recall: Mean=0.9028, 95% CI=(0.8935, 0.9121)
Precision: Mean=0.8952, 95% CI=(0.8842, 0.9062)


## Bi-LSTM
The cells below are required for 10-fold cross validation experiments using the Bi-LSTM approach:

In [None]:
# prompt: given the same c-v split of
# kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
# and dataset_hf, the same splits to  train_val_dataset and test_dataset, propose a code to train and evaluate a Bi-LSTM model linear model. Use torch and transformers packages. Evaluate models using the same metrics as in the previous cell, use mean_confidence_interval.

from sklearn.metrics import f1_score, recall_score, precision_score
from transformers import EvalPrediction, BertConfig, AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification

import numpy as np
from scipy.stats import t

def compute_metrics_sklearn(p: EvalPrediction):
    predictions = p.predictions.argmax(-1)
    labels = p.label_ids

    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_names[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    flat_preds = [item for sublist in true_predictions for item in sublist]
    flat_labels = [item for sublist in true_labels for item in sublist]

    f1 = f1_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    recall = recall_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    precision = precision_score(flat_labels, flat_preds, average='weighted', zero_division=0)
    return {"f1": f1, "recall": recall, "precision": precision}



In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import PreTrainedModel, PretrainedConfig
from transformers import DataCollatorForTokenClassification
from transformers import EvalPrediction, BertConfig, AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification

from sklearn.model_selection import KFold, train_test_split
import torch, os
import torch.nn as nn
import numpy as np

# Calculate mean and confidence interval
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), np.std(a, ddof=1) / np.sqrt(n)
    h = se * t.ppf((1 + confidence) / 2., n - 1)
    return m, m - h, m + h

def count_total_parameters(model):
    return sum(p.numel() for p in model.parameters())

# 1. Define custom Config
class BiLSTMTokenClassificationConfig(PretrainedConfig):
    model_type = "bilstm_token_classification"

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=256,
        num_labels=9,
        num_lstm_layers=1,
        dropout=0.1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_labels = num_labels
        self.num_lstm_layers = num_lstm_layers
        self.dropout = dropout

# 2. Define BiLSTM Model
class BiLSTMTokenClassifier(PreTrainedModel):
    config_class = BiLSTMTokenClassificationConfig

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.lstm = nn.LSTM(
            input_size=config.hidden_size,
            hidden_size=config.hidden_size // 2,
            num_layers=config.num_lstm_layers,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, labels=None):
        embeds = self.embeddings(input_ids)  # (batch_size, seq_len, hidden_size)
        lstm_out, _ = self.lstm(embeds)       # (batch_size, seq_len, hidden_size)
        lstm_out = self.dropout(lstm_out)
        logits = self.classifier(lstm_out)    # (batch_size, seq_len, num_labels)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Flatten logits and labels for loss computation
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.config.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

        return {"loss": loss, "logits": logits}

os.environ["WANDB_DISABLED"] = "true"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

n_splits = 10
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

f1_scores = []
recall_scores = []
precision_scores = []

for fold, (train_index, test_index) in enumerate(kf.split(dataset_hf)):

    print(f"Fold {fold+1}/{n_splits}")

    train_val_dataset = dataset_hf.select(train_index)
    test_dataset = dataset_hf.select(test_index)

    train_dataset, val_dataset = train_val_dataset.train_test_split(test_size=0.125, shuffle=True).values()

    tokenizer_obj = AutoTokenizer.from_pretrained("allegro/herbert-large-cased", do_lower_case=False)

    tokenized_train_dataset = train_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)
    tokenized_val_dataset = val_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)
    tokenized_test_dataset = test_dataset.map(lambda p: tokenize_adjust_labels(p, tokenizer_obj), batched=True)

    data_collator = DataCollatorForTokenClassification(tokenizer_obj, return_tensors='pt') #, padding='max_length')

    # Initialize and train the Bi-LSTM model
    config = BiLSTMTokenClassificationConfig(
        vocab_size=len(tokenizer_obj),
        hidden_size=256,
        num_labels=nr_classes,
        num_lstm_layers=len(id2tag),
        dropout=0.1,
    )

    model = BiLSTMTokenClassifier(config)
    model.to(device) # Assuming you have a device defined
    total_params = count_total_parameters(model)
    print(f"Total parameters: {total_params:,}")

    training_args = TrainingArguments(
      output_dir="./bilstm-token-classifier",
      overwrite_output_dir=True,
      num_train_epochs=10,              # total number of training epochs
      per_device_train_batch_size=20,  # batch size per device during training
      per_device_eval_batch_size=20,   # batch size for evaluation
      warmup_steps=80,                # number of warmup steps for learning rate scheduler
      weight_decay=0.01,               # strength of weight decay
      logging_dir='./logs',            # directory for storing logs
      logging_steps=80,
      eval_strategy="steps",     # Evaluation is done at the end of each epoch, alternatives: "steps", "epoch"
      save_strategy="steps",           # Save is done at the end of each epoch, alternatives: "steps", "epoch"
      # save_steps=500                 # Number of updates steps before two checkpoint saves if save_strategy="steps",
      learning_rate=5e-5,
      eval_steps = 20,
      save_total_limit = 1,
      load_best_model_at_end=True,     # Whether or not to load the best model found during training at the end of training.
      metric_for_best_model='eval_f1',
      disable_tqdm=False,
      report_to=None
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer_obj,
        compute_metrics=compute_metrics_sklearn
    )
    trainer.train()

    # Evaluate on the test set
    eval_results = trainer.predict(tokenized_test_dataset)
    metrics = compute_metrics_sklearn(eval_results)

    f1_scores.append(metrics["f1"])
    recall_scores.append(metrics["recall"])
    precision_scores.append(metrics["precision"])


# Calculate and print the metrics with confidence intervals
f1_mean, f1_lower, f1_upper = mean_confidence_interval(f1_scores)
recall_mean, recall_lower, recall_upper = mean_confidence_interval(recall_scores)
precision_mean, precision_lower, precision_upper = mean_confidence_interval(precision_scores)

print(f"F1: Mean={f1_mean:.4f}, 95% CI=({f1_lower:.4f}, {f1_upper:.4f})")
print(f"Recall: Mean={recall_mean:.4f}, 95% CI=({recall_lower:.4f}, {recall_upper:.4f})")
print(f"Precision: Mean={precision_mean:.4f}, 95% CI=({precision_lower:.4f}, {precision_upper:.4f})")


Using device: cuda
Fold 1/10


Map:   0%|          | 0/630 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/81 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.672801,0.788769,0.812623,0.76868
40,No log,0.653064,0.805266,0.864017,0.764493
60,No log,0.61929,0.806336,0.867672,0.753099
80,0.645200,0.564692,0.806411,0.867823,0.753116
100,0.645200,0.4841,0.806411,0.867823,0.753116
120,0.645200,0.416974,0.806411,0.867823,0.753116
140,0.645200,0.403929,0.806411,0.867823,0.753116
160,0.449000,0.400397,0.806411,0.867823,0.753116
180,0.449000,0.39738,0.806411,0.867823,0.753116
200,0.449000,0.395488,0.806411,0.867823,0.753116


Fold 2/10


Map:   0%|          | 0/630 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/81 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.711458,0.240658,0.212213,0.830536
40,No log,0.68626,0.69656,0.618903,0.825833
60,No log,0.642956,0.850736,0.879184,0.827465
80,0.676100,0.568993,0.85842,0.902521,0.825461
100,0.676100,0.459664,0.85893,0.904219,0.817961
120,0.676100,0.370302,0.859006,0.904371,0.817974
140,0.676100,0.343381,0.859031,0.904422,0.817979
160,0.457300,0.344343,0.859031,0.904422,0.817979
180,0.457300,0.340298,0.859031,0.904422,0.817979
200,0.457300,0.334127,0.859031,0.904422,0.817979


Fold 3/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.709494,0.248107,0.241405,0.775916
40,No log,0.687003,0.667201,0.607053,0.766161
60,No log,0.648073,0.800107,0.846714,0.767025
80,0.675900,0.582393,0.804972,0.866051,0.754823
100,0.675900,0.489718,0.805479,0.86717,0.751983
120,0.675900,0.428339,0.805479,0.86717,0.751983
140,0.675900,0.418547,0.805479,0.86717,0.751983
160,0.454600,0.412737,0.805479,0.86717,0.751983
180,0.454600,0.408832,0.805479,0.86717,0.751983
200,0.454600,0.405895,0.805479,0.86717,0.751983


Fold 4/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.710203,0.246882,0.230737,0.803601
40,No log,0.686514,0.686272,0.620783,0.79367
60,No log,0.644863,0.822034,0.861806,0.791597
80,0.676100,0.573745,0.828235,0.882326,0.780394
100,0.676100,0.472011,0.828803,0.883465,0.780511
120,0.676100,0.396428,0.828803,0.883465,0.780511
140,0.676100,0.382789,0.828803,0.883465,0.780511
160,0.456000,0.378183,0.828803,0.883465,0.780511
180,0.456000,0.374217,0.828803,0.883465,0.780511
200,0.456000,0.372366,0.828803,0.883465,0.780511


Fold 5/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.711039,0.245738,0.214858,0.826899
40,No log,0.686293,0.698463,0.621512,0.825626
60,No log,0.642576,0.851111,0.881922,0.825856
80,0.676200,0.568528,0.858477,0.903428,0.817788
100,0.676200,0.457663,0.858943,0.904361,0.817869
120,0.676200,0.363879,0.858943,0.904361,0.817869
140,0.676200,0.34592,0.858943,0.904361,0.817869
160,0.456600,0.338355,0.858943,0.904361,0.817869
180,0.456600,0.332411,0.858943,0.904361,0.817869
200,0.456600,0.334389,0.858943,0.904361,0.817869


Fold 6/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.710585,0.244852,0.223397,0.794816
40,No log,0.686558,0.685516,0.616345,0.79634
60,No log,0.644877,0.828772,0.867329,0.797424
80,0.676100,0.574538,0.836573,0.88808,0.795101
100,0.676100,0.472087,0.837018,0.889162,0.790651
120,0.676100,0.392823,0.83703,0.889187,0.790653
140,0.676100,0.37516,0.83703,0.889187,0.790653
160,0.461600,0.37056,0.83703,0.889187,0.790653
180,0.461600,0.366639,0.83703,0.889187,0.790653
200,0.461600,0.363886,0.83703,0.889187,0.790653


Fold 7/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.7095,0.247845,0.242568,0.773344
40,No log,0.686366,0.673212,0.616696,0.763053
60,No log,0.646883,0.796803,0.844958,0.762145
80,0.675900,0.580492,0.802316,0.864085,0.761659
100,0.675900,0.486041,0.802597,0.865148,0.748481
120,0.675900,0.424103,0.802597,0.865148,0.748481
140,0.675900,0.417178,0.802597,0.865148,0.748481
160,0.450000,0.410537,0.802597,0.865148,0.748481
180,0.450000,0.406553,0.802597,0.865148,0.748481
200,0.450000,0.403897,0.802597,0.865148,0.748481


Fold 8/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.709935,0.255168,0.236909,0.779531
40,No log,0.686759,0.676533,0.612014,0.782049
60,No log,0.6459,0.813182,0.855665,0.779926
80,0.676900,0.578059,0.820418,0.876751,0.770887
100,0.676900,0.478051,0.821069,0.878057,0.771027
120,0.676900,0.404825,0.821082,0.878083,0.77103
140,0.676900,0.390653,0.821082,0.878083,0.77103
160,0.453700,0.385961,0.821082,0.878083,0.77103
180,0.453700,0.382134,0.821082,0.878083,0.77103
200,0.453700,0.379958,0.821082,0.878083,0.77103


Fold 9/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.708749,0.251056,0.249469,0.757966
40,No log,0.686662,0.666637,0.613239,0.752824
60,No log,0.64801,0.786488,0.8385,0.752502
80,0.675700,0.582683,0.789831,0.855528,0.744634
100,0.675700,0.495459,0.790094,0.856358,0.733349
120,0.675700,0.439892,0.790094,0.856358,0.733349
140,0.675700,0.434421,0.790094,0.856358,0.733349
160,0.457300,0.427586,0.790094,0.856358,0.733349
180,0.457300,0.424211,0.790094,0.856358,0.733349
200,0.457300,0.421156,0.790094,0.856358,0.733349


Fold 10/10


Map:   0%|          | 0/631 [00:00<?, ? examples/s]

Map:   0%|          | 0/91 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Total parameters: 13,591,042


  trainer = Trainer(


Step,Training Loss,Validation Loss,F1,Recall,Precision
20,No log,0.710652,0.241064,0.222365,0.808107
40,No log,0.686279,0.687692,0.618104,0.802602
60,No log,0.643751,0.830537,0.868606,0.800232
80,0.675700,0.571779,0.837533,0.88873,0.796271
100,0.675700,0.465976,0.837987,0.889835,0.791849
120,0.675700,0.387149,0.838,0.88986,0.791851
140,0.675700,0.369591,0.838,0.88986,0.791851
160,0.453000,0.365523,0.838,0.88986,0.791851
180,0.453000,0.361778,0.838,0.88986,0.791851
200,0.453000,0.358345,0.838,0.88986,0.791851


F1: Mean=0.8068, 95% CI=(0.7870, 0.8266)
Recall: Mean=0.8680, 95% CI=(0.8542, 0.8819)
Precision: Mean=0.7538, 95% CI=(0.7297, 0.7779)
