In [10]:
# Import Libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertForSequenceClassification, DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import time
from sklearn.metrics import classification_report, accuracy_score
import os
import torch.quantization
import torch.nn.utils.prune as prune


In [11]:
# Constants
tEPOCHS = 8
EPOCHS = 10  # Set distillation epochs to 1 for demonstration purposes


# Load the dataset
data = pd.read_excel('/kaggle/input/ubmec-498r/UBMEC.xlsx')
#data = data.sample(frac=1, random_state=42).reset_index(drop=True)
#data = data.tail(1000)

# Split the data
X = data['text'].astype(str).tolist()
y = data['classes']
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Device configuration
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device: {device}')

# Utility functions
def encode_texts(texts, tokenizer, max_len=512):
    return tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')

def calculate_accuracy(preds, labels):
    preds_argmax = torch.argmax(preds, dim=1)
    correct = (preds_argmax == labels).float()
    accuracy = correct.sum() / len(correct)
    return accuracy.item()

# Train Teacher Model (BERT)
#teacher_name = 'bert-base-multilingual-cased'
#tokenizer = BertTokenizer.from_pretrained(teacher_name)
#teacher_model = BertForSequenceClassification.from_pretrained(teacher_name, num_labels=6)

# Load model directly
#from transformers import AutoTokenizer, AutoModelForPreTraining

#model = AutoModelForPreTraining.from_pretrained("csebuetnlp/banglabert")

Using device: cuda


In [12]:
# Load Teacher Model (BERT)
teacher_name = 'bert-base-multilingual-uncased'
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=6)
teacher_model.to(device)

# Checkpoint path for the teacher model
checkpoint_path = 'best_teacher_model.pth'

""""
# Load best teacher model checkpoint if it exists
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    teacher_model.load_state_dict(checkpoint['model_state_dict'])
    best_val_acc = checkpoint['val_acc']
    print(f"Loaded best teacher model checkpoint with val_acc: {best_val_acc*100:.4f}%")
"""
train_encodings = encode_texts(X_train, tokenizer)
test_encodings = encode_texts(X_test, tokenizer)

y_train = torch.tensor(y_train)
y_test = torch.tensor(y_test)

train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], y_train)
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], y_test)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(teacher_model.parameters(), lr=1e-5)

