In [1]:
from datasets import load_dataset
translation_dataset = load_dataset("wmt14", "de-en", split="train[:1%]")
sentiment_dataset = load_dataset("amazon_polarity", split="train[:1%]")  # Just an example in English, replace with multilingual.

In [2]:
#Phase 2: Experiments (Fine-tuning Models)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  # For translation tasks, e.g., MarianMT

def load_custom_tokenizer(tokenizer_type="bpe"):
    if tokenizer_type == "bpe":
        # Load a fast tokenizer from files
        # This is a simplified example, in practice you'd wrap it with PreTrainedTokenizerFast
        return AutoTokenizer.from_pretrained("tokenizers/bpe", use_fast=True)
    elif tokenizer_type == "sp":
        return AutoTokenizer.from_pretrained("tokenizers/sp_unigram_hf", use_fast=True)
    elif tokenizer_type == "wp":
        return AutoTokenizer.from_pretrained("tokenizers/wp", use_fast=True)
    else:
        raise ValueError("Unsupported tokenizer type")

# Example: Load baseline tokenizer (BPE)
baseline_tokenizer = load_custom_tokenizer("sp")

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [3]:
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [4]:
# For translation tasks, you might use a MarianMT model or mBART, for sentiment XLM-R or mBERT.
model_name = "Helsinki-NLP/opus-mt-en-de"  
baseline_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

In [5]:
def preprocess_function_translation(examples):
    # Extract lists of source and target texts from the list of dictionaries
    source_texts = [item["en"] for item in examples["translation"]]
    target_texts = [item["de"] for item in examples["translation"]]
    
    # Tokenize the source texts
    inputs = baseline_tokenizer(
        source_texts,
        truncation=True,
        padding="max_length",
        max_length=128
    )
    
    # Tokenize the target texts
    with baseline_tokenizer.as_target_tokenizer():
        labels = baseline_tokenizer(
            target_texts,
            truncation=True,
            padding="max_length",
            max_length=128
        )
    
    inputs["labels"] = labels["input_ids"]
    return inputs

# Apply the preprocessing function
tokenized_translation_dataset = translation_dataset.map(preprocess_function_translation, batched=True)


In [6]:

from transformers import (
    TrainingArguments,
    Trainer,
)

train_args = TrainingArguments(
    output_dir="checkpoints/trans_sp",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,  # For demonstration, use more epochs in practice
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available()
)

trainer = Trainer(
    model=baseline_model,
    args=train_args,
    train_dataset=tokenized_translation_dataset,
    eval_dataset=tokenized_translation_dataset,  # In practice, use a separate validation set
)

trainer.train()



  0%|          | 0/56360 [00:00<?, ?it/s]

{'loss': 11.3872, 'grad_norm': 88.74807739257812, 'learning_rate': 4.9991128459900645e-05, 'epoch': 0.0}
{'loss': 11.4468, 'grad_norm': 66.42715454101562, 'learning_rate': 4.998225691980128e-05, 'epoch': 0.0}
{'loss': 11.0497, 'grad_norm': 61.85126876831055, 'learning_rate': 4.997338537970192e-05, 'epoch': 0.0}
{'loss': 10.0992, 'grad_norm': 58.527488708496094, 'learning_rate': 4.996451383960256e-05, 'epoch': 0.0}
{'loss': 8.8722, 'grad_norm': 63.668678283691406, 'learning_rate': 4.9955642299503194e-05, 'epoch': 0.0}
{'loss': 7.481, 'grad_norm': 62.80469512939453, 'learning_rate': 4.994677075940383e-05, 'epoch': 0.01}
{'loss': 6.384, 'grad_norm': 71.98995971679688, 'learning_rate': 4.993789921930447e-05, 'epoch': 0.01}
{'loss': 5.5564, 'grad_norm': 66.47242736816406, 'learning_rate': 4.9929027679205115e-05, 'epoch': 0.01}
{'loss': 4.2688, 'grad_norm': 62.83067321777344, 'learning_rate': 4.992015613910575e-05, 'epoch': 0.01}
{'loss': 3.5037, 'grad_norm': 35.13351058959961, 'learning_rat

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 1.4283379316329956, 'eval_runtime': 910.8857, 'eval_samples_per_second': 49.499, 'eval_steps_per_second': 12.375, 'epoch': 1.0}




