# [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108)

In this lecture, we will explore the architecture of DistilBERT, its key components, and how it can be utilized for various natural language processing tasks. Additionally, we'll discuss its advantages, limitations, and provide hands-on examples to showcase its effectiveness.

Reference : [The Theory](https://towardsdatascience.com/distillation-of-bert-like-models-the-code-73c31e8c2b0a) | [Code](https://towardsdatascience.com/distillation-of-bert-like-models-the-theory-32e19a02641f)

In [1]:
# import os
# Set GPU device
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [12]:
# !pip install datasets
# !pip install evaluate



In [15]:
# !pip install datasets --upgrade
import datasets
from datasets import load_dataset
import transformers
import torch
datasets.__version__, transformers.__version__, torch.__version__

import torch.nn as nn
import torch
from tqdm.auto import tqdm
import random, math, time


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cuda


## 1. Loading our MNLI part of the GLUE dataset

In [17]:
# connect with google drive
from google.colab import drive
drive.mount('/content/drive')

# declare file path with os
import os
os.chdir('/content/drive/MyDrive/_NLP/A7/NLP-A7-Distillation_vs_LoRA')

Mounted at /content/drive


In [18]:
raw_datasets = load_dataset("Hate-speech-CNERG/hatexplain")
print(raw_datasets)

# testing another dataset
# raw_datasets = load_dataset("tweets-hate-speech-detection/tweets_hate_speech_detection")
# print(raw_datasets)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

hatexplain.py:   0%|          | 0.00/4.78k [00:00<?, ?B/s]

The repository for Hate-speech-CNERG/hatexplain contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Hate-speech-CNERG/hatexplain.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/592k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15383 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1922 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1924 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1924
    })
})


In [19]:
print(raw_datasets)

DatasetDict({
    train: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1924
    })
})


In [20]:
print(raw_datasets['train'][5])

{'id': '18790322_gab', 'annotators': {'label': [2, 1, 1], 'annotator_id': [203, 202, 204], 'target': [['Hispanic'], ['Hispanic', 'Refugee'], ['Hispanic', 'Refugee']]}, 'rationales': [], 'post_tokens': ['i', 'live', 'and', 'work', 'with', 'many', 'legal', 'mexican', 'immigrants', 'who', 'are', 'great', 'citizens', 'and', 'trump', 'supporters', 'they', 'have', 'no', 'problem', 'with', 'deporting', 'illegals', 'maga']}


In [21]:
from collections import Counter
print(Counter(raw_datasets['test']['annotators'][0]))  # Try for all entries


Counter({'label': [1, 1, 1], 'annotator_id': [9, 17, 64], 'target': [['None'], ['None'], ['None']]})


In [22]:
# HateXplain label definitions
label_list = ["hate_speech", "normal", "offensive"]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

print(id2label)


{0: 'hate_speech', 1: 'normal', 2: 'offensive'}


## 2. Model & Tokenization

In [23]:
import numpy as np
# Extract unique labels from "annotators" field
num_labels = np.unique([label for labels in raw_datasets["train"]["annotators"] for label in labels]).size

print("Number of unique labels:", num_labels)

Number of unique labels: 3


In [24]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

teacher_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_id)

teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
)

teacher_model

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

## 3. Preprocessing

In [25]:
def tokenize_function(examples):
    # Convert tokenized list into full sentences
    text_inputs = [" ".join(tokens) for tokens in examples["post_tokens"]]

    # Tokenize with dynamic padding
    return tokenizer(text_inputs, max_length=128, truncation=True, padding=True)


In [26]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets

Map:   0%|          | 0/15383 [00:00<?, ? examples/s]

Map:   0%|          | 0/1922 [00:00<?, ? examples/s]

Map:   0%|          | 0/1924 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1924
    })
})

In [27]:
#list data from tokenized
# tokenized_datasets['train'][0]['input_ids']

In [28]:
from collections import Counter

def convert_labels(example):
    batch_labels = []

    for annotators in example["annotators"]:  # Each annotator entry (dict)
        labels = annotators.get("label", [])

        # Keep only valid integer labels
        valid_labels = [int(l) for l in labels if isinstance(l, (int, float, str)) and str(l).isdigit()]

        # Majority vote or fallback
        if valid_labels:
            most_common = Counter(valid_labels).most_common(1)[0][0]
        else:
            most_common = 1  # Default to 'normal' if no valid label

        batch_labels.append(most_common)

    return {"labels": batch_labels}


