In [1]:
import os
import time
import json
from pathlib import Path

import torch
from datasets import load_dataset, load_metric
import sentencepiece as spm
from tokenizers import Tokenizer, trainers, models, pre_tokenizers, decoders, processors

# Additional imports for evaluation and logging
import sacrebleu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
translation_dataset = load_dataset("wmt14", "de-en", split="train[:1%]")
sentiment_dataset = load_dataset("amazon_polarity", split="train[:1%]")  # Just an example in English, replace with multilingual.

In [3]:
#Phase 2: Experiments (Fine-tuning Models)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  # For translation tasks, e.g., MarianMT

def load_custom_tokenizer(tokenizer_type="bpe"):
    if tokenizer_type == "bpe":
        # Load a fast tokenizer from files
        # This is a simplified example, in practice you'd wrap it with PreTrainedTokenizerFast
        return AutoTokenizer.from_pretrained("tokenizers/bpe", use_fast=True)
    elif tokenizer_type == "sp":
        return AutoTokenizer.from_pretrained("tokenizers/sp_unigram_hf", use_fast=True)
    elif tokenizer_type == "wp":
        return AutoTokenizer.from_pretrained("tokenizers/wp", use_fast=True)
    else:
        raise ValueError("Unsupported tokenizer type")

# Example: Load baseline tokenizer (BPE)
baseline_tokenizer = load_custom_tokenizer("bpe")

In [4]:
# For translation tasks, you might use a MarianMT model or mBART, for sentiment XLM-R or mBERT.
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model_name = "Helsinki-NLP/opus-mt-en-de"  
baseline_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

In [5]:
def preprocess_function_translation(examples):
    # Extract lists of source and target texts from the list of dictionaries
    source_texts = [item["en"] for item in examples["translation"]]
    target_texts = [item["de"] for item in examples["translation"]]
    
    # Tokenize the source texts
    inputs = baseline_tokenizer(
        source_texts,
        truncation=True,
        padding="max_length",
        max_length=128
    )
    
    # Tokenize the target texts
    with baseline_tokenizer.as_target_tokenizer():
        labels = baseline_tokenizer(
            target_texts,
            truncation=True,
            padding="max_length",
            max_length=128
        )
    
    inputs["labels"] = labels["input_ids"]
    return inputs

# Apply the preprocessing function
tokenized_translation_dataset = translation_dataset.map(preprocess_function_translation, batched=True)


In [6]:
from transformers import (
    TrainingArguments,
    Trainer,
)

train_args = TrainingArguments(
    output_dir="checkpoints/trans_bpe",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,  
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available()
)

trainer = Trainer(
    model=baseline_model,
    args=train_args,
    train_dataset=tokenized_translation_dataset,
    eval_dataset=tokenized_translation_dataset,  
)

trainer.train()



  0%|          | 0/56360 [00:00<?, ?it/s]

{'loss': 3.8996, 'grad_norm': 18.763490676879883, 'learning_rate': 4.9991128459900645e-05, 'epoch': 0.0}
{'loss': 2.5345, 'grad_norm': 3.4558730125427246, 'learning_rate': 4.998225691980128e-05, 'epoch': 0.0}
{'loss': 2.2961, 'grad_norm': 2.2678959369659424, 'learning_rate': 4.997338537970192e-05, 'epoch': 0.0}
{'loss': 1.9061, 'grad_norm': 2.102567195892334, 'learning_rate': 4.996451383960256e-05, 'epoch': 0.0}
{'loss': 2.0006, 'grad_norm': 1.8110806941986084, 'learning_rate': 4.9955642299503194e-05, 'epoch': 0.0}
{'loss': 1.6279, 'grad_norm': 7.101086616516113, 'learning_rate': 4.994677075940383e-05, 'epoch': 0.01}
{'loss': 1.7614, 'grad_norm': 1.3375777006149292, 'learning_rate': 4.993789921930447e-05, 'epoch': 0.01}
{'loss': 2.1774, 'grad_norm': 1.6551400423049927, 'learning_rate': 4.9929027679205115e-05, 'epoch': 0.01}
{'loss': 1.7076, 'grad_norm': 1.6849730014801025, 'learning_rate': 4.992015613910575e-05, 'epoch': 0.01}
{'loss': 1.8357, 'grad_norm': 2.656099319458008, 'learning_

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.9538769125938416, 'eval_runtime': 994.2464, 'eval_samples_per_second': 45.349, 'eval_steps_per_second': 11.337, 'epoch': 1.0}




