In [3]:
!pip install transformers
from transformers import TrainingArguments

class DistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
    super().__init__(*args,**kwargs)
    self.alpha = alpha
    self.temperature = temperature

Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m57.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m120.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m104.2 MB/s[0m eta [36m0:00:00[0m
C

In [4]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

class DistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model

  def compute_loss(self,model,inputs,return_outputs=False):
    student_outputs = model(**inputs)
    cross_ent = student_outputs.loss
    student_logits = student_outputs.logits

    with torch.no_grad():
      teacher_outputs = self.teacher_model(**inputs)
      teacher_logits = teacher_outputs.logits

    # Probabilities softening and distillation loss
    function_loss = nn.KLDivLoss(reduction="batchmean")
    loss_kd = self.args.temperature ** 2 * function_loss(
        F.log_softmax(student_logits / self.args.temperature,dim = -1),
        F.softmax(teacher_logits / self.args.temperature,dim=-1)
    )
    loss = self.args.alpha * cross_ent + (1 - self.args.alpha) * loss_kd
    return (loss, student_outputs) if return_outputs else loss