{'loss': 1.4796, 'grad_norm': 14.376633644104004, 'learning_rate': 3.999290276792052e-05, 'epoch': 1.0}
{'loss': 1.253, 'grad_norm': 6.245235443115234, 'learning_rate': 3.998403122782115e-05, 'epoch': 1.0}
{'loss': 1.5516, 'grad_norm': 14.0660400390625, 'learning_rate': 3.997515968772179e-05, 'epoch': 1.0}
{'loss': 1.1313, 'grad_norm': 48.66598892211914, 'learning_rate': 3.9966288147622424e-05, 'epoch': 1.0}
{'loss': 1.6994, 'grad_norm': 22.175275802612305, 'learning_rate': 3.995741660752307e-05, 'epoch': 1.0}
{'loss': 1.4569, 'grad_norm': 10.748394966125488, 'learning_rate': 3.994854506742371e-05, 'epoch': 1.01}
{'loss': 1.4002, 'grad_norm': 8.731328010559082, 'learning_rate': 3.9939673527324345e-05, 'epoch': 1.01}
{'loss': 1.5652, 'grad_norm': 8.98436164855957, 'learning_rate': 3.993080198722499e-05, 'epoch': 1.01}
{'loss': 1.6971, 'grad_norm': 28.387203216552734, 'learning_rate': 3.992193044712562e-05, 'epoch': 1.01}
{'loss': 1.7431, 'grad_norm': 5.814235687255859, 'learning_rate': 

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 1.3162552118301392, 'eval_runtime': 866.7514, 'eval_samples_per_second': 52.02, 'eval_steps_per_second': 13.005, 'epoch': 2.0}
{'loss': 1.3534, 'grad_norm': 12.202564239501953, 'learning_rate': 2.9994677075940386e-05, 'epoch': 2.0}
{'loss': 1.3188, 'grad_norm': 15.192639350891113, 'learning_rate': 2.9985805535841022e-05, 'epoch': 2.0}
{'loss': 1.321, 'grad_norm': 18.515859603881836, 'learning_rate': 2.9976933995741664e-05, 'epoch': 2.0}
{'loss': 1.2443, 'grad_norm': 8.934161186218262, 'learning_rate': 2.9968062455642297e-05, 'epoch': 2.0}
{'loss': 1.3019, 'grad_norm': 14.695798873901367, 'learning_rate': 2.995919091554294e-05, 'epoch': 2.0}
{'loss': 1.4256, 'grad_norm': 18.170082092285156, 'learning_rate': 2.9950319375443582e-05, 'epoch': 2.0}
{'loss': 1.4267, 'grad_norm': 13.486499786376953, 'learning_rate': 2.9941447835344217e-05, 'epoch': 2.01}
{'loss': 1.1849, 'grad_norm': 17.81015968322754, 'learning_rate': 2.9932576295244856e-05, 'epoch': 2.01}
{'loss': 1.0519, 'gra

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 1.2567590475082397, 'eval_runtime': 1003.3396, 'eval_samples_per_second': 44.938, 'eval_steps_per_second': 11.234, 'epoch': 3.0}
{'loss': 1.2127, 'grad_norm': 15.508963584899902, 'learning_rate': 1.999645138396026e-05, 'epoch': 3.0}
{'loss': 1.4161, 'grad_norm': 15.442106246948242, 'learning_rate': 1.9987579843860894e-05, 'epoch': 3.0}
{'loss': 1.3861, 'grad_norm': 27.369983673095703, 'learning_rate': 1.9978708303761533e-05, 'epoch': 3.0}
{'loss': 1.1379, 'grad_norm': 18.529056549072266, 'learning_rate': 1.9969836763662173e-05, 'epoch': 3.0}
{'loss': 1.3775, 'grad_norm': 9.777101516723633, 'learning_rate': 1.996096522356281e-05, 'epoch': 3.0}
{'loss': 1.3872, 'grad_norm': 10.937087059020996, 'learning_rate': 1.995209368346345e-05, 'epoch': 3.0}
{'loss': 1.2241, 'grad_norm': 27.27080535888672, 'learning_rate': 1.994322214336409e-05, 'epoch': 3.01}
{'loss': 1.3149, 'grad_norm': 7.479949474334717, 'learning_rate': 1.993435060326473e-05, 'epoch': 3.01}
{'loss': 1.2368, 'grad_

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 1.2212984561920166, 'eval_runtime': 1011.4973, 'eval_samples_per_second': 44.576, 'eval_steps_per_second': 11.144, 'epoch': 4.0}
{'loss': 1.2387, 'grad_norm': 64.02332305908203, 'learning_rate': 9.99822569198013e-06, 'epoch': 4.0}
{'loss': 1.3623, 'grad_norm': 19.761276245117188, 'learning_rate': 9.989354151880767e-06, 'epoch': 4.0}
{'loss': 1.4115, 'grad_norm': 16.86296272277832, 'learning_rate': 9.980482611781406e-06, 'epoch': 4.0}
{'loss': 1.11, 'grad_norm': 14.551931381225586, 'learning_rate': 9.971611071682045e-06, 'epoch': 4.0}
{'loss': 1.4296, 'grad_norm': 10.941765785217285, 'learning_rate': 9.962739531582684e-06, 'epoch': 4.0}
{'loss': 1.2244, 'grad_norm': 15.053211212158203, 'learning_rate': 9.953867991483321e-06, 'epoch': 4.0}
{'loss': 1.2466, 'grad_norm': 34.5660400390625, 'learning_rate': 9.94499645138396e-06, 'epoch': 4.01}
{'loss': 1.2072, 'grad_norm': 12.563286781311035, 'learning_rate': 9.9361249112846e-06, 'epoch': 4.01}
{'loss': 1.1591, 'grad_norm': 21.

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 1.212054967880249, 'eval_runtime': 999.1641, 'eval_samples_per_second': 45.126, 'eval_steps_per_second': 11.281, 'epoch': 5.0}


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].


