In [1]:
from datetime import datetime

from datasets import load_dataset, load_from_disk
from transformers import BatchEncoding, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments
from transformers.data import data_collator

from modelling_xlm_roberta import XLMRobertaForTokenClassification
import nervaluate

from functools import partial
import torch

from typing import Iterable
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

from sklearn.metrics import precision_score, recall_score, f1_score

import numpy as np
import wandb

model_name = 'facebook/xlm-v-base' # 'FacebookAI/xlm-roberta-base' 
device = 'cuda'
model_dtype = torch.bfloat16
torch.cuda.get_device_name(0)

lr = 2e-5
steps = 2500
batch_size = 16

# 1. Test that layer cutting works

In [2]:
model_test = XLMRobertaForTokenClassification.from_pretrained(model_name)
model_test

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


XLMRobertaForTokenClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(901629, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768

In [3]:
model_test = XLMRobertaForTokenClassification.from_pretrained(model_name, skip_last_layer=True)
model_test

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


XLMRobertaForTokenClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(901629, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-10): 11 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768

Works! Passing `skip_last_layer=True` removes the last layer in the transformer stack (11 x XLMRobertaLayer instead of 12 x XLMRobertaLayer)

# 2. Train models on the downstream tagging task and evaluate the knowledge transfer to a different language
For this we will use CoNLL 2003 corpus (`eriktks/conll2003`, 14k examples) to train the model and Afrikaans NER Corpus (`nwu-ctext/afrikaans_ner_corpus`, 9k examples) to test the model. The validation is done over CoNLL 2003, only the final scores for Afrikaans are reported.

In [4]:
train_dataset = load_dataset('eriktks/conll2003', split='train')
valid_dataset = load_dataset('eriktks/conll2003', split='validation')
test_dataset = load_dataset('nwu-ctext/afrikaans_ner_corpus', split='train')

Make sure that the labelling scheme is identical across datasets

In [5]:
train_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [6]:
valid_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [7]:
test_dataset.features['ner_tags']

Sequence(feature=ClassLabel(names=['OUT', 'B-PERS', 'I-PERS', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

The names are a bit different, but otherwise the schemes are identical

## 2.1 Convert word-level tags to subtoken-level tags

In [8]:
tok = AutoTokenizer.from_pretrained(model_name)
tok_name = model_name.replace('/', '__')

tok('test <mask> test', return_offsets_mapping=True)



{'input_ids': [0, 1340, 901628, 1340, 2], 'attention_mask': [1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 4), (4, 11), (11, 16), (0, 0)]}

In [9]:
# for reference
ner_tags_scheme = np.array(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])
ner_tags_ext    =          [  0,       2,       2,       4,       4,       6,       6,        8,        8]
# the ext is used when we need to split one word into multiple sub tokens

In [10]:
def tokenize(example: dict, tokenizer: PreTrainedTokenizer, tokenizer_name: str, max_length: int = 512) -> dict:
    ner_tags: list[int] = example['ner_tags']
    example_words: list[str] = example['tokens']
    text = ' '.join(example_words)
    
    # map words to positions in text
    word_positions: list[int] = example.get('word_positions', [])
    
    if len(word_positions) != len(example_words):
        text_iterator = 0
        for word in example_words:
            while text[text_iterator:text_iterator + len(word)] != word:
                text_iterator += 1
                assert text_iterator < len(text)
            
            word_positions.append(text_iterator)
    
    encoding: BatchEncoding = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=max_length)
    num_sub_tokens = len(encoding.offset_mapping)
    
    sub_token_iterator = 0
    sub_token_ner_tags: list[int] = []
    for word_id, ner_tag in enumerate(ner_tags):
        word_start = word_positions[word_id]
        word_end = word_start + len(example_words[word_id])
        
        # there may be some empty space between words. the sub tokens that include this empty space receive O label
        # we compare with the end ([1]) to ensure that 0-length tokens are labelled as -100 (for example <CLS>)
        while sub_token_iterator < num_sub_tokens and  encoding.offset_mapping[sub_token_iterator][1] <= word_start:
            if encoding.offset_mapping[sub_token_iterator][1] - encoding.offset_mapping[sub_token_iterator][0] == 0:
                # set to -100 for special tokens like <CLS>
                sub_token_ner_tags.append(-100)
            else:
                sub_token_ner_tags.append(0)  # 0 = O
            sub_token_iterator += 1
            
        ext_tag = ner_tags_ext[ner_tag]
        
        if sub_token_iterator < num_sub_tokens:
            # the first sub token of a word receives original label, the rest receive extended label
            sub_token_ner_tags.append(ner_tag)
            sub_token_iterator += 1
        
        # again, we need to be careful about 0-length tokens, so we compare start ([0]) with the word end
        while sub_token_iterator < num_sub_tokens and encoding.offset_mapping[sub_token_iterator][0] < word_end:
            
            # there is a weird quirk with transformers tokenizers: <SEP> token has (0, 0) offset 
            #   regardless of its real position, see https://github.com/huggingface/transformers/issues/35125
            if encoding.offset_mapping[sub_token_iterator][1] - encoding.offset_mapping[sub_token_iterator][0] == 0:
                sub_token_ner_tags.append(-100)
            else:
                sub_token_ner_tags.append(ext_tag)
                
            sub_token_iterator += 1
    
    # any tokens at the end (like <SEP>) receive O tokens
    while sub_token_iterator < num_sub_tokens:
        sub_token_iterator += 1
        sub_token_ner_tags.append(0)
        
    return {
        'word_positions': word_positions,
        f'{tokenizer_name}_sub_tokens': encoding.input_ids,
        f'{tokenizer_name}_sub_token_offsets': encoding.offset_mapping,
        f'{tokenizer_name}_sub_token_ner_tags': sub_token_ner_tags,
        'length': len(encoding.offset_mapping)
    }

tokenize_fn = partial(tokenize, tokenizer=tok, tokenizer_name=tok_name, max_length=512)

train_dataset = train_dataset.map(tokenize_fn)
valid_dataset = valid_dataset.map(tokenize_fn)
test_dataset = test_dataset.map(tokenize_fn)

In [11]:
for test_idx in range(25):
    ner_tags = torch.as_tensor(train_dataset[test_idx][f'{tok_name}_sub_token_ner_tags'])
    tokens = torch.as_tensor(train_dataset[test_idx][f'{tok_name}_sub_tokens'])
    print('Text:', ' '.join(train_dataset[test_idx]['tokens']))
    print('Ents:', tok.decode(tokens[ner_tags > 0]))
    print()

Text: EU rejects German call to boycott British lamb .
Ents: EU German British

Text: Peter Blackburn
Ents: Peter Blackburn

Text: BRUSSELS 1996-08-22
Ents: BRUSSELS

Text: The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep .
Ents: European Commission German British

Text: Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer .
Ents: Germany European Union Werner Zwingmann Britain

Text: " We do n't support any such recommendation because we do n't see any grounds for it , " the Commission 's chief spokesman Nikolaus van der Pas told a news briefing .
Ents: Commission Nikolaus van der Pas

Text: He said further scientific study was required and if it was found that action was needed it should be ta

Looks nice!

In [12]:
train_dataset.save_to_disk('data/train')
valid_dataset.save_to_disk('data/valid')
test_dataset.save_to_disk('data/test')

Saving the dataset (0/1 shards):   0%|          | 0/14041 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8962 [00:00<?, ? examples/s]

In [13]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, examples: Iterable[dict], tokenizer_name: str):
        self.input_ids = []
        self.labels = []
        
        for example in examples:
            self.input_ids.append(torch.as_tensor(example[f'{tokenizer_name}_sub_tokens']))
            self.labels.append(torch.as_tensor(example[f'{tokenizer_name}_sub_token_ner_tags']))
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.labels[idx]
    