# Training function for Teacher Model
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    
    for input_ids, attention_mask, labels in iterator:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        loss = criterion(logits, labels)
        acc = calculate_accuracy(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Evaluation function for Teacher Model
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for input_ids, attention_mask, labels in iterator:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            loss = criterion(logits, labels)
            acc = calculate_accuracy(logits, labels)

            all_preds.extend(logits.argmax(dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            epoch_loss += loss.item()
            epoch_acc += acc
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator), all_preds, all_labels

# Initialize the best validation accuracy variables
best_val_acc = 0.0
best_val_acc_student = 0.0

# Early stopping parameters
patience = 3
patience_counter = 0

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Training Teacher Model with Checkpointing and Early Stopping
for epoch in range(tEPOCHS):
    start_time = time.time()
    
    train_loss, train_acc = train(teacher_model, train_loader, optimizer, criterion)
    val_loss, val_acc, all_preds, all_labels = evaluate(teacher_model, test_loader, criterion)
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    print(f'Epoch: {epoch+1:02}/{tEPOCHS} | Epoch Time: {int(epoch_mins)}m {int(epoch_secs)}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.4f}%')
    print(f'\tVal Loss: {val_loss:.3f} | Val Acc: {val_acc*100:.4f}%')
    
    # Checkpointing for the best teacher model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': teacher_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f'Saved new best teacher model with val_acc: {val_acc*100:.4f}%')
        patience_counter = 0  # Reset patience counter
    else:
        patience_counter += 1

    # Early stopping
    if patience_counter >= patience:
        print("Early stopping triggered")
        break


Epoch: 01/8 | Epoch Time: 10m 39s
	Train Loss: 1.478 | Train Acc: 40.8575%
	Val Loss: 1.291 | Val Acc: 50.3348%
Saved new best teacher model with val_acc: 50.3348%
Epoch: 02/8 | Epoch Time: 10m 38s
	Train Loss: 1.181 | Train Acc: 55.2362%
	Val Loss: 1.203 | Val Acc: 54.1667%
Saved new best teacher model with val_acc: 54.1667%
Epoch: 03/8 | Epoch Time: 10m 42s
	Train Loss: 0.986 | Train Acc: 63.4487%
	Val Loss: 1.204 | Val Acc: 55.4688%
Saved new best teacher model with val_acc: 55.4688%
Epoch: 04/8 | Epoch Time: 10m 39s
	Train Loss: 0.816 | Train Acc: 70.1544%
	Val Loss: 1.269 | Val Acc: 54.6503%
Epoch: 05/8 | Epoch Time: 10m 39s
	Train Loss: 0.653 | Train Acc: 77.0554%
	Val Loss: 1.383 | Val Acc: 54.9851%
Epoch: 06/8 | Epoch Time: 10m 39s
	Train Loss: 0.529 | Train Acc: 80.7850%
	Val Loss: 1.492 | Val Acc: 54.7619%
Early stopping triggered


In [14]:
# Generate classification report for Teacher Model
class_names = ['joy', 'disgust', 'anger', 'sadness', 'surprise', 'fear']
teacher_model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
_, _, all_preds_teacher, all_labels_teacher = evaluate(teacher_model, test_loader, criterion)
report_teacher = classification_report(all_labels_teacher, all_preds_teacher, target_names=class_names, digits=4)
print("Teacher Model Classification Report:\n", report_teacher)
# Generate classification report for Teacher Model

#report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
#print("Teacher Model Classification Report:\n", report)


Teacher Model Classification Report:
               precision    recall  f1-score   support

         joy     0.5422    0.5234    0.5326       491
     disgust     0.4856    0.2428    0.3237       416
       anger     0.6064    0.5491    0.5763       275
     sadness     0.7384    0.7537    0.7460       674
    surprise     0.4178    0.6580    0.5111       541
        fear     0.5438    0.4055    0.4646       291

    accuracy                         0.5547      2688
   macro avg     0.5557    0.5221    0.5257      2688
weighted avg     0.5643    0.5547    0.5466      2688



In [33]:
# Load the best teacher model for distillation
#teacher_model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

# Define a custom distillation model
class DistillationModel(nn.Module):
    def __init__(self, student_model, teacher_model, temperature=2, alpha=0.5):
        super(DistillationModel, self).__init__()
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, input_ids, attention_mask):
        # Compute student logits
        student_outputs = self.student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits
        
        # Compute teacher logits (no gradients needed)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits
        
        return student_logits, teacher_logits

# Training and evaluation functions for Student Model
def train_student(model, iterator, optimizer, criterion, temperature=2, alpha=0.5):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    
    for input_ids, attention_mask, labels in iterator:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        student_logits, teacher_logits = model(input_ids, attention_mask=attention_mask)
        
        student_loss = criterion(student_logits, labels)
        distillation_loss = nn.KLDivLoss()(nn.functional.log_softmax(student_logits / temperature, dim=1),
                                           nn.functional.softmax(teacher_logits / temperature, dim=1)) * (temperature * temperature)
        loss = alpha * student_loss + (1 - alpha) * distillation_loss
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += calculate_accuracy(student_logits, labels)

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate_student(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    all_preds_student = []
    all_labels_student = []
    
    with torch.no_grad():
        for input_ids, attention_mask, labels in iterator:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            
            student_logits, _ = model(input_ids, attention_mask=attention_mask)
            
            loss = criterion(student_logits, labels)
            acc = calculate_accuracy(student_logits, labels)

            all_preds_student.extend(student_logits.argmax(dim=1).cpu().numpy())
            all_labels_student.extend(labels.cpu().numpy())
            
            epoch_loss += loss.item()
            epoch_acc += acc
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator), all_preds_student, all_labels_student

# Load Student Model (DistilBERT with Sequence Classification Head)
student_name = 'h4g3n/distilbert-mini-multilingual-cased'
student_tokenizer = AutoTokenizer.from_pretrained(student_name)
student_model = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=6)

