In [55]:
import torch
import torch.nn as nn
import pandas as pd
import torch.optim as optim
import math
from tqdm import tqdm

In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **Load Dataset**

**text:** a string feature.

**label:** a classification label, with possible values including World (0), Sports (1), Business (2), Sci/Tech (3).

**train:** 120000

**test:** 7600

In [57]:
from datasets import load_dataset

# Load the AG News dataset from Hugging Face Hub
dataset = load_dataset("wangrongsheng/ag_news")

# Display the first two training samples
print("Sample 1:", dataset["train"][0])
print("Sample 2:", dataset["train"][1])

Sample 1: {'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
Sample 2: {'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 2}


# **Tokenization**

In [58]:
from transformers import BertTokenizer

# Load the BERT tokenizer (uncased version)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Define the maximum sequence length
max_len = 128

# Tokenization function for a batch of text samples
def tokenize_batch(dataset_split):
    return tokenizer(
        [sample["text"] for sample in dataset_split],  # Extract texts
        padding="max_length",                          # Pad to max_len
        truncation=True,                               # Truncate if too long
        max_length=max_len,
        return_tensors="pt"                            # Return PyTorch tensors
    )

# Apply tokenization to train and test sets
train_encodings = tokenize_batch(dataset["train"])
test_encodings = tokenize_batch(dataset["test"])

# Extract labels and convert them to PyTorch tensors
train_labels = torch.tensor([sample["label"] for sample in dataset["train"]])
test_labels = torch.tensor([sample["label"] for sample in dataset["test"]])

In [59]:
print(train_encodings)

{'input_ids': tensor([[  101,  2813,  2358,  ...,     0,     0,     0],
        [  101, 18431,  2571,  ...,     0,     0,     0],
        [  101,  3514,  1998,  ...,     0,     0,     0],
        ...,
        [  101,  7842,  8193,  ...,     0,     0,     0],
        [  101,  2651,  1005,  ...,  2038,  2589,   102],
        [  101, 16996,  2131,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]])}


# **Create PyTorch Dataset**

In [60]:
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset class for AG News
class AGNewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

# Create dataset instances
train_dataset = AGNewsDataset(train_encodings, train_labels)
test_dataset = AGNewsDataset(test_encodings, test_labels)

# Define batch size
batch_size = 32

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


# **Teacher Model**

![](https://heidloff.net/assets/img/2023/02/transformers.png)?

In [61]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, ff_dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
        self.attn_norm = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim)
        )
        self.ff_norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, attention_mask=None):
        # Create key padding mask from attention mask
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        # Multi-head self-attention
        attn_output, attn_weights = self.attention(
            x, x, x,
            key_padding_mask=key_padding_mask,
            need_weights=True,
            average_attn_weights=False
        )
        value_vectors = x  # Input to attention acts as Value

        # Add & Norm (post-attention)
        x = self.attn_norm(x + attn_output)

        # Feed-forward network with Add & Norm
        ff_output = self.ff(x)
        x = self.ff_norm(x + ff_output)

        return x, attn_weights, value_vectors


class TeacherTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim=384, num_heads=8, ff_dim=1536, max_len=128, num_layers=4, type_vocab_size=2):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Embedding layers: token + position + segment (type)
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)
        self.segment_embedding = nn.Embedding(type_vocab_size, hidden_dim)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(hidden_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        # input_ids.shape = torch.Size([32, 128])
        B, L = input_ids.shape
        device = input_ids.device

        # Create position indices
        positions = torch.arange(0, L, device=device).unsqueeze(0).expand(B, L)

        # If segment (token type) IDs are not provided, default to all zeros (segment A)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # Combine embeddings: token + position + segment
        x = self.embedding(input_ids) \
            + self.pos_embedding(positions) \
            + self.segment_embedding(token_type_ids)

        # Forward through transformer layers
        attentions = None
        value_vectors = None

        for i, layer in enumerate(self.layers):
            x, attn_weights, values = layer(x, attention_mask)

            # Save only final layer's outputs for distillation
            if i == len(self.layers) - 1:
                attentions = attn_weights
                value_vectors = values
                
        return {
            "last_hidden_state": x,
            "attentions": attentions,
            "value_vectors": value_vectors
        }

In [62]:
# Instantiate the teacher model and move it to the appropriate device (CPU or GPU)
teacher_model = TeacherTransformerEncoder(vocab_size=tokenizer.vocab_size).to(device)

# Fetch one batch of data from the training loader
batch = next(iter(train_loader))

# Move input tensors to the same device as the model
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)

