# Comma Restoration with Encoder-Decoder Models (Seq2Seq)
This notebook fine-tunes sequence-to-sequence transformers to restore commas by generating the punctuated sentence from an unpunctuated input. During training, commas are removed from the source text; the target is the original, correctly punctuated sentence. At inference, the model takes clean, comma-stripped text and outputs the version with commas inserted end-to-end—no token-level labels needed.

Models to try:
- https://huggingface.co/google/mt5-small
- https://huggingface.co/google/byt5-small

# Prepare environment

In [1]:
!python -V
!pip -V
!pip install numpy transformers[torch] scikit-learn datasets torch tiktoken blobfile protobuf sentencepiece wandb

Python 3.12.11
pip 24.1.2 from /usr/local/lib/python3.12/dist-packages/pip (python 3.12)


In [2]:
import torch
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0), torch.cuda.get_device_capability(0), 'bf16', torch.cuda.is_bf16_supported(False))
    free_mem, total_mem = torch.cuda.mem_get_info(torch.device('cuda:0'))
    print(f'Memory: {free_mem / 1024 ** 2:.2f} MB free / {total_mem / 1024 ** 2:.2f} MB total')
else:
    print('No CUDA device available')

CUDA device: NVIDIA L4 (8, 9) bf16 True
Memory: 22503.38 MB free / 22692.88 MB total


In [3]:
import difflib
import re

import numpy as np
import requests
import torch
from datasets import load_dataset
from contextlib import nullcontext
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    set_seed, PreTrainedTokenizerBase, Seq2SeqTrainer,
)
import wandb

