# Imports

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pandas as pd
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizerFast
from tqdm import tqdm
from transformers import (
    BertForSequenceClassification,
    AutoTokenizer
)
from transformers import get_linear_schedule_with_warmup
from transformers import BertConfig, BertModel, BertTokenizer
from transformers.optimization import AdamW
from sys import platform
import time
from tqdm import tqdm
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    classification_report
)
from datasets import load_dataset

# Preparing dataset

In [3]:
class SST2Dataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_len):
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        text = self.sentences[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def load_sst2_data():
    sst2_dataset = load_dataset("glue", "sst2")
    train_dataset = sst2_dataset["train"]
    test_dataset = sst2_dataset["validation"]
    train_texts = train_dataset["sentence"]
    train_labels = train_dataset["label"]
    test_texts = test_dataset["sentence"]
    test_labels = test_dataset["label"]

    return train_texts, train_labels, test_texts, test_labels

train_texts, train_labels, test_texts, test_labels = load_sst2_data()

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 3.11M/3.11M [00:00<00:00, 10.1MB/s]
Downloading data: 100%|██████████| 72.8k/72.8k [00:00<00:00, 381kB/s]
Downloading data: 100%|██████████| 148k/148k [00:00<00:00, 709kB/s]


Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

# Creating dataloaders

In [4]:
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2', do_lower_case=True)

device = "cuda" # change to cuda when on kaggl

max_len = 64             # from HF docs
batch_size = 8           # keep things fast

train_dataset = SST2Dataset(train_texts, train_labels, tokenizer, max_len)
test_dataset = SST2Dataset(test_texts, test_labels, tokenizer, max_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/477 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

# Evalutating teacher

In [5]:
# Define teacher model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2', num_labels=2, output_hidden_states=True)

    def forward(self, input_ids, attention_mask, output_hidden_states=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states)
    
        # Conditionally return logits and hidden states
        if output_hidden_states:
            return outputs.logits, outputs.hidden_states
        else:
            return outputs.logits

# Instantiate models
teacher_model = TeacherModel()
teacher_model.to(device)

# Define optimizer and loss function
optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# Validate teacher model
teacher_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = teacher_model(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
total_params = sum(p.numel() for p in teacher_model.parameters() if p.requires_grad)
print("Teacher Model Validation Accuracy: {:.4f}".format(accuracy))
print(f"Parameters: {total_params}")


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
100%|██████████| 109/109 [00:02<00:00, 39.20it/s]

Teacher Model Validation Accuracy: 0.9243
Parameters: 109483778





# Define student model


In [7]:
# 4 layer bert config as per literature
class BertForSequenceClassification(nn.Module):
    def __init__(self, config):
        super(BertForSequenceClassification, self).__init__()
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size, 2)  # Adjusted for 2 classes

    def forward(self, input_ids, attention_mask=None, output_hidden_states=False):
        outputs = self.bert(input_ids, 
                            attention_mask=attention_mask, 
                            output_hidden_states=output_hidden_states)
        # Using the pooled output for classification
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)

        if output_hidden_states:
            return logits, outputs.hidden_states
        return logits


# Configuration for BERT
config = BertConfig(
    attention_probs_dropout_prob=0.1,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    hidden_size=540,
    initializer_range=0.02,
    intermediate_size=500,
    max_position_embeddings=512,
    num_attention_heads=12,
    num_hidden_layers=4,
    type_vocab_size=2,
    vocab_size=30522,
    output_hidden_states = True
)


# Evaluate student model's zero-shot performance

In [8]:
# Initialize the model
student_model = BertForSequenceClassification(config)
student_model.to(device)     # change if not apple silicone

optimizer_student = optim.Adam(student_model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = student_model.forward(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print("Student Model 0-shot Accuracy: {:.4f}".format(accuracy))

total_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

100%|██████████| 109/109 [00:00<00:00, 171.53it/s]

Student Model 0-shot Accuracy: 0.4908
Total parameters: 23900782





# Simple logits-based knowledge distillation

In [9]:
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):
    soft_labels = nn.functional.log_softmax(student_logits / temperature, dim=1)
    with torch.no_grad():
        soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    soft_loss = nn.KLDivLoss(reduction='batchmean')(soft_labels, soft_targets) * (temperature ** 2)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * soft_loss + (1. - alpha) * hard_loss

def train_student_with_distillation(student, teacher, train_loader, device, temperature=2.0, alpha=0.5, num_epochs=3):
    optimizer = torch.optim.AdamW(student.parameters(), lr=2e-5)

    student.train()
    teacher.eval()

    for epoch in (range(num_epochs)):
        for batch in tqdm(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(input_ids, attention_mask=attention_mask)

            student_logits = student(input_ids, attention_mask=attention_mask)

            loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [10]:
num_epochs = 10
train_student_with_distillation(student_model, teacher_model, train_loader, device, num_epochs=num_epochs)

100%|██████████| 8419/8419 [03:57<00:00, 35.42it/s]


Epoch 1/10, Loss: 0.16002747416496277


100%|██████████| 8419/8419 [03:57<00:00, 35.43it/s]


Epoch 2/10, Loss: 0.004210581071674824


100%|██████████| 8419/8419 [03:57<00:00, 35.45it/s]


Epoch 3/10, Loss: 0.4726927876472473


100%|██████████| 8419/8419 [03:57<00:00, 35.45it/s]


Epoch 4/10, Loss: 0.1543080359697342


100%|██████████| 8419/8419 [03:57<00:00, 35.49it/s]


Epoch 5/10, Loss: 0.033204078674316406


100%|██████████| 8419/8419 [03:57<00:00, 35.45it/s]


Epoch 6/10, Loss: 0.8552101850509644


100%|██████████| 8419/8419 [03:57<00:00, 35.48it/s]


Epoch 7/10, Loss: 0.09905454516410828


100%|██████████| 8419/8419 [03:57<00:00, 35.45it/s]


Epoch 8/10, Loss: 0.13442720472812653


100%|██████████| 8419/8419 [03:57<00:00, 35.43it/s]


Epoch 9/10, Loss: 0.001954026520252228


100%|██████████| 8419/8419 [03:57<00:00, 35.44it/s]


Epoch 10/10, Loss: 0.0010987480636686087


In [13]:
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = student_model(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total

print("Logits-KD Student Accuracy: {:.4f}".format(accuracy))

total_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

100%|██████████| 109/109 [00:00<00:00, 157.33it/s]

Logits-KD Student Accuracy: 0.7839
Total parameters: 23900782





# Simple layer-wise knowledge distillation 

In [10]:
# Projection module
class ProjectionModule(nn.Module):
    def __init__(self, teacher_dims, student_dim, device):
        super(ProjectionModule, self).__init__()
        self.projections = nn.ModuleDict()
        for idx, t_dim in enumerate(teacher_dims):
            self.projections[str(idx)] = nn.Linear(t_dim, student_dim).to(device)

    def forward(self, idx, x):
        return self.projections[str(idx)](x)
    
# Layer mapper
mapper = {
    0: [0, 1, 2],   # Combine Teacher layers 0, 1, 2 for Student layer 0
    1: [3, 4, 5],   # Combine Teacher layers 3, 4, 5 for Student layer 1
    2: [6, 7, 8],   # Combine Teacher layers 6, 7, 8 for Student layer 2
    3: [9, 10, 11]  # Combine Teacher layers 9, 10, 11 for Student layer 3
}



In [11]:
# CKD loss
def ckd_loss(student_layers, teacher_layers, mapper, projections, device):
    loss_fn = nn.MSELoss()
    total_loss = torch.tensor(0.0, device=device)
    for s_idx, t_indices in mapper.items():
        combined_layer = torch.cat([teacher_layers[t_idx] for t_idx in t_indices], dim=-1)
        # Apply the projection to the combined layer
        projected_layer = projections(s_idx, combined_layer)
        # Compute the loss
        loss = loss_fn(projected_layer, student_layers[s_idx])
        total_loss += loss

    return total_loss

teacher_layer_dim = 768
teacher_dims = [teacher_layer_dim * len(mapper[i]) for i in range(len(mapper))] 
student_dim = 540

projections = ProjectionModule(teacher_dims, student_dim, device)


In [12]:
def train_student_with_ckd(student, teacher, train_loader, projections, device, mapper, temperature=5.0, alpha=0.2, num_epochs=3):
    optimizer = torch.optim.AdamW(list(student.parameters()) + list(projections.parameters()), lr=1e-5)
    student.train()
    teacher.eval()

    for epoch in range(num_epochs):
        
        for batch in tqdm(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            # Get outputs with hidden states
            student_logits, student_layers = student(input_ids, attention_mask, output_hidden_states=True)
            soft_labels = nn.functional.log_softmax(student_logits / temperature, dim=1)

            with torch.no_grad():
                teacher_outputs, teacher_layers = teacher(input_ids, attention_mask=attention_mask, output_hidden_states=True)
                soft_targets = nn.functional.softmax(teacher_outputs / temperature, dim=1)
        
            soft_loss = nn.KLDivLoss(reduction='batchmean')(soft_labels, soft_targets) * (temperature ** 2)
            
            # Calculate CKD loss using the mapper and projection module
            ckd = ckd_loss(student_layers, teacher_layers, mapper, projections, device)
            loss = alpha * ckd + 0.5 * nn.CrossEntropyLoss()(student_logits, labels) + 0.1 *soft_loss

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [20]:
num_epochs = 10
student_model_ckd = BertForSequenceClassification(config)
student_model_ckd.to(device)
train_student_with_ckd(student_model_ckd, teacher_model, train_loader, projections, device, mapper, num_epochs=num_epochs)

100%|██████████| 8419/8419 [04:50<00:00, 28.98it/s]


Epoch 1/10, Loss: 0.5304255485534668


100%|██████████| 8419/8419 [04:50<00:00, 28.97it/s]


Epoch 2/10, Loss: 0.6978040337562561


100%|██████████| 8419/8419 [04:49<00:00, 29.07it/s]


Epoch 3/10, Loss: 0.16866007447242737


100%|██████████| 8419/8419 [04:49<00:00, 29.10it/s]


Epoch 4/10, Loss: 0.34456735849380493


100%|██████████| 8419/8419 [04:48<00:00, 29.16it/s]


Epoch 5/10, Loss: 0.08422499895095825


100%|██████████| 8419/8419 [04:49<00:00, 29.07it/s]


Epoch 6/10, Loss: 0.2337331771850586


100%|██████████| 8419/8419 [04:49<00:00, 29.10it/s]


Epoch 7/10, Loss: 0.1074661910533905


100%|██████████| 8419/8419 [04:48<00:00, 29.17it/s]


Epoch 8/10, Loss: 0.32152849435806274


100%|██████████| 8419/8419 [04:48<00:00, 29.19it/s]


Epoch 9/10, Loss: 1.0617775917053223


100%|██████████| 8419/8419 [04:46<00:00, 29.35it/s]

Epoch 10/10, Loss: 0.052588656544685364





In [23]:
student_model_ckd.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = student_model_ckd(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total

print("CKD Student Accuracy: {:.4f}".format(accuracy))

total_params = sum(p.numel() for p in student_model_ckd.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

100%|██████████| 109/109 [00:00<00:00, 166.32it/s]

CKD Student Accuracy: 0.7982
Total parameters: 23900782





# Dynamic Weighting for Layer Loss

In [13]:
def adjust_loss_alpha(initial_alpha, epoch, total_epochs, lambda_factor=0.9):
    return initial_alpha * (lambda_factor ** (epoch / total_epochs))

In [14]:
def train_student_with_ckd_decay(student, teacher, train_loader, projections, device, mapper, temperature=3.0, alpha=0.5, beta = 0.5, num_epochs=3):
    optimizer = torch.optim.AdamW(list(student.parameters()) + list(projections.parameters()), lr=1e-5)
    student.train()
    teacher.eval()
    
    lambda_factor = 0.9

    for epoch in range(num_epochs):
        current_alpha = adjust_loss_alpha(alpha, epoch, num_epochs, lambda_factor)
        for batch in tqdm(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            # Get outputs with hidden states
            student_logits, student_layers = student(input_ids, attention_mask, output_hidden_states=True)
            soft_labels = nn.functional.log_softmax(student_logits / temperature, dim=1)

            with torch.no_grad():
                teacher_outputs, teacher_layers = teacher(input_ids, attention_mask=attention_mask, output_hidden_states=True)
                soft_targets = nn.functional.softmax(teacher_outputs / temperature, dim=1)
        
            soft_loss = nn.KLDivLoss(reduction='batchmean')(soft_labels, soft_targets) * (temperature ** 2)
            # Calculate CKD loss using the mapper and projection module
            ckd = ckd_loss(student_layers, teacher_layers, mapper, projections, device)
            loss = current_alpha * ckd + 0.5 * nn.CrossEntropyLoss()(student_logits, labels) + 0.1 * soft_loss

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [15]:
num_epochs = 10
student_model_ckd_decay = BertForSequenceClassification(config)
student_model_ckd_decay.to(device)
train_student_with_ckd_decay(student_model_ckd_decay, teacher_model, train_loader, projections, device, mapper, num_epochs=num_epochs)                                              

100%|██████████| 8419/8419 [04:44<00:00, 29.57it/s]


Epoch 1/10, Loss: 0.6603521108627319


100%|██████████| 8419/8419 [04:44<00:00, 29.59it/s]


Epoch 2/10, Loss: 0.2214069366455078


100%|██████████| 8419/8419 [04:43<00:00, 29.69it/s]


Epoch 3/10, Loss: 0.2177843153476715


100%|██████████| 8419/8419 [04:45<00:00, 29.52it/s]


Epoch 4/10, Loss: 0.1356973946094513


100%|██████████| 8419/8419 [04:47<00:00, 29.30it/s]


Epoch 5/10, Loss: 0.10787427425384521


100%|██████████| 8419/8419 [04:44<00:00, 29.59it/s]


Epoch 6/10, Loss: 0.11893336474895477


100%|██████████| 8419/8419 [04:48<00:00, 29.16it/s]


Epoch 7/10, Loss: 0.32407450675964355


100%|██████████| 8419/8419 [04:54<00:00, 28.61it/s]


Epoch 8/10, Loss: 0.10395871847867966


100%|██████████| 8419/8419 [04:50<00:00, 28.95it/s]


Epoch 9/10, Loss: 0.11174415796995163


100%|██████████| 8419/8419 [04:50<00:00, 28.94it/s]

Epoch 10/10, Loss: 0.10902029275894165





In [17]:
student_model_ckd_decay.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = student_model_ckd_decay(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total

print("Decay CKD Student Accuracy: {:.4f}".format(accuracy))

total_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")

100%|██████████| 109/109 [00:00<00:00, 171.13it/s]

Decay CKD Student Accuracy: 0.7901
Total parameters: 23900782





# Using Attention 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def attention_loss(student_layers, teacher_layers,mapper, projections, device):
    loss_fn = nn.MSELoss()
    total_loss = torch.tensor(0.0, device=device)
    print("total layers: ", len(teacher_layers))
    for i, (s_idx, t_indices) in enumerate(mapper.items()):
        teacher_layers = teacher_layers[i]
        combined_layer = torch.cat([teacher_layers[t_idx] for t_idx in t_indices], dim=-1)
        # Apply the projection to the combined layer
        projected_layer = projections(s_idx, combined_layer)
        # Compute the loss
        loss = loss_fn(projected_layer, student_layers[s_idx])
        total_loss += loss

    return total_loss

class ProjectionModule(nn.Module):
    def __init__(self, teacher_dims, student_dim, device):
        super(ProjectionModule, self).__init__()
        self.projections = nn.ModuleDict()
        for idx, t_dim in enumerate(teacher_dims):
            self.projections[str(idx)] = nn.Linear(t_dim, student_dim).to(device)

    def forward(self, idx, x):
        return self.projections[str(idx)](x)




In [None]:
class Attention_Object(nn.Module):
    def __init__(self, teacher_dim, student_dim, num_heads):
        super(Attention_Object, self).__init__()
        self.teacher_dim = teacher_dim
        self.student_dim = student_dim
        self.num_heads = num_heads
        
        # If the dimensions are different, include a projection layer
        if self.teacher_dim != self.student_dim:
            self.projection = nn.Linear(self.student_dim, self.teacher_dim)
        else:
            self.projection = None
        
        self.attention = nn.MultiheadAttention(self.teacher_dim, self.num_heads)

    def forward(self, student_layer, teacher_layer):
        # Assuming student_layer and teacher_layers have shapes:
        # [seq_len, batch_size, feature_dim]
    
#         if isinstance(student_layes, tuple):
#             student_layer = student_layer[0]  # Assuming the tensor you need is the first elemen
        
        # Project student features if dimensions are different
        if self.projection is not None:
            student_layer = self.projection(student_layer)
        
        # Compute attention weights and apply them to the teacher layers
#         print("teacher_layers: ", teacher_layer.size())
#         print("student_layer: ", student_layer.size())
        attn_output, attn_output_weights = self.attention(teacher_layer, student_layer, student_layer)
        proj_student_layer = student_layer
        return attn_output, attn_output_weights, proj_student_layer
    
class CustomAttentionLoss(nn.Module):
    def __init__(self):
        super(CustomAttentionLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, combined_teacher_layers, student_layers):
        # Loss for the weighted hidden states
        losses = []
        for student_layer, weighted_teacher_layer in zip(student_layers, combined_teacher_layers):
            layer_loss = self.mse_loss(student_layer, weighted_teacher_layer)
            losses.append(layer_loss.mean())  # Mean loss per layer
        final_loss = torch.mean(torch.stack(losses))
        return final_loss

In [None]:
def train_student_with_attention_decay(student, teacher, train_loader, device, temperature=5.0, alpha=0.5, beta = 0.5, num_epochs=3):
    d_student = 540  
    d_teacher = 768 
    
    lambda_factor = 0.8
    Attention_Obj = Attention_Object(d_teacher,d_student,d_student)
    attention_loss_function = CustomAttentionLoss()
    Attention_Obj.to(device)
    optimizer = AdamW(list(student.parameters()) + list(Attention_Obj.parameters()), lr=2e-5)
    teacher.eval()
    for epoch in range(num_epochs):
        student.train()
        current_alpha = adjust_loss_alpha(alpha, epoch, num_epochs, lambda_factor)
        for batch in tqdm(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            # Get outputs with hidden states
            student_logits, student_layers = student(input_ids, attention_mask, output_hidden_states=True)
            soft_labels = nn.functional.log_softmax(student_logits / temperature, dim=1)

            with torch.no_grad():
                teacher_outputs, teacher_layers = teacher(input_ids, attention_mask=attention_mask, output_hidden_states=True)
                soft_targets = nn.functional.softmax(teacher_outputs / temperature, dim=1)
        
            soft_loss = nn.KLDivLoss(reduction='batchmean')(soft_labels, soft_targets) * (temperature ** 2)
            attention_loss_val_temp = 0
            for student_layer in student_layers:
                
                sum_hidden = None
                for teacher_layer in teacher_layers:                    
                    attn_output, attn_output_weights, proj_student_layer = Attention_Obj(student_layer,teacher_layer)
                    if(sum_hidden == None):
                        sum_hidden =attn_output
                    else:
                        sum_hidden+=attn_output                    
                attention_loss_val_temp += attention_loss_function(sum_hidden, proj_student_layer)  
            attention_loss_val = attention_loss_val_temp 

            loss = current_alpha * attention_loss_val + beta * nn.CrossEntropyLoss()(student_logits, labels) + 0.1 * soft_loss

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

        student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(test_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                logits = student(input_ids, attention_mask)
                predicted = torch.argmax(logits, dim=1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total

        print("Validation Accuracy: {:.4f}".format(accuracy))

        total_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params}")


In [None]:
student_model_attention = BertForSequenceClassification(config)
student_model_attention.to(device)
teacher_model.to(device)
num_epochs = 10
train_student_with_attention_decay(student_model_attention, teacher_model, train_loader, device, temperature=3.0, alpha=0.5, beta = 0.5, num_epochs=num_epochs)

In [None]:
student_model_attention.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = student_model_attention(input_ids, attention_mask)
        predicted = torch.argmax(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total

print("Attention Accuracy: {:.4f}".format(accuracy))

total_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")