# Forward pass through the teacher model
out_t = teacher_model(input_ids, attention_mask)

# Print the shape of attention outputs (should be B x num_heads x L x L)
print("Teacher attention output shape: ", out_t["attentions"].shape)

# Print the shape of value vectors from the last transformer layer
print("Teacher value output shape: ", out_t["value_vectors"].shape)

Teacher attention output shape:  torch.Size([32, 8, 100, 100])
Teacher value output shape:  torch.Size([32, 100, 384])


# **Student Model**


In [63]:
class StudentTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim=128, num_heads=8, ff_dim=256, max_len=128, num_classes=4, type_vocab_size=2):
        super(StudentTransformerEncoder, self).__init__()
        
        self.hidden_dim = hidden_dim

        # Embedding layers: token, position, and segment (as in BERT)
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)
        self.segment_embedding = nn.Embedding(type_vocab_size, hidden_dim)

        # Self-attention block
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, 
            num_heads=num_heads, 
            batch_first=True
        )
        self.attn_norm = nn.LayerNorm(hidden_dim)

        # Feed-forward block
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim)
        )
        self.ff_norm = nn.LayerNorm(hidden_dim)

        # Optional classification head
        if num_classes is not None:
            self.classifier = nn.Linear(hidden_dim, num_classes)
        else:
            self.classifier = None

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        B, L = input_ids.size()
        device = input_ids.device

        # Positional embedding
        positions = torch.arange(0, L, device=device).unsqueeze(0).expand(B, L)

        # Default to all zeros (segment A) if token_type_ids not provided
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # Combine embeddings: token + position + segment
        x = self.embedding(input_ids) \
            + self.pos_embedding(positions) \
            + self.segment_embedding(token_type_ids)

        # Attention masking (PAD = 0 → True)
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        # Self-attention with attention weights
        attn_output, attn_weights = self.attention(
            x, x, x,
            key_padding_mask=key_padding_mask,
            need_weights=True,
            average_attn_weights=False
        )

         # Save pre-attention input as value vectors (for VR loss)
        value_vectors = x.detach()  # (B, L, D)

        # Residual + LayerNorm after attention
        x = self.attn_norm(x + attn_output)

        # Feed-forward + Residual + LayerNorm
        ff_output = self.ff(x)
        x = self.ff_norm(x + ff_output)

        # Prepare outputs
        output = {
            "last_hidden_state": x,         # (B, L, D)
            "attentions": attn_weights,     # (B, num_heads, L, L)
            "value_vectors": value_vectors  # (B, L, D)
        }

        # Add classification logits if classifier is defined
        if self.classifier is not None:
            cls_repr = x[:, 0, :]  # First token (CLS)
            output["logits"] = self.classifier(cls_repr)

        return output

In [64]:
# Instantiate the student model and move it to the appropriate device (CPU or GPU)
student_model = StudentTransformerEncoder(vocab_size=tokenizer.vocab_size).to(device)

# Fetch one batch of data from the training loader
batch = next(iter(train_loader))

# Move input tensors to the same device as the model
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)

# Forward pass through the teacher model
out_s = student_model(input_ids, attention_mask)

# Print the shape of attention outputs (should be B x num_heads x L x L)
print("Teacher attention output shape: ", out_s["attentions"].shape)

# Print the shape of value vectors from the last transformer layer
print("Teacher value output shape: ", out_s["value_vectors"].shape)

Teacher attention output shape:  torch.Size([32, 8, 100, 100])
Teacher value output shape:  torch.Size([32, 100, 128])


# **Loss Function**

In [65]:
import torch.nn.functional as F

In [66]:
def attention_loss(student_attn_last, teacher_attn_last, eps=1e-8):
    """
    Compute KL-divergence loss between teacher and student attention distributions.
    
    Args:
        student_attn_last: Tensor of shape (B, A_h, L, L) - student's attention scores.
        teacher_attn_last: Tensor of shape (B, A_h, L, L) - teacher's attention scores.
        eps: small constant for numerical stability.

    Returns:
        Scalar tensor representing the averaged KL divergence loss.
    """
    # Clamp values to avoid log(0)
    teacher_attn = teacher_attn_last.clamp(min=eps)
    student_attn = student_attn_last.clamp(min=eps)

    # Compute element-wise KL divergence: KL(teacher || student)
    kl = teacher_attn * (torch.log(teacher_attn) - torch.log(student_attn))
    
    # Sum over keys dimension (last dimension)
    kl = kl.sum(dim=-1)  # shape: (B, A_h, L)
    
    # Average over batch, attention heads, and tokens
    loss = kl.mean()
    return loss