{'train_runtime': 26334.5727, 'train_samples_per_second': 8.561, 'train_steps_per_second': 2.14, 'train_loss': 1.4012301090848285, 'epoch': 5.0}


TrainOutput(global_step=56360, training_loss=1.4012301090848285, metrics={'train_runtime': 26334.5727, 'train_samples_per_second': 8.561, 'train_steps_per_second': 2.14, 'total_flos': 7642047389368320.0, 'train_loss': 1.4012301090848285, 'epoch': 5.0})

In [7]:
import sacrebleu

seed = 123
sample = translation_dataset.shuffle(seed=seed).select(range(100))

def generate_predictions(model, tokenizer, dataset):
    model.eval()
    predictions = []
    references = []
    for example in dataset:
        input_ids = tokenizer(example["translation"]["en"], return_tensors="pt").input_ids.to(device)
        # Generate translations
        outputs = model.generate(input_ids)
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append(pred)
        references.append([example["translation"]["de"]])  # sacrebleu expects a list of references
    return predictions, references

preds, refs = generate_predictions(baseline_model, baseline_tokenizer, sample)
bleu_score = sacrebleu.corpus_bleu(preds, list(zip(*refs)))
print("BLEU Score:", bleu_score.score)

BLEU Score: 6.163352900935236


In [8]:
rare_word = "supercalifragilisticexpialidocious"
print("SP Tokens:", baseline_tokenizer.tokenize(rare_word))

SP Tokens: ['▁super', 'cal', 'i', 'f', 'r', 'ag', 'il', 'ist', 'ic', 'ex', 'p', 'ial', 'id', 'oc', 'ious']
