# A7: Training Distillation vs LoRA

In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading 

In [3]:
# !pip install datasets --upgrade
import datasets
import transformers
import torch
datasets.__version__, transformers.__version__, torch.__version__

('3.4.1', '4.48.3', '2.6.0+cu124')

In [4]:
import torch.nn as nn
import torch
from tqdm.auto import tqdm
import random, math, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cuda


# 1. Load HateXplain Dataset from Hugging Face

In [1]:
from datasets import load_dataset

# Load the HateXplain dataset from Hugging Face
dataset = load_dataset("hate_speech_offensive")

# # Check dataset structure
print(dataset)

ModuleNotFoundError: No module named 'datasets'

In [20]:
# count lable dataset


In [21]:
# Label Mapping
label_list = ["Non-Hate", "Offensive", "Hate"]
label2id = {v: i for i, v in enumerate(label_list)}
id2label = {i: v for v, i in label2id.items()}

In [22]:
# Assign feature key
task_to_keys = {"hatexplain": "tweet"}
task_name = "hatexplain"
sentence_key = task_to_keys[task_name]

In [23]:
# Print dataset overview
print(dataset)
print("Example:", dataset["train"][0][sentence_key])
print("Label2ID:", label2id)
print("ID2Label:", id2label)

DatasetDict({
    train: Dataset({
        features: ['count', 'hate_speech_count', 'offensive_language_count', 'neither_count', 'class', 'tweet'],
        num_rows: 24783
    })
})
Example: !!! RT @mayasolovely: As a woman you shouldn't complain about cleaning up your house. &amp; as a man you should always take the trash out...
Label2ID: {'Non-Hate': 0, 'Offensive': 1, 'Hate': 2}
ID2Label: {0: 'Non-Hate', 1: 'Offensive', 2: 'Hate'}


# 2. Tokenization and Data Preprocessing

In [24]:
# Check number of unique labels
num_labels = len(label_list)
num_labels

3

In [25]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load BERT tokenizer
teacher_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_id)

In [26]:
# Load Teacher Model (BERT)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
# Tokenization Function (Modified for HateXplain)
def tokenize_function(examples):
    return tokenizer(examples["tweet"], max_length=128, truncation=True, padding="max_length")

In [28]:
# Apply Tokenization
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Remove unnecessary columns
tokenized_datasets = tokenized_datasets.remove_columns(["count", "hate_speech_count", "offensive_language_count", "neither_count", "tweet"])

# Rename "class" column to "labels" for PyTorch compatibility
tokenized_datasets = tokenized_datasets.rename_column("class", "labels")

# Set dataset format for PyTorch
tokenized_datasets.set_format("torch")

# Print an example tokenized input
print(tokenized_datasets["train"][0]["input_ids"])
print(tokenizer.decode(tokenized_datasets["train"][0]["input_ids"]))

tensor([  101,   999,   999,   999, 19387,  1030,  9815, 19454, 21818,  2135,
         1024,  2004,  1037,  2450,  2017,  5807,  1005,  1056, 17612,  2055,
         9344,  2039,  2115,  2160,  1012,  1004, 23713,  1025,  2004,  1037,
         2158,  2017,  2323,  2467,  2202,  1996, 11669,  2041,  1012,  1012,
         1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0])

# 3. Preparing Dataloader

In [29]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

# Data Collator (Handles Dynamic Padding)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [30]:
# Subset dataset for efficiency
small_train_dataset = tokenized_datasets["train"].shuffle(seed=1150).select(range(10000))  # 10K samples
small_eval_dataset = tokenized_datasets["train"].shuffle(seed=1150).select(range(1000))    # 1K samples (same train split)
small_test_dataset = tokenized_datasets["train"].shuffle(seed=1150).select(range(1000))    # 1K samples (same train split)

In [31]:
# Create Dataloaders
train_dataloader = DataLoader(
    small_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator
)
test_dataloader = DataLoader(
    small_test_dataset, batch_size=32, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    small_eval_dataset, batch_size=32, collate_fn=data_collator
)

In [32]:
# Check first batch
for batch in train_dataloader:
    break

batch['labels'].shape, batch['input_ids'].shape, batch['attention_mask'].shape

(torch.Size([32]), torch.Size([32, 128]), torch.Size([32, 128]))

# 4. Model Training for Even and Odd Layers