In [4]:
# Authenticate with Weights & Biases to enable logging and experiment tracking
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33martursz[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Prepare dataset
Raw sentences from the Latvian Universal Dependencies (LVTB) corpus: https://universaldependencies.org/treebanks/lv_lvtb/index.html

In [5]:
def fetch_ud_texts(split, seed=42):
    conllu = requests.get(f'https://raw.githubusercontent.com/UniversalDependencies/UD_Latvian-LVTB/r2.16/lv_lvtb-ud-{split}.conllu').text
    texts = [line[9:].strip() for line in conllu.splitlines() if line.startswith('# text = ')]
    if seed:
        import random
        random.Random(seed).shuffle(texts)
    return texts

def prepare_data(max_chars=200, dev_txt='dev.txt', train_txt='train.txt'):
    # Download UD Latvian splits, filter by mBERT token count, and save plain .txt files.
    dev_texts = fetch_ud_texts('dev')
    train_texts = fetch_ud_texts('train')

    if max_chars:
        # Filter out long sentences to avoid truncation
        print('Sentence lengths before filtering:', 'DEV', len(dev_texts), 'TRAIN', len(train_texts))
        dev_texts = [t for t in dev_texts if len(t) <= max_chars]
        train_texts = [t for t in train_texts if len(t) <= max_chars]
    print('Dataset sentence lengths:', 'DEV', len(dev_texts), 'TRAIN', len(train_texts))

    with open(dev_txt, 'w') as f:
        for t in dev_texts: f.write(t + '\n')
    with open(train_txt, 'w') as f:
        for t in train_texts: f.write(t + '\n')

    return dev_texts, train_texts

dev_texts, train_texts = prepare_data()
print(*train_texts[:5], sep='\n')

Sentence lengths before filtering: DEV 2080 TRAIN 15055
Dataset sentence lengths: DEV 1912 TRAIN 13811
To tu man stāstīji jau pirms divām nedēļām.
Ka pieticis tikai autobusa biļetei un barankām.
Uz skatuves kāpa skolas koris, pēc tam uzstājās arī dramatiskā pulciņa dalībnieki un divi bērnudārza audzēkņi.
Burka esot jāizdekorē ar dillēm, mārrutku lapu, upeņu zariņu un ķiploka pusdaiviņām, jāsaliek gurķīši un jāaplej ar verdošu ūdeni, kurā iebērta ēdamkarote cukura un ēdamkarote sāls.
Izaugs sava raža, nevajadzēs lieku reizi braukt uz tirgu.


# Utilities

In [6]:
def remove_commas(s) -> str:
    return re.sub(r'\s*,+\s*', ' ', s)

# Tokenization

In [7]:
def test_tokenization(model=None):
    s = 'Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.'
    if model:
        print('Tokenizer', model)
        t = AutoTokenizer.from_pretrained(model)
        print('Encoded sample:', t(s))
        print('Encoded sample - subword units:', t.convert_ids_to_tokens(t.encode(s)))
        lengths = sorted([len(t.encode(seq)) for seq in train_texts])
        print(f'Max {max(lengths)}, min {min(lengths)}, avg {sum(lengths)/len(lengths)}')
        print(f'95% length: {lengths[int(len(lengths) * 0.95)]}')
        print(f'99% length: {lengths[int(len(lengths) * 0.99)]}')
        print(f'99.9% length: {lengths[int(len(lengths) * 0.999)]}')

test_tokenization('google/mt5-small')
test_tokenization('google/byt5-small')

Tokenizer google/mt5-small


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Encoded sample: {'input_ids': [434, 35079, 259, 11859, 29953, 44897, 261, 427, 3546, 1184, 50856, 42786, 2849, 46813, 1460, 261, 2045, 381, 106373, 344, 135878, 5861, 58648, 11537, 260, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Encoded sample - subword units: ['▁V', 'ēl', '▁', '9%', '▁sac', 'īja', ',', '▁ka', '▁nav', '▁iz', 'lēm', 'uši', '▁kā', '▁bals', 'ot', ',', '▁bet', '▁3', ',2%', '▁at', 'teic', 'ās', '▁atbild', 'ēt', '.', '</s>']
Max 80, min 2, avg 29.943740496705523
95% length: 58
99% length: 66
99.9% length: 73
Tokenizer google/byt5-small
Encoded sample: {'input_ids': [89, 199, 150, 111, 35, 60, 40, 35, 118, 100, 102, 199, 174, 109, 100, 47, 35, 110, 100, 35, 113, 100, 121, 35, 108, 125, 111, 199, 150, 112, 120, 200, 164, 108, 35, 110, 199, 132, 35, 101, 100, 111, 118, 114, 119, 47, 35, 101, 104, 119, 35, 54, 47, 53, 40, 35, 100, 119, 119, 104, 108, 102, 199, 132, 118, 35, 100, 119, 101, 108, 111, 103, 199, 150, 119, 49

Tokenize and format dataset for model training and evaluation

In [8]:
def build_dataset(*, tokenizer, train_file='train.txt', dev_file='dev.txt', train_samples=None, dev_samples=None, max_length=100):
    ds = load_dataset('text', data_files={'train': train_file, 'dev': dev_file})
    if train_samples:
        ds['train'] = ds['train'].take(train_samples)
    if dev_samples:
        ds['dev'] = ds['dev'].take(dev_samples)

    def _encode_examples(batch):
        targets = batch['text']
        sources = [remove_commas(t) for t in targets]
        enc_in = tokenizer(sources, max_length=max_length, truncation=True)
        enc_out = tokenizer(text_target=targets, max_length=max_length, truncation=True)
        enc_in['labels'] = enc_out['input_ids']
        return enc_in

    ds_encoded = ds.map(_encode_examples, batched=True, remove_columns=ds['train'].column_names)

    return ds_encoded

tok = AutoTokenizer.from_pretrained('google/mt5-small')
ds = build_dataset(tokenizer=tok, train_samples=10, dev_samples=10)
collator = DataCollatorForSeq2Seq(tok)
loader = DataLoader(ds['train'], batch_size=3, shuffle=False, collate_fn=collator)
batch = next(iter(loader))
print(batch)

Generating train split: 0 examples [00:00, ? examples/s]

Generating dev split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

{'input_ids': tensor([[   926,    719,    674,    259,  61180, 171881,   6168,   7602,    263,
           6562,   4769,    448,  16521,  28981,    282,    260,      1,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0],
        [  1362,   2533,  82736,    263,    259,  13212,  31541,    262,    837,
          10999, 124749,    335,    259,  34422, 159443,    260,      1,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0],
        [ 16389,    259, 119665,   4570,    259,  53065,    262,    259,  74755,
         116533,    263,    259,   7659,   2335,   1678, 142198,   5861,    259,
           4815, 102020,  51670,    259, 146473,  17825,    350,  64041,  49955,
            335,  94591,  2

# Metrics for token classification.
*F1-score* (the harmonic mean of precision and recall) specifically for the COMMA class gives a more honest view of model quality:
  - Precision: when the model predicts COMMA, is it right?
  - Recall: does the model catch most of the true commas?
  - F1: balances both, penalizing if one is much lower.

*Changes* - the percentage of sentences where the model introduced modifications that were **not desired**.  
This highlights over-correction: even if the model achieves good precision/recall on commas, a high *Changes* value means it is altering sentences unnecessarily, reducing usability in practice.

In [9]:
def align_and_count_commas(hyp: str, ref: str) -> tuple[int, int, int]:
    sm = difflib.SequenceMatcher(a=hyp, b=ref, autojunk=False)
    tp = fp = fn = 0
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == 'equal':
            tp += hyp[i1:i2].count(',')
        else:
            fp += hyp[i1:i2].count(',')
            fn += ref[j1:j2].count(',')
    return tp, fp, fn


def eval_commas(refs: list[str], preds: list[str], verbose=False) -> dict[str, float]:
    verbose_changes_limit = 5
    tp = fp = fn = changes = exact = 0
    for hyp, ref in zip(preds, refs):
        tpp, fpp, fnn = align_and_count_commas(hyp, ref)
        tp += tpp; fp += fpp; fn += fnn
        if hyp == ref:
            exact += 1
        is_changed = re.sub(r'[\s,]', '', hyp) != re.sub(r'[\s,]', '', ref)
        if is_changed:
            changes += 1

        if verbose and verbose_changes_limit > 0 and is_changed:
            print('--- Changed')
            print('REF:', ref)
            print('OUT:', hyp)
            verbose_changes_limit -= 1

    p = tp / (tp + fp) if (tp + fp) else 0.0
    r = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * p * r / (p + r) if (p + r) else 0.0
    return {
        'f1': f1, 'p': p, 'r': r,
        'changes': (changes / len(preds) if preds else 0.0),
        'exact':  (exact / len(preds) if preds else 0.0),
        'tp': tp, 'fp': fp, 'fn': fn,
    }


def compute_metrics(eval_preds, tokenizer, verbose=False):
    preds, labels = eval_preds
    pad_id = tokenizer.pad_token_id

    # Replace ignore index in preds
    labels = np.where(labels != -100, labels, pad_id)
    preds = np.where(preds != -100, preds, pad_id)

    decoded_preds  = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    return eval_commas(decoded_labels, decoded_preds, verbose=verbose)

print(eval_commas(
    ['Labi atpūšamies, draugi mīļie, un lai veiksmīga, sportiska, panākumiem bagāta mums visiem jaunā vasaras sezona!'],
    ['Labi atpūšamies, draugi, mīļie un lai veiksmīga, sportiska, panākumiem bagāta mums visiem jaunā vasaras sezona!'],
))
print(eval_commas(
    ['Labi atpūšamies, draugi mīļie, un lai veiksmīga, sportiska, panākumiem bagāta mums visiem jaunā vasaras sezona!'],
    ['Labi atpūšamies draugi mīļie un lai veiksmīgas, sportiska, panākumiem bagāta mums visiem jaunā vasaras sezona.'],
))

{'f1': 0.75, 'p': 0.75, 'r': 0.75, 'changes': 0.0, 'exact': 0.0, 'tp': 3, 'fp': 1, 'fn': 1}
{'f1': 0.6666666666666666, 'p': 1.0, 'r': 0.5, 'changes': 1.0, 'exact': 0.0, 'tp': 2, 'fp': 0, 'fn': 2}


# Inference
Given plain text, we strip commas, tokenize with word boundaries, run the model, and insert commas after tokens labeled COMMA.

In [10]:
def process_text(text, model, tokenizer: PreTrainedTokenizerBase, max_len=120, verbose=True):
    model.eval()
    device = next(model.parameters()).device
    source = remove_commas(text)
    inputs = tokenizer([source], return_tensors='pt', truncation=True, max_length=max_len).to(device)
    with torch.no_grad():
        gen = model.generate(
            **inputs,
            max_new_tokens=max_len,
        )
    result = tokenizer.decode(gen[0], skip_special_tokens=True)
    if verbose:
        print(f'REF: {text}')
        print(f' IN: {source}')
        print(f'OUT: {result}')
    return result

# Model fine-tuning
- Track loss curves, gradient norms, and evaluation metrics over time
- Use an appropriate optimizer and learning rate schedule (e.g., warmup + decay)
- Watch for overfitting (gap between train and eval performance)
- Adjust batch size, accumulation steps, or precision (fp16/bf16) if needed
- Save best checkpoints based on validation metric (e.g., F1)

In [11]:
def main(
    name='punctuator',
    base_model='google/mt5-small',
    max_len=80,
    seed=42,
    verbose=True,
    lr=1e-3,
    bs=8,
    train_samples=None,
    dev_samples=100,
    epochs=3,
    report_wandb=True,
    wandb_group=None
):
    with wandb.init(project='punctuator', group=wandb_group, name=name) if report_wandb else nullcontext():
        print('Train:', locals())
        set_seed(seed)
        tokenizer = AutoTokenizer.from_pretrained(base_model)

        # Load dataset
        ds = build_dataset(tokenizer=tokenizer, train_samples=train_samples, dev_samples=dev_samples, max_length=max_len)

        # Initialize base model for tokenize sequence to sequence task
        model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
        model.config.use_cache = False

        # Define training hyperparameters
        training_args = Seq2SeqTrainingArguments(
            output_dir=name,
            report_to='wandb' if report_wandb else 'none',
            learning_rate=lr,
            per_device_train_batch_size=bs,
            per_device_eval_batch_size=bs,
            num_train_epochs=epochs,
            warmup_ratio=0.05,
            gradient_accumulation_steps=1,
            gradient_checkpointing=True,
            bf16=True,
            logging_steps=50,
            save_total_limit=1,
            save_strategy='epoch',
            eval_strategy='epoch',
            eval_accumulation_steps=1,
            load_best_model_at_end=True,
            metric_for_best_model='f1',
            greater_is_better=True,
            predict_with_generate=True,
            generation_max_length=max_len * 2,
        )

        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=ds['train'],
            eval_dataset=ds['dev'],
            processing_class=tokenizer,
            data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
            compute_metrics=lambda p: compute_metrics(p, tokenizer, verbose=verbose),
        )

        # Actual training
        trainer.train()
        trainer.save_model(name)
        tokenizer.save_pretrained(name)

        process_text('Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.', trainer.model, tokenizer, max_len=max_len)

main('mt5_punctuator_sample', train_samples=1000)

Train: {'name': 'mt5_punctuator_sample', 'base_model': 'google/mt5-small', 'max_len': 80, 'seed': 42, 'verbose': True, 'lr': 0.001, 'bs': 8, 'train_samples': 1000, 'dev_samples': 100, 'epochs': 3, 'report_wandb': True, 'wandb_group': None}


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,F1,P,R,Changes,Exact,Tp,Fp,Fn
1,1.2608,0.257187,0.43,0.605634,0.333333,0.39,0.3,43,28,86
2,0.3605,0.164941,0.547486,0.98,0.379845,0.02,0.5,49,1,80
3,0.2668,0.149681,0.565217,0.945455,0.403101,0.03,0.53,52,3,77


--- Changed
REF: Polietilēna maisiņā tās ir sasmakušas un tāpat nav ēdamas.
OUT: Lietilēna maisiņā tās ir sasmakušas un tāpat nav ēdamas.
--- Changed
REF: 1997.g. Nīderlandē bezdarba līmenis bija zemāks nekā vidēji ES – nedaudz vairāk par 6% [Visser, Hemerijck, 9].
OUT: Neīderlandē bezdarba līmenis bija zemāks nekā vidēji ES – nedaudz vairāk par 6% [Visser Hemerijck 9].
--- Changed
REF: Tāpat Kuks vērš uzmanību uz rakstniekiem, kurus ir samaitājušas Apgaismības idejas, kuri cionismam piešķir Apgaismības neprāta daļu, izraujot Toru no ebreju apziņas.
OUT: Taspat Kuks vērš uzmanību uz rakstniekiem, kurus ir samaitājušas Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cionismam piešķir Apgaismības idejas, kuri cioni

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


REF: Vēl 9% sacīja, ka nav izlēmuši kā balsot, bet 3,2% atteicās atbildēt.
 IN: Vēl 9% sacīja ka nav izlēmuši kā balsot bet 3 2% atteicās atbildēt.
OUT: Vēl 9% sacīja, ka nav izlēmuši, kā balsot, bet 3 2% atteicās atbildēt.


0,1
eval/changes,█▁▁
eval/exact,▁▇█
eval/f1,▁▇█
eval/fn,█▃▁
eval/fp,█▁▂
eval/loss,█▂▁
eval/p,▁█▇
eval/r,▁▆█
eval/runtime,█▁▁
eval/samples_per_second,▁██

0,1
eval/changes,0.03
eval/exact,0.53
eval/f1,0.56522
eval/fn,77
eval/fp,3
eval/loss,0.14968
eval/p,0.94545
eval/r,0.4031
eval/runtime,16.6819
eval/samples_per_second,5.995


# Inference

In [12]:
m = AutoModelForSeq2SeqLM.from_pretrained('mt5_punctuator_sample')
t = AutoTokenizer.from_pretrained('mt5_punctuator_sample')
process_text('Nogalināt nedrīkst, apžēlot!', m, t)
process_text('Palielināt izdevumus nedrīkst taupīt!', m, t)

REF: Nogalināt nedrīkst, apžēlot!
 IN: Nogalināt nedrīkst apžēlot!
OUT: Nogalināt nedrīkst apžēlot!
REF: Palielināt izdevumus nedrīkst taupīt!
 IN: Palielināt izdevumus nedrīkst taupīt!
OUT: Palielināt izdevumus nedrīkst taupīt!


'Palielināt izdevumus nedrīkst taupīt!'