In [67]:
def value_relation_loss(student_value_last, teacher_value_last, eps=1e-8):
    """
    Compute KL-divergence loss between value relation matrices of teacher and student.
    
    Args:
        student_value_last: Tensor of shape (B, L, D) - student's value vectors.
        teacher_value_last: Tensor of shape (B, L, D) - teacher's value vectors.
        eps: small constant for numerical stability.

    Returns:
        Scalar tensor representing the averaged KL divergence loss.
    """
    # Normalize value vectors along feature dimension
    sv = F.normalize(student_value_last, p=2, dim=-1)  # (B, L, D)
    tv = F.normalize(teacher_value_last, p=2, dim=-1)  # (B, L, D)

    # Compute relation matrices: similarity between tokens
    student_rel = torch.matmul(sv, sv.transpose(-1, -2))  # (B, L, L)
    teacher_rel = torch.matmul(tv, tv.transpose(-1, -2))  # (B, L, L)

    # sv.shape == (B, L, D)
    # sv.transpose(-1, -2) shape = (B, D, L)
    # student_rel.shape => (B, L, D) @ (B, D, L) → (B, L, L)

    # Convert to probability distributions using softmax along keys dimension
    student_rel = F.softmax(student_rel, dim=-1).clamp(min=eps)
    teacher_rel = F.softmax(teacher_rel, dim=-1).clamp(min=eps)

    # Compute element-wise KL divergence: KL(teacher || student)
    kl = teacher_rel * (torch.log(teacher_rel) - torch.log(student_rel))

    # Sum over keys dimension
    kl = kl.sum(dim=-1)  # shape: (B, L)

    # Average over batch and tokens
    loss = kl.mean()
    return loss

# **Forward Pass**

In [68]:
from torch.nn import CrossEntropyLoss

num_epochs = 10
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

for epoch in range(num_epochs):
    print(f"------------------------ Epoch {epoch+1}/{num_epochs} ------------------------")
    
    student_model.train()
    teacher_model.eval()
    train_loss = 0.0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # ---- TEACHER FORWARD PASS (No gradient) ----
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_attn = teacher_outputs["attentions"]        # (B, num_heads, L, L)
            teacher_value = teacher_outputs["value_vectors"]    # (B, L, D)

        # ---- STUDENT FORWARD PASS ----
        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        student_attn = student_outputs["attentions"]           # (B, num_heads, L, L)
        student_value = student_outputs["value_vectors"]       # (B, L, D)

        # ---- DISTILLATION LOSS COMPUTATION ----
        loss_at = attention_loss(student_attn, teacher_attn)
        loss_vr = value_relation_loss(student_value, teacher_value)
        loss = loss_at + loss_vr

        # ---- OPTIMIZATION ----
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print(f"Training - Distillation Loss: {train_loss / len(train_loader):.4f}")

    # print("------------------------ Validation ------------------------")
    
    student_model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            
            # Classification loss for evaluation
            loss = criterion(logits, labels)
            val_loss += loss.item()

    print(f"Average Validation Loss: {val_loss / len(test_loader):.4f}")

------------------------ Epoch 1/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.21it/s]


Training - Distillation Loss: 0.0771
Average Validation Loss: 1.5559
------------------------ Epoch 2/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.26it/s]


Training - Distillation Loss: 0.0342
Average Validation Loss: 1.5442
------------------------ Epoch 3/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.45it/s]


Training - Distillation Loss: 0.0280
Average Validation Loss: 1.5231
------------------------ Epoch 4/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.38it/s]


Training - Distillation Loss: 0.0244
Average Validation Loss: 1.5079
------------------------ Epoch 5/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.36it/s]


Training - Distillation Loss: 0.0220
Average Validation Loss: 1.4960
------------------------ Epoch 6/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.38it/s]


Training - Distillation Loss: 0.0204
Average Validation Loss: 1.4878
------------------------ Epoch 7/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.47it/s]


Training - Distillation Loss: 0.0193
Average Validation Loss: 1.4823
------------------------ Epoch 8/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.49it/s]


Training - Distillation Loss: 0.0185
Average Validation Loss: 1.4782
------------------------ Epoch 9/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.51it/s]


Training - Distillation Loss: 0.0179
Average Validation Loss: 1.4747
------------------------ Epoch 10/10 ------------------------


100%|██████████| 3750/3750 [01:22<00:00, 45.55it/s]


Training - Distillation Loss: 0.0174
Average Validation Loss: 1.4718
