# Comma Restoration with Token Classification using BERT
This notebook fine-tunes a transformer encoder (e.g., mBERT, LvBERT) to restore commas in text as a token classification task. Each token receives one of two labels: COMMA (a comma should follow this word) or O (no comma). At inference time, existing commas are stripped, labels are predicted, and the sentence is rebuilt by inserting commas after tokens predicted as COMMA.
Models to try:
- https://huggingface.co/google-bert/bert-base-multilingual-cased
- https://huggingface.co/AiLab-IMCS-UL/lvbert
- https://huggingface.co/FacebookAI/xlm-roberta-base
- https://huggingface.co/EMBEDDIA/litlat-bert
- https://huggingface.co/jhu-clsp/mmBERT-small

# Prepare environment

In [None]:
# Authenticate with Weights & Biases to enable logging and experiment tracking.
# Comment out the following lines if you don't want to use W&B.
!pip install wandb
import wandb
wandb.login()

In [None]:
# Check if a CUDA device is available
!pip install torch
import torch
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0), torch.cuda.get_device_capability(0), 'bf16', torch.cuda.is_bf16_supported(False))
    free_mem, total_mem = torch.cuda.mem_get_info(torch.device('cuda:0'))
    print(f'Memory: {free_mem / 1024 ** 2:.2f} MB free / {total_mem / 1024 ** 2:.2f} MB total')
else:
    print('No CUDA device available')

CUDA device: Tesla T4 (7, 5) bf16 False
Memory: 14992.12 MB free / 15095.06 MB total


In [None]:
!python -V
!pip -V
!pip install numpy transformers[torch] scikit-learn datasets wandb

Python 3.12.11
pip 24.1.2 from /usr/local/lib/python3.12/dist-packages/pip (python 3.12)


In [None]:
import json
import re
from contextlib import nullcontext

import numpy as np
import requests
import torch
from datasets import load_dataset
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
)
import wandb

# Prepare dataset
Raw sentences from the Latvian Universal Dependencies (LVTB) corpus: https://universaldependencies.org/treebanks/lv_lvtb/index.html

In [None]:
def fetch_ud_texts(split, seed=42):
    conllu = requests.get(f'https://raw.githubusercontent.com/UniversalDependencies/UD_Latvian-LVTB/r2.16/lv_lvtb-ud-{split}.conllu').text
    texts = [line[9:].strip() for line in conllu.splitlines() if line.startswith('# text = ')]
    if seed:
        import random
        random.Random(seed).shuffle(texts)
    return texts

def prepare_data(max_chars=200, dev_txt='dev.txt', train_txt='train.txt'):
    # Download UD Latvian splits, filter by mBERT token count, and save plain .txt files.
    dev_texts = fetch_ud_texts('dev')
    train_texts = fetch_ud_texts('train')

    if max_chars:
        # Filter out long sentences to avoid truncation
        print('Sentence lengths before filtering:', 'DEV', len(dev_texts), 'TRAIN', len(train_texts))
        dev_texts = [t for t in dev_texts if len(t) <= max_chars]
        train_texts = [t for t in train_texts if len(t) <= max_chars]
    print('Dataset sentence lengths:', 'DEV', len(dev_texts), 'TRAIN', len(train_texts))

    with open(dev_txt, 'w') as f:
        for t in dev_texts: f.write(t + '\n')
    with open(train_txt, 'w') as f:
        for t in train_texts: f.write(t + '\n')

    return dev_texts, train_texts

dev_texts, train_texts = prepare_data()
print(*train_texts[:5], sep='\n')

Sentence lengths before filtering: DEV 2080 TRAIN 15055
Dataset sentence lengths: DEV 1912 TRAIN 13811
To tu man stāstīji jau pirms divām nedēļām.
Ka pieticis tikai autobusa biļetei un barankām.
Uz skatuves kāpa skolas koris, pēc tam uzstājās arī dramatiskā pulciņa dalībnieki un divi bērnudārza audzēkņi.
Burka esot jāizdekorē ar dillēm, mārrutku lapu, upeņu zariņu un ķiploka pusdaiviņām, jāsaliek gurķīši un jāaplej ar verdošu ūdeni, kurā iebērta ēdamkarote cukura un ēdamkarote sāls.
Izaugs sava raža, nevajadzēs lieku reizi braukt uz tirgu.