In [29]:
tokenized_datasets_with_labels = tokenized_datasets.map(convert_labels, batched=True)

# Drop old 'annotators' field
tokenized_datasets_with_labels = tokenized_datasets_with_labels.remove_columns(["annotators"])

# Final dataset with labels
tokenized_datasets = tokenized_datasets_with_labels


Map:   0%|          | 0/15383 [00:00<?, ? examples/s]

Map:   0%|          | 0/1922 [00:00<?, ? examples/s]

Map:   0%|          | 0/1924 [00:00<?, ? examples/s]

In [30]:
import numpy as np
print(np.unique(tokenized_datasets['test']['labels'], return_counts=True))

(array([0, 1, 2]), array([594, 782, 548]))


In [31]:
import numpy as np
print(np.unique(tokenized_datasets['validation']['labels'], return_counts=True))

(array([0, 1, 2]), array([593, 781, 548]))


In [32]:
import numpy as np
print(np.unique(tokenized_datasets['train']['labels'], return_counts=True))

(array([0, 1, 2]), array([4748, 6251, 4384]))


In [33]:
print(tokenized_datasets)

DatasetDict({
    train: Dataset({
        features: ['id', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['id', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['id', 'rationales', 'post_tokens', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1924
    })
})


In [34]:
# tokenized_datasets['train'][0]

In [35]:
print(tokenizer.pad_token_id)  # Typically returns 0
print(tokenizer.pad_token)  # Typically returns [PAD]

0
[PAD]


In [36]:
non_padded_tokens = [token for token in tokenized_datasets['train'][0]['input_ids'] if token != tokenizer.pad_token_id]
print(non_padded_tokens)


[101, 1057, 2428, 2228, 1045, 2052, 2025, 2031, 2042, 15504, 2011, 18993, 7560, 2030, 5152, 2067, 1999, 2634, 2030, 7269, 1998, 1037, 9253, 6394, 2052, 9040, 2033, 2004, 2092, 2074, 2000, 2156, 2033, 5390, 102]


In [37]:
decoded_text = tokenizer.decode(non_padded_tokens)
print(decoded_text)


[CLS] u really think i would not have been raped by feral hindu or muslim back in india or bangladesh and a neo nazi would rape me as well just to see me cry [SEP]


In [38]:
tokenizer.decode(tokenized_datasets['train'][0]['input_ids'])

'[CLS] u really think i would not have been raped by feral hindu or muslim back in india or bangladesh and a neo nazi would rape me as well just to see me cry [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'





## 4. Preparing the dataloader

In [39]:
tokenized_datasets = tokenized_datasets.remove_columns(['id','rationales','post_tokens'])

In [40]:
tokenized_datasets.set_format("torch")

In [41]:
tokenized_datasets['train']

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 15383
})

In [42]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"], batch_size=8, shuffle=True, collate_fn=data_collator
)


In [43]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=1150).select(range(1000))
small_eval_dataset = tokenized_datasets["validation"].shuffle(seed=1150).select(range(100))
small_test_dataset = tokenized_datasets["test"].shuffle(seed=1150).select(range(100))

In [44]:
# prepare for Dataloader

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"], batch_size=8, shuffle=True, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, shuffle=False, collate_fn=data_collator
)
test_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=8, shuffle=False, collate_fn=data_collator
)


In [45]:
for batch in train_dataloader:
    break

print(batch["input_ids"].shape, batch["attention_mask"].shape, batch["labels"].shape)


torch.Size([8, 128]) torch.Size([8, 128]) torch.Size([8])


In [46]:
for batch in train_dataloader:
    print(batch.keys())  # Ensure it includes 'input_ids', 'attention_mask', 'labels'
    print(batch["input_ids"].shape)
    print(batch["attention_mask"].shape)
    print(batch["labels"].shape)
    break


dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
torch.Size([8, 110])
torch.Size([8, 110])
torch.Size([8])


In [47]:
# for batch in train_dataloader:
#     break

# print(batch["input_ids"].shape, batch["attention_mask"].shape, batch["labels"].shape)


## 5. Design the model and losses

### 5.1 Teacher Model & Student Model