def collate_fn(inputs: list[(Tensor, Tensor)], *, pad_token: int) -> dict:
    all_input_ids = []
    all_labels = []
    for input_ids, labels in inputs:
        all_input_ids.append(input_ids)
        all_labels.append(labels)
    
    input_ids = pad_sequence(all_input_ids, batch_first=True, padding_value=pad_token)
    
    batch_size, seq_length = input_ids.shape

    # do not attend to pad and pad does not attend to anything
    pad_mask = (input_ids != pad_token).long()
    return {
        'input_ids': input_ids,
        'labels': pad_sequence(all_labels, batch_first=True, padding_value=-100),
        'attention_mask': pad_mask
    }

In [14]:
def compute_ner_metrics(eval_pred) -> dict:
    predictions, labels = eval_pred
    
    predictions = np.argmax(predictions, axis=-1)
    padding = (labels < 0)
    
    predictions = predictions[~padding]
    labels = labels[~padding]

    predictions = ner_tags_scheme[predictions]
    labels = ner_tags_scheme[labels]
    
    token_precision = precision_score(labels, predictions, average='macro', zero_division=0)
    token_recall = recall_score(labels, predictions, average='macro', zero_division=0)
    token_f1 = f1_score(labels, predictions, average='macro', zero_division=0)

    evaluator = nervaluate.Evaluator([labels], [predictions], tags=['PER', 'LOC', 'ORG', 'MISC'], loader='list')
    results, results_per_tag, _, _ = evaluator.evaluate()

    overall_metrics = results['strict']
    
    metrics = {
        'token_precision_macro': token_precision,
        'token_recall_macro': token_recall,
        'token_f1_macro': token_f1,
        'overall_precision': overall_metrics['precision'],
        'overall_recall': overall_metrics['recall'],
        'overall_f1': overall_metrics['f1'],
    }
    
    for tag, tag_metrics in results_per_tag.items():
        metrics[f'{tag}_precision'] = tag_metrics['strict']['precision']
        metrics[f'{tag}_recall'] = tag_metrics['strict']['recall']
        metrics[f'{tag}_f1'] = tag_metrics['strict']['f1']

    # Return desired metrics
    return metrics