# Tokenization

In [None]:
def tokenize(s):
    # Tokenizes string into words and punctuation tokens.
    return re.findall(r'\s*(?:\w+|\S)', s)

def tokenize_with_comma_labels(s):
    tokens_with_labels = re.findall(r'(\s*\w+|[^\s,])\s*(,+)?', s)
    tokens_with_labels = [(tok, 'COMMA' if comma else 'O') for tok, comma in tokens_with_labels]
    return tokens_with_labels

def remove_commas(s) -> str:
    return re.sub(r'\s*,+\s*', ' ', s)

tokenize_with_comma_labels('Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.')

[('Vēl', 'O'),
 ('9', 'O'),
 ('%', 'O'),
 ('sacīja', 'COMMA'),
 (' ka', 'O'),
 ('nav', 'O'),
 ('izlēmuši', 'O'),
 ('kā', 'O'),
 ('balsot', 'COMMA'),
 (' bet', 'O'),
 ('3', 'COMMA'),
 ('2', 'O'),
 ('%', 'O'),
 ('atteicās', 'O'),
 ('atbildēt', 'O'),
 ('.', 'O')]

In [None]:
def test_tokenization(model=None):
    s = 'Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.'
    if model:
        print('Tokenizer stats', model)
        t = AutoTokenizer.from_pretrained(model)
        print('Encoded sample:', t(s))
        print('Encoded sample - subword units:', t.convert_ids_to_tokens(t.encode(s)))
        lengths = sorted([len(t.encode(seq)) for seq in train_texts])
        print(f'Max {max(lengths)}, min {min(lengths)}, avg {sum(lengths)/len(lengths)}')
        print(f'95% length: {lengths[int(len(lengths) * 0.95)]}')
        print(f'99% length: {lengths[int(len(lengths) * 0.99)]}')
        print(f'99.9% length: {lengths[int(len(lengths) * 0.999)]}')

test_tokenization('AiLab-IMCS-UL/lvbert')
test_tokenization('jhu-clsp/mmBERT-small')


LABELS = ['O', 'COMMA']
LABEL2ID = {name: i for i, name in enumerate(LABELS)}
ID2LABEL = {i: name for i, name in enumerate(LABELS)}

