In [27]:
!pip install --quiet transformers datasets accelerate>=0.20.1

In [2]:
from transformers import TrainingArguments

In [42]:
class KnowledgeDistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args,  **kwargs):
    #*args allows us to pass a variable number of non-keyword arguments to a Python function.
    #**kwargs stands for keyword arguments. The only difference from args is that it uses keywords and returns the values in the form of a dictionary.
    super().__init__(*args, **kwargs)
    #The super() function is often used with the __init__() method to initialize the attributes of the parent class.
    self.alpha = 1
    self.temperature = 2

In [5]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

In [46]:
class KnowledgeDistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model
    self.alpha = 1
    self.temperature = 2


  def compute_loss(self, model, inputs, return_outputs=False):
    #Extract cross-entropy loss and logits from student
    outputs_student = model(**inputs)
    loss_ce = outputs_student.loss
    logits_student = outputs_student.logits

    # Extract logits from teacher
    outputs_teacher = self.teacher_model(**inputs)
    logits_teacher = outputs_teacher.logits

     #Computing distillation loss by Softening probabilities
    loss_fct = nn.KLDivLoss(reduction="batchmean")
    #The reduction=batchmean argument in nn.KLDivLoss() specifies that we average the losses over the batch dimension.
    loss_kd = self.temperature ** 2 * loss_fct(
                F.log_softmax(logits_student / self.temperature, dim=-1),
                F.softmax(logits_teacher / self.temperature, dim=-1))

    # Return weighted student loss
    loss = self.alpha * loss_ce + (1. - self.alpha) * loss_kd
    return (loss, outputs_student) if return_outputs else loss

In [8]:
from datasets import load_dataset

In [9]:
clinic = load_dataset('clinc_oos', 'plus')

Downloading builder script:   0%|          | 0.00/8.57k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/23.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/291k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15250 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5500 [00:00<?, ? examples/s]

In [11]:
sample = clinic['train'][4]
sample

{'text': 'if i were mongolian, how would i say that i am a tourist',
 'intent': 61}

In [12]:
intents = clinic['train'].features['intent']
intent = intents.int2str(sample['intent'])
intent

'translate'

## Tokenize the data

In [13]:
from transformers import AutoTokenizer

In [14]:
student_checkpoint = 'distilbert-base-uncased'
student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [15]:
def tokenize_text(batch):
  return student_tokenizer(batch['text'], truncation = True)

In [17]:
clinc_tokenized = clinic.map(tokenize_text, batched=True, remove_columns=["text"])

#We will remove text column as we don't need it
#We will also rename the intent column to labels so it can be automatically detected by the trainer.
clinc_tokenized = clinc_tokenized.rename_column("intent", "labels")

Map:   0%|          | 0/15250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3100 [00:00<?, ? examples/s]

Map:   0%|          | 0/5500 [00:00<?, ? examples/s]

In [18]:
import numpy as np
from datasets import load_metric


In [19]:

accuracy_score = load_metric("accuracy")

def compute_metrics(pred):
  predictions, labels = pred
  predictions = np.argmax(predictions, axis=1)
  return accuracy_score.compute(predictions=predictions, references=labels)

  accuracy_score = load_metric("accuracy")


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [26]:
!pip install transformers[torch]



In [20]:
batch_size = 48
finetuned_student_ckpt = "distilbert-base-uncased-finetuned-clinc-student"

In [48]:
student_training_args = KnowledgeDistillationTrainingArguments(
    output_dir=finetuned_student_ckpt, evaluation_strategy = "epoch",
    num_train_epochs=1, learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size, weight_decay=0.01)

FrozenInstanceError: ignored

In [32]:
from transformers import pipeline

In [35]:
bert_ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline('text-classification', model=bert_ckpt)
id2label = pipe.model.config.id2label
label2id = pipe.model.config.label2id

Downloading (…)lve/main/config.json:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/252 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Initialize the student Model

In [33]:
from transformers import AutoConfig

In [36]:
num_labels = intents.num_classes
student_config = (AutoConfig.from_pretrained(student_checkpoint, num_labels = num_labels, id2label = id2label, label2id=label2id))

In [37]:
import torch
from transformers import AutoModelForSequenceClassification


In [38]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def student_init():
  return(AutoModelForSequenceClassification.from_pretrained(student_checkpoint, config=student_config)).to(device)

In [39]:
teacher_checkpoint = "transformersbook/bert-base-uncased-finetuned-clinc"
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_checkpoint, num_labels = num_labels).to(device)

In [47]:
distilbert_trainer = KnowledgeDistillationTrainer(model_init=student_init,
        teacher_model=teacher_model, args=student_training_args,
        train_dataset=clinc_tokenized['train'], eval_dataset=clinc_tokenized['validation'],
        compute_metrics=compute_metrics, tokenizer=student_tokenizer)

distilbert_trainer.train()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.160308,0.572581


TrainOutput(global_step=318, training_loss=4.556366926469143, metrics={'train_runtime': 79.9938, 'train_samples_per_second': 190.64, 'train_steps_per_second': 3.975, 'total_flos': 82628707452228.0, 'train_loss': 4.556366926469143, 'epoch': 1.0})