In [1]:
# Import Libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertForSequenceClassification, DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import time
from sklearn.metrics import classification_report, accuracy_score
import os
import torch.quantization
import torch.nn.utils.prune as prune

In [2]:
import transformers

In [3]:
PRUNING_RATE = 0.3  # Set the pruning rate as needed

In [4]:
# Constants
tEPOCHS = 8
EPOCHS = 10  # Set distillation epochs to 1 for demonstration purposes


# Load the dataset
data = pd.read_excel('/kaggle/input/498r-umbec/UBMEC.xlsx')
#data = data.sample(frac=1, random_state=42).reset_index(drop=True)
#data = data.tail(1000)

# Split the data
X = data['text'].astype(str).tolist()
y = data['classes']
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Device configuration
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device: {device}')

# Utility functions
def encode_texts(texts, tokenizer, max_len=512):
    return tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')

def calculate_accuracy(preds, labels):
    preds_argmax = torch.argmax(preds, dim=1)
    correct = (preds_argmax == labels).float()
    accuracy = correct.sum() / len(correct)
    return accuracy.item()

# Train Teacher Model (BERT)
#teacher_name = 'bert-base-multilingual-cased'
#tokenizer = BertTokenizer.from_pretrained(teacher_name)
#teacher_model = BertForSequenceClassification.from_pretrained(teacher_name, num_labels=6)

# Load model directly
#from transformers import AutoTokenizer, AutoModelForPreTraining

#model = AutoModelForPreTraining.from_pretrained("csebuetnlp/banglabert")

Using device: cuda


In [5]:
# Load Teacher Model (BERT)
teacher_name = 'csebuetnlp/banglabert'
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=6)
teacher_model.to(device)

# Checkpoint path for the teacher model
checkpoint_path = 'best_teacher_model.pth'

""""
# Load best teacher model checkpoint if it exists
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    teacher_model.load_state_dict(checkpoint['model_state_dict'])
    best_val_acc = checkpoint['val_acc']
    print(f"Loaded best teacher model checkpoint with val_acc: {best_val_acc*100:.4f}%")
"""
train_encodings = encode_texts(X_train, tokenizer)
test_encodings = encode_texts(X_test, tokenizer)

y_train = torch.tensor(y_train)
y_test = torch.tensor(y_test)

train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], y_train)
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], y_test)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(teacher_model.parameters(), lr=1e-5)