Tokenizer stats AiLab-IMCS-UL/lvbert


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Encoded sample: {'input_ids': [2, 574, 684, 70, 417, 5, 16, 35, 29811, 24, 4622, 5, 27, 168, 5, 146, 70, 6862, 4850, 6, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Encoded sample - subword units: ['[CLS]', 'Vēl', '9', '%', 'sacīja', ',', 'ka', 'nav', 'izlēmuši', 'kā', 'balsot', ',', 'bet', '3', ',', '2', '%', 'atteicās', 'atbildēt', '.', '[SEP]']
Max 65, min 3, avg 21.21772500181015
95% length: 40
99% length: 48
99.9% length: 56
Tokenizer stats jhu-clsp/mmBERT-small
Encoded sample: {'input_ids': [2, 744, 229673, 235248, 235315, 235358, 6817, 236073, 1663, 235269, 5675, 5103, 9417, 135924, 2704, 27536, 52635, 70402, 562, 235269, 1285, 235248, 235304, 235269, 235284, 235358, 41643, 520, 28688, 696, 137369, 235265, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Encoded sample - subword unit

In [None]:
# Tokenize text into subwords and align word-level labels to the correct subword positions

def tokenize_and_align_labels(tokenizer, words, word_labels=None, label2id=None, debug=False, return_tensors=None):
    # Tokenize with word boundaries preserved
    enc = tokenizer(
        list(words),
        is_split_into_words=True,
        add_special_tokens=True,
        return_tensors=return_tensors,
        truncation=False
    )

    # Map each token back to its source word index
    word_ids = enc.word_ids()  # one per token position (None for specials)

    # Figure out which token is the last subword of each word. Assign word labels only there; others get -100.
    # HF Trainer and loss functions (like cross-entropy) automatically ignore -100, so you don't need to modify loss computation.
    if word_labels is not None:
        IGNORE = -100
        labels = [IGNORE] * len(word_ids)
        for i, wid in enumerate(word_ids):
            if wid is None:
                continue
            next_wid = word_ids[i+1] if i+1 < len(word_ids) else None
            if wid != next_wid:
                # last subword of this word: assign the word label
                labels[i] = label2id[word_labels[wid]]
    else:
        labels = None

    if debug:
        input_ids = enc['input_ids']
        if return_tensors == 'pt':
            input_ids = input_ids.tolist()[0]
        print('WORDS:         ', words)
        print('WORD_LABELS:   ', word_labels)
        print('WORD_IDS:      ', word_ids)
        print('TOKEN_IDS:     ', input_ids)
        print('TOKENS:        ', tokenizer.convert_ids_to_tokens(input_ids))
        print('ALIGNED_LABELS:', labels)

    r = {
        'input_ids': enc['input_ids'],
        'attention_mask': enc['attention_mask'],
    }
    if labels is not None:
        if return_tensors == 'pt':
            labels = torch.tensor([labels], dtype=torch.long)
        r['labels'] = labels
    return r

print(tokenize_and_align_labels(AutoTokenizer.from_pretrained('AiLab-IMCS-UL/lvbert'), *zip(*tokenize_with_comma_labels('Viens, divi.')), LABEL2ID, debug=True))

WORDS:          ('Viens', ' divi', '.')
WORD_LABELS:    ('COMMA', 'O', 'O')
WORD_IDS:       [None, 0, 1, 2, None]
TOKEN_IDS:      [2, 1394, 516, 6, 3]
TOKENS:         ['[CLS]', 'Viens', 'divi', '.', '[SEP]']
ALIGNED_LABELS: [-100, 1, 0, 0, -100]
{'input_ids': [2, 1394, 516, 6, 3], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [-100, 1, 0, 0, -100]}


# Tokenize and format dataset for model training and evaluation

In [None]:
def build_dataset(*, tokenizer, train_file='train.txt', dev_file='dev.txt', train_samples=None, dev_samples=None, max_length=100, label2id=None):
    ds = load_dataset('text', data_files={'train': train_file, 'dev': dev_file})
    if train_samples:
        ds['train'] = ds['train'].take(train_samples)
    if dev_samples:
        ds['dev'] = ds['dev'].take(dev_samples)

    def _map(example):
        words, word_labels = zip(*tokenize_with_comma_labels(example['text']))
        return tokenize_and_align_labels(tokenizer, words=words, word_labels=word_labels, label2id=label2id)

    ds_tokenized = ds.map(_map, remove_columns=ds['train'].column_names)

    if max_length is not None:
        ds_tokenized = ds_tokenized.filter(lambda ex: len(ex['input_ids']) <= max_length)

    return ds_tokenized

tok = AutoTokenizer.from_pretrained('AiLab-IMCS-UL/lvbert')
ds = build_dataset(tokenizer=tok, train_samples=2, dev_samples=2, label2id=LABEL2ID)
loader = DataLoader(ds['train'], batch_size=2, shuffle=False, collate_fn=DataCollatorForTokenClassification(tok))
batch = next(iter(loader))
print(batch)

Generating train split: 0 examples [00:00, ? examples/s]

Generating dev split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2 [00:00<?, ? examples/s]

{'input_ids': tensor([[    2,   317,   277,   100, 26927,    38,   134,  1516,  7068,     6,
             3,     0,     0,     0,     0],
        [    2,  1105,    41,  1464,    61,    55,  8697, 10471,    12,     8,
          1984,  1209,  4887,     6,     3]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[-100,    0,    0,    0,    0,    0,    0,    0,    0,    0, -100, -100,
         -100, -100, -100],
        [-100,    0, -100, -100,    0,    0,    0, -100,    0,    0, -100, -100,
            0,    0, -100]])}


# Metrics for token classification.
*Accuracy* can be misleading for imbalanced tasks:
  - In our data, most tokens are "O" (no comma).
  - A dumb model that always predicts "O" could reach very high accuracy (e.g. 95%+) simply by never predicting commas at all.

*F1-score* (the harmonic mean of precision and recall) specifically for the COMMA class gives a more honest view of model quality:
  - Precision: when the model predicts COMMA, is it right?
  - Recall: does the model catch most of the true commas?
  - F1: balances both, penalizing if one is much lower.

In [None]:
def compute_metrics_fn(p):
    # Model outputs: shape [batch_size, seq_len, num_labels]
    # -> pick the most likely label for each token
    preds = np.argmax(p.predictions, axis=-1)

    # True labels: shape [batch_size, seq_len]
    labels = p.label_ids

    # Flatten but skip positions marked with -100
    y_ref = []
    y_pred = []
    for ref_seq, pred_seq in zip(labels, preds):
        for t, p_ in zip(ref_seq, pred_seq):
            if t == -100:
                continue
            y_ref.append(t)
            y_pred.append(p_)

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_ref, y_pred,
        average='binary', pos_label=1, # for binary classification (COMMA vs O)
        # average='micro', # for multi-class classification
    )
    acc = accuracy_score(y_ref, y_pred)
    return {
        'f1': f1,
        'p': precision,
        'r': recall,
        'acc': acc,
    }

# Inference
Given plain text, we strip commas, tokenize with word boundaries, run the model, and insert commas after tokens labeled COMMA.

In [None]:
def process_text(text, model, tokenizer, verbose=True):
    #  Preprocess: remove commas, split into words
    input_text = remove_commas(text)

    # Tokenize with subword alignment
    words = tokenize(input_text)
    enc = tokenizer(
        words,
        is_split_into_words=True,
        add_special_tokens=True,
        return_tensors='pt',
        truncation=False
    )
    word_ids = enc.word_ids()
    # Move to the same device
    device = next(model.parameters()).device
    enc = {k: v.to(device) for k, v in enc.items()}

    # Forward pass
    model.eval()
    with torch.no_grad():
        logits = model(**enc).logits  # [1, seq_len, label_count]
        pred_ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()

    # Collapse subwords -> last subword gets the label
    word_preds = {}
    for i, wid in enumerate(word_ids):
        if wid is None:  # skip [CLS], [SEP], etc.
            continue
        next_wid = word_ids[i + 1] if i + 1 < len(word_ids) else None
        if wid != next_wid:  # last subword of the word
            word_preds[wid] = pred_ids[i]

    # Return word-level predictions
    results = [(w, model.config.id2label[word_preds[i]]) for i, w in enumerate(words)]
    output_text = ''.join([w + (',' if label == 'COMMA' else '') for w, label in results])

    if verbose:
        print(f'REF: {text}')
        print(f' IN: {input_text}')
        print(f'OUT: {output_text}')
    return output_text

# Model fine-tuning
- Track loss curves, gradient norms, and evaluation metrics over time
- Use an appropriate optimizer and learning rate schedule (e.g., warmup + decay)
- Watch for overfitting (gap between train and eval performance)
- Adjust batch size, accumulation steps, or precision (fp16/bf16) if needed
- Save best checkpoints based on validation metric (e.g., F1)

In [None]:
def main(
    name,
    base_model='AiLab-IMCS-UL/lvbert',
    max_len=100,
    seed=42,
    verbose=True,
    lr=5e-6,
    bs=32,
    train_samples=None,
    dev_samples=None,
    epochs=3,
    report_wandb=True,
    wandb_group=None,
    save=True,
):
    if report_wandb and not wandb.api.api_key:
        print('Not authenticated with W&B')
        report_wandb = False

    with wandb.init(project='punctuator', group=wandb_group, name=name) if report_wandb else nullcontext():
        print('Train:', locals())
        set_seed(seed)
        tokenizer = AutoTokenizer.from_pretrained(base_model)

        ds = build_dataset(tokenizer=tokenizer, train_samples=train_samples, dev_samples=dev_samples, max_length=max_len, label2id=LABEL2ID)

        # Initialize base model for token classification task
        model = AutoModelForTokenClassification.from_pretrained(base_model, num_labels=len(LABELS), id2label=ID2LABEL, label2id=LABEL2ID)
        model.config.use_cache = False

        # Define training hyperparameters
        training_args = TrainingArguments(
            output_dir=name,
            learning_rate=lr,
            per_device_train_batch_size=bs,
            per_device_eval_batch_size=bs,
            num_train_epochs=epochs,
            eval_strategy='epoch',
            save_strategy='epoch' if save else 'no',
            load_best_model_at_end=save,
            metric_for_best_model='f1',
            greater_is_better=True,
            warmup_ratio=0.05,
            gradient_accumulation_steps=1,
            fp16=True,

            logging_steps=20,
            report_to='wandb' if report_wandb else 'none',
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=ds['train'],
            eval_dataset=ds['dev'],
            processing_class=tokenizer,
            data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
            compute_metrics=compute_metrics_fn,
        )

        # Actual training
        trainer.train()
        if save:
            trainer.save_model(name)
            tokenizer.save_pretrained(name)

        process_text('Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.', trainer.model, tokenizer)

In [None]:
# Use only 3000 samples for training to run a quick experiment
main('bert_punctuator_sample', train_samples=3000)

Train: {'name': 'bert_punctuator_sample', 'base_model': 'AiLab-IMCS-UL/lvbert', 'max_len': 100, 'seed': 42, 'verbose': True, 'lr': 5e-06, 'bs': 32, 'train_samples': 3000, 'dev_samples': None, 'epochs': 3, 'report_wandb': True, 'wandb_group': None, 'save': True}


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1912 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1912 [00:00<?, ? examples/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at AiLab-IMCS-UL/lvbert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,F1,P,R,Acc
1,0.1671,0.131554,0.679846,0.890007,0.549978,0.95327
2,0.0877,0.095823,0.784766,0.9,0.695691,0.965574
3,0.0754,0.089201,0.804215,0.896721,0.729009,0.967979


REF: Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.
 IN: Vēl 9% sacīja ka nav izlēmuši kā balsot bet 3 2% atteicās atbildēt.
OUT: Vēl 9% sacīja, ka nav izlēmuši, kā balsot, bet 3, 2% atteicās atbildēt.


0,1
eval/acc,▁▇█
eval/f1,▁▇█
eval/loss,█▂▁
eval/p,▁█▆
eval/r,▁▇█
eval/runtime,█▃▁
eval/samples_per_second,▁▆█
eval/steps_per_second,▁▆█
train/epoch,▁▂▂▃▃▃▄▄▅▅▅▆▆▇▇███
train/global_step,▁▂▂▃▃▃▄▄▅▅▅▆▆▇▇███

0,1
eval/acc,0.96798
eval/f1,0.80421
eval/loss,0.0892
eval/p,0.89672
eval/r,0.72901
eval/runtime,1.3395
eval/samples_per_second,1427.427
eval/steps_per_second,44.794
total_flos,191873474680704.0
train/epoch,3


# Inference

In [None]:
m = AutoModelForTokenClassification.from_pretrained('bert_punctuator_sample')
t = AutoTokenizer.from_pretrained('bert_punctuator_sample')
process_text('Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.', m, t)
process_text('Nogalināt nedrīkst, apžēlot!', m, t)

# Hyperparameter optimization
- Use smaller experiments (1 epoch, limited data) for faster iteration  
- Try random or Bayesian search for hyperparameter tuning
- Limit training/eval samples when testing setups  
- Scale up once the pipeline works end-to-end

In [None]:
for lr in [1e-6, 3e-6, 1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3]:
    main(f'bert_sweep_lr{lr:.2e}', lr=lr, train_samples=3000, dev_samples=100, epochs=1, wandb_group='bert_sweep', save=False)