# Sequence-level knowledge distillation
We perform sequence-level knowledge distillation on an en->de dataset with a model provided by Helsinki-NLP from https://huggingface.co/Helsinki-NLP.

## Import all the libraries and set up the device and the parameters

In [1]:
import torch
import evaluate
import pandas as pd
import os
from transformers import MarianMTModel, MarianTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from utilities import model_n_parameters, model_size

# Parameters
src_lang = "en"
tgt_lang = "de"
lang_pair = src_lang + "-" + tgt_lang
dataset = "yhavinga/ccmatrix"
dataset_size = 100
cache_dir = "D:/MasterDegreeThesis/datasets/ccmatrix"
batch_size = 16
evaluate_teacher = True

# Set-up device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Working on: {0}".format(device))

Working on: cpu


## Load the models and its tokenizer
The chosen model is based on the transformer base while the tokenizer is based on Sentencepiece.

In [2]:
opus_mt_model: MarianMTModel = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-{0}"
                                                             .format(lang_pair)).to(device)
opus_mt_tokenizer: MarianTokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-{0}"
                                                                     .format(lang_pair))
model_parameters, model_trainable_parameters = model_n_parameters(opus_mt_model)
opus_mt_size = model_size(opus_mt_model)
print("Model parameters: {0}\nModel trainable parameters: {1}".format(model_parameters, model_trainable_parameters))
print("Model size in mb: {0}\n".format(opus_mt_size))

Model parameters: 74410496
Model trainable parameters: 73886208
Model size in mb: 284.0751533508301



## Load the dataset and extract target sentences
In this example we will work with some sentence pairs from the ccmatrix dataset available at https://huggingface.co/datasets/yhavinga/ccmatrix.

In [3]:
# Load dataset
dataset_to_distill = load_dataset(dataset, lang_pair, cache_dir="{0}_{1}_{2}".format(cache_dir, src_lang, tgt_lang),
                                  split="train[:{0}]".format(dataset_size), ignore_verifications=True)

# Extract target sentences
tgt_sentences = [tgt_sentence[tgt_lang] for tgt_sentence in dataset_to_distill["translation"]]

Found cached dataset ccmatrix (D:/MasterDegreeThesis/datasets/ccmatrix_en_de/yhavinga___ccmatrix/en-de/1.0.0/5f733aeea277b2b1bb792442ba120c0f7f4b1c7288897051bdf1e9865fe77b93)


## Compute translations
Use the model to predict tokens by using the integrated generator on top of the model which uses beam search with $beam\_size=4$, the tokens will be detokenized in order to build the distilled sentences.

In [4]:
# Build the dataloader
dataloader_opus_mt = DataLoader(dataset_to_distill["translation"], batch_size=batch_size)
translations = []
for batch in tqdm(dataloader_opus_mt, "Tokens prediction"):
    # Retrieve sentence in the source language
    src_batch = batch[src_lang]

    # Tokenized the batch and generate predictions, this will be the most time-consuming part
    batch_tokens = opus_mt_tokenizer(src_batch, padding=True, return_tensors="pt").to(device)
    output = opus_mt_model.generate(**batch_tokens, max_new_tokens=opus_mt_model.config.max_length)

    # Detokenize the model's output to obtain translations
    translation = opus_mt_tokenizer.batch_decode(output, skip_special_tokens=True)
    translations.append(translation)

# Reconstruct the translations' list by dissolving the batches
translations = [translation for batch in translations for translation in batch]

print("En sentence: {0}".format(dataset_to_distill["translation"][12]["en"]))
print("De target sentence: {0}".format(tgt_sentences[12]))
print("De generated by the model: {0}".format(translations[12]))

Tokens prediction: 100%|██████████| 7/7 [00:33<00:00,  4.73s/it]

En sentence: It is “before the Lord,” and therefore it ought to be before us.
De target sentence: Es ist "vor dem Herrn", und deshalb sollte es vor uns sein.
De generated by the model: Es ist vor dem Herrn, und deshalb sollte es vor uns sein.





## Evaluate teacher performance
We can also evaluate the performance of the teacher model on the dataset that it has just distilled by computing BLEU and chrF scores.

In [5]:
# Evaluate teacher's translations and save scores inside a csv file
if evaluate_teacher:
    bleu_metric = evaluate.load("bleu")
    chrf_metric = evaluate.load("chrf")
    bleu_score = bleu_metric.compute(predictions=translations, references=tgt_sentences)["bleu"] * 100
    chrf_score = chrf_metric.compute(predictions=translations, references=tgt_sentences)["score"]
    df_scores = {
        "teacher_model": "Helsinki-NLP/opus-mt-{0}".format(lang_pair),
        "lang_pair": lang_pair,
        "dataset_size": dataset_size,
        "bleu": [bleu_score],
        "chrf": [chrf_score]
    }
    df_scores = pd.DataFrame(df_scores)
    if os.path.exists("../data/distillation_teacher_scores.csv"):
        df_teacher_scores: pd.DataFrame = pd.read_csv("../data/distillation_teacher_scores.csv", index_col=0)
        df_teacher_scores = pd.concat([df_teacher_scores, df_scores], ignore_index=True)
        df_teacher_scores = df_teacher_scores.drop_duplicates(["teacher_model", "lang_pair", "dataset_size"])
        df_teacher_scores.to_csv("../data/distillation_teacher_scores.csv")
    else:
        df_scores.to_csv("../data/distillation_teacher_scores.csv")

    print(df_scores)

                teacher_model lang_pair  dataset_size       bleu       chrf
0  Helsinki-NLP/opus-mt-en-de     en-de           100  51.141876  71.473842


## Save the translations
As a final step, perhaps the most important one, we need to save the translations in order to use them as the new ground truth fot the student model. We advise to publish the dataset on https://huggingface.co/ in order to let other users to work with it.

In [6]:
# Save translations
with open("../data/distilled_dataset_{0}_{1}.txt".format(src_lang, tgt_lang), "w", encoding="utf_8") as datafile:
    for translation in translations:
        datafile.write(translation + "\n")