In [1]:
from datasets import load_dataset, load_metric
translation_dataset = load_dataset("wmt14", "de-en", split="train[:1%]")
sentiment_dataset = load_dataset("amazon_polarity", split="train[:1%]")  # Just an example in English, replace with multilingual.

In [2]:
#Phase 2: Experiments (Fine-tuning Models)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  # For translation tasks, e.g., MarianMT

def load_custom_tokenizer(tokenizer_type="bpe"):
    if tokenizer_type == "bpe":
        # Load a fast tokenizer from files
        # This is a simplified example, in practice you'd wrap it with PreTrainedTokenizerFast
        return AutoTokenizer.from_pretrained("tokenizers/bpe", use_fast=True)
    elif tokenizer_type == "sp":
        return AutoTokenizer.from_pretrained("tokenizers/sp_unigram_hf", use_fast=True)
    elif tokenizer_type == "wp":
        return AutoTokenizer.from_pretrained("tokenizers/wp", use_fast=True)
    else:
        raise ValueError("Unsupported tokenizer type")

# Example: Load baseline tokenizer (BPE)
baseline_tokenizer = load_custom_tokenizer("wp")

In [3]:
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [4]:
# For translation tasks, you might use a MarianMT model or mBART, for sentiment XLM-R or mBERT.
model_name = "Helsinki-NLP/opus-mt-en-de"  
baseline_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

In [5]:
def preprocess_function_translation(examples):
    # Extract lists of source and target texts from the list of dictionaries
    source_texts = [item["en"] for item in examples["translation"]]
    target_texts = [item["de"] for item in examples["translation"]]
    
    # Tokenize the source texts
    inputs = baseline_tokenizer(
        source_texts,
        truncation=True,
        padding="max_length",
        max_length=128
    )
    
    # Tokenize the target texts
    with baseline_tokenizer.as_target_tokenizer():
        labels = baseline_tokenizer(
            target_texts,
            truncation=True,
            padding="max_length",
            max_length=128
        )
    
    inputs["labels"] = labels["input_ids"]
    return inputs

# Apply the preprocessing function
tokenized_translation_dataset = translation_dataset.map(preprocess_function_translation, batched=True)


Map:   0%|          | 0/45088 [00:00<?, ? examples/s]



In [6]:
from transformers import (
    TrainingArguments,
    Trainer,
)

train_args = TrainingArguments(
    output_dir="checkpoints/trans_wp",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,  # For demonstration, use more epochs in practice
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available()
)

trainer = Trainer(
    model=baseline_model,
    args=train_args,
    train_dataset=tokenized_translation_dataset,
    eval_dataset=tokenized_translation_dataset,  
)

trainer.train()



  0%|          | 0/56360 [00:00<?, ?it/s]

{'loss': 4.6093, 'grad_norm': 31.651872634887695, 'learning_rate': 4.9991128459900645e-05, 'epoch': 0.0}
{'loss': 3.4113, 'grad_norm': 27.671483993530273, 'learning_rate': 4.998225691980128e-05, 'epoch': 0.0}
{'loss': 3.1317, 'grad_norm': 26.274314880371094, 'learning_rate': 4.997338537970192e-05, 'epoch': 0.0}
{'loss': 2.3484, 'grad_norm': 10.815181732177734, 'learning_rate': 4.996451383960256e-05, 'epoch': 0.0}
{'loss': 2.1898, 'grad_norm': 4.794785976409912, 'learning_rate': 4.9955642299503194e-05, 'epoch': 0.0}
{'loss': 1.6582, 'grad_norm': 2.499250888824463, 'learning_rate': 4.994677075940383e-05, 'epoch': 0.01}
{'loss': 1.7859, 'grad_norm': 1.653375267982483, 'learning_rate': 4.993789921930447e-05, 'epoch': 0.01}
{'loss': 2.1346, 'grad_norm': 1.731756567955017, 'learning_rate': 4.9929027679205115e-05, 'epoch': 0.01}
{'loss': 1.6399, 'grad_norm': 1.460418462753296, 'learning_rate': 4.992015613910575e-05, 'epoch': 0.01}
{'loss': 1.7745, 'grad_norm': 3.647819757461548, 'learning_rat

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.8776467442512512, 'eval_runtime': 1719.801, 'eval_samples_per_second': 26.217, 'eval_steps_per_second': 6.554, 'epoch': 1.0}