student_model.to(device)

distillation_model = DistillationModel(student_model, teacher_model)
distillation_model.to(device)

optimizer_student = optim.AdamW(distillation_model.student_model.parameters(), lr=5e-5)

# Checkpoint path for the student model
checkpoint_path_student = 'best_student_model.pth'


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at h4g3n/distilbert-mini-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [35]:
# Training Student Model with Checkpointing and Early Stopping
best_val_acc_student = 0
for epoch in range(EPOCHS):
    start_time = time.time()
    
    train_loss, train_acc = train_student(distillation_model, train_loader, optimizer_student, criterion, temperature=2, alpha=0.5)
    val_loss, val_acc, all_preds_student, all_labels_student = evaluate_student(distillation_model, test_loader, criterion)
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    print(f'Epoch: {epoch+1:02}/{EPOCHS} | Epoch Time: {int(epoch_mins)}m {int(epoch_secs)}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.4f}%')
    print(f'\tVal Loss: {val_loss:.3f} | Val Acc: {val_acc*100:.4f}%')
    
    # Checkpointing for the best student model
    
    if val_acc > best_val_acc_student:
        best_val_acc_student = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': distillation_model.student_model.state_dict(),
            'optimizer_state_dict': optimizer_student.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, checkpoint_path_student)
        print(f'Saved new best student model with val_acc: {val_acc*100:.4f}%')
        patience_counter = 0  # Reset patience counter
    else:
        patience_counter += 1

    # Early stopping
    if patience_counter >= patience:
        print("Early stopping triggered")
        break


Epoch: 01/10 | Epoch Time: 4m 34s
	Train Loss: 0.792 | Train Acc: 40.8854%
	Val Loss: 1.475 | Val Acc: 39.5833%
Saved new best student model with val_acc: 39.5833%
Epoch: 02/10 | Epoch Time: 4m 33s
	Train Loss: 0.752 | Train Acc: 43.9174%
	Val Loss: 1.462 | Val Acc: 41.8155%
Saved new best student model with val_acc: 41.8155%
Epoch: 03/10 | Epoch Time: 4m 33s
	Train Loss: 0.720 | Train Acc: 46.7820%
	Val Loss: 1.461 | Val Acc: 42.5595%
Saved new best student model with val_acc: 42.5595%
Epoch: 04/10 | Epoch Time: 4m 33s
	Train Loss: 0.688 | Train Acc: 49.9535%
	Val Loss: 1.463 | Val Acc: 42.2619%
Saved new best student model with val_acc: 42.2619%
Epoch: 05/10 | Epoch Time: 4m 33s
	Train Loss: 0.652 | Train Acc: 53.0878%
	Val Loss: 1.444 | Val Acc: 44.6801%
Saved new best student model with val_acc: 44.6801%
Epoch: 06/10 | Epoch Time: 4m 33s
	Train Loss: 0.615 | Train Acc: 56.2593%
	Val Loss: 1.463 | Val Acc: 45.3125%
Saved new best student model with val_acc: 45.3125%
Epoch: 07/10 | E

KeyboardInterrupt: 

In [36]:
# Load the best teacher and student model checkpoints for evaluation
teacher_model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
distillation_model.student_model.load_state_dict(torch.load(checkpoint_path_student)['model_state_dict'])

# Generate evaluation reports for the best models
_, _, all_preds_teacher, all_labels_teacher = evaluate(teacher_model, test_loader, criterion)
_, _, all_preds_student, all_labels_student = evaluate_student(distillation_model, test_loader, criterion)

