This is a rudimentary implementation of the process described in the paper that attempts to distill a pretrained bert into a smaller model of the same architecture. Creating an actually efficient distillation of a SOTA model would of course be a research project in and of itself, but this is an outline of the process.

### Shared tokenizer

In [1]:
#create shared tokenizer
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm


### Teacher model

In [2]:
#create pretrained teacher model
from transformers import BertModel
teacher = BertModel.from_pretrained('bert-base-uncased')



### Student Model

In [54]:
from transformers import BertConfig
from math import ceil
from torch import nn

#get default settings
config = BertConfig.from_pretrained('bert-base-uncased')

#shrink down the model
config.num_hidden_layers = ceil(config.num_hidden_layers*0.6)
config.hidden_size = ceil(config.hidden_size*0.6)
config.num_attention_heads = ceil(config.num_attention_heads*0.6)
config.intermediate_size = ceil(config.intermediate_size*0.6)

#round hidden size to multiple of num attention heads as required by model
config.hidden_size = (config.hidden_size//config.num_attention_heads)*config.num_attention_heads

#base model
student_bert = BertModel(config)

#projected model
class ProjectedSmallBert(nn.Module):
    def __init__(self, bert, output_dim):
        super().__init__()
        self.bert = bert
        self.proj = nn.Linear(config.hidden_size, output_dim)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        return self.proj(outputs.last_hidden_state)

#Linearly project from smaller hidden size to that of the teacher so that final embeddings match
#In practice, this may be done with more MLP layers, but this is just for demonstration
student = ProjectedSmallBert(student_bert, 768) 

### Data

In [7]:
from datasets import load_dataset

#use sst2 just for demo purposes
dataset = load_dataset('glue', 'sst2')

Downloading readme: 100%|██████████| 35.3k/35.3k [00:00<00:00, 82.2MB/s]
Downloading data: 100%|██████████| 3.11M/3.11M [00:00<00:00, 7.67MB/s]
Downloading data: 100%|██████████| 72.8k/72.8k [00:00<00:00, 412kB/s]
Downloading data: 100%|██████████| 148k/148k [00:00<00:00, 655kB/s]
Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 2436745.68 examples/s]
Generating validation split: 100%|██████████| 872/872 [00:00<00:00, 659710.15 examples/s]
Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 988843.55 examples/s]


In [44]:
#tokenize data
def tokenize(examples):
    tokenized = tokenizer(examples['sentence'], padding=True, truncation=True, max_length=128)
    examples.pop('sentence')
    return tokenized
dataset = dataset.map(tokenize, batched=True)


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map: 100%|██████████| 67349/67349 [00:05<00:00, 12625.18 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 7674.85 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 7985.17 examples/s]


In [47]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

collate = DataCollatorWithPadding(tokenizer=tokenizer)

#create dataloaders
train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True, collate_fn=collate)
eval_loader = DataLoader(dataset['validation'], batch_size=8, collate_fn=collate)
# train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
# eval_loader = DataLoader(dataset['validation'], batch_size=8)

### Train

In [59]:
from torch.optim import Adam
from torch.nn import KLDivLoss
from torch.nn.functional import log_softmax, softmax
import torch
from tqdm import tqdm

NUM_EPOCHS = 10
LR = 1e-4
TEMP = 3 # >1 temp as per paper

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device', device)

optimizer = Adam(student.parameters(), LR)
loss = KLDivLoss(reduction='batchmean') #KL Divergence loss for logits as per paper

teacher.eval() #will only be doing inference on teacher
student.train() 

teacher.to(device)
student.to(device)

for epoch in range(NUM_EPOCHS):
    for batch in tqdm(train_loader):
        inputs = batch['input_ids'].squeeze(1).to(device)
        attention_mask = batch['attention_mask'].squeeze(1).to(device)    #inputs were padded

        #pass input thru teacher first
        with torch.no_grad():
            teacher_outputs = teacher(input_ids=inputs, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.last_hidden_state

        soft_targets = softmax(teacher_logits / TEMP, dim=-1)

        #now student
        student_output = student(input_ids=inputs, attention_mask=attention_mask)

        #student has to be log softmaxed to work with KL Divergence
        student_logits = log_softmax(student_output, dim=-1)

        #calculate loss
        l = loss(student_logits, soft_targets) * TEMP**2

        #backprop
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    
    
    print(f'Epoch {epoch+1}/{NUM_EPOCHS} Loss: {l.item()}')


Device cuda


100%|██████████| 8419/8419 [02:07<00:00, 66.04it/s]


Epoch 1/10 Loss: 1.7077540159225464


100%|██████████| 8419/8419 [02:07<00:00, 65.92it/s]


Epoch 2/10 Loss: 1.9323749542236328


100%|██████████| 8419/8419 [02:07<00:00, 65.84it/s]


Epoch 3/10 Loss: 1.332496166229248


100%|██████████| 8419/8419 [02:07<00:00, 65.83it/s]


Epoch 4/10 Loss: 1.3278758525848389


100%|██████████| 8419/8419 [02:08<00:00, 65.76it/s]


Epoch 5/10 Loss: 1.1895263195037842


  6%|▌         | 494/8419 [00:07<02:00, 65.63it/s]


KeyboardInterrupt: 