{'loss': 0.8734, 'grad_norm': 2.3798351287841797, 'learning_rate': 3.999290276792052e-05, 'epoch': 1.0}
{'loss': 0.7512, 'grad_norm': 2.2173285484313965, 'learning_rate': 3.998403122782115e-05, 'epoch': 1.0}
{'loss': 1.0216, 'grad_norm': 2.615579843521118, 'learning_rate': 3.997515968772179e-05, 'epoch': 1.0}
{'loss': 0.6932, 'grad_norm': 2.2587454319000244, 'learning_rate': 3.9966288147622424e-05, 'epoch': 1.0}
{'loss': 1.0695, 'grad_norm': 3.7716736793518066, 'learning_rate': 3.995741660752307e-05, 'epoch': 1.0}
{'loss': 0.844, 'grad_norm': 2.1957476139068604, 'learning_rate': 3.994854506742371e-05, 'epoch': 1.01}
{'loss': 0.8563, 'grad_norm': 1.9534205198287964, 'learning_rate': 3.9939673527324345e-05, 'epoch': 1.01}
{'loss': 0.9775, 'grad_norm': 1.9640684127807617, 'learning_rate': 3.993080198722499e-05, 'epoch': 1.01}
{'loss': 1.063, 'grad_norm': 2.30810546875, 'learning_rate': 3.992193044712562e-05, 'epoch': 1.01}
{'loss': 1.1008, 'grad_norm': 2.0015220642089844, 'learning_rate':

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.700666069984436, 'eval_runtime': 1758.899, 'eval_samples_per_second': 25.634, 'eval_steps_per_second': 6.409, 'epoch': 2.0}
{'loss': 0.7445, 'grad_norm': 1.836183786392212, 'learning_rate': 2.9994677075940386e-05, 'epoch': 2.0}
{'loss': 0.7054, 'grad_norm': 2.9236626625061035, 'learning_rate': 2.9985805535841022e-05, 'epoch': 2.0}
{'loss': 0.698, 'grad_norm': 2.2732112407684326, 'learning_rate': 2.9976933995741664e-05, 'epoch': 2.0}
{'loss': 0.6759, 'grad_norm': 2.7948811054229736, 'learning_rate': 2.9968062455642297e-05, 'epoch': 2.0}
{'loss': 0.739, 'grad_norm': 2.185814380645752, 'learning_rate': 2.995919091554294e-05, 'epoch': 2.0}
{'loss': 0.728, 'grad_norm': 2.294501304626465, 'learning_rate': 2.9950319375443582e-05, 'epoch': 2.0}
{'loss': 0.7744, 'grad_norm': 1.9246397018432617, 'learning_rate': 2.9941447835344217e-05, 'epoch': 2.01}
{'loss': 0.6527, 'grad_norm': 2.8892300128936768, 'learning_rate': 2.9932576295244856e-05, 'epoch': 2.01}
{'loss': 0.5852, 'grad_no

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.6039337515830994, 'eval_runtime': 1782.1093, 'eval_samples_per_second': 25.3, 'eval_steps_per_second': 6.325, 'epoch': 3.0}
{'loss': 0.6679, 'grad_norm': 1.892907738685608, 'learning_rate': 1.999645138396026e-05, 'epoch': 3.0}
{'loss': 0.7634, 'grad_norm': 2.524404525756836, 'learning_rate': 1.9987579843860894e-05, 'epoch': 3.0}
{'loss': 0.6834, 'grad_norm': 2.685338020324707, 'learning_rate': 1.9978708303761533e-05, 'epoch': 3.0}
{'loss': 0.5489, 'grad_norm': 2.332313060760498, 'learning_rate': 1.9969836763662173e-05, 'epoch': 3.0}
{'loss': 0.6624, 'grad_norm': 2.5267112255096436, 'learning_rate': 1.996096522356281e-05, 'epoch': 3.0}
{'loss': 0.7135, 'grad_norm': 2.3519816398620605, 'learning_rate': 1.995209368346345e-05, 'epoch': 3.0}
{'loss': 0.6308, 'grad_norm': 1.909571647644043, 'learning_rate': 1.994322214336409e-05, 'epoch': 3.01}
{'loss': 0.6472, 'grad_norm': 1.8806449174880981, 'learning_rate': 1.993435060326473e-05, 'epoch': 3.01}
{'loss': 0.6567, 'grad_norm'

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.5489139556884766, 'eval_runtime': 852.0534, 'eval_samples_per_second': 52.917, 'eval_steps_per_second': 13.229, 'epoch': 4.0}
{'loss': 0.5987, 'grad_norm': 2.40364933013916, 'learning_rate': 9.99822569198013e-06, 'epoch': 4.0}
{'loss': 0.6252, 'grad_norm': 2.481786012649536, 'learning_rate': 9.989354151880767e-06, 'epoch': 4.0}
{'loss': 0.7076, 'grad_norm': 1.8038291931152344, 'learning_rate': 9.980482611781406e-06, 'epoch': 4.0}
{'loss': 0.5152, 'grad_norm': 1.26546049118042, 'learning_rate': 9.971611071682045e-06, 'epoch': 4.0}
{'loss': 0.6876, 'grad_norm': 1.8192981481552124, 'learning_rate': 9.962739531582684e-06, 'epoch': 4.0}
{'loss': 0.5546, 'grad_norm': 1.798527717590332, 'learning_rate': 9.953867991483321e-06, 'epoch': 4.0}
{'loss': 0.5494, 'grad_norm': 2.6462104320526123, 'learning_rate': 9.94499645138396e-06, 'epoch': 4.01}
{'loss': 0.523, 'grad_norm': 1.818846583366394, 'learning_rate': 9.9361249112846e-06, 'epoch': 4.01}
{'loss': 0.5221, 'grad_norm': 2.5477

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.5269470810890198, 'eval_runtime': 855.0975, 'eval_samples_per_second': 52.728, 'eval_steps_per_second': 13.182, 'epoch': 5.0}


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].