# Generate classification reports for both models
report_teacher = classification_report(all_labels_teacher, all_preds_teacher, target_names=class_names, digits=4)
report_student = classification_report(all_labels_student, all_preds_student, target_names=class_names, digits=4)

print("Teacher Model Classification Report:\n", report_teacher)
print("Distilled Student Model Classification Report:\n", report_student)

Teacher Model Classification Report:
               precision    recall  f1-score   support

         joy     0.5422    0.5234    0.5326       491
     disgust     0.4856    0.2428    0.3237       416
       anger     0.6064    0.5491    0.5763       275
     sadness     0.7384    0.7537    0.7460       674
    surprise     0.4178    0.6580    0.5111       541
        fear     0.5438    0.4055    0.4646       291

    accuracy                         0.5547      2688
   macro avg     0.5557    0.5221    0.5257      2688
weighted avg     0.5643    0.5547    0.5466      2688

Distilled Student Model Classification Report:
               precision    recall  f1-score   support

         joy     0.4080    0.5010    0.4497       491
     disgust     0.3152    0.2500    0.2788       416
       anger     0.4753    0.3855    0.4257       275
     sadness     0.5554    0.7062    0.6218       674
    surprise     0.4481    0.3512    0.3938       541
        fear     0.3506    0.3024    0.3247   

# Calculating size

In [37]:
import os

def get_model_size_and_params(model):
    param_count = sum(p.numel() for p in model.parameters())
    temp_model_path = "temp_model.pth"
    torch.save(model.state_dict(), temp_model_path)
    model_size = os.path.getsize(temp_model_path)
    os.remove(temp_model_path)
    return param_count, model_size

# Calculate size and param count for the distilled student model
#param_count_student, model_size_student = get_model_size_and_params(distillation_model.student_model)
#print(f"Distilled Student Model - Params: {param_count_student}, Size: {model_size_student / 1e6:.2f}MB")

In [38]:
# Save the distilled student model
#student_model_save_path = "distilled_student_model.pth"
#torch.save(distillation_model.student_model.state_dict(), student_model_save_path)

# Verify the saved model size
#model_size_student = os.path.getsize(student_model_save_path)
#print(f"Distilled Student Model Size: {model_size_student / 1e6:.2f} MB")


In [39]:
# Calculate size and param count before distillation for the student model
param_count_before_distillation, model_size_before_distillation = get_model_size_and_params(teacher_model)
print(f"Original Teacher Model - Params: {param_count_before_distillation}, Size: {model_size_before_distillation / 1e6:.2f}MB")

# Calculate size and param count after distillation for the student model
param_count_after_distillation, model_size_after_distillation = get_model_size_and_params(distillation_model.student_model)
print(f"Distilled Student Model - Params: {param_count_after_distillation}, Size: {model_size_after_distillation / 1e6:.2f}MB")


Original Teacher Model - Params: 167361030, Size: 669.53MB
Distilled Student Model - Params: 52164870, Size: 208.68MB


In [24]:
# Pruning Code
import torch.nn.utils.prune as prune
import transformers

def prune_attention(module, amount):
    for name in ['query', 'key', 'value', 'dense']:
        layer = getattr(module, name)
        prune.l1_unstructured(layer, name='weight', amount=amount)
        if layer.bias is not None:
            prune.l1_unstructured(layer, name='bias', amount=amount)

def prune_model(model, pruning_rate=0.3):
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)
            if module.bias is not None:
                prune.l1_unstructured(module, name='bias', amount=pruning_rate)
        elif isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
            prune_attention(module, pruning_rate)

In [25]:
# Apply Pruning to the Distilled Student Model
pruning_rate = 0.3  # Adjust pruning rate as needed
prune_model(distillation_model.student_model, pruning_rate)


In [26]:
# Remove pruning re-parametrization for cleaner model saving
def safe_remove_pruning(module, param_name):
    try:
        prune.remove(module, param_name)
    except ValueError as e:
        print(f"Skipping removal of pruning for {param_name} in {module}: {e}")

