In [1]:
from datetime import datetime

from datasets import load_dataset, load_from_disk
from transformers import BatchEncoding, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments
from transformers.data import data_collator

from modelling_xlm_roberta import XLMRobertaForTokenClassification
import nervaluate

from functools import partial
import torch

from typing import Iterable
from torch import Tensor

import numpy as np

device = 'cuda'
model_dtype = torch.bfloat16
torch.cuda.get_device_name(0)

'NVIDIA H100 80GB HBM3'

# 1. Test that layer cutting works

In [2]:
model_test = XLMRobertaForTokenClassification.from_pretrained('facebook/xlm-v-base')
model_test

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


XLMRobertaForTokenClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(901629, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768

In [3]:
model_test = XLMRobertaForTokenClassification.from_pretrained('facebook/xlm-v-base', skip_last_layer=True)
model_test

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


XLMRobertaForTokenClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(901629, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-10): 11 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768

Works! Passing `skip_last_layer=True` removes the last layer in the transformer stack (11 x XLMRobertaLayer instead of 12 x XLMRobertaLayer)

# 2. Train models on the downstream tagging task and evaluate the knowledge transfer to a different language
For this we will use CoNLL 2003 corpus (`eriktks/conll2003`, 14k examples) to train the model and Afrikaans NER Corpus (`nwu-ctext/afrikaans_ner_corpus`, 9k examples) to test the model. The validation is done over CoNLL 2003, only the final scores for Afrikaans are reported.

In [4]:
train_dataset = load_dataset('eriktks/conll2003', split='train')
valid_dataset = load_dataset('eriktks/conll2003', split='validation')
test_dataset = load_dataset('nwu-ctext/afrikaans_ner_corpus', split='train')

conll2003.py:   0%|          | 0.00/9.57k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/5.82k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/945k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8962 [00:00<?, ? examples/s]

Make sure that the labelling scheme is identical across datasets

In [5]:
train_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [6]:
valid_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [7]:
test_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['OUT', 'B-PERS', 'I-PERS', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

The names are a bit different, but otherwise the schemes are identical

## 2.1 Convert word-level tags to subtoken-level tags

In [2]:
xlm_tok = AutoTokenizer.from_pretrained('facebook/xlm-v-base')
xlm_tok_name = 'xlm-v'

xlm_tok('test <mask> test', return_offsets_mapping=True)



{'input_ids': [0, 1340, 901628, 1340, 2], 'attention_mask': [1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 4), (4, 11), (11, 16), (0, 0)]}

In [6]:
# for reference
ner_tags_scheme = np.array(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])
ner_tags_ext    =          [  0,       2,       2,       4,       4,       6,       6,        8,        8]
# the ext is used when we need to split one word into multiple sub tokens