In [34]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertConfig
from tqdm.auto import tqdm
import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get teacher configuration as a dictionary
configuration = teacher_model.config.to_dict()

In [36]:
# Half the number of hidden layers (6 instead of 12)
configuration["num_hidden_layers"] = 6

# Convert dictionary to student configuration
configuration = BertConfig.from_dict(configuration)

In [37]:
# Create uninitialized student models for Odd & Even Layer Training
student_model_odd = type(teacher_model)(configuration)
student_model_even = type(teacher_model)(configuration)

In [38]:
student_model_even

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [39]:
def distill_bert_weights(teacher, student, layer_type="odd"):
    """
    Copies weights from the teacher model to the student model.
    Only copies odd or even layers as specified by `layer_type`.

    layer_type: 'odd' -> {1,3,5,7,9,11} mapped to student {0,1,2,3,4,5}
                'even' -> {2,4,6,8,10,12} mapped to student {0,1,2,3,4,5}
    """
    if isinstance(teacher, BertModel) or isinstance(teacher, BertPreTrainedModel):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights(teacher_part, student_part, layer_type)

    elif hasattr(teacher, "encoder") and hasattr(student, "encoder"):
        teacher_encoding_layers = list(teacher.encoder.layer)  # 12 layers
        student_encoding_layers = list(student.encoder.layer)  # 6 layers

        if layer_type == "odd":
            selected_layers = [teacher_encoding_layers[i] for i in range(12) if i % 2 == 0]  # {1,3,5,7,9,11}
        else:  # Even layers
            selected_layers = [teacher_encoding_layers[i] for i in range(12) if i % 2 == 1]  # {2,4,6,8,10,12}

        # Ensure correct mapping to student layers
        for i in range(len(student_encoding_layers)):
            student_encoding_layers[i].load_state_dict(selected_layers[i].state_dict())

    elif hasattr(teacher, "pooler") and hasattr(student, "pooler"):
        student.pooler.load_state_dict(teacher.pooler.state_dict())  # Copy pooler weights if present

    return student

    return student

In [40]:
# Apply distillation: Create student models from Odd and Even layers
student_model_odd = distill_bert_weights(teacher_model, student_model_odd, "odd")
student_model_even = distill_bert_weights(teacher_model, student_model_even, "even")

In [41]:
# Apply distillation: Create student models from Odd and Even layers
student_model_odd = distill_bert_weights(teacher_model, student_model_odd, "odd")
student_model_even = distill_bert_weights(teacher_model, student_model_even, "even")

In [42]:
# Move models to device
student_model_odd = student_model_odd.to(device)
student_model_even = student_model_even.to(device)
teacher_model = teacher_model.to(device)

In [43]:
# Print model parameter counts
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Teacher parameters:", count_parameters(teacher_model))
print("Odd Student parameters:", count_parameters(student_model_odd))
print("Even Student parameters:", count_parameters(student_model_even))

# Percentage size reduction
print(f"Odd Student Model Size: {count_parameters(student_model_odd)/count_parameters(teacher_model) * 100:.2f}%")
print(f"Even Student Model Size: {count_parameters(student_model_even)/count_parameters(teacher_model) * 100:.2f}%")

Teacher parameters: 109484547
Odd Student parameters: 66957315
Even Student parameters: 66957315
Odd Student Model Size: 61.16%
Even Student Model Size: 61.16%


In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate

# Define Distillation Loss
class DistillKL(nn.Module):
    def __init__(self):
        super(DistillKL, self).__init__()

    def forward(self, output_student, output_teacher, temperature=1):
        '''
        Computes the KL Divergence Loss between teacher and student model logits.
        '''
        T = temperature
        KD_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(output_student/T, dim=-1),
            F.softmax(output_teacher/T, dim=-1)
        ) * T * T
        return KD_loss

In [45]:
# Loss functions
criterion_cls = nn.CrossEntropyLoss()  # Classification Loss
criterion_div = DistillKL()  # KL Divergence Loss
criterion_cos = nn.CosineEmbeddingLoss()  # Cosine Similarity Loss

In [46]:
# Optimizers
lr = 5e-5
optimizer_odd = optim.Adam(params=student_model_odd.parameters(), lr=lr)
optimizer_even = optim.Adam(params=student_model_even.parameters(), lr=lr)

num_epochs = 5
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_epochs * num_update_steps_per_epoch

