In [None]:
import os
import torch
import glob
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, EarlyStoppingCallback
from datasets import Dataset

# 1. File patterns
LV_PATTERN = "../txt/latvian_sentences_*.txt"
GLOSS_PATTERN = "../txt/lsl_glosses_*.txt"
PRUNED_MODEL_PATH = "../mt5-pruned"

# 2. Append all files to master list
lv_lines = []
gloss_lines = []
total_files = 0

for file_path in sorted(glob.glob(LV_PATTERN)):
    with open(file_path, "r", encoding="utf-8") as f:
        lv_lines.extend([line.strip() for line in f if line.strip()])
    total_files += 1

for file_path in sorted(glob.glob(GLOSS_PATTERN)):
    with open(file_path, "r", encoding="utf-8") as f:
        gloss_lines.extend([line.strip() for line in f if line.strip()])

# 3. Check validity
assert len(lv_lines) == len(gloss_lines), f"❌ Mismatch! Sentence lines: {len(lv_lines)}, gloss lines: {len(gloss_lines)}"
print(f"✅ Successfully loaded {total_files} batches with {len(lv_lines)} total pairs.")

# 4. Create a dataset
data = {"lv": lv_lines, "gloss": gloss_lines}
raw_dataset = Dataset.from_dict(data)

# Split into Train (90%) and Test (10%) so we can verify learning
split_dataset = raw_dataset.train_test_split(test_size=0.1)
print("Data split:", split_dataset)

