# mBERT Baseline for Kyrgyz Punctuation Restoration

This notebook trains `bert-base-multilingual-cased` on the Kyrgyz punctuation restoration task as a baseline comparison for the XLM-RoBERTa model.

**Before running:** Go to Runtime → Change runtime type → select **T4 GPU**

## 1. Install dependencies

In [None]:
!pip install -q transformers datasets seqeval accelerate scikit-learn

## 2. Upload dataset

Upload your `train_data.json` file when prompted.

In [None]:
from google.colab import files
import json
import os

# Upload train_data.json
if not os.path.exists('train_data.json'):
    uploaded = files.upload()
    print("File uploaded!")
else:
    print("File already exists.")

with open('train_data.json', 'r') as f:
    raw_data = json.load(f)

print(f"Total sentences: {len(raw_data)}")

## 3. Prepare dataset for token classification

In [None]:
import re
from collections import Counter

LABEL2ID = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
NUM_LABELS = len(LABEL2ID)

def extract_tokens_and_labels(text):
    """Extract (word, label) pairs from punctuated text."""
    raw_tokens = text.split()
    words = []
    labels = []
    for token in raw_tokens:
        word = token.rstrip('.,?!;:\'\"\u201c\u201d\u2013\u2014\u2026()[]{}\u00ab\u00bb')
        if not word:
            continue
        trailing = token[len(word):]
        label = 'O'
        for char in trailing:
            if char == '.':
                label = 'PERIOD'
                break
            elif char == ',':
                label = 'COMMA'
                break
            elif char == '?':
                label = 'QUESTION'
                break
        words.append(word)
        labels.append(label)
    return words, labels

# Process all sentences
all_words = []
all_labels = []
skipped = 0

for entry in raw_data:
    text = entry['text'][0]
    words, labels = extract_tokens_and_labels(text)
    if len(words) < 2:
        skipped += 1
        continue
    all_words.append(words)
    all_labels.append(labels)

print(f"Processed: {len(all_words)} sentences, skipped: {skipped}")

# Label distribution
flat_labels = [l for labels in all_labels for l in labels]
print(f"Label distribution: {Counter(flat_labels)}")

## 4. Train/Test split

In [None]:
from sklearn.model_selection import train_test_split

train_words, test_words, train_labels, test_labels = train_test_split(
    all_words, all_labels, test_size=0.15, random_state=42
)

train_words, val_words, train_labels, val_labels = train_test_split(
    train_words, train_labels, test_size=0.1, random_state=42
)

print(f"Train: {len(train_words)}, Val: {len(val_words)}, Test: {len(test_words)}")

## 5. Tokenize with mBERT tokenizer

In [None]:
from transformers import AutoTokenizer
import torch
from torch.utils.data import Dataset

MODEL_NAME = "bert-base-multilingual-cased"
MAX_LEN = 256

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class PunctDataset(Dataset):
    def __init__(self, words_list, labels_list, tokenizer, max_len):
        self.words_list = words_list
        self.labels_list = labels_list
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.words_list)

    def __getitem__(self, idx):
        words = self.words_list[idx]
        labels = self.labels_list[idx]

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )

        word_ids = encoding.word_ids(batch_index=0)
        label_ids = []
        previous_word_id = None

        for word_id in word_ids:
            if word_id is None:
                label_ids.append(-100)
            elif word_id != previous_word_id:
                label_ids.append(LABEL2ID[labels[word_id]])
            else:
                # For sub-tokens, only the last subtoken gets the label
                label_ids.append(-100)
            previous_word_id = word_id

        # Fix: assign label to LAST subtoken of each word instead of first
        label_ids_fixed = [-100] * len(word_ids)
        for i in range(len(word_ids) - 1, -1, -1):
            wid = word_ids[i]
            if wid is None:
                continue
            if i == len(word_ids) - 1 or word_ids[i + 1] != wid:
                # This is the last subtoken for this word
                label_ids_fixed[i] = LABEL2ID[labels[wid]]

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label_ids_fixed, dtype=torch.long)
        }

train_dataset = PunctDataset(train_words, train_labels, tokenizer, MAX_LEN)
val_dataset = PunctDataset(val_words, val_labels, tokenizer, MAX_LEN)
test_dataset = PunctDataset(test_words, test_labels, tokenizer, MAX_LEN)

