In [19]:
!pip install -q transformers datasets evaluate sacrebleu torch sentencepiece rouge_score onnxruntime onnx

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m110.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Final model lies here: https://disk.yandex.com/d/kQm4p7UOrYK1qA

In [2]:
import torch
from torch.optim import AdamW
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import load_dataset
import evaluate
import time
import os
from tqdm import tqdm
import numpy as np

In [3]:
MODEL_TEACHER = "google-t5/t5-base"
MODEL_STUDENT = "google-t5/t5-small"
DATASET = "wmt16"
DATASET_CONFIG = "ro-en"
BATCH_SIZE = 4
EPOCHS = 3
LEARNING_RATE = 10e-5
MAX_LENGTH = 256

In [None]:
dataset = load_dataset(DATASET, DATASET_CONFIG)
tokenizer = T5Tokenizer.from_pretrained(MODEL_TEACHER)

In [None]:
split_dataset = dataset["train"].train_test_split(test_size=0.1)
train_dataset = split_dataset["train"].select(range(10000))
val_dataset = split_dataset["test"].select(range(300))

def preprocess_function(examples):
    inputs = ["translate Romanian to English: " + ex["ro"] for ex in examples["translation"]]
    targets = [ex["en"] for ex in examples["translation"]]
    model_inputs = tokenizer(
        inputs, max_length=MAX_LENGTH, truncation=True
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, max_length=MAX_LENGTH, truncation=True
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)

In [6]:
## # take only 10 000 samples
# tokenized_train = tokenized_train.select(range(10000))
# tokenized_val = tokenized_val.select(range(1000))

In [7]:
teacher = T5ForConditionalGeneration.from_pretrained(MODEL_TEACHER)
student = T5ForConditionalGeneration.from_pretrained(MODEL_STUDENT)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
optimizer = AdamW(student.parameters(), lr=LEARNING_RATE)
kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
ce_loss = torch.nn.CrossEntropyLoss()
student.gradient_checkpointing_enable()

In [8]:
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        if hasattr(self.teacher, "config"):
          # Configure teacher to use less memory
          self.teacher.config.use_cache = False

        self.kl_loss = torch.nn.KLDivLoss(reduction="batchmean")

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        # Save labels before popping
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        student_loss = outputs.loss
        student_logits = outputs.logits
        # Move the teacher model to the same device as the student model
        self.teacher.to(model.device)
        if not self.model.training:
          with torch.no_grad():
              # Process on the same device as the student model
              inputs_device = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
              # Get labels and move to the same device as the student model
              labels_device = labels.to(model.device) if isinstance(labels, torch.Tensor) else labels
              # No need to pop again, use saved labels
              teacher_outputs = self.teacher(**inputs_device, labels=labels_device)
              teacher_logits = teacher_outputs.logits
        else:
              # Normal processing during training
              # No need to pop again, use saved labels
              with torch.no_grad():
                  teacher_outputs = self.teacher(**inputs, labels=labels)
                  teacher_logits = teacher_outputs.logits

        # Calculate KL divergence
        kl_loss = self.kl_loss(
            torch.nn.functional.log_softmax(student_logits, dim=-1),
            torch.nn.functional.softmax(teacher_outputs.logits, dim=-1)
        )

        # Combine losses
        total_loss = 0.7 * kl_loss + 0.3 * student_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def evaluation_loop(self, *args, **kwargs):
      output = super().evaluation_loop(*args, **kwargs)
      if torch.cuda.is_available():
          torch.cuda.empty_cache()
      return output

In [9]:
training_args = TrainingArguments(
    output_dir="./content/distill_results",
    # no way to cram the eval into colab gpu
    eval_strategy="no",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=4,
    # max_grad_norm=1.0,  # Add gradient clipping
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    save_total_limit=3,
    logging_steps=250,
    report_to="none",
)

# Metrics for Validation
sacrebleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")

In [10]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    # Flatten predictions if they are nested lists
    preds = [item for sublist in preds for item in sublist]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute BLEU
    bleu_result = sacrebleu.compute(
        predictions=decoded_preds,
        references=[[label] for label in decoded_labels]
    )

    # Compute ROUGE
    rouge_result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels
    )

    return {
        "bleu": bleu_result["score"],
        "rouge": rouge_result["rougeL"],
    }

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher.to(device).eval()
student.to(device).train()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [12]:
trainer = DistillationTrainer(
    teacher_model=teacher,
    model=student,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train with progress bar
trainer.train()


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
250,29.8435
500,25.1205
750,24.1756
1000,23.4791
1250,22.4075
1500,22.5606
1750,21.4965
2000,21.1894
2250,20.384
2500,20.7939


TrainOutput(global_step=7500, training_loss=20.023592578125, metrics={'train_runtime': 1316.2839, 'train_samples_per_second': 22.791, 'train_steps_per_second': 5.698, 'total_flos': 598039470538752.0, 'train_loss': 20.023592578125, 'epoch': 3.0})

In [26]:
t_data = dataset["test"].select(range(20))
def evaluate_model(model, dataset):
    model.eval()
    latencies = []
    predictions = []
    references = []

    for example in tqdm(t_data, desc="Evaluating"):
        input_text = "translate Romanian to English: " + example["translation"]["ro"]
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

        start = time.time()
        outputs = model.generate(**inputs, max_length=MAX_LENGTH)
        latencies.append(time.time() - start)

        predictions.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
        references.append([example["translation"]["en"]])

    return {
        "bleu": sacrebleu.compute(predictions=predictions, references=references)["score"],
        "rouge": rouge.compute(predictions=predictions, references=references)["rougeL"],
        "avg_latency": sum(latencies)/len(latencies),
    }

# Save and evaluate
trainer.save_model("./content/distill_results")
student.to(device).eval()
final_metrics = evaluate_model(student, dataset)
teacher_metrics = evaluate_model(teacher, dataset)
print(f"\nFinal Evaluation Results:\nStudent:{final_metrics}\n\nTeacher:{teacher_metrics}")

Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.46it/s]
Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


Final Evaluation Results:
Student:{'bleu': 1.2400353466674146, 'rouge': np.float64(0.08968097626320509), 'avg_latency': 0.39914491176605227}

Teacher:{'bleu': 0.9846649936268649, 'rouge': np.float64(0.07925201387926117), 'avg_latency': 0.6880018234252929}





In [20]:
import onnxruntime as ort

# Load the saved student model
student = T5ForConditionalGeneration.from_pretrained("./content/distill_results")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")

# Example input
text = "translate Romanian to English: Bună ziua!"
inputs = tokenizer(text, return_tensors="pt")

# Convert to ONNX
dummy_input = {k: v.to('cpu') for k, v in inputs.items()}
dummy_decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long)
dummy_input["decoder_input_ids"] = dummy_decoder_input_ids
torch.onnx.export(
    student,
    tuple(dummy_input.values()),
    "./content/student_model.onnx",
    input_names=list(dummy_input.keys()),
    output_names=["output"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "output": {0: "batch_size", 1: "sequence_length"}
    }
)