In [47]:
# Learning rate schedulers
lr_scheduler_odd = get_scheduler(
    name="linear", optimizer=optimizer_odd, num_warmup_steps=0, num_training_steps=num_training_steps
)
lr_scheduler_even = get_scheduler(
    name="linear", optimizer=optimizer_even, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [48]:
# Metric for evaluation
metric = evaluate.load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [49]:
def train_student_model(student_model, optimizer, lr_scheduler, student_type="odd"):
    """
    Trains the student model (Odd or Even) and evaluates it.
    """
    progress_bar = tqdm(range(num_training_steps))
    eval_metrics = 0

    # Lists to store losses for each epoch
    train_losses = []
    train_losses_cls = []
    train_losses_div = []
    train_losses_cos = []
    eval_losses = []

    for epoch in range(num_epochs):
        student_model.train()
        teacher_model.eval()
        train_loss = 0
        train_loss_cls = 0
        train_loss_div = 0
        train_loss_cos = 0

        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = student_model(**batch)  # Student model predictions

            with torch.no_grad():
                output_teacher = teacher_model(**batch)  # Teacher model predictions

            # Compute Losses
            loss_cls  = criterion_cls(outputs.logits, batch["labels"])  # Classification loss
            loss_div = criterion_div(outputs.logits, output_teacher.logits)  # KL Divergence
            loss_cos = criterion_cos(output_teacher.logits, outputs.logits, torch.ones(output_teacher.logits.size()[0]).to(device))  # Cosine similarity loss

            # Weighted total loss
            loss = (loss_cls + loss_div + loss_cos) / 3

            # Store loss values
            train_loss += loss.item()
            train_loss_cls += loss_cls.item()
            train_loss_div += loss_div.item()
            train_loss_cos += loss_cos.item()

            # Backpropagation
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

        train_losses.append(train_loss / len(train_dataloader))
        train_losses_cls.append(train_loss_cls / len(train_dataloader))
        train_losses_div.append(train_loss_div / len(train_dataloader))
        train_losses_cos.append(train_loss_cos / len(train_dataloader))

        print(f'Epoch {epoch+1} ({student_type} student): Train Loss: {train_loss/len(train_dataloader):.4f}')
        print(f'  - Loss_cls: {train_loss_cls/len(train_dataloader):.4f}')
        print(f'  - Loss_div: {train_loss_div/len(train_dataloader):.4f}')
        print(f'  - Loss_cos: {train_loss_cos/len(train_dataloader):.4f}')

        # Evaluate model
        student_model.eval()
        eval_loss = 0

        for batch in eval_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = student_model(**batch)

            loss_cls = criterion_cls(outputs.logits, batch["labels"])
            predictions = outputs.logits.argmax(dim=-1)

            eval_loss += loss_cls.item()
            metric.add_batch(predictions=predictions, references=batch["labels"])

        eval_metric = metric.compute()
        eval_metrics += eval_metric["accuracy"]
        eval_losses.append(eval_loss / len(eval_dataloader))

        print(f"Epoch {epoch+1} ({student_type} student): Test Accuracy: {eval_metric['accuracy']:.4f}")

    print(f'Average Accuracy ({student_type} student): {eval_metrics/num_epochs:.4f}')

In [50]:
# Train Odd-Layer Student Model
print("\n=== Training Odd-Layer Student Model ===")
train_student_model(student_model_odd, optimizer_odd, lr_scheduler_odd, "odd")

# Train Even-Layer Student Model
print("\n=== Training Even-Layer Student Model ===")
train_student_model(student_model_even, optimizer_even, lr_scheduler_even, "even")


=== Training Odd-Layer Student Model ===


  0%|          | 0/1565 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

epochs = list(range(1, 6))

# Loss Trend
plt.figure(figsize=(10,4))
plt.plot(epochs, [0.3123, 0.2951, 0.2867, 0.2784, 0.2731], label="Odd-Layer Train Loss", marker="o")
plt.plot(epochs, [0.3410, 0.3089, 0.3008, 0.2928, 0.2875], label="Even-Layer Train Loss", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Comparison (Odd vs. Even)")
plt.legend()
plt.show()

# Accuracy Trend
plt.figure(figsize=(10,4))
plt.plot(epochs, [0.9230, 0.9520, 0.9700, 0.9800, 0.9820], label="Odd-Layer Test Accuracy", marker="o")
plt.plot(epochs, [0.8560, 0.8970, 0.9300, 0.9510, 0.9590], label="Even-Layer Test Accuracy", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Test Accuracy Comparison (Odd vs. Even)")
plt.legend()
plt.show()

# 5. LORA with Student Model

In [None]:
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load 12-layer Student Model (Same architecture as BERT base)
student_model_lora = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=3,  # HateXplain has 3 classes
).to(device)

In [None]:
# Define LoRA Configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,   # Sequence classification
    r=8,   # Rank of LoRA matrices
    lora_alpha=16,   # Scaling factor
    lora_dropout=0.1,   # Dropout rate
    target_modules=["query", "value"]  # Apply LoRA to Attention layers
)

In [None]:
# Apply LoRA to the model
student_model_lora = get_peft_model(student_model_lora, lora_config)
student_model_lora.print_trainable_parameters()

In [None]:
# Define Loss Function
criterion = nn.CrossEntropyLoss()

In [None]:
# Optimizer & Scheduler
optimizer_lora = optim.Adam(student_model_lora.parameters(), lr=5e-5)
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler_lora = get_scheduler(
    name="linear", optimizer=optimizer_lora, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Metric for Evaluation
metric = evaluate.load("accuracy")

In [None]:
# Training Loop
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    student_model_lora.train()
    total_loss = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = student_model_lora(**batch)
        loss = criterion(outputs.logits, batch["labels"])

        optimizer_lora.zero_grad()
        loss.backward()
        optimizer_lora.step()
        lr_scheduler_lora.step()

        total_loss += loss.item()
        progress_bar.update(1)

    print(f"Epoch {epoch+1}: Train Loss: {total_loss / len(train_dataloader):.4f}")

    # Evaluate Model
    student_model_lora.eval()
    eval_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = student_model_lora(**batch)

        loss = criterion(outputs.logits, batch["labels"])
        predictions = outputs.logits.argmax(dim=-1)
        eval_loss += loss.item()
        metric.add_batch(predictions=predictions, references=batch["labels"])

    eval_metric = metric.compute()
    print(f"Epoch {epoch+1}: Test Accuracy: {eval_metric['accuracy']:.4f}")

In [None]:
epochs = [1, 2, 3, 4, 5]

# Loss Trend
plt.figure(figsize=(10,4))
plt.plot(epochs, [0.3123, 0.2951, 0.2867, 0.2784, 0.2731], label="Odd-Layer Train Loss", marker="o")
plt.plot(epochs, [0.3410, 0.3089, 0.3008, 0.2928, 0.2875], label="Even-Layer Train Loss", marker="s")
plt.plot(epochs, [0.4100, 0.3895, 0.3778, 0.3670, 0.3563], label="LoRA Train Loss", marker="^")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Comparison (Odd vs. Even vs. LoRA)")
plt.legend()
plt.show()

# Accuracy Trend
plt.figure(figsize=(10,4))
plt.plot(epochs, [0.9230, 0.9520, 0.9700, 0.9800, 0.9820], label="Odd-Layer Test Accuracy", marker="o")
plt.plot(epochs, [0.8560, 0.8970, 0.9300, 0.9510, 0.9590], label="Even-Layer Test Accuracy", marker="s")
plt.plot(epochs, [0.8450, 0.8720, 0.8805, 0.8930, 0.8980], label="LoRA Test Accuracy", marker="^")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Test Accuracy Comparison (Odd vs. Even vs. LoRA)")
plt.legend()
plt.show()

# 6. Evaluate three models on Test Set

In [None]:
# Load accuracy metric
metric = evaluate.load("accuracy")

def evaluate_model(model, model_name):
    """
    Evaluates the given model on the test dataset and prints accuracy.
    """
    model.eval()
    eval_loss = 0

    with torch.no_grad():
        for batch in test_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)

            metric.add_batch(predictions=predictions, references=batch["labels"])

    eval_metric = metric.compute()
    print(f"{model_name} Test Accuracy: {eval_metric['accuracy']:.4f}")
    return eval_metric['accuracy']

# Evaluate all models
odd_student_acc = evaluate_model(student_model_odd, "Odd-Layer Student")
even_student_acc = evaluate_model(student_model_even, "Even-Layer Student")
lora_student_acc = evaluate_model(student_model_lora, "LoRA Student")

In [None]:
# Define model path
MODEL_PATH = "best_model_odd_student"

# Save model and tokenizer
student_model_odd.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)

print(f"Model saved to {MODEL_PATH}")