for module in distillation_model.student_model.modules():
    if isinstance(module, torch.nn.Linear) or isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
        safe_remove_pruning(module, 'weight')
        if module.bias is not None:
            safe_remove_pruning(module, 'bias')

# Delete zeroed parameters to reduce model size
def delete_zero_parameters(model):
    for name, param in list(model.named_parameters()):
        mask = param != 0
        if not mask.any():
            delattr(model, name)
            
def count_nonzero_weights(model):
    nonzero_count = 0
    total_count = 0
    for param in model.parameters():
        nonzero_count += param.nonzero().size(0)
        total_count += param.numel()
    return nonzero_count, total_count

delete_zero_parameters(distillation_model.student_model)




In [53]:
# Evaluate the Pruned Student Model
_, _, all_preds_pruned_student, all_labels_pruned_student = evaluate_student(distillation_model, test_loader, criterion)

# Generate classification report for the Pruned Student Model
report_pruned_student = classification_report(all_labels_pruned_student, all_preds_pruned_student, target_names=class_names, digits=4)
print("Pruned Distilled Student Model Classification Report:\n", report_pruned_student)

# Calculate size and param count after pruning for the student model
param_count_after_pruning_student, model_size_after_pruning_student = get_model_size_and_params(distillation_model.student_model)
print(f"Pruned Distilled Student Model - Params: {param_count_after_pruning_student}, Size: {model_size_after_pruning_student / 1e6:.2f}MB")

# Count non-zero parameters
nonzero_count_pruned_student, total_count_pruned_student = count_nonzero_weights(distillation_model.student_model)
print(f"Non-zero parameters after pruning in Distilled Student Model: {nonzero_count_pruned_student}/{total_count_pruned_student} ({nonzero_count_pruned_student/total_count_pruned_student:.2%})")

Pruned Distilled Student Model Classification Report:
               precision    recall  f1-score   support

         joy     0.4019    0.5051    0.4477       491
     disgust     0.2956    0.2260    0.2561       416
       anger     0.4730    0.3818    0.4225       275
     sadness     0.5713    0.6958    0.6274       674
    surprise     0.4328    0.3752    0.4020       541
        fear     0.3610    0.2990    0.3271       291

    accuracy                         0.4487      2688
   macro avg     0.4226    0.4138    0.4138      2688
weighted avg     0.4370    0.4487    0.4383      2688

Pruned Distilled Student Model - Params: 52164870, Size: 208.68MB
Non-zero parameters after pruning in Distilled Student Model: 50323476/52164870 (96.47%)


In [51]:
import torch
import torch.nn as nn

# Apply dynamic quantization to the pruned model
quantized_model = torch.quantization.quantize_dynamic(
    distillation_model.student_model,  # the pruned model
    {nn.Linear},  # layers to quantize
    dtype=torch.qint8  # quantization data type
)

# Function to calculate model size
def calculate_model_size(model):
    torch.save(model.state_dict(), "temp.pth")
    model_size = os.path.getsize("temp.pth")
    os.remove("temp.pth")
    return model_size

# Calculate size and number of parameters after quantization
param_count_after_quantization = sum(p.numel() for p in quantized_model.parameters())
model_size_after_quantization = calculate_model_size(quantized_model)

print(f"Quantized Pruned Model - Params: {param_count_after_quantization}, Size: {model_size_after_quantization / 1e6:.2f}MB")

# Count non-zero parameters after quantization
nonzero_count_quantized, total_count_quantized = count_nonzero_weights(quantized_model)
print(f"Non-zero parameters after quantization in Pruned Model: {nonzero_count_quantized}/{total_count_quantized} ({nonzero_count_quantized/total_count_quantized:.2%})")


Quantized Pruned Model - Params: 46106496, Size: 190.54MB
Non-zero parameters after quantization in Pruned Model: 46106111/46106496 (100.00%)