# Training function for Teacher Model
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    
    for input_ids, attention_mask, labels in iterator:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        loss = criterion(logits, labels)
        acc = calculate_accuracy(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Evaluation function for Teacher Model
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for input_ids, attention_mask, labels in iterator:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            loss = criterion(logits, labels)
            acc = calculate_accuracy(logits, labels)

            all_preds.extend(logits.argmax(dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            epoch_loss += loss.item()
            epoch_acc += acc
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator), all_preds, all_labels

# Initialize the best validation accuracy variables
best_val_acc = 0.0
best_val_acc_student = 0.0

# Early stopping parameters
patience = 3
patience_counter = 0

tokenizer_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/528k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at csebuetnlp/banglabert and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Training Teacher Model with Checkpointing and Early Stopping
for epoch in range(tEPOCHS):
    start_time = time.time()
    
    train_loss, train_acc = train(teacher_model, train_loader, optimizer, criterion)
    val_loss, val_acc, all_preds, all_labels = evaluate(teacher_model, test_loader, criterion)
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    print(f'Epoch: {epoch+1:02}/{tEPOCHS} | Epoch Time: {int(epoch_mins)}m {int(epoch_secs)}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.4f}%')
    print(f'\tVal Loss: {val_loss:.3f} | Val Acc: {val_acc*100:.4f}%')
    
    # Checkpointing for the best teacher model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': teacher_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f'Saved new best teacher model with val_acc: {val_acc*100:.4f}%')
        patience_counter = 0  # Reset patience counter
    else:
        patience_counter += 1

    # Early stopping
    if patience_counter >= patience:
        print("Early stopping triggered")
        break


Epoch: 01/8 | Epoch Time: 10m 11s
	Train Loss: 1.321 | Train Acc: 49.8419%
	Val Loss: 1.112 | Val Acc: 58.4821%
Saved new best teacher model with val_acc: 58.4821%
Epoch: 02/8 | Epoch Time: 10m 11s
	Train Loss: 0.962 | Train Acc: 64.9926%
	Val Loss: 1.066 | Val Acc: 61.1979%
Saved new best teacher model with val_acc: 61.1979%
Epoch: 03/8 | Epoch Time: 10m 11s
	Train Loss: 0.733 | Train Acc: 74.2001%
	Val Loss: 1.110 | Val Acc: 61.7932%
Saved new best teacher model with val_acc: 61.7932%
Epoch: 04/8 | Epoch Time: 10m 11s
	Train Loss: 0.544 | Train Acc: 81.2779%
	Val Loss: 1.207 | Val Acc: 61.7560%
Epoch: 05/8 | Epoch Time: 10m 11s
	Train Loss: 0.379 | Train Acc: 87.5279%
	Val Loss: 1.483 | Val Acc: 59.7470%
Epoch: 06/8 | Epoch Time: 10m 10s
	Train Loss: 0.259 | Train Acc: 91.6760%
	Val Loss: 1.550 | Val Acc: 59.7470%
Early stopping triggered


In [7]:
# Generate classification report for Teacher Model
class_names = ['joy', 'disgust', 'anger', 'sadness', 'surprise', 'fear']
teacher_model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
_, _, all_preds_teacher, all_labels_teacher = evaluate(teacher_model, test_loader, criterion)
report_teacher = classification_report(all_labels_teacher, all_preds_teacher, target_names=class_names, digits=4)
print("Teacher Model Classification Report:\n", report_teacher)
# Generate classification report for Teacher Model

#report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
#print("Teacher Model Classification Report:\n", report)


Teacher Model Classification Report:
               precision    recall  f1-score   support

         joy     0.6235    0.5397    0.5786       491
     disgust     0.4115    0.4639    0.4362       416
       anger     0.7609    0.5091    0.6100       275
     sadness     0.7952    0.8353    0.8148       674
    surprise     0.5206    0.6765    0.5884       541
        fear     0.6734    0.4605    0.5469       291

    accuracy                         0.6179      2688
   macro avg     0.6309    0.5808    0.5958      2688
weighted avg     0.6325    0.6179    0.6175      2688



In [8]:
# Pruning Code
def prune_attention(module, amount):
    for name in ['query', 'key', 'value', 'dense']:
        layer = getattr(module, name)
        prune.l1_unstructured(layer, name='weight', amount=amount)
        if layer.bias is not None:
            prune.l1_unstructured(layer, name='bias', amount=amount)

def prune_model(model, pruning_rate=0.3):
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)
            if module.bias is not None:
                prune.l1_unstructured(module, name='bias', amount=pruning_rate)
        elif isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
            prune_attention(module, pruning_rate)

In [9]:
# Apply Pruning and Inspect Masks
for module in teacher_model.modules():
    if isinstance(module, torch.nn.Linear) or isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
        prune.l1_unstructured(module, name='weight', amount=PRUNING_RATE)
        print(f"Pruning applied to {module}. Non-zero elements after pruning: {torch.count_nonzero(module.weight)}")


Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=3072, bias=True). Non-zero elements after pruning: 1651507
Pruning applied to Linear(in_features=3072, out_features=768, bias=True). Non-zero elements after pruning: 1651507
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements after pruning: 412877
Pruning applied to Linear(in_features=768, out_features=768, bias=True). Non-zero elements a