{'train_runtime': 38193.8481, 'train_samples_per_second': 5.903, 'train_steps_per_second': 1.476, 'train_loss': 0.8031453931771584, 'epoch': 5.0}


TrainOutput(global_step=56360, training_loss=0.8031453931771584, metrics={'train_runtime': 38193.8481, 'train_samples_per_second': 5.903, 'train_steps_per_second': 1.476, 'total_flos': 7642047389368320.0, 'train_loss': 0.8031453931771584, 'epoch': 5.0})

In [7]:
import sacrebleu

seed = 123
sample = translation_dataset.shuffle(seed=seed).select(range(100))

def generate_predictions(model, tokenizer, dataset):
    model.eval()
    predictions = []
    references = []
    for example in dataset:
        input_ids = tokenizer(example["translation"]["en"], return_tensors="pt").input_ids.to(device)
        # Generate translations
        outputs = model.generate(input_ids)
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append(pred)
        references.append([example["translation"]["de"]])  # sacrebleu expects a list of references
    return predictions, references

preds, refs = generate_predictions(baseline_model, baseline_tokenizer, sample)
bleu_score = sacrebleu.corpus_bleu(preds, list(zip(*refs)))
print("BLEU Score:", bleu_score.score)

BLEU Score: 7.779238359913661


In [8]:
rare_word = "supercalifragilisticexpialidocious"
print("SP Tokens:", baseline_tokenizer.tokenize(rare_word))

SP Tokens: ['super', '##cal', '##if', '##rag', '##il', '##istic', '##ex', '##p', '##ial', '##id', '##oc', '##ious']