{'loss': 0.9986, 'grad_norm': 2.33758282661438, 'learning_rate': 3.999290276792052e-05, 'epoch': 1.0}
{'loss': 0.8204, 'grad_norm': 2.3600807189941406, 'learning_rate': 3.998403122782115e-05, 'epoch': 1.0}
{'loss': 1.1192, 'grad_norm': 2.64839768409729, 'learning_rate': 3.997515968772179e-05, 'epoch': 1.0}
{'loss': 0.7536, 'grad_norm': 2.489570379257202, 'learning_rate': 3.9966288147622424e-05, 'epoch': 1.0}
{'loss': 1.1712, 'grad_norm': 3.623785972595215, 'learning_rate': 3.995741660752307e-05, 'epoch': 1.0}
{'loss': 0.9476, 'grad_norm': 2.2167155742645264, 'learning_rate': 3.994854506742371e-05, 'epoch': 1.01}
{'loss': 0.9648, 'grad_norm': 1.9794601202011108, 'learning_rate': 3.9939673527324345e-05, 'epoch': 1.01}
{'loss': 1.0353, 'grad_norm': 2.330047845840454, 'learning_rate': 3.993080198722499e-05, 'epoch': 1.01}
{'loss': 1.1469, 'grad_norm': 2.432835340499878, 'learning_rate': 3.992193044712562e-05, 'epoch': 1.01}
{'loss': 1.1823, 'grad_norm': 3.1246016025543213, 'learning_rate':

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.7680385708808899, 'eval_runtime': 1047.5652, 'eval_samples_per_second': 43.041, 'eval_steps_per_second': 10.76, 'epoch': 2.0}
{'loss': 0.8352, 'grad_norm': 2.101726531982422, 'learning_rate': 2.9994677075940386e-05, 'epoch': 2.0}
{'loss': 0.7687, 'grad_norm': 3.020157814025879, 'learning_rate': 2.9985805535841022e-05, 'epoch': 2.0}
{'loss': 0.8027, 'grad_norm': 2.3877310752868652, 'learning_rate': 2.9976933995741664e-05, 'epoch': 2.0}
{'loss': 0.7465, 'grad_norm': 2.4928441047668457, 'learning_rate': 2.9968062455642297e-05, 'epoch': 2.0}
{'loss': 0.7815, 'grad_norm': 2.3846893310546875, 'learning_rate': 2.995919091554294e-05, 'epoch': 2.0}
{'loss': 0.8124, 'grad_norm': 2.407050371170044, 'learning_rate': 2.9950319375443582e-05, 'epoch': 2.0}
{'loss': 0.855, 'grad_norm': 2.813831090927124, 'learning_rate': 2.9941447835344217e-05, 'epoch': 2.01}
{'loss': 0.7295, 'grad_norm': 2.8531298637390137, 'learning_rate': 2.9932576295244856e-05, 'epoch': 2.01}
{'loss': 0.6434, 'grad

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.6646057367324829, 'eval_runtime': 947.1873, 'eval_samples_per_second': 47.602, 'eval_steps_per_second': 11.9, 'epoch': 3.0}
{'loss': 0.7267, 'grad_norm': 1.8957370519638062, 'learning_rate': 1.999645138396026e-05, 'epoch': 3.0}
{'loss': 0.8089, 'grad_norm': 2.5564990043640137, 'learning_rate': 1.9987579843860894e-05, 'epoch': 3.0}
{'loss': 0.7231, 'grad_norm': 2.6713905334472656, 'learning_rate': 1.9978708303761533e-05, 'epoch': 3.0}
{'loss': 0.596, 'grad_norm': 2.0855352878570557, 'learning_rate': 1.9969836763662173e-05, 'epoch': 3.0}
{'loss': 0.7581, 'grad_norm': 2.252817392349243, 'learning_rate': 1.996096522356281e-05, 'epoch': 3.0}
{'loss': 0.7939, 'grad_norm': 2.1396076679229736, 'learning_rate': 1.995209368346345e-05, 'epoch': 3.0}
{'loss': 0.6916, 'grad_norm': 2.1294164657592773, 'learning_rate': 1.994322214336409e-05, 'epoch': 3.01}
{'loss': 0.7096, 'grad_norm': 2.1968603134155273, 'learning_rate': 1.993435060326473e-05, 'epoch': 3.01}
{'loss': 0.6978, 'grad_no

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.605286180973053, 'eval_runtime': 1005.2621, 'eval_samples_per_second': 44.852, 'eval_steps_per_second': 11.213, 'epoch': 4.0}
{'loss': 0.685, 'grad_norm': 2.622464895248413, 'learning_rate': 9.99822569198013e-06, 'epoch': 4.0}
{'loss': 0.6992, 'grad_norm': 2.7732419967651367, 'learning_rate': 9.989354151880767e-06, 'epoch': 4.0}
{'loss': 0.7505, 'grad_norm': 1.7019001245498657, 'learning_rate': 9.980482611781406e-06, 'epoch': 4.0}
{'loss': 0.575, 'grad_norm': 1.2946555614471436, 'learning_rate': 9.971611071682045e-06, 'epoch': 4.0}
{'loss': 0.7732, 'grad_norm': 1.9766230583190918, 'learning_rate': 9.962739531582684e-06, 'epoch': 4.0}
{'loss': 0.6163, 'grad_norm': 2.0382940769195557, 'learning_rate': 9.953867991483321e-06, 'epoch': 4.0}
{'loss': 0.6243, 'grad_norm': 2.8159990310668945, 'learning_rate': 9.94499645138396e-06, 'epoch': 4.01}
{'loss': 0.5705, 'grad_norm': 1.7958823442459106, 'learning_rate': 9.9361249112846e-06, 'epoch': 4.01}
{'loss': 0.5988, 'grad_norm': 3

  0%|          | 0/11272 [00:00<?, ?it/s]

{'eval_loss': 0.5807990431785583, 'eval_runtime': 1111.7491, 'eval_samples_per_second': 40.556, 'eval_steps_per_second': 10.139, 'epoch': 5.0}


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].