####  Architecture
In the present work, the student - DistilBERT - has the same general architecture as BERT.
- The `token-type embeddings` and the `pooler` are removed while `the number of layers` is reduced by a factor of 2.
- Most of the operations used in the Transformer architecture `linear layer` and `layer normalisation` are highly optimized in modern linear algebra frameworks.
- our investigations showed that variations on the last dimension of the tensor (hidden size dimension) have a smaller impact on computation efficiency (for a fixed parameters budget) than variations on other factors like the number of layers.
- Thus we focus on reducing the number of layers.

#### Initialize Student Model
- To initialize a new model from an existing one, we need to access the weights of the old model (the teacher).
- In order to get the weights, we first have to know how to access them. We’ll use BERT as our teacher model.

In [48]:
teacher_model.config

BertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "hate_speech",
    "1": "normal",
    "2": "offensive"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "hate_speech": 0,
    "normal": 1,
    "offensive": 2
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.49.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

####
- The student model has the same configuration, except the number of layers is reduced by a factor of 2
- The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0.
- The head of the teacher is also copied.

In [49]:
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertConfig
# Get teacher configuration as a dictionnary
configuration = teacher_model.config.to_dict()
# configuration

In [50]:
# Half the number of hidden layer
configuration['num_hidden_layers'] //= 2
# Convert the dictionnary to the student configuration
configuration = BertConfig.from_dict(configuration)

In [51]:
# Create uninitialized student model
model = type(teacher_model)(configuration)
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

- Recursively copies the weights of the (teacher) to the (student).
- This function is meant to be first called on a BertFor... model, but is then called on every children of that model recursively.
- The only part that's not fully copied is the encoder, of which only half is copied.

In [52]:
from transformers.models.bert.modeling_bert import BertEncoder, BertModel
from torch.nn import Module

def distill_bert_weights(
    teacher: Module,
    student: Module,
    use_odd_layers: bool = True  # Flag to select odd or even layers
) -> None:
    """
    Copies weights from the teacher model to the student model.
    - If `use_odd_layers=True`, copies odd-numbered layers {1, 3, 5, 7, 9, 11}.
    - If `use_odd_layers=False`, copies even-numbered layers {2, 4, 6, 8, 10, 12}.
    """
    if isinstance(teacher, BertModel) or type(teacher).__name__.startswith('BertFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights(teacher_part, student_part, use_odd_layers)

    elif isinstance(teacher, BertEncoder):
        teacher_encoding_layers = [layer for layer in next(teacher.children())]  # 12 layers
        student_encoding_layers = [layer for layer in next(student.children())]  # 6 layers

        #  Select Odd or Even Layers
        selected_layers = range(1, 12, 2) if use_odd_layers else range(0, 12, 2)

        for i, layer_idx in enumerate(selected_layers):
            student_encoding_layers[i].load_state_dict(teacher_encoding_layers[layer_idx].state_dict())

    else:
        student.load_state_dict(teacher.state_dict())

    return student  # Fix: Return the student model after weight copying


In [53]:
# model = distill_bert_weights(teacher=teacher_model, student=model)

In [54]:
student_model_odd = distill_bert_weights(teacher=teacher_model, student=model, use_odd_layers=True)
student_model_even = distill_bert_weights(teacher=teacher_model, student=model, use_odd_layers=False)

In [55]:
def count_parameters(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(student_model_odd))

Teacher parameters : 109484547
Student parameters : 66957315


In [56]:
def count_parameters(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(student_model_even))

Teacher parameters : 109484547
Student parameters : 66957315


In [57]:
count_parameters(model)/count_parameters(teacher_model) * 100

61.15686353435797

In [58]:
# Move models to the available device
student_model_odd.to(device)
student_model_even.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [59]:
#It has 40% less parameters than bert-base-uncased

### 5.2 Loss function

#### Softmax

$$
P_i(\mathbf{z}_i, T) = \frac{\exp(\mathbf{z}_i / T)}{\sum_{q=0}^k \exp(\mathbf{z}_q / T)}
$$


#### Knowledge Distillation

#### CE Loss
$$\mathcal{L}_\text{CE} = -\sum^N_{j=0}\sum_{i=0}^k {y}_i^{(j)}\log(P_i({v}_i^{(j)}, 1))$$

#### KL Loss
$$\mathcal{L}_\text{KD} = -\sum^N_{j=0}\sum_{i=0}^k P_i({z}_i^{(j)}, T) \log (P_i({v}_i^{(j)}, T))$$

#### Cosine Embedding Loss
$$\mathcal{L}_{\text{cosine}}(x_1, x_2, y) = \frac{1}{N} \sum_{i=1}^{N} \left(1 - y_i \cdot \cos(\theta_i)\right)$$

<!-- $$\mathcal{L} = \lambda \mathcal{L}_\text{KD} + (1-\lambda)\mathcal{L}_\text{CE}$$
 -->

#### Total Loss
$$\mathcal{L} = \mathcal{L}_\text{KD} + \mathcal{L}_\text{CE} + \mathcal{L}_{\text{cosine}}$$

In [60]:
import torch.nn.functional as F

class DistillKL(nn.Module):
    """
    Distilling the Knowledge in a Neural Network
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities!
    """

    def __init__(self):
        super(DistillKL, self).__init__()

    def forward(self, output_student, output_teacher, temperature=1):
        '''
        Note: the output_student and output_teacher are logits
        '''
        T = temperature #.cuda()

        KD_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(output_student/T, dim=-1),
            F.softmax(output_teacher/T, dim=-1)
        ) * T * T

        return KD_loss