In [10]:
# Remove pruning re-parametrization for cleaner model saving
def safe_remove_pruning(module, param_name):
    try:
        prune.remove(module, param_name)
    except ValueError as e:
        print(f"Skipping removal of pruning for {param_name} in {module}: {e}")

for module in teacher_model.modules():
    if isinstance(module, torch.nn.Linear) or isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
        safe_remove_pruning(module, 'weight')
        if module.bias is not None:
            safe_remove_pruning(module, 'bias')


Skipping removal of pruning for bias in Linear(in_features=768, out_features=768, bias=True): Parameter 'bias' of module Linear(in_features=768, out_features=768, bias=True) has to be pruned before pruning can be removed
Skipping removal of pruning for bias in Linear(in_features=768, out_features=768, bias=True): Parameter 'bias' of module Linear(in_features=768, out_features=768, bias=True) has to be pruned before pruning can be removed
Skipping removal of pruning for bias in Linear(in_features=768, out_features=768, bias=True): Parameter 'bias' of module Linear(in_features=768, out_features=768, bias=True) has to be pruned before pruning can be removed
Skipping removal of pruning for bias in Linear(in_features=768, out_features=768, bias=True): Parameter 'bias' of module Linear(in_features=768, out_features=768, bias=True) has to be pruned before pruning can be removed
Skipping removal of pruning for bias in Linear(in_features=768, out_features=3072, bias=True): Parameter 'bias' of m

In [11]:
# Evaluate the pruned model
teacher_model.to(device)
_, _, all_preds_pruned, all_labels_pruned = evaluate(teacher_model, test_loader, criterion)

# Generate classification reports
report_pruned = classification_report(all_labels_pruned, all_preds_pruned, target_names=class_names, digits=4)
print("Pruned Model Classification Report:\n", report_pruned)


Pruned Model Classification Report:
               precision    recall  f1-score   support

         joy     0.5915    0.5662    0.5786       491
     disgust     0.4136    0.4832    0.4457       416
       anger     0.6652    0.5636    0.6102       275
     sadness     0.7912    0.8264    0.8084       674
    surprise     0.5449    0.6396    0.5884       541
        fear     0.7125    0.3918    0.5055       291

    accuracy                         0.6142      2688
   macro avg     0.6198    0.5785    0.5895      2688
weighted avg     0.6253    0.6142    0.6130      2688



# Calculating size

In [12]:
# Model size and parameter count functions
def get_model_size_and_params(model):
    param_count = sum(p.numel() for p in model.parameters())
    temp_model_path = "temp_model.pth"
    torch.save(model.state_dict(), temp_model_path)
    model_size = os.path.getsize(temp_model_path)
    os.remove(temp_model_path)
    return param_count, model_size

# Calculate size and param count before pruning
param_count_before_pruning, model_size_before_pruning = get_model_size_and_params(teacher_model)
print(f"Model - Params: {param_count_before_pruning}, Size: {model_size_before_pruning / 1e6:.2f}MB")

# Calculate size and param count after pruning
param_count_after_pruning, model_size_after_pruning = get_model_size_and_params(teacher_model)
print(f"Pruned Model - Params: {param_count_after_pruning}, Size: {model_size_after_pruning / 1e6:.2f}MB")


Model - Params: 110621958, Size: 442.57MB
Pruned Model - Params: 110621958, Size: 442.57MB


In [13]:
# Count non-zero parameters
def count_nonzero_weights(model):
    nonzero_count = 0
    total_count = 0
    for param in model.parameters():
        nonzero_count += torch.count_nonzero(param)
        total_count += param.numel()
    return nonzero_count, total_count

nonzero_count, total_count = count_nonzero_weights(teacher_model)
print(f"Non-zero parameters after pruning: {nonzero_count}/{total_count} ({nonzero_count/total_count:.2%})")

Non-zero parameters after pruning: 84963237/110621958 (76.81%)


The size of the pruned model, based on the 84,963,237 non-zero parameters, is approximately 324.11 MB. ​