# Fine-tune token classifier for social group mention detection and extraction

In this notebook, we use annotations from

> Licht, Hauke, and Ronja Sczepanski. 2025. “Detecting Group Mentions in Political Rhetoric A Supervised Learning Approach.” British Journal of Political Science 55: e119. https://doi.org/10.1017/S0007123424000954.

to finetune a classifier capable of identifying and extracting phrases in texts that refer to social groups

<a target="_blank" href="https://colab.research.google.com/github/haukelicht/advanced_text_analysis/blob/main/notebooks/encoder_finetuning/finetune_token_classifier.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Backgrounf

span extraction:

- **task**: extract spans of words from a text that are mentions/indicators of a target concept
- **approaches**
    - supervised learning: token classification
    - prompting: prompt to respond with structured output of list of strings (verbatim extractions from input text, see Kasner et al, [2025](https://doi.org/10.48550/arXiv.2504.08697))

### examples

- **named entities** (place, organization, person)

    > "This year, **COMPTEXT** was in **Vienna**."

- **group mention** (social group, political group, etc.)

    > "**Labour** fights for **hard-working people**."

- **policy pledge**

    > "We will **lower taxes by 50%**."
    
- **valence attack** (i.e., criticism of political opponent's character/abilit/credibility)

    <!-- > "The Prime Minister **has not been honest with us**. -->

    > "The government has **betrayed the people** over and over again."


### token classification

![Illustration of token classification](https://github.com/haukelicht/advanced_text_analysis/blob/main/notebooks/.assets/task_types-token_classification.svg?raw=1){ height=50% }


- _token_-level classification: assign each token documents (e.g., sentences) categories
- single-label classification: assign each token to one and only one category

## methods papers and research applications

- Licht, Hauke, and Ronja Sczepanski. 2025. “Detecting Group Mentions in Political Rhetoric A Supervised Learning Approach.” British Journal of Political Science 55: e119. https://doi.org/10.1017/S0007123424000954.
- Kasner, Zdeněk, Vilém Zouhar, Patrícia Schmidtová, et al. 2025. “Large Language Models as Span Annotators.” arXiv:2504.08697. Version 1. Preprint, arXiv, April 11. https://doi.org/10.48550/arXiv.2504.08697.
- Klamm C, Rehbein I, Ponzetto SP. Our kind of people? Detecting populist references in political debates. 2023. *Findings of the Association for Computational Linguistics: EACL 2023*. 1227–1243. doi:[10.18653/v1/2023.findings-eacl.91](https://doi.org/10.18653/v1/2023.findings-eacl.91)
- Skorupa Parolin E, Hosseini MS, Hu Y, Khan L, Brandt PT, Osorio J, D'Orazio V. Multi-CoPED: A Multilingual Multi-Task Approach for Coding Political Event Data on Conflict and Mediation Domain. 2022. *Proceedings of the 2022 AAAI/ACM Conference on AI, Ethics, and Society (AIES '22)*. 700–711. doi:[10.1145/3514094.3534178](https://doi.org/10.1145/3514094.3534178)


## Setup

### Setup Colab (if using Colab)

In [63]:
# check if on Colab
COLAB = True
try:
  from google import colab
except:
  COLAB = False

In [64]:
# install soft-seqeval (latest version)
!pip install -q --upgrade --force-reinstall --no-deps git+https://github.com/haukelicht/soft-seqeval.git@main
!pip install seqeval
!pip install src

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for soft_seqeval (pyproject.toml) ... [?25l[?25hdone
Collecting src
  Using cached src-0.0.7.zip (6.3 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: src
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for src (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for src[0m[31m
[0m[?25h  Running setup.py clean for src
Failed to build src
[31mERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (src)[0m[31m
[0m

### Load required libraries

In [65]:
from pathlib import Path
import shutil

import numpy as np
import pandas as pd

import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    set_seed,
)

from soft_seqeval.metrics import compute_sequence_metrics

In [66]:
import os
from nltk.data import find as nltk_find
from nltk import download as nltk_download
nltk_res = ['punkt', 'punkt_tab']
for res in nltk_res:
    try:
        nltk_find(os.path.join('tokenizers', res))
    except LookupError:
        nltk_download(nltk_res)

In [67]:
import json
from pathlib import Path

from soft_seqeval.classes import LabeledSequence, Entities, Entity
from collections import OrderedDict

from typing import List, Dict, Any, Mapping, Union

def read_jsonlines_corpus(
    file: str,
    id_field: str='id',
    text_field: str='text',
    annotations_field: str='label',
    remove_unsure: bool=True,
    lang: str='english'
) -> Mapping[Union[str, int], LabeledSequence]:
    """Read a jsonlines corpus and return a dictionary of LabeledSequence objects.
    Args:
        file (str): Path to the jsonlines file.
        id_field (str): Name of the field containing the document ID.
        text_field (str): Name of the field containing the document text.
        annotations_field (str): Name of the field containing the annotations.
        remove_unsure (bool): Whether to remove annotations that end with 'unsure'.
        lang (str): Language of the documents.

    Returns:
        Mapping[Union[str, int], LabeledSequence]: A dictionary mapping document IDs to LabeledSequence objects.
    """
    with open(file, 'r') as f:
        data = []
        for line in f:
            try:
                line = json.loads(line)
                data.append(line)
            except json.JSONDecodeError:
                pass

    documents = [
        # doc[text_field]:
        LabeledSequence(
            text=doc[text_field],
            entities=Entities([
                Entity(*lab)
                for lab in doc[annotations_field]
                if (not lab[2].lower().endswith('unsure') if remove_unsure else True)
            ]),
            id = doc[id_field],
            lang=lang
        )
        for doc in data
    ]

    return documents


In [68]:
MODEL_NAME = "answerdotai/ModernBERT-base"

In [69]:
device = 'cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'cuda:0'

In [70]:
SEED = 42

In [71]:
set_seed(SEED)

In [72]:
base_path = Path("/content/advanced_text_analysis/" if COLAB else "../../")

## Load and prepare the data

In [73]:
data_path = base_path / "data" / "labeled" / "licht_detecting_2025"

In [74]:
fp = data_path / "licht_detecting_2025-uk_manifestos.jsonl"
if not fp.exists():
    url = "https://cta-text-datasets.s3.eu-central-1.amazonaws.com/labeled/licht_detecting_2025/licht_detecting_2025-uk_manifestos.jsonl"
    df = pd.read_json(url, lines=True)
    fp.parent.mkdir(parents=True, exist_ok=True)
    df.to_json(fp, lines=True, orient='records', force_ascii=False)

In [75]:
corpus = read_jsonlines_corpus(fp)

In [76]:
corpus[1]

[1m48b1c6ba33bb5e538c420148ec993090[0m: "This would provide a boost of over £100 million, which we believe will provide important new opportunities for [43m[1mproduction companies[0m[43m [other][49m and [43m[1mthe creative sector[0m[43m [other][49m in Scotland."

In [77]:
import os
import json
import shutil
import pandas as pd

from copy import deepcopy

from datasets import Dataset

import torch
from transformers import (
    PreTrainedTokenizer,
    DefaultDataCollator,
    TrainingArguments,
    TrainerCallback,
    EarlyStoppingCallback,
    Trainer,
)

from sklearn.model_selection import train_test_split
import gc

from typing import List, Dict, Union, Optional, Callable, Tuple

# ------------------------------------------------
#  General utils
# ------------------------------------------------

def get_device() -> torch.device:
    return torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

# ------------------------------------------------
#  Dataset splitting
# ------------------------------------------------

def _check_split_sizes(dev_size, test_size, n):
        if dev_size and isinstance(dev_size, float): assert 0 < dev_size < 1, "dev_size must be in (0, 1)"
        if test_size and isinstance(test_size, float): assert 0 < test_size < 1, "test_size must be in (0, 1)"
        if (
            dev_size and isinstance(dev_size, float)
            and
            test_size and isinstance(test_size, float)
        ):
            assert (dev_size + test_size) < 1, "dev_size + test_size must be less than 1"

        if dev_size and isinstance(dev_size, int): assert dev_size < n, "dev_size must be in less than "+str(n)
        if test_size and isinstance(test_size, int): assert test_size < n, "test_size must be less than "+str(n)

        # compute sizes
        n_test = 0 if test_size is None else test_size if isinstance(test_size, int) else int(test_size * n)
        n_dev =  0 if dev_size  is None else dev_size  if isinstance(dev_size, int)  else int(dev_size  * n)
        assert n_test + n_dev < n, "test_size + dev_size must be less than the number of examples"

        return n_dev, n_test

def _split_data_frame(
        df: pd.DataFrame,
        dev_size: float=0.15,
        test_size: float=0.15,
        stratify_by: Optional[Union[str, List[str]]]=None,
        seed: int=42,
        return_dict: bool=False
    ):
    n = len(df)
    dev_size, test_size = _check_split_sizes(dev_size, test_size, n)

    if stratify_by:
        if isinstance(stratify_by, str):
            stratify_by = [stratify_by]
        for col in stratify_by:
            assert col in df.columns, f"Column '{col}' not found in ``df``. cannot use for stratified splitting."
        # create a grouping indicator based on the stratification columns
        df['__stratum__'] = df.groupby(stratify_by).ngroup()
    else:
        df['__stratum__'] = 0

    idxs = df.index
    tmp, test_idxs = train_test_split(idxs, test_size=test_size, random_state=seed, stratify=df['__stratum__'] if stratify_by else None) if test_size > 0 else (idxs, [])
    train_idxs, dev_idxs = train_test_split(tmp, test_size=dev_size, random_state=seed, stratify=df.loc[tmp, '__stratum__'] if stratify_by else None) if dev_size > 0 else (tmp, [])

    del df['__stratum__']

    out = {'train': df.loc[train_idxs]}
    out['dev'] = df.loc[dev_idxs] if dev_size > 0 else None
    out['test'] = df.loc[test_idxs] if test_size > 0 else None
    del df, tmp, test_idxs, train_idxs, dev_idxs
    gc.collect()

    if return_dict:
        return {s: d for s, d in out.items() if d is not None}
    else:
        return tuple(out.values())

def _split_corpus(
        corpus: List[Dict],
        test_size: Union[None, float, int]=0.2,
        dev_size: Union[None, float, int]=0.2,
        stratify_by: Optional[Union[str, List[str]]]=None,
        seed: int=42,
        return_dict: bool=False
    ):
    n = len(corpus)
    dev_size, test_size = _check_split_sizes(dev_size, test_size, n)

    if stratify_by:
        assert all('metadata' in doc for doc in corpus), "Stratification requires 'metadata' field in each document's dictionary"
        if isinstance(stratify_by, str):
            stratify_by = [stratify_by]
        for field in stratify_by:
            assert all(field in doc['metadata'] for doc in corpus), f"Field '{field}' not found in 'metadata' of all documents"
        # create a grouping indicator based on the stratification columns
        strata = ['__'.join([str(doc['metadata'][field]) for field in stratify_by]) for doc in corpus]
    else:
        strata = None

    idxs = list(range(n))
    tmp, test_idxs = train_test_split(idxs, test_size=test_size, random_state=seed, stratify=strata) if test_size > 0 else (idxs, [])
    strata = [strata[i] for i in test_idxs] if stratify_by else None
    train_idxs, dev_idxs = train_test_split(tmp, test_size=dev_size, random_state=seed, stratify=strata) if dev_size > 0 else (tmp, [])

    out = {'train': [corpus[i] for i in train_idxs]}
    out['dev']  = [corpus[i] for i in dev_idxs]  if dev_size > 0 else None
    out['test'] = [corpus[i] for i in test_idxs] if test_size > 0 else None

    if return_dict:
        return {s: d for s, d in out.items() if d is not None}
    else:
        return tuple(out.values())

def split_data(
        data: Union[pd.DataFrame, List[Dict]],
        test_size: Union[None, float, int]=0.2,
        dev_size: Union[None, float, int]=0.2,
        stratify_by: Optional[Union[str, List[str]]]=None,
        seed: int=42,
        return_dict: bool=False
    ):
    """Split a dataset into training, development, and test sets.

    df: List[Dict]
        The data to split. Must be a data frame or a list of dictionaries.
    dev_size: float
        The proportion of the data to include in the development set.
    test_size: float
        The proportion of the data to include in the test set.
    stratify_by: str or list of str, optional
        Metadata field(s)/column(s) to use for stratified splitting. If a single field is
        provided, the data will be stratified by the values in that field/column.
        If multiple columns are provided, the data will be stratified by
        the unique combinations of values in these fields/columns.
    seed: int
        Random seed for reproducibility.
    return_dict: bool
        Whether to return the splits as a dictionary.
    """
    if isinstance(data, pd.DataFrame):
        return _split_data_frame(data, dev_size, test_size, stratify_by, seed, return_dict)
    elif isinstance(data, list) and all(isinstance(doc, dict) for doc in data):
        return _split_corpus(data, test_size, dev_size, stratify_by, seed, return_dict)
    elif isinstance(data, list):
        return _split_corpus(data, test_size, dev_size, None, seed, return_dict)
    else:
        raise ValueError('`data` must be a pandas DataFrame or a list of dictionaries')


# ------------------------------------------------
#  Sequence classification
# ------------------------------------------------

def create_sequence_classification_dataset(
        corpus: Union[pd.DataFrame, List[Dict]],
        text_field: str='text',
        label_field: str='label'
    ) -> Dataset:
    dataset = Dataset.from_list(corpus) if isinstance(corpus, list) else Dataset.from_pandas(corpus)
    if text_field != 'text':
        dataset = dataset.rename_column(text_field, 'text')
    if label_field != 'label':
        dataset = dataset.rename_column(label_field, 'label')
    required = ['text', 'label']
    rm = [c for c in dataset.column_names if c not in required]
    if len(rm) > 0:
        dataset = dataset.remove_columns(rm)
    return dataset

def preprocess_sequence_classification_dataset(examples, tokenizer, label2id: Optional[Dict[str, int]]=None, **kwargs):
    output = tokenizer(examples['text'], **kwargs)
    output['labels'] = [label2id[l] for l in examples['label']] if label2id else examples['label']
    return output

# ------------------------------------------------
#  Pairwise finetuning
# ------------------------------------------------

def unpair_data(data: List[Dict]) -> List[Dict]:
    """
    Convert paired data into separate examples
    """
    seen_ids = set()
    unpacked_data = []
    for row in data:
        for i in [1, 2]:
            if row[f"id{i}"] not in seen_ids:
                seen_ids.add(row[f"id{i}"])
                unpacked_data.append({
                    "id": row[f"id{i}"],
                    "text": row[f"text{i}"],
                })
    return unpacked_data

def create_pairwise_classification_dataset(
        corpus: List[Dict],
        text_fields: List[str]=['text1', 'text2'],
        label_field: str='label'
    ) -> Dataset:
    dataset = Dataset.from_list(corpus)
    if len(text_fields) != 2:
        raise ValueError('text_fields must be a list of length 2')
    if label_field != 'label':
        dataset = dataset.rename_column(label_field, 'label')

    if text_fields[0] != 'text1':
        dataset = dataset.rename_column(text_fields[0], 'text1')
    if text_fields[1] != 'text2':
        dataset = dataset.rename_column(text_fields[1], 'text2')
    required = ['text1', 'text2', 'label']
    rm = [c for c in dataset.column_names if c not in required]
    if len(rm) > 0:
        dataset = dataset.remove_columns(rm)
    return dataset

def preprocess_pairwise_classification_dataset_for_reward_modeling(
        examples,
        tokenizer,
        max_seq_length: Optional[int]= None,
        **kwargs
    ):
    new_examples = {
        # "labels": [],
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for text1, text2, label in zip(examples["text1"], examples["text2"], examples["label"]):
        _tokenize = lambda x: tokenizer(x, **kwargs)
        if label == 1:
            lab = 0
            tokenized_chosen, tokenized_rejected = _tokenize(text1), _tokenize(text2)
        elif label == 2:
            lab = 1
            tokenized_rejected, tokenized_chosen = _tokenize(text1), _tokenize(text2)
        else:
            raise ValueError("Label must be `1` or `2` to indicate index of chosen item.")

        # new_examples["labels"].append(lab)
        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
    return new_examples

# ------------------------------------------------
#  Token Classification
# ------------------------------------------------

def create_token_classification_dataset(
    corpus: List[Dict],
    tokens_field: str='tokens',
    labels_field: Union[None, str]='labels'
):
    dataset = Dataset.from_list(corpus)
    if tokens_field != 'tokens':
        dataset = dataset.rename_column(tokens_field, 'tokens')
    if labels_field is not None and labels_field != 'labels':
        dataset = dataset.rename_column(labels_field, 'labels')
    required = ['tokens'] if labels_field is None else ['tokens', 'labels']
    rm = [c for c in dataset.column_names if c not in required]
    if len(rm) > 0:
        dataset = dataset.remove_columns(rm)
    return dataset


def preprocess_token_classification_dataset(examples, tokenizer, label2id: Optional[Dict[str, int]]=None, **kwargs):
    # source: simplied from  https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/examples/pytorch/token-classification/run_ner.py#L442
    tokenized_inputs = tokenizer(examples['tokens'], is_split_into_words=True, **kwargs)

    labels = []
    for i, label in enumerate(examples['labels']):
        # map tokens to their respective word
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
        	# set the special tokens to -100
            if word_idx is None:
                label_ids.append(-100)
            # only label the first token of a given word
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    if label2id:
        labels = [[-100 if l==-100 else label2id[l] for l in label] for label in labels]

    tokenized_inputs['labels'] = labels
    return tokenized_inputs

# ------------------------------------------------
#  Trainer
# ------------------------------------------------

from torch.nn import CrossEntropyLoss
class ClassWeightsTrainer(Trainer):

    def __init__(self, class_weights: Union[List, Dict[Union[int, str], float]], **kwargs):
        """
        argument ``class_weights`` should be a dictionary mapping class labels to weights or a list of only the weights
        """
        super().__init__(**kwargs)
        # self.model = self.model.to(self.model.device)
        if len(class_weights) != self.model.config.num_labels:
            raise ValueError(f'length of `class_weights` must be {self.model.config.num_labels}')
        if isinstance(class_weights, dict):
            if set(class_weights.keys()) != set(self.model.config.id2label.keys()):
                raise ValueError(f'keys of `class_weights` mismatch label classes {list(self.model.config.id2label.keys())}')
            class_weights = [v for k, v in sorted(class_weights.items(), key=lambda item: item[1])]
        self.class_weights = torch.tensor(class_weights, dtype=torch.float32).to(self.model.device)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get('labels')
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get('logits')
        # compute custom loss
        loss_fct = CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


class WriteValidationResultsCallback(TrainerCallback):
    """Trainer callback to write validation set results to disk while training"""
    def __init__(self, path='validation_results.jsonl', overwrite=True):
        super().__init__()
        self.path = path
        if overwrite:
            with open(self.path, 'w') as f:
                f.write('')

    def on_evaluate(self, args, state, control, **kwargs):
        validation_results = state.log_history[-1]
        with open(self.path, "a") as f:
            f.write(json.dumps(validation_results) + "\n")


def train_and_test(
    experiment_name: str,
    experiment_results_path: str,
    run_id: Union[None, str],
    model_init: Callable,
    tokenizer: PreTrainedTokenizer,
    data_collator: DefaultDataCollator,
    train_dat: Dataset,
    dev_dat: Union[None, Dataset],
    test_dat: Union[None, Dataset],
    compute_metrics: Callable,
    metric: str,
    class_weights: Optional[Union[List, Dict[Union[int, str], float]]]=None,
    epochs: int = TrainingArguments.num_train_epochs,
    learning_rate: float = TrainingArguments.learning_rate,
    train_batch_size: int = TrainingArguments.per_device_train_batch_size,
    gradient_accumulation_steps: int = TrainingArguments.gradient_accumulation_steps,
    fp16_training: bool = True,
    eval_batch_size: int = TrainingArguments.per_device_eval_batch_size,
    weight_decay: float = TrainingArguments.weight_decay,
    early_stopping: bool = True,
    early_stopping_patience: int = 3,
    early_stopping_threshold: float = 0.03,
    seed: int = 42,
    save_best_model: bool = True,
    save_tokenizer: bool = True,
) -> Tuple[Trainer, str, Dict[str, float]]:
    """
    Fine-tune and evaluate a Transformer model.

    Args:
        experiment_name (str):
            Name of the experiment. Used for creating directories for saving results.
        experiment_results_path (str):
            Base path where experiment results will be saved.
        run_id (Union[None, str]):
            Optional unique identifier for the run. If None, an identifier will be generated.
        model_init (Callable):
            A function that initializes the model to be trained.
        tokenizer (PreTrainedTokenizer):
            Tokenizer for preprocessing text data.
        data_collator (DefaultDataCollator):
            Data collator that batches data samples.
        train_dat (Dataset):
            The dataset used for training.
        dev_dat (Union[None, Dataset]):
            The dataset used for validation. If None, validation is skipped.
        test_dat (Union[None, Dataset]):
            The dataset used for testing. If None, testing is skipped.
        compute_metrics (Callable):
            Function to compute metrics based on predictions and true labels.
        metric (str):
            Name of the metric to be used for evaluation.
        epochs (int):
            Number of training epochs. Defaults to TrainingArguments.num_train_epochs.
        learning_rate (float):
            Learning rate for the optimizer. Defaults to TrainingArguments.learning_rate.
        train_batch_size (int):
            Batch size for training. Defaults to TrainingArguments.per_device_train_batch_size.
        gradient_accumulation_steps (int):
            Number of steps to accumulate gradients before updating model parameters. Defaults to TrainingArguments.gradient_accumulation_steps.
        fp16_training (bool):
            Whether to use mixed precision training. Defaults to True.
        eval_batch_size (int):
            Batch size for evaluation. Defaults to TrainingArguments.per_device_eval_batch_size.
        weight_decay (float):
            Weight decay for the optimizer. Defaults to TrainingArguments.weight_decay.
        early_stopping (bool):
            Whether to use early stopping. Defaults to True.
        early_stopping_patience (int):
            Number of evaluations with no improvement after which training will be stopped. Defaults to 3.
        early_stopping_threshold (float):
            Minimum change in the monitored metric to qualify as an improvement. Defaults to 0.03.
        seed (int):
            Random seed for reproducibility. Defaults to 42.

    Returns:
        Trainer:
            Trainer object used for training.
        str:
            Path to the best model checkpoint.
        dict:
            Evaluation results on the test set.
    """
    results_path = os.path.join(experiment_results_path, experiment_name)
    os.makedirs(results_path, exist_ok=True)

    output_path = os.path.join(results_path, 'checkpoints')
    logs_path = os.path.join(results_path, 'logs')

    # note: the following training options depend on the availability of a dev set and will be disabled if none is provided
    #  - evaluating after each epoch
    #  - early stopping
    #  - saving at most 2 models during training
    #  - saving the best model at the end
    #  - saving the dev results

    training_args = TrainingArguments(
        # hyperparameters
        num_train_epochs=epochs,
        learning_rate=learning_rate,
        per_device_train_batch_size=train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        per_device_eval_batch_size=eval_batch_size,
        weight_decay=weight_decay,
        optim='adamw_torch',
        # how to select "best" model
        do_eval=dev_dat is not None,
        metric_for_best_model=metric,
        load_best_model_at_end=True,
        # when to evaluate
        evaluation_strategy='epoch',
        # when to save
        save_strategy='epoch',
        save_total_limit=2 if dev_dat is not None else None, # don't save all model checkpoints
        # where to store results
        output_dir=output_path,
        overwrite_output_dir=True,
        # logging
        logging_dir=logs_path,
        logging_strategy='epoch',
        report_to='none',
        # efficiency
        fp16=fp16_training if torch.cuda.is_available() else False,
        fp16_full_eval=False,
        # reproducibility
        seed=seed,
        data_seed=seed,
        full_determinism=True
    )

    # build callbacks
    callbacks = []
    if early_stopping:
        if dev_dat is None:
            raise ValueError('Early stopping requires a dev data set')
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=early_stopping_patience, early_stopping_threshold=early_stopping_threshold))
    if dev_dat:
        fn = run_id+'-dev_results.jsonl' if run_id else 'dev_results.jsonl'
        fp = os.path.join(results_path, fn)
        callbacks.append(WriteValidationResultsCallback(path=fp))

    # train
    trainer_args = dict(
        model_init=model_init,
        args=training_args,
        train_dataset=train_dat,
        eval_dataset=dev_dat if dev_dat is not None else None,
        tokenizer=tokenizer,
        data_collator=data_collator if data_collator is not None else None,
        compute_metrics=compute_metrics,
        callbacks=callbacks,
    )
    if class_weights:
        trainer_args['class_weights'] = class_weights
        trainer = ClassWeightsTrainer(**trainer_args)
    else:
        trainer = Trainer(**trainer_args)

    print('Training ...')
    _ = trainer.train()

    # save best model to results folder
    # CAVEAT: this is not the "best" model if no dev_dat is provided
    dest = run_id+'-best_model' if run_id else 'best_model'
    dest = os.path.join(results_path, dest)
    if os.path.exists(dest):
        shutil.rmtree(dest)
    if save_best_model:
        trainer.save_model(dest)
        # save tokenizer to best_model folder
        if save_tokenizer:
            tokenizer.save_pretrained(dest)

    # evaluate
    if test_dat:
        print('Evaluating ...')
        res = trainer.evaluate(test_dat, metric_key_prefix='test')
        print(res)
        fn = run_id+'-test_results.json' if run_id else 'test_results.json'
        fp = os.path.join(results_path, fn)
        with open(fp, 'w') as file:
            json.dump(res, file)
    else:
      res = None

    # finally: clean up
    if os.path.exists(output_path):
        # TODO: reconsider this when dev_dat is None (in this case, no best model will be copied and deliting the output path would delete any model checkpoints)
        shutil.rmtree(output_path)
    if os.path.exists(logs_path):
        shutil.rmtree(logs_path)

    return trainer, dest, res



data_splits = split_data(corpus, dev_size=0.1, test_size=0.15, seed=SEED, return_dict=True)

In [78]:
# from the annotations, get all entity "types" and construct a label2id mapping
#  where the labels are the IOB2-scheme for each entity type

types = list(set(ent.type for dataset in data_splits.values() for doc in dataset for ent in doc.entities))
scheme = ['O'] + ['I-'+t for t in types] + ['B-'+t for t in types]
label2id = {l: i for i, l in enumerate(scheme)}
id2label = {i: l for i, l in enumerate(scheme)}
NUM_LABELS = len(label2id)

label2id
# NOTE: the span-level annotations will be converted to token-level annotations using the IOB2 scheme.append
#       This means that
#        - a word that are not part of any entity will be labeled as "O",
#        - a word at the beginning of a span will be labeled as "B-<entity_type>", and
#        - a word inside a span will be labeled as "I-<entity_type>"

{'O': 0, 'I-other': 1, 'I-SG': 2, 'B-other': 3, 'B-SG': 4}

In [79]:
# NOTE: here we use the LabeledSequence instances to_labeled_tokens methods to convert
#       the span-level annotations to token-level annotations
#       This method returns a LabeledTokens instance with the token-level annotations,
#       which we then convert into a dictionary with fields 'tokens' and 'labels'.
data_splits = {
    s: [doc.to_labeled_tokens(label2id).to_dict() for doc in dataset]
    for s, dataset in data_splits.items()
}

In [80]:
# uncomment to show example
i = 2
for t, l in zip(data_splits['train'][i]['tokens'], data_splits['train'][i]['labels']):
    print(repr(t), '==>', id2label[l])

'Finally' ==> O
',' ==> O
' we' ==> O
' will' ==> O
' transform' ==> O
' our' ==> O
' mental' ==> B-other
' health' ==> I-other
' service' ==> I-other
' by' ==> O
' treating' ==> O
' it' ==> O
' with' ==> O
' the' ==> O
' same' ==> O
' urgency' ==> O
' as' ==> O
' physical' ==> O
' health' ==> O
'.' ==> O


In [81]:

from datasets import DatasetDict

# use custom function defined above to convert corpus to a datasets.Dataset instance (used by transformers' Trainer below)
datasets = DatasetDict({
    s: create_token_classification_dataset(dataset)
    for s, dataset in data_splits.items()
})

In [82]:
datasets.num_rows

{'train': 6433, 'dev': 857, 'test': 1286}

In [83]:
datasets['train'][0]

{'tokens': ['Obviously',
  ' there',
  ' is',
  ' a',
  ' price',
  ' we',
  ' would',
  ' not',
  ' be',
  ' prepared',
  ' to',
  ' pay',
  '.'],
 'labels': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

In [84]:
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, add_prefix_space=True)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

# apply the custom function defined above to set subword tokens' labels to -100
# this is necessary because the tokenization may split a word into multiple subwords
datasets = datasets.map(lambda example: preprocess_token_classification_dataset(example, tokenizer=tokenizer), batched=True)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

Map:   0%|          | 0/6433 [00:00<?, ? examples/s]

Map:   0%|          | 0/857 [00:00<?, ? examples/s]

Map:   0%|          | 0/1286 [00:00<?, ? examples/s]

In [85]:
# uncomment to show example
example = datasets['train'][2]
for t, l in zip(example['input_ids'], example['labels']):
    if t == tokenizer.pad_token_id:
        break
    print(l, '\t', repr(tokenizer.decode(t)))

-100 	 '[CLS]'
0 	 ' Finally'
0 	 ','
0 	 ' we'
0 	 ' will'
0 	 ' transform'
0 	 ' our'
3 	 ' mental'
1 	 ' health'
1 	 ' service'
0 	 ' by'
0 	 ' treating'
0 	 ' it'
0 	 ' with'
0 	 ' the'
0 	 ' same'
0 	 ' urgency'
0 	 ' as'
0 	 ' physical'
0 	 ' health'
0 	 '.'
-100 	 '[SEP]'


In [86]:
# NOTE: after tokenization, text tokens are represented with their token IDs
#        so we can remove them from the dataset (need to load these to the GPU)
datasets = datasets.remove_columns(['tokens'])

## Prepare the model fine-tuning

In [87]:
from transformers import AutoConfig
# NOTE: the `model_init` function is used by the Trainer to initialize the model
#   and is called each time before training starts.
#  So we define it here to load the model from the Huggingface model hub
#   and set the number of labels to the number of unique labels in the dataset
#   and the label2id and id2label mappings
def model_init():
    config = AutoConfig.from_pretrained(MODEL_NAME)
    config.num_labels = NUM_LABELS
    config.label2id = label2id
    config.id2label = id2label
    return AutoModelForTokenClassification.from_pretrained(MODEL_NAME, config=config, device_map='auto')

In [88]:
# NOTE: we define a custom function for computing the fine-tuned model's performance in its
#       prediction output for the dev or test set examples

# uncomment for example (with perfect scores)
y_true = datasets['test']['labels'][:25]
y_pred = datasets['test']['labels'][:25]

compute_sequence_metrics(y_true, y_pred, id2label, flatten_output=True)

{'seqeval-macro_f1': 1.0,
 'seqeval-macro_precision': 1.0,
 'seqeval-macro_recall': 1.0,
 'seqeval-micro_f1': 1.0,
 'seqeval-micro_precision': 1.0,
 'seqeval-micro_recall': 1.0,
 'seqeval-other_f1': 1.0,
 'seqeval-other_precision': 1.0,
 'seqeval-other_recall': 1.0,
 'seqeval-SG_f1': 1.0,
 'seqeval-SG_precision': 1.0,
 'seqeval-SG_recall': 1.0,
 'softseqeval-macro_f1': 1.0,
 'softseqeval-macro_precision': 1.0,
 'softseqeval-macro_recall': 1.0,
 'softseqeval-micro_f1': 1.0,
 'softseqeval-micro_precision': 1.0,
 'softseqeval-micro_recall': 1.0,
 'softseqeval-other_f1': 1.0,
 'softseqeval-other_precision': 1.0,
 'softseqeval-other_recall': 1.0,
 'softseqeval-SG_f1': 1.0,
 'softseqeval-SG_precision': 1.0,
 'softseqeval-SG_recall': 1.0,
 'doclevel-micro_precision': 1.0,
 'doclevel-micro_recall': 1.0,
 'doclevel-micro_f1': 1.0,
 'doclevel-other_precision': 1.0,
 'doclevel-other_recall': 1.0,
 'doclevel-other_f1': 1.0,
 'doclevel-SG_precision': 1.0,
 'doclevel-SG_recall': 1.0,
 'doclevel-SG_f

In [89]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # convert predictions and labels to list of lists of ints
    predictions = predictions.astype(int).tolist()
    labels = labels.astype(int).tolist()
    return compute_sequence_metrics(y_true=labels, y_pred=predictions, id2label=id2label, flatten_output=True)

In [91]:
# NOTE: at the beginning of the script, we have defined args.metric as the metric to be used for early stopping
#       and model selection among saved checkpoints after stopping
#       This metric must be available in the output of our `compute_metrics` function defined above
#       So let's check this
metric = "seqeval-macro_f1"
ex = ['O', 'B-social group', 'I-social group', 'O']
scores = compute_sequence_metrics([ex], [ex], id2label, flatten_output=True)
if metric not in scores.keys():
    raise ValueError(f"Invalid metric: {metric}, valid metrics are: {', '.join(scores.keys())}")

### Define the training arguments

In [92]:
model_path = base_path / "models" / "licht_detecting_2025-group_mention_detector"

In [97]:
out_dir = model_path
checkpoints_dir = out_dir / 'checkpoints'
logs_dir = out_dir / 'logs'

training_args = TrainingArguments(

    # hyperparameters
    num_train_epochs=10,
    learning_rate=4e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=32,
    weight_decay=0.3,
    optim='adamw_torch',

    # when to evaluate
    eval_strategy='epoch',
    # how to select "best" model
    do_eval=bool('dev' in datasets),
    metric_for_best_model="seqeval-macro_f1",
    load_best_model_at_end=True,
    # when to save
    save_strategy='epoch',
    save_total_limit=2 if 'dev' in datasets else None, # don't save all model checkpoints
    # where to store results
    output_dir=checkpoints_dir,
    overwrite_output_dir=True,

    # logging
    logging_dir=logs_dir,
    logging_strategy='epoch',
    report_to='none', # Disable Weights & Biases logging

    # reproducibility
    seed=SEED,
    data_seed=SEED,
    full_determinism=True
)


# build callbacks
callbacks = []
if 'dev' in datasets:
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.03))

### Create the trainer

In [99]:
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=datasets['train'],
    eval_dataset=datasets['dev'] if 'dev' in datasets else None,
    processing_class=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
    callbacks=callbacks
)

Some weights of ModernBertForTokenClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train

In [100]:
print('Training ...')
train_hist = trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': None, 'bos_token_id': None}.


Training ...


Some weights of ModernBertForTokenClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Seqeval-macro F1,Seqeval-macro Precision,Seqeval-macro Recall,Seqeval-micro F1,Seqeval-micro Precision,Seqeval-micro Recall,Seqeval-other F1,Seqeval-other Precision,Seqeval-other Recall,Seqeval-sg F1,Seqeval-sg Precision,Seqeval-sg Recall,Softseqeval-macro F1,Softseqeval-macro Precision,Softseqeval-macro Recall,Softseqeval-micro F1,Softseqeval-micro Precision,Softseqeval-micro Recall,Softseqeval-other F1,Softseqeval-other Precision,Softseqeval-other Recall,Softseqeval-sg F1,Softseqeval-sg Precision,Softseqeval-sg Recall,Doclevel-micro Precision,Doclevel-micro Recall,Doclevel-micro F1,Doclevel-other Precision,Doclevel-other Recall,Doclevel-other F1,Doclevel-sg Precision,Doclevel-sg Recall,Doclevel-sg F1,Wordlevel-accuracy,Wordlevel-macro F1,Wordlevel-macro Precision,Wordlevel-macro Recall,Wordlevel-o F1,Wordlevel-o Precision,Wordlevel-o Recall,Wordlevel-other F1,Wordlevel-other Precision,Wordlevel-other Recall,Wordlevel-sg F1,Wordlevel-sg Precision,Wordlevel-sg Recall
1,0.2294,0.150491,0.672303,0.664315,0.680782,0.676415,0.667598,0.685468,0.707173,0.689145,0.72617,0.637433,0.639485,0.635394,0.745851,0.774832,0.740149,0.787308,0.818905,0.780748,0.761916,0.794404,0.751871,0.729785,0.75526,0.728427,0.922987,0.922987,0.922987,0.933489,0.933489,0.933489,0.91832,0.91832,0.91832,0.956713,0.877655,0.911985,0.849046,0.976583,0.967354,0.985989,0.83136,0.854744,0.809221,0.825021,0.913858,0.751926
2,0.104,0.143183,0.697367,0.686802,0.708301,0.699623,0.689239,0.710325,0.720412,0.713073,0.727903,0.674322,0.660532,0.688699,0.769879,0.794397,0.768031,0.810213,0.837276,0.808536,0.770181,0.797684,0.762714,0.769576,0.79111,0.773349,0.941657,0.941657,0.941657,0.934656,0.934656,0.934656,0.928821,0.928821,0.928821,0.959495,0.887148,0.906779,0.869087,0.978205,0.972498,0.983979,0.83292,0.868103,0.800477,0.850318,0.879736,0.822804
3,0.0532,0.166651,0.705948,0.667228,0.750358,0.708259,0.667513,0.754302,0.723946,0.669118,0.788562,0.687951,0.665339,0.712154,0.767507,0.780901,0.7733,0.815002,0.831688,0.820507,0.769439,0.784166,0.768748,0.765574,0.777636,0.777853,0.935823,0.935823,0.935823,0.928821,0.928821,0.928821,0.917153,0.917153,0.917153,0.957325,0.883751,0.883664,0.88549,0.977321,0.977733,0.976909,0.841663,0.809244,0.876789,0.832268,0.864013,0.802773
4,0.0219,0.207925,0.69211,0.67518,0.710707,0.69642,0.677828,0.716061,0.726073,0.692913,0.762565,0.658147,0.657447,0.658849,0.747403,0.767097,0.748277,0.796665,0.820658,0.797257,0.761345,0.785703,0.756223,0.73346,0.748491,0.740332,0.941657,0.941657,0.941657,0.935823,0.935823,0.935823,0.912485,0.912485,0.912485,0.956824,0.880424,0.892081,0.870046,0.976989,0.973591,0.980411,0.831354,0.828076,0.834658,0.83293,0.874576,0.795069


### Evaluate

In [None]:
# apply the best model loaded after finishing training to the test set
print('Evaluating ...')
test_res = trainer.evaluate(datasets['test'], metric_key_prefix='test')

In [None]:
# create a more nice-to-look-at output
out = pd.DataFrame(test_res, index=['value']).T
out = out.reset_index().rename(columns={'index': 'cat'})
out[['set', 'scheme', 'metric', 'misc']] = out.cat.str.split('_', expand=True)
out = out[out.misc.isnull()]
out = out[out.metric.notnull()]
out[['scheme', 'type']] = out.scheme.str.split('-', expand=True)
out = out.drop(columns=['set', 'cat', 'misc'])
out = out[['scheme', 'type', 'metric', 'value']]
out = out.pivot(index=['type', 'scheme', ], columns='metric', values='value')
keys = [
    (typ, scheme)
    for typ in types
    for scheme in ['seqeval', 'softseqeval', 'wordlevel', 'doclevel']
]
out.loc[keys, :]

### Inference

In [None]:
from datasets import Dataset

from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

from tqdm import tqdm

We'll use the transformer's `pipeline` for inference (i.e., predicting spans in unlabeled data).

Specifically, we use the **NER** (named entity recognition) task and pass the fine-tuned model from the trainer.

In [None]:
extractor = pipeline(task='ner', model=trainer.model, tokenizer=tokenizer, batch_size=32, aggregation_strategy='simple')

In [None]:
fields = ['id', 'text']
df = pd.DataFrame([{f: doc[f] for f in fields} for doc in data_splits['test']])

In [None]:
# apply the extractor to the dataset
pred_ents = [p for p in extractor(df['text'].tolist())]

In [None]:
pred_ents[0]

For each text in the list of texts taken from `docs`, we get a list of dictionaries, here called `pred_ents`.

Each item in `pred_ents` is a dictionary with the following fields:

- start: character start index of the entity in the text
- end: character end index of the entity in the text
- score: confidence score of the prediction
- word: the text of the entity
- entity_group: the entity type (e.g., 'social group')


Let's use convert these annotations into one `Entities` instance and create a new `LabeledSequence` instance from this information for each text:  

In [None]:
from soft_seqeval.classes import Entity, Entities
from soft_seqeval.classes import LabeledSequence
from copy import deepcopy

def pipeline_output_to_entities(pred) -> Entities:
    """Take output from the NER pipeline and convert to Entities instance"""
    ents = []
    for ent in pred:
        ent = deepcopy(ent)
        if ent['word'][0] == ' ':
            ent['start'] += 1
        if ent['word'][-1] == ' ':
            ent['end'] -= 1
        ents.append(Entity(ent['start'], ent['end'], ent['entity_group']))
    return Entities(ents)

# iterate over the documents and predicted annotations to create a list of LabeledSequence instances
preds = [
    LabeledSequence(text=doc['text'], entities=pipeline_output_to_entities(pred), id=doc['id'], lang='english')
    for (_, doc), pred in zip(df.iterrows(), pred_ents)
]

In [None]:
# look at first 10 examples
preds[:10]


## Finally

#### Delete intermediate checkpoints and log files

In [None]:
# finally: clean up
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
if logs_dir.exists():
    shutil.rmtree(logs_dir)

#### Save the best model (if desired)

In [None]:
trainer.save_model(out_dir)
tokenizer.save_pretrained(out_dir)

### Free the GPU and remove large objects

In [None]:
import gc
trainer = trainer.model.to('cpu')
del trainer, tokenizer, data_splits, datasets
gc.collect()