In [61]:
# Loss functions
criterion_cls = nn.CrossEntropyLoss()  # Classification Loss
criterion_div = DistillKL()  # KL Divergence Loss
criterion_cos = nn.CosineEmbeddingLoss()  # Cosine Similarity Loss

## 6. Optimizer

In [62]:
import torch.optim as optim
import torch.nn as nn

# Optimizers
lr = 5e-5
optimizer_odd = optim.Adam(params=student_model_odd.parameters(), lr=lr)
optimizer_even = optim.Adam(params=student_model_even.parameters(), lr=lr)

num_epochs = 5
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_epochs * num_update_steps_per_epoch

## 7. Learning rate scheduler

In [63]:
# Learning rate schedulers
from transformers import get_scheduler

lr_scheduler_odd = get_scheduler(
    name="linear", optimizer=optimizer_odd, num_warmup_steps=0, num_training_steps=num_training_steps
)
lr_scheduler_even = get_scheduler(
    name="linear", optimizer=optimizer_even, num_warmup_steps=0, num_training_steps=num_training_steps
)

## 8. Metric

In [64]:
# Metric for
import evaluate
import numpy as np
from tqdm.auto import tqdm
import torch
import evaluate

#metric = evaluate.load("accuracy")
# Load evaluation metrics once
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(preds, labels):
    return {
        "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
        "precision": precision_metric.compute(predictions=preds, references=labels, average="macro")["precision"],
        "recall": recall_metric.compute(predictions=preds, references=labels, average="macro")["recall"],
        "f1": f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"],
    }


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.56k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.38k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

## 9. Train

In [65]:
# traing odd and even student model