In [9]:
def tokenize(example: dict, tokenizer: PreTrainedTokenizer, tokenizer_name: str, max_length: int = 512) -> dict:
    ner_tags: list[int] = example['ner_tags']
    example_words: list[str] = example['tokens']
    text = ' '.join(example_words)
    
    # map words to positions in text
    word_positions: list[int] = example.get('word_positions', [])
    
    if len(word_positions) != len(example_words):
        text_iterator = 0
        for word in example_words:
            while text[text_iterator:text_iterator + len(word)] != word:
                text_iterator += 1
                assert text_iterator < len(text)
            
            word_positions.append(text_iterator)
    
    encoding: BatchEncoding = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=max_length)
    num_sub_tokens = len(encoding.offset_mapping)
    
    sub_token_iterator = 0
    sub_token_ner_tags: list[int] = []
    for word_id, ner_tag in enumerate(ner_tags):
        word_start = word_positions[word_id]
        word_end = word_start + len(example_words[word_id])
        
        # there may be some empty space between words. the sub tokens that include this empty space receive O label
        # we compare with the end ([1]) to ensure that 0-length tokens are labelled as O (for example <CLS>)
        while sub_token_iterator < num_sub_tokens and  encoding.offset_mapping[sub_token_iterator][1] <= word_start:
            if encoding.offset_mapping[sub_token_iterator][1] - encoding.offset_mapping[sub_token_iterator][0] == 0:
                # set to -100 for special tokens like <CLS>
                sub_token_ner_tags.append(-100)
            else:
                sub_token_ner_tags.append(0)  # 0 = O
            sub_token_iterator += 1
            
        ext_tag = ner_tags_ext[ner_tag]
        
        if sub_token_iterator < num_sub_tokens:
            # the first sub token of a word receives original label, the rest receive extended label
            sub_token_ner_tags.append(ner_tag)
            sub_token_iterator += 1
        
        # again, we need to be careful about 0-length tokens, so we compare start ([0]) with the word end
        while sub_token_iterator < num_sub_tokens and encoding.offset_mapping[sub_token_iterator][0] < word_end:
            
            # there is a weird quirk with transformers tokenizers: <SEP> token has (0, 0) offset 
            #   regardless of its real position, see https://github.com/huggingface/transformers/issues/35125
            if encoding.offset_mapping[sub_token_iterator][1] - encoding.offset_mapping[sub_token_iterator][0] == 0:
                sub_token_ner_tags.append(-100)
            else:
                sub_token_ner_tags.append(ext_tag)
                
            sub_token_iterator += 1
    
    # any tokens at the end (like <SEP>) receive O tokens
    while sub_token_iterator < num_sub_tokens:
        sub_token_iterator += 1
        sub_token_ner_tags.append(0)
        
    return {
        'word_positions': word_positions,
        f'{tokenizer_name}_sub_tokens': encoding.input_ids,
        f'{tokenizer_name}_sub_token_offsets': encoding.offset_mapping,
        f'{tokenizer_name}_sub_token_ner_tags': sub_token_ner_tags,
        'length': len(encoding.offset_mapping)
    }

tokenize_fn = partial(tokenize, tokenizer=xlm_tok, tokenizer_name=xlm_tok_name, max_length=512)

train_dataset = train_dataset.map(tokenize_fn)
valid_dataset = valid_dataset.map(tokenize_fn)
test_dataset = test_dataset.map(tokenize_fn)

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/8962 [00:00<?, ? examples/s]

In [10]:
for test_idx in range(25):
    ner_tags = torch.as_tensor(train_dataset[test_idx]['xlm-v_sub_token_ner_tags'])
    tokens = torch.as_tensor(train_dataset[test_idx]['xlm-v_sub_tokens'])
    print('Text:', ' '.join(train_dataset[test_idx]['tokens']))
    print('Ents:', xlm_tok.decode(tokens[ner_tags > 0]))
    print()

Text: EU rejects German call to boycott British lamb .
Ents: EU German British

Text: Peter Blackburn
Ents: Peter Blackburn

Text: BRUSSELS 1996-08-22
Ents: BRUSSELS

Text: The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep .
Ents: European Commission German British

Text: Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer .
Ents: Germany European Union Werner Zwingmann Britain

Text: " We do n't support any such recommendation because we do n't see any grounds for it , " the Commission 's chief spokesman Nikolaus van der Pas told a news briefing .
Ents: Commission Nikolaus van der Pas

Text: He said further scientific study was required and if it was found that action was needed it should be ta

Looks nice!

In [18]:
train_dataset.save_to_disk('data/train')
valid_dataset.save_to_disk('data/valid')
test_dataset.save_to_disk('data/test')

Saving the dataset (0/1 shards):   0%|          | 0/14041 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8962 [00:00<?, ? examples/s]

In [2]:
from torch.nn.utils.rnn import pad_sequence


class Dataset(torch.utils.data.Dataset):
    def __init__(self, examples: Iterable[dict], tokenizer_name: str):
        self.input_ids = []
        self.labels = []
        
        for example in examples:
            self.input_ids.append(torch.as_tensor(example[f'{tokenizer_name}_sub_tokens']))
            self.labels.append(torch.as_tensor(example[f'{tokenizer_name}_sub_token_ner_tags']))
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.labels[idx]
    

def collate_fn(inputs: list[(Tensor, Tensor)], *, pad_token: int) -> dict:
    all_input_ids = []
    all_labels = []
    for input_ids, labels in inputs:
        all_input_ids.append(input_ids)
        all_labels.append(labels)
    
    input_ids = pad_sequence(all_input_ids, batch_first=True, padding_value=pad_token)
    
    batch_size, seq_length = input_ids.shape

    # do not attend to pad and pad does not attend to anything
    pad_mask = (input_ids != pad_token)
    attention_mask = (pad_mask.reshape(batch_size, 1, -1) != pad_mask.reshape(batch_size, -1, 1))
    return {
        'input_ids': input_ids,
        'labels': pad_sequence(all_labels, batch_first=True, padding_value=-100),
        'attention_mask': attention_mask
    }