print(f"Datasets created. Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# Quick sanity check
sample = train_dataset[0]
print(f"Input shape: {sample['input_ids'].shape}")
print(f"Labels shape: {sample['labels'].shape}")
print(f"Non-ignored labels: {(sample['labels'] != -100).sum().item()}")

## 6. Train mBERT

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from sklearn.metrics import classification_report
import numpy as np

model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    id2label=ID2LABEL,
    label2id=LABEL2ID,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    # Flatten and remove ignored tokens
    true_labels = []
    pred_labels = []
    for i in range(labels.shape[0]):
        for j in range(labels.shape[1]):
            if labels[i][j] != -100:
                true_labels.append(ID2LABEL[labels[i][j]])
                pred_labels.append(ID2LABEL[preds[i][j]])

    report = classification_report(
        true_labels, pred_labels,
        labels=['O', 'COMMA', 'PERIOD', 'QUESTION'],
        output_dict=True
    )

    return {
        'f1_weighted': report['weighted avg']['f1-score'],
        'precision_weighted': report['weighted avg']['precision'],
        'recall_weighted': report['weighted avg']['recall'],
    }

training_args = TrainingArguments(
    output_dir='./mbert_punct_kyrgyz',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1_weighted',
    greater_is_better=True,
    fp16=True,
    logging_steps=50,
    report_to='none',
    seed=42,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print("Starting training...")
trainer.train()

## 7. Evaluate on test set

In [None]:
from torch.utils.data import DataLoader

# Get predictions on test set
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

all_true = []
all_pred = []

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=-1).cpu()

        for i in range(labels.shape[0]):
            for j in range(labels.shape[1]):
                if labels[i][j] != -100:
                    all_true.append(ID2LABEL[labels[i][j].item()])
                    all_pred.append(ID2LABEL[preds[i][j].item()])

print(f"Total test tokens evaluated: {len(all_true)}")
print()

# Full classification report
print("=" * 60)
print("mBERT CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(
    all_true, all_pred,
    labels=['O', 'COMMA', 'PERIOD', 'QUESTION'],
    digits=3
))

## 8. Results summary

Copy the **weighted avg** Precision, Recall, F1 from the report above.

These numbers go into the paper's Table 3 (main results):

```
| Model              | Precision | Recall | F1    |
|--------------------|-----------|--------|-------|
| Rule-based         | 0.801     | 0.781  | 0.790 |
| mBERT              | ???       | ???    | ???   |  <-- YOUR NUMBERS HERE
| XLM-RoBERTa (ours) | 0.941     | 0.868  | 0.903 |
```

In [None]:
# Print in a copy-paste friendly format
report = classification_report(
    all_true, all_pred,
    labels=['O', 'COMMA', 'PERIOD', 'QUESTION'],
    digits=3,
    output_dict=True
)

w = report['weighted avg']
print("\n" + "=" * 60)
print("COPY THESE NUMBERS FOR THE PAPER:")
print("=" * 60)
print(f"mBERT  &  {w['precision']:.3f}  &  {w['recall']:.3f}  &  {w['f1-score']:.3f}")
print()
print("LaTeX table row:")
print(f"mBERT  & {w['precision']:.3f} & {w['recall']:.3f} & {w['f1-score']:.3f} \\\\")

## 9. Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

labels = ['O', 'COMMA', 'PERIOD', 'QUESTION']

# Normalized confusion matrix
cm_norm = confusion_matrix(all_true, all_pred, labels=labels, normalize='true')

fig, ax = plt.subplots(figsize=(6, 5))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_norm, display_labels=labels)
disp.plot(ax=ax, cmap='Blues', values_format='.3f')
ax.set_title('mBERT — Confusion Matrix (Normalized)', fontsize=13)
plt.tight_layout()
plt.savefig('confusion_matrix_mbert.png', dpi=300, bbox_inches='tight')
plt.show()
print("Saved: confusion_matrix_mbert.png")

# Raw counts
cm_raw = confusion_matrix(all_true, all_pred, labels=labels)
print("\nRaw counts:")
print(cm_raw)

In [None]:
files.download('confusion_matrix_mbert.png')

## 10. Save Model

In [None]:
import shutil

SAVE_DIR = './mbert_kyrgyz_punct_model'

# Save model and tokenizer
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

print(f"Model and tokenizer saved to: {SAVE_DIR}")
print(f"Files:")
for f in sorted(os.listdir(SAVE_DIR)):
    size = os.path.getsize(os.path.join(SAVE_DIR, f))
    print(f"  {f} ({size / 1e6:.1f} MB)" if size > 1e6 else f"  {f} ({size} B)")

# Zip for download
shutil.make_archive(SAVE_DIR, 'zip', SAVE_DIR)
print(f"\nArchive created: {SAVE_DIR}.zip")
files.download(f'{SAVE_DIR}.zip')