def train_student_model(student_model, optimizer, lr_scheduler, student_type="odd"):
    print(f"\nStarting training for {student_type} student model...\n")

    progress_bar = tqdm(range(num_training_steps))
    student_model.to(device)
    teacher_model.to(device)

    # Store loss and metric history
    train_losses, train_cls, train_div, train_cos = [], [], [], []
    eval_losses, eval_metrics_list = [], []

    # Track average metrics across epochs
    avg_metrics = {"accuracy": 0, "precision": 0, "recall": 0, "f1": 0}

    for epoch in range(num_epochs):
        student_model.train()
        teacher_model.eval()

        total_loss, loss_cls_total, loss_div_total, loss_cos_total = 0, 0, 0, 0

        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs_student = student_model(**batch)
            with torch.no_grad():
                outputs_teacher = teacher_model(**batch)

            # Losses
            loss_cls = criterion_cls(outputs_student.logits, batch["labels"])
            loss_div = criterion_div(outputs_student.logits, outputs_teacher.logits)
            loss_cos = criterion_cos(outputs_teacher.logits, outputs_student.logits,
                                     torch.ones(outputs_teacher.logits.size(0)).to(device))

            # Total loss (weighted avg)
            loss = (loss_cls + loss_div + loss_cos) / 3

            # Backprop
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Accumulate losses
            total_loss += loss.item()
            loss_cls_total += loss_cls.item()
            loss_div_total += loss_div.item()
            loss_cos_total += loss_cos.item()

            progress_bar.update(1)

        # Store training losses
        train_losses.append(total_loss / len(train_dataloader))
        train_cls.append(loss_cls_total / len(train_dataloader))
        train_div.append(loss_div_total / len(train_dataloader))
        train_cos.append(loss_cos_total / len(train_dataloader))

        # Evaluation
        student_model.eval()
        eval_loss, all_preds, all_labels = 0, [], []

        for batch in eval_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = student_model(**batch)
                loss_eval = criterion_cls(outputs.logits, batch["labels"])
                preds = outputs.logits.argmax(dim=-1)

            eval_loss += loss_eval.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch["labels"].cpu().numpy())

        # Compute metrics
        metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
        for key in avg_metrics:
            avg_metrics[key] += metrics[key]

        eval_losses.append(eval_loss / len(eval_dataloader))
        eval_metrics_list.append(metrics)

        # Logging
        print(f"\n Epoch {epoch+1}/{num_epochs} — {student_type} student")
        print(f" Train Loss:     {train_losses[-1]:.4f}")
        print(f" - CLS: {train_cls[-1]:.4f}, DIV: {train_div[-1]:.4f}, COS: {train_cos[-1]:.4f}")
        print(f" Eval Loss:      {eval_losses[-1]:.4f}")
        print(f" Eval Metrics:   Acc={metrics['accuracy']:.4f}, Prec={metrics['precision']:.4f}, "
              f"Rec={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")

    # Print average metrics
    print(f"\n Average Metrics for {student_type} student over {num_epochs} epochs:")
    for key in avg_metrics:
        print(f"  {key.capitalize()}: {avg_metrics[key]/num_epochs:.4f}")

    # Return results for plotting or further analysis
    return {
        "train_losses": train_losses,
        "train_cls": train_cls,
        "train_div": train_div,
        "train_cos": train_cos,
        "eval_losses": eval_losses,
        "metrics_per_epoch": eval_metrics_list,
        "final_avg_metrics": {k: avg_metrics[k]/num_epochs for k in avg_metrics}
    }


In [None]:
# Train Odd-Layer Student Model
print("\n---- Training result from student_model_odd ---")
odd_result = train_student_model(student_model_odd, optimizer_odd, lr_scheduler_odd, "odd")

# Train Even-Layer Student Model
print("\n--- Training result from student_model_even ---")
even_result = train_student_model(student_model_even, optimizer_even, lr_scheduler_even, "even")




---- Training result from student_model_odd ---

Starting training for odd student model...



  0%|          | 0/9615 [00:00<?, ?it/s]


 Epoch 1/5 — odd student
 Train Loss:     0.3562
 - CLS: 0.8928, DIV: 0.0763, COS: 0.0996
 Eval Loss:      0.8656
 Eval Metrics:   Acc=0.6847, Prec=0.6857, Rec=0.6771, F1=0.6802

 Epoch 2/5 — odd student
 Train Loss:     0.3266
 - CLS: 0.7782, DIV: 0.1114, COS: 0.0902
 Eval Loss:      0.8459
 Eval Metrics:   Acc=0.6587, Prec=0.6887, Rec=0.6613, F1=0.6622

 Epoch 3/5 — odd student
 Train Loss:     0.2956
 - CLS: 0.6441, DIV: 0.1546, COS: 0.0880
 Eval Loss:      0.8303
 Eval Metrics:   Acc=0.6712, Prec=0.6700, Rec=0.6731, F1=0.6684

 Epoch 4/5 — odd student
 Train Loss:     0.2713
 - CLS: 0.5389, DIV: 0.1882, COS: 0.0868
 Eval Loss:      0.8279
 Eval Metrics:   Acc=0.6785, Prec=0.6830, Rec=0.6736, F1=0.6763

 Epoch 5/5 — odd student
 Train Loss:     0.2602
 - CLS: 0.4926, DIV: 0.2027, COS: 0.0853
 Eval Loss:      0.8245
 Eval Metrics:   Acc=0.6675, Prec=0.6656, Rec=0.6644, F1=0.6639

 Average Metrics for odd student over 5 epochs:
  Accuracy: 0.6721
  Precision: 0.6786
  Recall: 0.6699


  0%|          | 0/9615 [00:00<?, ?it/s]


 Epoch 1/5 — even student
 Train Loss:     0.2749
 - CLS: 0.5475, DIV: 0.1897, COS: 0.0876
 Eval Loss:      0.8279
 Eval Metrics:   Acc=0.6608, Prec=0.6651, Rec=0.6568, F1=0.6591

 Epoch 2/5 — even student
 Train Loss:     0.2652
 - CLS: 0.5113, DIV: 0.1997, COS: 0.0846
 Eval Loss:      0.8284
 Eval Metrics:   Acc=0.6644, Prec=0.6573, Rec=0.6539, F1=0.6553

 Epoch 3/5 — even student
 Train Loss:     0.2560
 - CLS: 0.4765, DIV: 0.2097, COS: 0.0818
 Eval Loss:      0.8472
 Eval Metrics:   Acc=0.6592, Prec=0.6561, Rec=0.6577, F1=0.6555


