In [1]:
!pip install --upgrade transformers datasets fsspec



In [2]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding

In [4]:
batch_size = 16
lr = 5e-5
epochs = 1
temperature = 2.0
alpha_soft = 0.5
max_len = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
raw = load_dataset("tweet_eval", "sentiment")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
label_feature = raw["train"].features["label"]

In [7]:
print("Label Names: ", label_feature.names)

Label Names:  ['negative', 'neutral', 'positive']


In [8]:
train = raw["train"].shuffle(seed=42).select(range(2500))

In [9]:
val = raw["validation"]

# Loading the Tokenizer

In [10]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [11]:
def tokenize(example):
  return tokenizer(example["text"], truncation=True, max_length = max_len)

In [12]:
tokenized = {}

In [13]:
tokenized["train"] = train.map(tokenize, batched=True, remove_columns=["text"])
tokenized["validation"] = val.map(tokenize, batched=True, remove_columns=["text"])

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [14]:
tokenized

{'train': Dataset({
     features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 2500
 }),
 'validation': Dataset({
     features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 2000
 })}

In [15]:
# Defining a Collator

collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

In [16]:
train_dl = DataLoader(tokenized["train"], batch_size=batch_size, shuffle=True, collate_fn=collator)

In [17]:
val_dl = DataLoader(tokenized["validation"], batch_size=batch_size, shuffle=False, collate_fn=collator)

In [18]:
num_labels = 3

In [19]:
from transformers import AutoModelForSequenceClassification

# Teacher Student Model Definition

We are going to use the BERT Large model as the teacher model and the BERT Base model as the student model and we will try to distil the knowledge from the teacher model to the student model.

In [20]:
teacher = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=num_labels).to(device)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
student = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
# Freezing all the parameters of the teacher model, that means we are not going to train the teacher model

for p in teacher.parameters():
  p.requires_grad = False

We are not going to train the teacher model and use it for evaluation only.

In [23]:
teacher.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1

In [24]:
ce_loss = nn.CrossEntropyLoss()

In [25]:
kl_loss = nn.KLDivLoss(reduction="batchmean")

In [26]:
optimizer = optim.AdamW(student.parameters(), lr=lr)

In [27]:
from transformers import get_scheduler

In [28]:
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epochs * len(train_dl),
)

In [29]:
from tqdm.auto import tqdm

In [31]:
def distil_epoch():

  student.train()

  pbar = tqdm(train_dl, desc="train")

  for batch in pbar:
    input_ids = batch["input_ids"].to(device)
    attention = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with torch.no_grad():
      # Teacher Predictions
      t_logits = teacher(input_ids, attention_mask=attention).logits
      t_soft = torch.softmax(t_logits / temperature, dim = 1)

    # Student Predictions
    s_logits = student(input_ids, attention_mask = attention).logits
    s_soft = torch.log_softmax(s_logits / temperature, dim = 1)

    # Distillation + CE Loss
    loss_soft = kl_loss(s_soft, t_soft) * (temperature**2)
    loss_hard = ce_loss(s_logits, labels)
    loss = alpha_soft * loss_soft + (1 - alpha_soft) * loss_hard

    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()

    pbar.set_postfix({"loss": f"{loss.item():.4f}"})

In [32]:
def evaluate():
  student.eval()
  correct = 0
  total = 0

  with torch.no_grad():
    for batch in val_dl:
      ids = batch["input_ids"].to(device)
      attn = batch["attention_mask"].to(device)
      lbl = batch["labels"].to(device)
      out = student(ids, attention_mask=attn).logits
      preds = torch.argmax(out, dim=1)
      correct += (preds == lbl).sum().item()
      total += len(lbl)
  return round(correct / total * 100, 2)

In [33]:
for ep in range(1, epochs+1):
  distil_epoch()
  acc = evaluate()
  print(f"Epoch: {ep}, Accuracy: {acc}")

train:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 1, Accuracy: 65.2


In [34]:
student.save_pretrained("distilled_bert")
tokenizer.save_pretrained("distilled_bert")

('distilled_bert/tokenizer_config.json',
 'distilled_bert/special_tokens_map.json',
 'distilled_bert/vocab.txt',
 'distilled_bert/added_tokens.json',
 'distilled_bert/tokenizer.json')

Now we have created our own distilled version of BERT

In [35]:
# Loading the test dataset

test = load_dataset("tweet_eval", "sentiment", split="test[:500]")
tokenized_test = test.map(tokenize, batched=True, remove_columns=["text"])
tokenized_test.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test_dl = DataLoader(tokenized_test, batch_size=batch_size, shuffle=False, collate_fn=collator)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [36]:
from sklearn.metrics import accuracy_score
import time

def predict_and_evaluate(model, name, test_dl):
  model.eval()
  all_preds, all_labels = [], []
  start_time = time.time()

  with torch.no_grad():
    for batch in test_dl:
      ids = batch["input_ids"].to(device)
      attn = batch["attention_mask"].to(device)
      lbls = batch["labels"].to(device)

      logits = model(ids, attention_mask=attn).logits
      preds = torch.argmax(logits, dim=1)

      all_preds.extend(preds.cpu().tolist())
      all_labels.extend(lbls.cpu().tolist())

  total_time = time.time() - start_time
  acc = accuracy_score(all_labels, all_preds)
  avg_time = total_time / len(test_dl.dataset)

  print(f"\n Name: {name}, Accuracy: {acc}, Time: {total_time}, Avg Time: {avg_time}")
  return acc, total_time, avg_time

In [37]:
predict_and_evaluate(student, "Distilled BERT", test_dl)
predict_and_evaluate(teacher, "Teacher BERT", test_dl)


 Name: Distilled BERT, Accuracy: 0.666, Time: 0.4991264343261719, Avg Time: 0.0009982528686523437

 Name: Teacher BERT, Accuracy: 0.3, Time: 1.0917081832885742, Avg Time: 0.0021834163665771485


(0.3, 1.0917081832885742, 0.0021834163665771485)