In [19]:
def compute_ner_metrics(eval_pred) -> dict:
    predictions, labels = eval_pred
    
    predictions = np.argmax(predictions, axis=-1)
    padding = (labels < 0)
    
    predictions = predictions[~padding]
    labels = labels[~padding]

    predictions = ner_tags_scheme[predictions]
    labels = ner_tags_scheme[labels]

    evaluator = nervaluate.Evaluator([labels], [predictions], tags=['PER', 'LOC', 'ORG', 'MISC'], loader='list')
    results, results_per_tag, _, _ = evaluator.evaluate()

    overall_metrics = results["strict"]
    
    metrics = {
        'overall_precision': overall_metrics['precision'],
        'overall_recall': overall_metrics['recall'],
        'overall_f1': overall_metrics['f1'],
    }
    
    for tag, tag_metrics in results_per_tag.items():
        metrics[f'{tag}_precision'] = tag_metrics['strict']['precision']
        metrics[f'{tag}_recall'] = tag_metrics['strict']['recall']
        metrics[f'{tag}_f1'] = tag_metrics['strict']['f1']

    # Return desired metrics
    return metrics

## 2.2 Train a conventional model

In [20]:
n_run = 0

In [24]:
n_run += 1
model = XLMRobertaForTokenClassification.from_pretrained('facebook/xlm-v-base', num_labels=9)
xlm_tok = AutoTokenizer.from_pretrained('facebook/xlm-v-base')
trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=f'xlm-v-base-finetuned-l12-conll03/{datetime.now().strftime("%m-%d")}/{n_run}',
        overwrite_output_dir=True,
        eval_strategy='steps',
        eval_delay=0.001,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=128,
        learning_rate=2e-5,
        max_steps=10000,
        lr_scheduler_type='cosine',
        lr_scheduler_kwargs={ "num_cycles": 1 },
        warmup_ratio=0.1,
        logging_steps=250,
        bf16=True,
        eval_steps=500,
        dataloader_num_workers=4,
        torch_compile=True,
        include_num_input_tokens_seen=True,
        disable_tqdm=True
    ),
    data_collator=partial(collate_fn, pad_token=xlm_tok.pad_token_id),
    train_dataset=Dataset(load_from_disk('data/train'), tokenizer_name='xlm-v'),
    eval_dataset=Dataset(load_from_disk('data/valid'), tokenizer_name='xlm-v'),
    compute_metrics=compute_ner_metrics
)
trainer.train()

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
max_steps is given, it will override any value given in num_train_epochs


{'loss': 2.2377, 'grad_norm': 7.369006156921387, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05694760820045558, 'num_input_tokens_seen': 42848}
{'loss': 2.1255, 'grad_norm': 21.421642303466797, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.11389521640091116, 'num_input_tokens_seen': 85568}
{'loss': 1.6144, 'grad_norm': 15.64748764038086, 'learning_rate': 3e-06, 'epoch': 0.17084282460136674, 'num_input_tokens_seen': 132768}
{'loss': 1.2593, 'grad_norm': 5.6847381591796875, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.22779043280182232, 'num_input_tokens_seen': 174496}
{'eval_loss': 1.1656874418258667, 'eval_overall_precision': 0.0022560631697687537, 'eval_overall_recall': 0.0006723819129265423, 'eval_overall_f1': 0.001036001036001036, 'eval_PER_precision': 0.0023937761819269898, 'eval_PER_recall': 0.002171552660152009, 'eval_PER_f1': 0.0022772559066325075, 'eval_LOC_precision': 0.0, 'eval_LOC_recall': 0.0, 'eval_LOC_f1': 0, 'eval_ORG_precision': 0.0, 'eval_ORG_recall

KeyboardInterrupt: 

In [14]:
test_results_l12 = trainer.evaluate(Dataset(load_from_disk('data/test'), tokenizer_name='xlm-v'))
test_results_l12

KeyboardInterrupt: 

zsh:1: no matches found: transformers[torch]