In [None]:
# visualise the result
import matplotlib.pyplot as plt

# Extract from results
epochs = list(range(1, len(odd_result['train_losses']) + 1))

# Training Loss
plt.figure(figsize=(10, 5))
plt.plot(epochs, odd_result['train_losses'], label="Odd - Total Train Loss", marker="o")
plt.plot(epochs, even_result['train_losses'], label="Even - Total Train Loss", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Total Training Loss Comparison")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Individual Losses (CLS, DIV, COS)
loss_types = ['train_cls', 'train_div', 'train_cos']
loss_labels = ['Classification Loss', 'KL Divergence Loss', 'Cosine Similarity Loss']

for loss_key, label in zip(loss_types, loss_labels):
    plt.figure(figsize=(10, 4))
    plt.plot(epochs, odd_result[loss_key], label=f"Odd - {label}", marker="o")
    plt.plot(epochs, even_result[loss_key], label=f"Even - {label}", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{label} Comparison")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Evaluation Loss
plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_result['eval_losses'], label="Odd - Eval Loss", marker="o")
plt.plot(epochs, even_result['eval_losses'], label="Even - Eval Loss", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Evaluation Loss Comparison")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Accuracy
odd_acc = [m['accuracy'] for m in odd_result['metrics_per_epoch']]
even_acc = [m['accuracy'] for m in even_result['metrics_per_epoch']]

plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_acc, label="Odd - Accuracy", marker="o")
plt.plot(epochs, even_acc, label="Even - Accuracy", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Evaluation Accuracy Comparison")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


## 10. Compare LORA with Student Model

In [None]:
# Prepare 12-layer BERT Student Model
from transformers import AutoModelForSequenceClassification

student_model_lora = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=num_labels
)

In [None]:
# Apply LoRA Using PEFT
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,           # Rank
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none"
)

student_model_lora = get_peft_model(student_model_lora, lora_config)
student_model_lora.print_trainable_parameters()  # Optional: to confirm only LoRA layers are trainable

In [None]:
from transformers import AdamW

optimizer_lora = AdamW(student_model_lora.parameters(), lr=2e-5, weight_decay=0.01)


In [None]:
from transformers import get_scheduler

num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_epochs * num_update_steps_per_epoch
num_warmup_steps = int(0.1 * num_training_steps)

scheduler_lora = get_scheduler(
    name="linear",
    optimizer=optimizer_lora,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)


In [None]:
def train_lora_model(model, optimizer, scheduler, student_type="lora"):
    print(f"\nStarting LoRA training for {student_type} model...\n")

    model.to(device)
    progress_bar = tqdm(range(num_training_steps))

    # For storing logs
    train_losses = []
    eval_losses = []
    eval_metrics_list = []
    avg_metrics = {"accuracy": 0, "precision": 0, "recall": 0, "f1": 0}

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = criterion_cls(outputs.logits, batch["labels"])

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            total_loss += loss.item()
            progress_bar.update(1)

        avg_train_loss = total_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)

        # Evaluation
        model.eval()
        eval_loss = 0
        all_preds, all_labels = [], []

        for batch in eval_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
                loss = criterion_cls(outputs.logits, batch["labels"])
                preds = outputs.logits.argmax(dim=-1)

            eval_loss += loss.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch["labels"].cpu().numpy())

        avg_eval_loss = eval_loss / len(eval_dataloader)
        eval_losses.append(avg_eval_loss)

        metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
        eval_metrics_list.append(metrics)

        for key in avg_metrics:
            avg_metrics[key] += metrics[key]

        print(f"\nEpoch {epoch+1}/{num_epochs} — {student_type} model")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Eval Loss:  {avg_eval_loss:.4f}")
        print(f"Eval Metrics: Acc={metrics['accuracy']:.4f}, Prec={metrics['precision']:.4f}, "
              f"Rec={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")

    print(f"Final Average Metrics for {student_type} model:")
    for key in avg_metrics:
        print(f"  {key.capitalize()}: {avg_metrics[key]/num_epochs:.4f}")

    return {
        "train_losses": train_losses,
        "eval_losses": eval_losses,
        "metrics_per_epoch": eval_metrics_list,
        "final_avg_metrics": {k: avg_metrics[k]/num_epochs for k in avg_metrics}
    }