{'train_runtime': 26614.9996, 'train_samples_per_second': 8.47, 'train_steps_per_second': 2.118, 'train_loss': 0.87365803493516, 'epoch': 5.0}


TrainOutput(global_step=56360, training_loss=0.87365803493516, metrics={'train_runtime': 26614.9996, 'train_samples_per_second': 8.47, 'train_steps_per_second': 2.118, 'total_flos': 7642047389368320.0, 'train_loss': 0.87365803493516, 'epoch': 5.0})

In [7]:
seed = 123
sample = translation_dataset.shuffle(seed=seed).select(range(100))

def generate_predictions(model, tokenizer, dataset):
    model.eval()
    predictions = []
    references = []
    for example in dataset:
        input_ids = tokenizer(example["translation"]["en"], return_tensors="pt").input_ids.to(device)
        # Generate translations
        outputs = model.generate(input_ids)
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append(pred)
        references.append([example["translation"]["de"]])  # sacrebleu expects a list of references
    return predictions, references

preds, refs = generate_predictions(baseline_model, baseline_tokenizer, sample)
bleu_score = sacrebleu.corpus_bleu(preds, list(zip(*refs)))
print("BLEU Score:", bleu_score.score)

BLEU Score: 18.527955881408168


In [8]:
rare_word = "supercalifragilisticexpialidocious"
print("BPE Tokens:", baseline_tokenizer.tokenize(rare_word))

BPE Tokens: ['s', 'up', 'erc', 'al', 'if', 'rag', 'il', 'ist', 'ice', 'xp', 'ial', 'id', 'oc', 'ious']
