# Distillation for Efficient BERT Models

In this tutorial, we will go over the high-level theory and implementation details of various distillation techniques for building efficient task-specific BERT models. 

This notebook was created by [Ganesh Jawahar](ganeshjwhr@gmail.com). Contact me for any questions or suggestions.

## Prerequisites
- [PyTorch](https://pytorch.org/)
- [Transformers](https://arxiv.org/abs/1706.03762)
- [BERT](https://arxiv.org/abs/1810.04805)

## Problem Setting (General)
The goal of distillation is to transfer knowledge from a large model (also known as teacher model) to a small model (also known as student model). Teacher model is typically large (in terms of model size) and hence needs large amount of memory to be loaded into device memory and slow to make predictions. On the other hand, student model is typically small and hence needs less memory to be loaded into device memory and relatively fast to make predictions. Usually, the student model is trained to reduce the distance between one or more of these:
1. Student and teacher prediction probabilities (**logits distillation**) 
2. Student and teacher hidden layer outputs (**hidden layer distillation**)
3. Student and teacher token embeddings (**embedding distillation**)
4. Student and teacher self-attention matrices (**self-attention distillation**)
5. Student and teacher value-value matrices (**value-relation distillation**)
6. Student prediction and gold label (**task-specific loss**) 

## Problem Setting (This Tutorial)

In this tutorial, we will focus on distilling a student model from the pre-trained BERT model (teacher model). Some of the existing research include:

1. **[DistilBERT](https://arxiv.org/abs/1910.01108)** - Triple loss: (1) Distillation over soft target probabilities of teacher, (2) Supervised training loss (MLM), and (3) Cosine embedding loss that aligns student and teacher last hidden states.
2. **[Patient Knowledge Distillation](https://arxiv.org/abs/1908.09355)** - Triple loss: (1) PKD-Last that has task-specific loss, (2) soft-target disitillation loss, and (3) hidden states MSE loss. PKD-Last distills from last k layers, while PKD-Skip distills from every k layers.
3. **[TinyBERT](https://arxiv.org/abs/1909.10351)** - Embedding distillation, hidden and attention distillation, logits distillation.
4. **[MiniLM](https://arxiv.org/abs/2002.10957)** - Self-attention and value-relation distillation.

![HAT block diagram](https://d3i71xaburhd42.cloudfront.net/459b34447952feebacc2f12778c539618e8a299f/5-Figure2-1.png)

(Picture courtesy: [A Survey on Model Compression for Natural Language Processing
](https://arxiv.org/abs/2202.07105))




## Solution (This Tutorial)

In this tutorial, we show the application of logits, hidden layer, and embedding to distill a 6-layer student model from 12-layer BERT-base model.

Let's start with the standard implementation of fine-tuning BERT model for a classification task.

In [62]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [63]:
import tensorflow
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from tqdm import tqdm, trange
import pandas as pd
import numpy as np
import io
import os
import matplotlib.pyplot as plt
from keras_preprocessing.sequence import pad_sequences
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, classification_report, confusion_matrix
import matplotlib
import matplotlib.pyplot as plt

## Set seed of randomization and working device
manual_seed = 77
torch.manual_seed(manual_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed(manual_seed)

cuda


Install `transformers` library

In [64]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Import necessary classes: `BertModel`, `BertForSequenceClassification`

In [65]:
from transformers import BertModel, BertTokenizerFast, BertForSequenceClassification

### Prepare data, load tokenizer

In this tutorial, the downstream task we will focus on is `sociality classification` (part of [CL-Aff shared task](https://sites.google.com/view/affcon2019/cl-aff-shared-task?authuser=0)). This task takes a piece of text and maps into one of the two classes: (1) the author of the text interacts with other people in the emotion situation and (2) the author of the text doesn't interact with other people. The train (`train.tsv`), validation (`dev.tsv`) and test size (`test.tsv`) consist of $8448$, $2112$ and $2112$ records respectively. 

In [66]:
# define a function for data preparation
def data_prepare(file_path, lab2ind, tokenizer, max_len = 32, mode = 'train'):
    '''
    file_path: the path to input file. 
                In train mode, the input must be a tsv file that includes two columns where the first is text, and second column is label.
                The first row must be header of columns.

                In predict mode, the input must be a tsv file that includes only one column where the first is text.
                The first row must be header of column.

    lab2ind: dictionary of label classes
    tokenizer: BERT tokenizer
    max_len: maximal length of input sequence
    mode: train or predict
    '''
    # if we are in train mode, we will load two columns (i.e., text and label).
    if mode == 'train':
        # Use pandas to load dataset
        df = pd.read_csv(file_path, delimiter='\t',header=0, names=['content','label'])
        print("Data size ", df.shape)
        labels = df.label.values
        
        # Create sentence and label lists
        labels = [lab2ind[i] for i in labels] 
        print("Label is ", labels[0])
        
        # Convert data into torch tensors
        labels = torch.tensor(labels)

    # if we are in predict mode, we will load one column (i.e., text).
    elif mode == 'predict':
        df = pd.read_csv(file_path, delimiter='\t',header=0, names=['content'])
        print("Data size ", df.shape)
        # create placeholder
        labels = []
    else:
        print("the type of mode should be either 'train' or 'predict'. ")
        return
        
    # Create sentence and label lists
    content = df.content.values

    #### REF START ####

    # We need to add a special token at the beginning for BERT to work properly.
    content = ["[CLS] " + text for text in content]

    # Import the BERT tokenizer, used to convert our text into tokens that correspond to BERT's vocabulary.
    tokenized_texts = [tokenizer.tokenize(text) for text in content]
    
    # if the sequence is longer the maximal length, we truncate it to the pre-defined maximal length
    tokenized_texts = [ text[:max_len+1] for text in tokenized_texts]

    # We also need to add a special token at the end.
    tokenized_texts = [ text+['[SEP]'] for text in tokenized_texts]
    print ("Tokenize the first sentence:\n",tokenized_texts[0])
    
    # Use the BERT tokenizer to convert the tokens to their index numbers in the BERT vocabulary
    input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
    print ("Index numbers of the first sentence:\n",input_ids[0])

    # Pad our input seqeunce to the fixed length (i.e., max_len) with index of [PAD] token
    pad_ind = tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
    input_ids = pad_sequences(input_ids, maxlen=max_len+2, dtype="long", truncating="post", padding="post", value=pad_ind)
    print ("Index numbers of the first sentence after padding:\n",input_ids[0])

    # Create attention masks
    attention_masks = []

    # Create a mask of 1s for each token followed by 0s for pad tokens
    for seq in input_ids:
        seq_mask = [float(i>0) for i in seq]
        attention_masks.append(seq_mask)

    # Convert all of our data into torch tensors, the required datatype for our model
    inputs = torch.tensor(input_ids)
    masks = torch.tensor(attention_masks)
    #### REF END ####

    return inputs, labels, masks


In [67]:
# set the model path
model_path = "bert-base-uncased"

# label dictionary
lab2ind = {'no': 0, 'yes': 1}

# tokenizer for pre-trained BERT model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)

# preprocess the data
train_inputs, train_labels, train_masks = data_prepare("./drive/My Drive/Colab Notebooks/happy_db/train.tsv", lab2ind,tokenizer)
validation_inputs, validation_labels, validation_masks = data_prepare("./drive/My Drive/Colab Notebooks/happy_db/dev.tsv", lab2ind,tokenizer)

# set batch size
batch_size = 32

# take training samples in random order in each epoch. 
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_dataloader = DataLoader(train_data, 
                              sampler = RandomSampler(train_data), # Select batches randomly
                              batch_size=batch_size)

# Read validation set sequentially.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_dataloader = DataLoader(validation_data, 
                                   sampler = SequentialSampler(validation_data), # Pull out batches sequentially.
                                   batch_size=batch_size)


Data size  (8448, 2)
Label is  1
Tokenize the first sentence:
 ['[CLS]', 'it', 'was', 'my', 'birthday', ',', 'and', 'my', 'wife', 'and', 'daughter', 'surprised', 'me', 'with', 'some', 'surprise', 'guests', 'and', 'a', 'small', 'party', '.', '[SEP]']
Index numbers of the first sentence:
 [101, 2009, 2001, 2026, 5798, 1010, 1998, 2026, 2564, 1998, 2684, 4527, 2033, 2007, 2070, 4474, 6368, 1998, 1037, 2235, 2283, 1012, 102]
Index numbers of the first sentence after padding:
 [ 101 2009 2001 2026 5798 1010 1998 2026 2564 1998 2684 4527 2033 2007
 2070 4474 6368 1998 1037 2235 2283 1012  102    0    0    0    0    0
    0    0    0    0    0    0]
Data size  (1056, 2)
Label is  1
Tokenize the first sentence:
 ['[CLS]', 'my', 'baby', 'took', 'a', '1', '.', '5', 'hour', 'nap', 'instead', 'of', 'a', '20', '##min', '##ute', 'nap', 'and', 'i', 'was', 'able', 'to', 'get', 'some', 'things', 'done', '!', '[SEP]']
Index numbers of the first sentence:
 [101, 2026, 3336, 2165, 1037, 1015, 1012, 1019, 

### Load BERT teacher model, criterion and finetuning settings

Let us load the BERT teacher model with the pre-trained weights.

In [68]:
class Bert_cls(nn.Module):

    def __init__(self, lab2ind, model_path, hidden_size):
        super(Bert_cls, self).__init__()
        self.model_path = model_path
        self.hidden_size = hidden_size
        self.bert_model = BertModel.from_pretrained(model_path, output_hidden_states=True, output_attentions=True)
        
        self.label_num = len(lab2ind)
        
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.hidden_size, self.label_num)

    def forward(self, bert_ids, bert_mask):
        outputs = self.bert_model(input_ids=bert_ids, attention_mask = bert_mask)
        pooler_output = outputs['pooler_output']
        attentions = outputs['attentions']
        
        x = self.dense(pooler_output)
        x = torch.tanh(x)
        x = self.dropout(x)
        fc_output = self.fc(x)

        return fc_output, outputs

bert_teacher_model = Bert_cls(lab2ind, model_path, 768).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Print structure of the teacher model (12 layers, 768 hidden, 3072 FFN. Inter.):

In [69]:
print(bert_teacher_model)

Bert_cls(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

Let us set the fine-tuning settings:

In [70]:
# fine-tuning hyper-parameters:
lr = 2e-5
max_grad_norm = 1.0
epochs = 3
warmup_proportion = 0.1
num_training_steps  = len(train_dataloader) * epochs
num_warmup_steps = num_training_steps * warmup_proportion

# create the optimizer
from transformers import AdamW, get_linear_schedule_with_warmup
optimizer = AdamW(bert_teacher_model.parameters(), lr=lr, correct_bias=False)

# set the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)  # PyTorch scheduler




Let us set the task specific criterion:

In [71]:
# create the criterion 
task_specific_criterion = nn.CrossEntropyLoss()

### Setup the model training and evaluation functions


Let us create the training logic for single epoch:


In [72]:
def train(model, iterator, optimizer, scheduler, criterion):
    
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        # Unpack the inputs from our dataloader
        input_ids, input_mask, labels = batch

        outputs,_ = model(input_ids, input_mask)

        loss = criterion(outputs, labels)
        # delete used variables to free GPU memory
        del batch, input_ids, input_mask, labels
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.cpu().item()
        optimizer.zero_grad()
    
    # free GPU memory
    if device == 'cuda':
        torch.cuda.empty_cache()

    return epoch_loss / len(iterator)

Let us create the evaluation logic for a single evaluation:

In [73]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    all_pred=[]
    all_label = []
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)
            # Unpack the inputs from our dataloader
            input_ids, input_mask, labels = batch

            outputs,_ = model(input_ids, input_mask)
            
            loss = criterion(outputs, labels)

            # delete used variables to free GPU memory
            del batch, input_ids, input_mask
            epoch_loss += loss.cpu().item()

            # identify the predicted class for each example in the batch
            probabilities, predicted = torch.max(outputs.cpu().data, 1)
            # put all the true labels and predictions to two lists
            all_pred.extend(predicted)
            all_label.extend(labels.cpu())
    
    accuracy = accuracy_score(all_label, all_pred)
    f1score = f1_score(all_label, all_pred, average='macro') 
    return epoch_loss / len(iterator), accuracy, f1score

### Fine-tune the BERT teacher model on the classification task

In [74]:
# Train the model
bert_teacher_checkpoint_dir = "./drive/My Drive/Colab Notebooks/teacher_ckpt"

for epoch in trange(epochs, desc="Epoch"):
    train_loss = train(bert_teacher_model, train_dataloader, optimizer, scheduler, task_specific_criterion)  
    val_loss, val_acc, val_f1 = evaluate(bert_teacher_model, validation_dataloader, task_specific_criterion)

    # Create checkpoint at end of each epoch
    state = {
      'epoch': epoch,
      'state_dict': bert_teacher_model.state_dict(),
      'optimizer': optimizer.state_dict(),
      'scheduler': scheduler.state_dict()
    }

    torch.save(state, bert_teacher_checkpoint_dir + "/BERT_"+str(epoch+1)+".pt")

    print('\n Epoch [{}/{}], Train Loss: {:.4f}, Validation Loss: {:.4f}, Validation Accuracy: {:.4f}, Validation F1: {:.4f}'.format(epoch+1, epochs, train_loss, val_loss, val_acc, val_f1))
    
print("Fine-tuning completed.")


Epoch:  33%|███▎      | 1/3 [01:01<02:03, 61.98s/it]


 Epoch [1/3], Train Loss: 0.2758, Validation Loss: 0.1973, Validation Accuracy: 0.9318, Validation F1: 0.9313


Epoch:  67%|██████▋   | 2/3 [02:03<01:01, 61.83s/it]


 Epoch [2/3], Train Loss: 0.1435, Validation Loss: 0.2111, Validation Accuracy: 0.9394, Validation F1: 0.9390


Epoch: 100%|██████████| 3/3 [03:05<00:00, 61.67s/it]


 Epoch [3/3], Train Loss: 0.0906, Validation Loss: 0.2283, Validation Accuracy: 0.9403, Validation F1: 0.9399
Fine-tuning completed.





### Student model initialization

Usually, the student BERT model is initialized with every k layers of teacher BERT model. In this tutorial, we will choose k as 3, that is, student BERT model is 4 layers (mapped to layer 2, 5, 8, 11 of teacher).

Let us try to create this student model from the teacher model.

In [75]:
def create_student_model(teacher_model, student_to_teacher_layer):
  # adapted from https://github.com/huggingface/transformers/issues/2483

  oldModuleList = teacher_model.bert_model.encoder.layer
  newModuleList = nn.ModuleList()

  # copy the layers to keep
  for student_layer_id in range(len(student_to_teacher_layer)):
    newModuleList.append(oldModuleList[student_to_teacher_layer[student_layer_id]])
  
  # create a copy of the model, modify it with the new list, and return
  import copy
  student_model = copy.deepcopy(teacher_model)
  student_model.bert_model.encoder.layer = newModuleList

  return student_model

# layer_mapping from student to teacher (and reverse)
student_to_teacher_layer = {0:2, 1:5, 2:8, 3:11}
teacher_to_student_layer = {2:0, 5:1, 8:2, 11:3}

# creates the student model
bert_student_model = create_student_model(bert_teacher_model, student_to_teacher_layer)

# print the structure of the student model
print(bert_student_model)


Bert_cls(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

### Logits distillation

The logits based distillation loss can be simply defined as:

$\mathcal{L}_{pred} = $Cross-Entropy$(\frac{\textbf{z}^T}{t}, \frac{\textbf{z}^S}{t})$ 

where $t$ corresponds to the temperature (usually set to 1). $\textbf{z}^T$ and $\textbf{z}^S$ correspond to teacher model and student model prediction logits respectively. 

Let us augment the fine-tuning loss for the student model with the logits-based distillation loss.




In [76]:
import torch.nn.functional as F

def train_with_logitsdistill(student_model, iterator, optimizer, scheduler, criterion, teacher_model):
    
    student_model.train()
    teacher_model.eval()
    epoch_loss = 0
    distillation_loss = 0
    
    for i, batch in enumerate(iterator):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        input_ids, input_mask, labels = batch

        student_logits, _ = student_model(input_ids, input_mask)

        # standard supervised loss
        loss = criterion(student_logits, labels)

        # distillation loss
        teacher_logits, _ = teacher_model(input_ids, input_mask)
        # compute KL-div of student logits and teacher logits
        # https://raw.githubusercontent.com/liuzechun/ReActNet/master/utils/KD_loss.py
        model_output_log_prob = F.log_softmax(student_logits, dim=1)
        model_output_log_prob = model_output_log_prob.unsqueeze(2)

        real_output_soft = F.softmax(teacher_logits, dim=1)
        real_output_soft = real_output_soft.unsqueeze(1)
        cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob)
        cross_entropy_loss = cross_entropy_loss.mean()
        distillation_loss += cross_entropy_loss.item()
        loss += cross_entropy_loss

        # delete used variables to free GPU memory
        del batch, input_ids, input_mask, labels
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.cpu().item()
        optimizer.zero_grad()
    
    # free GPU memory
    if device == 'cuda':
        torch.cuda.empty_cache()

    return epoch_loss / len(iterator), distillation_loss / len(iterator)

# Train the student model
student_checkpoint_dir = "./drive/My Drive/Colab Notebooks/student_ckpt"

optimizer = AdamW(bert_student_model.parameters(), lr=lr, correct_bias=False)

# set the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)  # PyTorch scheduler

for epoch in trange(epochs, desc="Epoch"):
    train_loss, distill_loss = train_with_logitsdistill(bert_student_model, train_dataloader, optimizer, scheduler, task_specific_criterion, bert_teacher_model)  
    val_loss, val_acc, val_f1 = evaluate(bert_student_model, validation_dataloader, task_specific_criterion)

    # Create checkpoint at end of each epoch
    state = {
      'epoch': epoch,
      'state_dict': bert_teacher_model.state_dict(),
      'optimizer': optimizer.state_dict(),
      'scheduler': scheduler.state_dict()
    }

    torch.save(state, bert_teacher_checkpoint_dir + "/BERT_"+str(epoch+1)+".pt")

    print('\n Epoch [{}/{}], Train Loss: {:.4f}, Train Distill Loss: {:.4f}, Validation Loss: {:.4f}, Validation Accuracy: {:.4f}, Validation F1: {:.4f}'.format(epoch+1, epochs, train_loss, distill_loss, val_loss, val_acc, val_f1))
    
print("Fine-tuning completed.")

Epoch:  33%|███▎      | 1/3 [01:07<02:14, 67.45s/it]


 Epoch [1/3], Train Loss: 0.4314, Train Distill Loss: 0.1869, Validation Loss: 0.2402, Validation Accuracy: 0.9252, Validation F1: 0.9249


Epoch:  67%|██████▋   | 2/3 [02:15<01:07, 67.76s/it]


 Epoch [2/3], Train Loss: 0.2125, Train Distill Loss: 0.0715, Validation Loss: 0.2381, Validation Accuracy: 0.9280, Validation F1: 0.9277


Epoch: 100%|██████████| 3/3 [03:23<00:00, 67.89s/it]


 Epoch [3/3], Train Loss: 0.1442, Train Distill Loss: 0.0485, Validation Loss: 0.2789, Validation Accuracy: 0.9299, Validation F1: 0.9296
Fine-tuning completed.





### Hidden and Embedding distillation

The hidden state distillation loss can be simply defined as:

$\mathcal{L}_{hidn} = $Mean-Squared-Error$(\textbf{H}^S, \textbf{H}^T)$ 

where $\textbf{H}^S$ corresponds to the student hidden states ($\#layers \times $hidden-dim) and $\textbf{H}^T$ corresponds to the teacher hidden states ($\#layers \times $hidden-dim).

The embedding distillation loss can be simply defined as:

$\mathcal{L}_{embd} = $Mean-Squared-Error$(\textbf{e}^S, \textbf{e}^T)$

where $\textbf{e}^S$ and $\textbf{e}^T$ correspond to student and teacher input embeddings respectively.

Let us augment the fine-tuning loss for the student model with the hidden state and embedding distillation loss.




In [77]:
import torch.nn.functional as F

def train_with_hidembddistill(student_model, iterator, optimizer, scheduler, criterion, teacher_model):
    
    student_model.train()
    teacher_model.eval()
    epoch_loss = 0
    distillation_loss = 0
    
    for i, batch in enumerate(iterator):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        input_ids, input_mask, labels = batch

        student_logits, student_outputs  = student_model(input_ids, input_mask)

        # standard supervised loss
        loss = criterion(student_logits, labels)

        # distillation loss
        teacher_logits, teacher_outputs = teacher_model(input_ids, input_mask)
        full_loss = 0.0
        # hidden states
        for student_layer_id in student_to_teacher_layer:
          teacher_layer_id = student_to_teacher_layer[student_layer_id]
          non_trainable_layernorm = nn.LayerNorm(teacher_outputs["hidden_states"][teacher_layer_id+1].shape[1:], elementwise_affine=False)
          teacher_hidden, student_hidden  = teacher_outputs["hidden_states"][teacher_layer_id+1], student_outputs["hidden_states"][student_layer_id+1]
          teacher_hidden = non_trainable_layernorm(teacher_hidden) 
          student_hidden = non_trainable_layernorm(student_hidden)
          cur_fkt = nn.MSELoss()(teacher_hidden, student_hidden)
          full_loss += cur_fkt
        # embedding 
        teacher_layer_id = student_to_teacher_layer[0]
        non_trainable_layernorm = nn.LayerNorm(teacher_outputs["hidden_states"][0].shape[1:], elementwise_affine=False)
        teacher_hidden, student_hidden  = teacher_outputs["hidden_states"][0], student_outputs["hidden_states"][0]
        teacher_hidden = non_trainable_layernorm(teacher_hidden) 
        student_hidden = non_trainable_layernorm(student_hidden)
        cur_fkt = nn.MSELoss()(teacher_hidden, student_hidden)
        full_loss += cur_fkt
        # full distillation loss
        full_loss = full_loss / (float(len(student_to_teacher_layer))+1)
        distillation_loss += full_loss.item()
        loss += full_loss

        # delete used variables to free GPU memory
        del batch, input_ids, input_mask, labels
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.cpu().item()
        optimizer.zero_grad()
    
    # free GPU memory
    if device == 'cuda':
        torch.cuda.empty_cache()

    return epoch_loss / len(iterator), distillation_loss / len(iterator)

# Train the student model
student_checkpoint_dir = "./drive/My Drive/Colab Notebooks/student_ckpt"

optimizer = AdamW(bert_student_model.parameters(), lr=lr, correct_bias=False)

# set the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)  # PyTorch scheduler

for epoch in trange(epochs, desc="Epoch"):
    train_loss, distill_loss = train_with_hidembddistill(bert_student_model, train_dataloader, optimizer, scheduler, task_specific_criterion, bert_teacher_model)  
    val_loss, val_acc, val_f1 = evaluate(bert_student_model, validation_dataloader, task_specific_criterion)
    
    # Create checkpoint at end of each epoch
    state = {
      'epoch': epoch,
      'state_dict': bert_teacher_model.state_dict(),
      'optimizer': optimizer.state_dict(),
      'scheduler': scheduler.state_dict()
    }

    torch.save(state, bert_teacher_checkpoint_dir + "/BERT_"+str(epoch+1)+".pt")

    print('\n Epoch [{}/{}], Train Loss: {:.4f}, Train Distill Loss: {:.4f}, Validation Loss: {:.4f}, Validation Accuracy: {:.4f}, Validation F1: {:.4f}'.format(epoch+1, epochs, train_loss, distill_loss, val_loss, val_acc, val_f1))
    
print("Fine-tuning completed.")

Epoch:  33%|███▎      | 1/3 [01:09<02:19, 69.71s/it]


 Epoch [1/3], Train Loss: 0.3386, Train Distill Loss: 0.2051, Validation Loss: 0.2722, Validation Accuracy: 0.9176, Validation F1: 0.9168


Epoch:  67%|██████▋   | 2/3 [02:19<01:09, 69.78s/it]


 Epoch [2/3], Train Loss: 0.1771, Train Distill Loss: 0.0948, Validation Loss: 0.3416, Validation Accuracy: 0.9176, Validation F1: 0.9169


Epoch: 100%|██████████| 3/3 [03:29<00:00, 69.83s/it]


 Epoch [3/3], Train Loss: 0.1277, Train Distill Loss: 0.0800, Validation Loss: 0.3486, Validation Accuracy: 0.9186, Validation F1: 0.9180
Fine-tuning completed.





That's it!