In [None]:
lora_result = train_lora_model(student_model_lora, optimizer_lora, scheduler_lora, student_type="lora")

In [None]:
import matplotlib.pyplot as plt

# Epochs list
epochs = list(range(1, len(odd_result['train_losses']) + 1))

# Accuracy
odd_acc = [m['accuracy'] for m in odd_result['metrics_per_epoch']]
even_acc = [m['accuracy'] for m in even_result['metrics_per_epoch']]
lora_acc = [m['accuracy'] for m in lora_result['metrics_per_epoch']]

# F1 Score
odd_f1 = [m['f1'] for m in odd_result['metrics_per_epoch']]
even_f1 = [m['f1'] for m in even_result['metrics_per_epoch']]
lora_f1 = [m['f1'] for m in lora_result['metrics_per_epoch']]

# Plot: Training Loss
plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_result['train_losses'], label="Odd", marker='o')
plt.plot(epochs, even_result['train_losses'], label="Even", marker='s')
plt.plot(epochs, lora_result['train_losses'], label="LoRA", marker='^')
plt.title("Training Loss Comparison")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot: Evaluation Loss
plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_result['eval_losses'], label="Odd", marker='o')
plt.plot(epochs, even_result['eval_losses'], label="Even", marker='s')
plt.plot(epochs, lora_result['eval_losses'], label="LoRA", marker='^')
plt.title("Evaluation Loss Comparison")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot: Accuracy
plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_acc, label="Odd", marker='o')
plt.plot(epochs, even_acc, label="Even", marker='s')
plt.plot(epochs, lora_acc, label="LoRA", marker='^')
plt.title("Evaluation Accuracy Comparison")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot: F1 Score
plt.figure(figsize=(10, 4))
plt.plot(epochs, odd_f1, label="Odd", marker='o')
plt.plot(epochs, even_f1, label="Even", marker='s')
plt.plot(epochs, lora_f1, label="LoRA", marker='^')
plt.title("Evaluation F1 Score Comparison")
plt.xlabel("Epochs")
plt.ylabel("F1 Score")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


## 11. Evalute and Save the best model

In [None]:
def evaluate_model_on_test(model, test_dataloader, name=""):
    model.eval()
    model.to(device)

    all_preds, all_labels = [], []
    total_loss = 0

    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            loss = criterion_cls(outputs.logits, batch["labels"])
            total_loss += loss.item()

            preds = outputs.logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch["labels"].cpu().numpy())

    metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
    avg_loss = total_loss / len(test_dataloader)

    print(f"{name} Test Performance:")
    print(f"  Loss:      {avg_loss:.4f}")
    print(f"  Accuracy:  {metrics['accuracy']:.4f}")
    print(f"  Precision: {metrics['precision']:.4f}")
    print(f"  Recall:    {metrics['recall']:.4f}")
    print(f"  F1-score:  {metrics['f1']:.4f}")

    return metrics, avg_loss


In [None]:
metrics_odd, loss_odd = evaluate_model_on_test(student_model_odd, test_dataloader, name="Odd Layer")
metrics_even, loss_even = evaluate_model_on_test(student_model_even, test_dataloader, name="Even Layer")
metrics_lora, loss_lora = evaluate_model_on_test(student_model_lora, test_dataloader, name="LoRA")


In [None]:
best_model = None
best_model_name = ""
best_f1 = max(metrics_odd["f1"], metrics_even["f1"], metrics_lora["f1"])

if metrics_odd["f1"] == best_f1:
    best_model = student_model_odd
    best_model_name = "student_model_odd"
elif metrics_even["f1"] == best_f1:
    best_model = student_model_even
    best_model_name = "student_model_even"
else:
    best_model = student_model_lora
    best_model_name = "student_model_lora"

print(f"Best Model Selected: {best_model_name} (F1: {best_f1:.4f})")


In [None]:
# Create a folder to store the model
save_path = f"./saved_models/{best_model_name}"
best_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"\n Model and tokenizer saved to: {save_path}")