# 5. Load model
tokenizer = T5Tokenizer.from_pretrained(PRUNED_MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(PRUNED_MODEL_PATH)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Pruned Model loaded on: {device}")

✅ Successfully loaded 3 batches with 630 total pairs.
Data split: DatasetDict({
    train: Dataset({
        features: ['lv', 'gloss'],
        num_rows: 567
    })
    test: Dataset({
        features: ['lv', 'gloss'],
        num_rows: 63
    })
})


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
The tokenizer you are loading from '../mt5-pruned' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


Pruned Model loaded on: cpu


In [22]:
import json
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

# --- 1. LOAD THE MAP (The "Rosetta Stone") ---
# This file tells us: "Old ID 15020 is now New ID 5"
with open(os.path.join(PRUNED_MODEL_PATH, "vocab_map.json"), "r") as f:
    # The JSON is saved as { "new_id": old_id }, so we reverse it.
    new2old_map = json.load(f)
    old2new_map = {v: int(k) for k, v in new2old_map.items()}

# Find the "New" ID for the Unknown token (UNK)
# We need this for words like "translate" if they weren't in your pruning text
original_unk_id = tokenizer.unk_token_id
new_unk_id = old2new_map.get(original_unk_id, 0) # Default to 0 if weirdness happens

print(f"Loaded vocab map. Remapping {len(old2new_map)} tokens.")

# --- 2. DEFINE THE REMAPPING FUNCTION ---
def remap_tokens(token_ids):
    # Convert list of Old IDs to New IDs
    # If a token wasn't in our pruning list, turn it into UNK (new_unk_id)
    return [old2new_map.get(tid, new_unk_id) for tid in token_ids]

# --- 3. UPDATED PREPROCESS FUNCTION ---
def preprocess_function(examples):
    # A. Tokenize Inputs (Standard way - produces HUGE IDs)
    # inputs = ["translate Latvian to Gloss: " + ex for ex in examples["lv"]]
    inputs = examples["lv"]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    # B. Tokenize Targets (Standard way - produces HUGE IDs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["gloss"], max_length=128, truncation=True)

    # C. MANUAL REMAPPING STEP (The Fix!)
    # We replace the huge IDs with the tiny mapped IDs
    model_inputs["input_ids"] = [remap_tokens(ids) for ids in model_inputs["input_ids"]]
    
    # We must also remap the labels!
    model_inputs["labels"] = [remap_tokens(ids) for ids in labels["input_ids"]]
    
    return model_inputs

# Apply the new function
tokenized_datasets = split_dataset.map(preprocess_function, batched=True)

# --- 4. TRAINING ARGUMENTS ---
args = Seq2SeqTrainingArguments(
    output_dir="../mt5-lsl-model",
    learning_rate=1e-3,        # Used to be 1e-3
    num_train_epochs=50,       # can be high, but early stopping will prevent overfitting
    eval_strategy="steps",
    eval_steps=5,             # evaluate frequently, since dataset is tiny
    save_strategy="steps",
    save_steps=1000000,
    per_device_train_batch_size=8,
    weight_decay=0.01,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # track validation loss
    greater_is_better=False,
    predict_with_generate=True,
    optim="adafactor",
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

Loaded vocab map. Remapping 1210 tokens.


Map: 100%|██████████| 567/567 [00:00<00:00, 23602.56 examples/s]
Map: 100%|██████████| 63/63 [00:00<00:00, 10492.42 examples/s]
  trainer = Seq2SeqTrainer(


In [23]:
sample = tokenized_datasets["train"][1]
print(sample["input_ids"][:20])
print(sample["labels"][:20])
print("Max input id:", max(sample["input_ids"]))
print("Max label id:", max(sample["labels"]))
print("Vocab size:", model.config.vocab_size)

token_set = set()
for line in lv_lines + gloss_lines:
    token_set.update(tokenizer(line)['input_ids'])
print("Dataset vocab size:", len(token_set))
print("Pruned vocab size:", len(old2new_map))

[510, 409, 495, 865, 260, 257, 1]
[575, 663, 495, 865, 260, 1]
Max input id: 865
Max label id: 865
Vocab size: 1210
Dataset vocab size: 955
Pruned vocab size: 1210


In [24]:
# --- 3. TRAIN ---
print("Starting training...")
trainer.train()

Starting training...




Step,Training Loss,Validation Loss
5,No log,10.431396
10,14.615900,7.898682
15,14.615900,6.731386
20,9.392400,6.419647
25,9.392400,6.690964
30,8.305200,4.966996
35,8.305200,4.557003
40,7.488600,6.201177
45,7.488600,4.72754
50,6.849400,4.498833


TrainOutput(global_step=95, training_loss=7.467303627415707, metrics={'train_runtime': 55.8699, 'train_samples_per_second': 507.428, 'train_steps_per_second': 63.54, 'total_flos': 2937432502272.0, 'train_loss': 7.467303627415707, 'epoch': 1.3380281690140845})

In [None]:
print("Saving model...")

save_dir = "E:/Documents/GitHub/LSL/mt5-lsl-model"

model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print("Done.")

Saving model...
Done.


## Testing

In [None]:
import torch
import json
import os
import glob

# --- 1. SETUP MAPS (Crucial!) ---
# We need to load the map to convert between "Big Tokenizer" and "Small Model"
vocab_map_path = os.path.join(PRUNED_MODEL_PATH, "vocab_map.json")

with open(vocab_map_path, "r") as f:
    new2old_map = json.load(f)
    # We need both directions!
    old2new_map = {int(v): int(k) for k, v in new2old_map.items()} # Big -> Small
    new2old_map = {int(k): int(v) for k, v in new2old_map.items()} # Small -> Big

# Identify the UNK token ID in the new mapping
# If a word (like "translate") isn't in our map, we point it to the pruned UNK ID.
# Usually, UNK is ID 2 in standard T5, let's find where ID 2 went.
original_unk_id = tokenizer.unk_token_id
pruned_unk_id = old2new_map.get(original_unk_id, 0) # Fallback to 0 if not found


# 1. File patterns
LV_PATTERN = "../txt/latvian_sentences_*.txt"
GLOSS_PATTERN = "../txt/lsl_glosses_*.txt"
PRUNED_MODEL_PATH = "../mt5-pruned"

# 2. Append all files to master list
lv_lines = []
gloss_lines = []
total_files = 0

for file_path in sorted(glob.glob(LV_PATTERN)):
    with open(file_path, "r", encoding="utf-8") as f:
        lv_lines.extend([line.strip() for line in f if line.strip()])
    total_files += 1

for file_path in sorted(glob.glob(GLOSS_PATTERN)):
    with open(file_path, "r", encoding="utf-8") as f:
        gloss_lines.extend([line.strip() for line in f if line.strip()])

# 3. Check validity
assert len(lv_lines) == len(gloss_lines), f"❌ Mismatch! Sentence lines: {len(lv_lines)}, gloss lines: {len(gloss_lines)}"
print(f"✅ Successfully loaded {total_files} batches with {len(lv_lines)} total pairs.")




# --- 2. CUSTOM TRANSLATION FUNCTION ---
def predict_gloss(text):
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    input_ids = inputs.input_ids[0].tolist()
    
    # Remap to pruned IDs
    pruned_input_ids = [old2new_map.get(tid, pruned_unk_id) for tid in input_ids]
    input_tensor = torch.tensor([pruned_input_ids]).to(model.device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_tensor,
            max_length=32,
            num_beams=3,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
    
    # Remap output IDs back to full tokenizer
    output_ids = outputs[0].tolist()
    original_output_ids = [new2old_map.get(tid, tokenizer.unk_token_id) for tid in output_ids]

    # Decode
    return tokenizer.decode(original_output_ids, skip_special_tokens=True)


# --- 3. RUN TESTS ---
print("\n--- RESULTS ---")

test_sentences = [
    lv_lines[10],
    lv_lines[20],
    lv_lines[30],
    lv_lines[40],
    lv_lines[50]
]

for text in test_sentences:
    gloss = predict_gloss(text)
    print(f"\nInput:  {text}")
    print(f"Result: {gloss}")

✅ Successfully loaded 3 batches with 630 total pairs.

--- RESULTS ---

Input:  Čau!
Result: vai Čas

Input:  Kāds ir jūsu vārds?
Result: vai jūs ir jūsu vārds

Input:  Kāds ir tavs vārds?
Result: mans ir tavs vārds

Input:  Priecājos ar Jums iepazīties!
Result: mans iepazīties

Input:  Man iet labi.
Result: man iet labi

Input:  Vai tu esi labs cilvēks?
Result: vai tu esi labs cilvēks

Input:  Man ir ļoti lielas mājas.
Result: mans ir liels māja

Input:  Sveika!
Result: vai Sveika

Input:  Labdien, [NAME]!
Result: [NAME]

Input:  [NAME] ir īss...
Result: [NAME] ir īss

Input:  Braucam pie zaļās gaismas.
Result: vai zaļās gaismas

Input:  Viņš nopirka divus kreklus
Result: vai nopirkt krekls krevs