## 2.2 Train a conventional model

In [15]:
n_run = 0

In [16]:
run_name = f'{tok_name}-finetuned-l12-conll03/{datetime.now().strftime("%m-%d")}/{n_run}'
wandb.init(
    project='ner-alignment',
    name=run_name,
    dir=run_name,
    resume=False
)
n_run += 1

model = XLMRobertaForTokenClassification.from_pretrained(
    model_name, 
    num_labels=9,
)
model.roberta.embeddings.requires_grad_(False)  # freeze input embeddings to avoid parameter shift (training on english and inferencing on africaans -> different tokens are activated)
print(f"Percentage of frozen modules: {100 * sum(1 for module in model.modules() if not any(p.requires_grad for p in module.parameters())) / sum(1 for module in model.modules()):.2f}%")
print(f"Percentage of frozen parameters: {100 * sum(p.numel() for p in model.parameters() if not p.requires_grad) / sum(p.numel() for p in model.parameters()):.2f}%")


tok = AutoTokenizer.from_pretrained(model_name)


trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=run_name,
        overwrite_output_dir=True,
        eval_strategy='steps',
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        max_steps=steps,
        lr_scheduler_type='cosine_with_min_lr',
        lr_scheduler_kwargs={ 'num_cycles': 0.5, 'min_lr_rate': 0.01 },
        warmup_ratio=0.1,
        adam_epsilon=1e-8,
        adam_beta1=0.9,
        adam_beta2=0.999,
        weight_decay=0.0,
        logging_steps=100,
        eval_steps=200,
        bf16=True,
        torch_compile=False,
        include_num_input_tokens_seen=True,
        disable_tqdm=True,
        report_to='wandb'
    ),
    data_collator=partial(collate_fn, pad_token=tok.pad_token_id),
    train_dataset=Dataset(load_from_disk('data/train'), tokenizer_name=tok_name),
    eval_dataset=Dataset(load_from_disk('data/valid'), tokenizer_name=tok_name),
    compute_metrics=compute_ner_metrics
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mviktoroo-sch[0m ([33mviktoroo-sch-epfl[0m). Use [1m`wandb login --relogin`[0m to force relogin


Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Percentage of frozen modules: 24.12%
Percentage of frozen parameters: 89.07%


max_steps is given, it will override any value given in num_train_epochs


{'loss': 1.6795, 'grad_norm': 8.110573768615723, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.11389521640091116, 'num_input_tokens_seen': 76672}
{'loss': 0.7058, 'grad_norm': 2.834045171737671, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.22779043280182232, 'num_input_tokens_seen': 158272}
{'eval_loss': 0.5673161149024963, 'eval_token_precision_macro': 0.08892700414117066, 'eval_token_recall_macro': 0.1111111111111111, 'eval_token_f1_macro': 0.09878895554926144, 'eval_overall_precision': 0.0, 'eval_overall_recall': 0.0, 'eval_overall_f1': 0, 'eval_PER_precision': 0.0, 'eval_PER_recall': 0.0, 'eval_PER_f1': 0, 'eval_LOC_precision': 0, 'eval_LOC_recall': 0.0, 'eval_LOC_f1': 0, 'eval_ORG_precision': 0, 'eval_ORG_recall': 0.0, 'eval_ORG_f1': 0, 'eval_MISC_precision': 0, 'eval_MISC_recall': 0.0, 'eval_MISC_f1': 0, 'eval_runtime': 0.5682, 'eval_samples_per_second': 5719.996, 'eval_steps_per_second': 45.76, 'epoch': 0.22779043280182232, 'num_input_tokens_seen': 158272}
{'loss': 0

TrainOutput(global_step=2500, training_loss=0.20046089973449707, metrics={'train_runtime': 160.8835, 'train_samples_per_second': 248.627, 'train_steps_per_second': 15.539, 'train_loss': 0.20046089973449707, 'epoch': 2.847380410022779, 'num_input_tokens_seen': 1950135})

In [17]:
test_results_l12 = trainer.evaluate(Dataset(load_from_disk('data/test'), tokenizer_name=tok_name), metric_key_prefix='transfer')
test_results_l12

{'transfer_loss': 0.42435064911842346, 'transfer_token_precision_macro': 0.7183066027403103, 'transfer_token_recall_macro': 0.6434193935372868, 'transfer_token_f1_macro': 0.6434778181808135, 'transfer_overall_precision': 0.5156887354879197, 'transfer_overall_recall': 0.4577675649328041, 'transfer_overall_f1': 0.4850049798959755, 'transfer_PER_precision': 0.4577421344848859, 'transfer_PER_recall': 0.3596703829374697, 'transfer_PER_f1': 0.4028230184581976, 'transfer_LOC_precision': 0.7540342298288508, 'transfer_LOC_recall': 0.8449315068493151, 'transfer_LOC_f1': 0.7968992248062016, 'transfer_ORG_precision': 0.6072380106571936, 'transfer_ORG_recall': 0.7629009762900977, 'transfer_ORG_f1': 0.6762269749041909, 'transfer_MISC_precision': 0.3396679772826562, 'transfer_MISC_recall': 0.22575493612078978, 'transfer_MISC_f1': 0.27123669980812837, 'transfer_runtime': 26.7937, 'transfer_samples_per_second': 334.481, 'transfer_steps_per_second': 2.65, 'epoch': 2.847380410022779, 'num_input_tokens_se

{'transfer_loss': 0.42435064911842346,
 'transfer_token_precision_macro': 0.7183066027403103,
 'transfer_token_recall_macro': 0.6434193935372868,
 'transfer_token_f1_macro': 0.6434778181808135,
 'transfer_overall_precision': 0.5156887354879197,
 'transfer_overall_recall': 0.4577675649328041,
 'transfer_overall_f1': 0.4850049798959755,
 'transfer_PER_precision': 0.4577421344848859,
 'transfer_PER_recall': 0.3596703829374697,
 'transfer_PER_f1': 0.4028230184581976,
 'transfer_LOC_precision': 0.7540342298288508,
 'transfer_LOC_recall': 0.8449315068493151,
 'transfer_LOC_f1': 0.7968992248062016,
 'transfer_ORG_precision': 0.6072380106571936,
 'transfer_ORG_recall': 0.7629009762900977,
 'transfer_ORG_f1': 0.6762269749041909,
 'transfer_MISC_precision': 0.3396679772826562,
 'transfer_MISC_recall': 0.22575493612078978,
 'transfer_MISC_f1': 0.27123669980812837,
 'transfer_runtime': 26.7937,
 'transfer_samples_per_second': 334.481,
 'transfer_steps_per_second': 2.65,
 'epoch': 2.847380410022779

In [18]:
wandb.finish()

VBox(children=(Label(value='0.021 MB of 0.021 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/LOC_f1,▁▅▇▇████████
eval/LOC_precision,▁▅▇▇████████
eval/LOC_recall,▁▅█▇████████
eval/MISC_f1,▁▁▅▇▇███████
eval/MISC_precision,▁▁▅▇▇███████
eval/MISC_recall,▁▁▅▇▇███████
eval/ORG_f1,▁▂▇█████████
eval/ORG_precision,▁▁▇█████████
eval/ORG_recall,▁▂▇█████████
eval/PER_f1,▁███████████

0,1
eval/LOC_f1,0.96485
eval/LOC_precision,0.95862
eval/LOC_recall,0.97115
eval/MISC_f1,0.88488
eval/MISC_precision,0.86966
eval/MISC_recall,0.90065
eval/ORG_f1,0.90849
eval/ORG_precision,0.8976
eval/ORG_recall,0.91964
eval/PER_f1,0.97539


## 2.3 Train a truncated model (without last layer)

In [19]:
run_name = f'{tok_name}-finetuned-l11-conll03/{datetime.now().strftime("%m-%d")}/{n_run}'
wandb.init(
    project='ner-alignment',
    name=run_name,
    dir=run_name,
    resume=False
)
n_run += 1

model = XLMRobertaForTokenClassification.from_pretrained(
    model_name, 
    num_labels=9, 
    skip_last_layer=True
)
model.roberta.embeddings.requires_grad_(False)  # freeze input embeddings to avoid parameter shift (training on english and inferencing on africaans -> different tokens are activated)
print(f"Percentage of frozen modules: {100 * sum(1 for module in model.modules() if not any(p.requires_grad for p in module.parameters())) / sum(1 for module in model.modules()):.2f}%")
print(f"Percentage of frozen parameters: {100 * sum(p.numel() for p in model.parameters() if not p.requires_grad) / sum(p.numel() for p in model.parameters()):.2f}%")


tok = AutoTokenizer.from_pretrained(model_name)


trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=run_name,
        overwrite_output_dir=True,
        eval_strategy='steps',
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        max_steps=steps,
        lr_scheduler_type='cosine_with_min_lr',
        lr_scheduler_kwargs={ 'num_cycles': 0.5, 'min_lr_rate': 0.01 },
        warmup_ratio=0.1,
        adam_epsilon=1e-8,
        adam_beta1=0.9,
        adam_beta2=0.999,
        weight_decay=0.0,
        logging_steps=100,
        eval_steps=200,
        torch_compile=False,
        bf16=True,
        include_num_input_tokens_seen=True,
        disable_tqdm=True,
        report_to='wandb'
    ),
    data_collator=partial(collate_fn, pad_token=tok.pad_token_id),
    train_dataset=Dataset(load_from_disk('data/train'), tokenizer_name=tok_name),
    eval_dataset=Dataset(load_from_disk('data/valid'), tokenizer_name=tok_name),
    compute_metrics=compute_ner_metrics
)
trainer.train()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112837596899933, max=1.0…

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at facebook/xlm-v-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Percentage of frozen modules: 24.29%
Percentage of frozen parameters: 89.88%


max_steps is given, it will override any value given in num_train_epochs


{'loss': 1.3923, 'grad_norm': 5.5005974769592285, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.11389521640091116, 'num_input_tokens_seen': 76672}
{'loss': 0.2907, 'grad_norm': 2.5719411373138428, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.22779043280182232, 'num_input_tokens_seen': 158272}
{'eval_loss': 0.1464698612689972, 'eval_token_precision_macro': 0.8515416048410184, 'eval_token_recall_macro': 0.7551059357327375, 'eval_token_f1_macro': 0.7532563089953259, 'eval_overall_precision': 0.7416327807319362, 'eval_overall_recall': 0.7971087577744159, 'eval_overall_f1': 0.7683707364498097, 'eval_PER_precision': 0.9471153846153846, 'eval_PER_recall': 0.9625407166123778, 'eval_PER_f1': 0.9547657512116317, 'eval_LOC_precision': 0.8446215139442231, 'eval_LOC_recall': 0.9232444202504083, 'eval_LOC_f1': 0.882184655396619, 'eval_ORG_precision': 0.605296343001261, 'eval_ORG_recall': 0.7142857142857143, 'eval_ORG_f1': 0.6552901023890785, 'eval_MISC_precision': 0.33728448275862066, 'e

TrainOutput(global_step=2500, training_loss=0.11646024713516236, metrics={'train_runtime': 150.9191, 'train_samples_per_second': 265.043, 'train_steps_per_second': 16.565, 'train_loss': 0.11646024713516236, 'epoch': 2.847380410022779, 'num_input_tokens_seen': 1950135})

In [20]:
test_results_l11 = trainer.evaluate(Dataset(load_from_disk('data/test'), tokenizer_name=tok_name), metric_key_prefix='transfer')
test_results_l11

{'transfer_loss': 0.41569140553474426, 'transfer_token_precision_macro': 0.7318125169933103, 'transfer_token_recall_macro': 0.6453559423396728, 'transfer_token_f1_macro': 0.6507366667293208, 'transfer_overall_precision': 0.5066474208007293, 'transfer_overall_recall': 0.46438270315437646, 'transfer_overall_f1': 0.4845952623165238, 'transfer_PER_precision': 0.44285714285714284, 'transfer_PER_recall': 0.36063984488608825, 'transfer_PER_f1': 0.3975420785466204, 'transfer_LOC_precision': 0.7189163038219643, 'transfer_LOC_recall': 0.8142465753424658, 'transfer_LOC_f1': 0.7636176772867421, 'transfer_ORG_precision': 0.6208838821490468, 'transfer_ORG_recall': 0.799442119944212, 'transfer_ORG_f1': 0.6989391537617364, 'transfer_MISC_precision': 0.3277083333333333, 'transfer_MISC_recall': 0.22836817653890826, 'transfer_MISC_f1': 0.2691649555099247, 'transfer_runtime': 27.5571, 'transfer_samples_per_second': 325.215, 'transfer_steps_per_second': 2.576, 'epoch': 2.847380410022779, 'num_input_tokens_

{'transfer_loss': 0.41569140553474426,
 'transfer_token_precision_macro': 0.7318125169933103,
 'transfer_token_recall_macro': 0.6453559423396728,
 'transfer_token_f1_macro': 0.6507366667293208,
 'transfer_overall_precision': 0.5066474208007293,
 'transfer_overall_recall': 0.46438270315437646,
 'transfer_overall_f1': 0.4845952623165238,
 'transfer_PER_precision': 0.44285714285714284,
 'transfer_PER_recall': 0.36063984488608825,
 'transfer_PER_f1': 0.3975420785466204,
 'transfer_LOC_precision': 0.7189163038219643,
 'transfer_LOC_recall': 0.8142465753424658,
 'transfer_LOC_f1': 0.7636176772867421,
 'transfer_ORG_precision': 0.6208838821490468,
 'transfer_ORG_recall': 0.799442119944212,
 'transfer_ORG_f1': 0.6989391537617364,
 'transfer_MISC_precision': 0.3277083333333333,
 'transfer_MISC_recall': 0.22836817653890826,
 'transfer_MISC_f1': 0.2691649555099247,
 'transfer_runtime': 27.5571,
 'transfer_samples_per_second': 325.215,
 'transfer_steps_per_second': 2.576,
 'epoch': 2.8473804100227

In [21]:
wandb.finish()

VBox(children=(Label(value='0.021 MB of 0.021 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/LOC_f1,▁▄▄▅▇▇▇█████
eval/LOC_precision,▁▄▅▆▇▇▇█████
eval/LOC_recall,▁▂▂▃▇▆▆█▇▇▇▇
eval/MISC_f1,▁▆▇▇████████
eval/MISC_precision,▁▆▇▇████████
eval/MISC_recall,▁▇▇▇████████
eval/ORG_f1,▁▆▇▇▇███████
eval/ORG_precision,▁▆▆▇▇██▇████
eval/ORG_recall,▁▇▇████▇████
eval/PER_f1,▁▄▄▁▆▅▇▆▇█▇▇

0,1
eval/LOC_f1,0.96628
eval/LOC_precision,0.95775
eval/LOC_recall,0.97496
eval/MISC_f1,0.88407
eval/MISC_precision,0.86708
eval/MISC_recall,0.90173
eval/ORG_f1,0.9174
eval/ORG_precision,0.90936
eval/ORG_recall,0.9256
eval/PER_f1,0.97863
