# Knowledge Distillation for COVID Tweet Sentiment Analysis

This notebook implements knowledge distillation to compress a BERTweet sentiment analysis model into a smaller DistilRoBERTa student model.

## Distillation Setup:
- **Teacher Model**: `finiteautomata/bertweet-base-sentiment-analysis` (frozen)
- **Student Model**: `distilroberta-base` with classification head
- **Loss Function**: Combined hard target (CrossEntropy) + soft target (KL Divergence) loss
- **Temperature**: T ≈ 2.0 for soft target computation
- **Alpha**: α ≈ 0.5 for loss weighting


## Environment Setup


In [70]:
import os

# Set environment variables to force CPU usage and disable MPS
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

print("Environment configured for CPU-only execution")
print("Please restart the kernel and run all cells from the beginning")


Environment configured for CPU-only execution
Please restart the kernel and run all cells from the beginning


## Installations


In [71]:
%pip install -q evaluate
%pip install -q emoji==0.6.0
%pip install -q torch
%pip install -q transformers
%pip install -q accelerate
%pip install -q wandb


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


## Imports


In [72]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch import nn
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, cohen_kappa_score
from transformers import (
    TrainingArguments,
    Trainer,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainerCallback,
    EarlyStoppingCallback
)
import evaluate
import wandb
import time
from datetime import datetime


## Device Setup


In [73]:
# Smart device selection for cross-platform GPU support
def get_optimal_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA GPU: {torch.cuda.get_device_name()}")
        return device
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using Apple Silicon MPS")
        return device
    else:
        device = torch.device("cpu")
        print("Using CPU")
        return device

device = get_optimal_device()

Using Apple Silicon MPS


In [74]:
# W&B login
print('Logging to wandb.ai account')
wandb.login(key="6dd13a6018f089606e418d323dd8b502f31bca4e")


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/taltzafrir/.netrc


Logging to wandb.ai account


True

## Data Loading and Preprocessing


In [75]:
# Read raw data
train = pd.read_csv("OOT_train.csv", encoding='latin-1')
val = pd.read_csv("OOT_val.csv", encoding='latin-1')
test = pd.read_csv("OOT_test.csv", encoding='latin-1')


### Label Encoding


In [76]:
# Encoding the labels numerically from Sentiment
ordinal_mapping = {
    'Extremely Negative': 0,
    'Negative': 1,
    'Neutral': 2,
    'Positive': 3,
    'Extremely Positive': 4
}

# Map to ordinal labels
train["ordinal_label_id"] = train["Sentiment"].map(ordinal_mapping)
val["ordinal_label_id"] = val["Sentiment"].map(ordinal_mapping)
test["ordinal_label_id"] = test["Sentiment"].map(ordinal_mapping)

# Define mapping between label id and sentiment for later use
ordinal_label2id = ordinal_mapping
ordinal_id2label = {v: k for k, v in ordinal_mapping.items()}


### Augmented Input Building


In [77]:
# Function to build the input string from multiple columns
def build_augmented_input(row):
    parts = []

    if pd.notna(row.get('clean_tweet')):
        parts.append(f"{row['clean_tweet']}")

    if pd.notna(row.get('Location_standardized')) and row['Location_standardized'].lower() != 'unknown':
        parts.append(f"{row['Location_standardized']}")

    if pd.notna(row.get('TweetAt')):
        parts.append(f"{row['TweetAt']}")

    return " | ".join(parts)

# Apply to the DataFrames
train['model_input'] = train.apply(build_augmented_input, axis=1)
val['model_input'] = val.apply(build_augmented_input, axis=1)
test['model_input'] = test.apply(build_augmented_input, axis=1)

# Create new DataFrames with only what's needed for modeling
formatted_train = train[['model_input', 'ordinal_label_id']].copy()
formatted_val = val[['model_input', 'ordinal_label_id']].copy()
formatted_test = test[['model_input', 'ordinal_label_id']].copy()


### Dataset Balancing


In [78]:
def balance_dataset(df, target_samples_per_class=5000):
    """Balance dataset by undersampling"""
    balanced_dfs = []

    print("Original class distribution:")
    print(df['ordinal_label_id'].value_counts().sort_index())

    for class_id in range(5):
        class_data = df[df['ordinal_label_id'] == class_id]

        if len(class_data) > target_samples_per_class:
            class_data = class_data.sample(n=target_samples_per_class, random_state=42)
            print(f"Class {class_id}: {len(class_data)} samples (undersampled)")
        else:
            print(f"Class {class_id}: {len(class_data)} samples (kept all)")

        balanced_dfs.append(class_data)

    balanced_df = pd.concat(balanced_dfs, ignore_index=True).sample(frac=1, random_state=42)

    print(f"Balanced dataset: {len(balanced_df)} total samples")
    print("New distribution:")
    print(balanced_df['ordinal_label_id'].value_counts().sort_index())

    return balanced_df

# Apply balancing to training data
formatted_train = balance_dataset(formatted_train, target_samples_per_class=5000)


Original class distribution:
ordinal_label_id
0     5175
1     9230
2     6784
3    10140
4     5845
Name: count, dtype: int64
Class 0: 5000 samples (undersampled)
Class 1: 5000 samples (undersampled)
Class 2: 5000 samples (undersampled)
Class 3: 5000 samples (undersampled)
Class 4: 5000 samples (undersampled)
Balanced dataset: 25000 total samples
New distribution:
ordinal_label_id
0    5000
1    5000
2    5000
3    5000
4    5000
Name: count, dtype: int64


## Model Setup

### Teacher Model (BERTweet) - Frozen
### Student Model (DistilRoBERTa) - Trainable


In [79]:
# Teacher model configuration
MODEL_TYPE = 'bert'
teacher_model_name = "finiteautomata/bertweet-base-sentiment-analysis"
teacher_model_path = "./Full model/bert/best_bert_model_so_far"
teacher_model_file = "model_bert.pt"

# Student model configuration
student_model_name = "distilroberta-base"

print(f"Teacher Model: {teacher_model_name}")
print(f"Student Model: {student_model_name}")

Teacher Model: finiteautomata/bertweet-base-sentiment-analysis
Student Model: distilroberta-base


### Load Teacher Model


In [80]:
# Load the best teacher model from hyperparameter tuning
print("Loading teacher model from best checkpoint...")
teacher_model = torch.load(os.path.join(teacher_model_path, teacher_model_file), map_location=device)
teacher_model.eval()  # Set to evaluation mode
teacher_model.to(device)

# Freeze teacher model parameters
for param in teacher_model.parameters():
    param.requires_grad = False

print("Teacher model loaded and frozen!")

# Load teacher tokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)


Loading teacher model from best checkpoint...
Teacher model loaded and frozen!


### Initialize Student Model


In [81]:
# Initialize student model with classification head
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_model_name,
    num_labels=5,
    id2label=ordinal_id2label,
    label2id=ordinal_label2id
)
student_model.to(device)

# Load student tokenizer
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

print(f"Student model initialized with {sum(p.numel() for p in student_model.parameters() if p.requires_grad):,} trainable parameters")
print(f"Teacher model has {sum(p.numel() for p in teacher_model.parameters()):,} total parameters (frozen)")


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Student model initialized with 82,122,245 trainable parameters
Teacher model has 134,903,813 total parameters (frozen)


## Custom Dataset for Knowledge Distillation


In [82]:
class DistillationDataset(Dataset):
    """Custom dataset for knowledge distillation that stores raw text"""
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __getitem__(self, idx):
        # Return raw text and label - tokenization will happen in the collate function
        return {
            'text': self.texts[idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)


def create_distillation_collate_fn(student_tokenizer, teacher_tokenizer, device, max_length=128):
    """Create a collate function that tokenizes for both student and teacher"""
    def collate_fn(batch):
        texts = [item['text'] for item in batch]
        labels = torch.stack([item['labels'] for item in batch])
        
        # Tokenize for student
        student_encodings = student_tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Tokenize for teacher
        teacher_encodings = teacher_tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Move all tensors to the correct device
        return {
            'student_input_ids': student_encodings['input_ids'].to(device),
            'student_attention_mask': student_encodings['attention_mask'].to(device),
            'teacher_input_ids': teacher_encodings['input_ids'].to(device),
            'teacher_attention_mask': teacher_encodings['attention_mask'].to(device),
            'labels': labels.to(device)
        }
    
    return collate_fn

### Create Datasets


In [83]:
# Convert labels to lists
train_texts = formatted_train['model_input'].tolist()
val_texts = formatted_val['model_input'].tolist()
test_texts = formatted_test['model_input'].tolist()

train_labels = formatted_train['ordinal_label_id'].tolist()
val_labels = formatted_val['ordinal_label_id'].tolist()
test_labels = formatted_test['ordinal_label_id'].tolist()

# Create distillation datasets
train_dataset = DistillationDataset(train_texts, train_labels)
val_dataset = DistillationDataset(val_texts, val_labels)
test_dataset = DistillationDataset(test_texts, test_labels)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


Train dataset size: 25000
Validation dataset size: 4357
Test dataset size: 3424


## Metrics Computation


In [84]:
def compute_detailed_metrics(eval_pred):
    """Enhanced metrics using HuggingFace Evaluate library"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    # Load HuggingFace metrics (cached after first load)
    accuracy_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")
    precision_metric = evaluate.load("precision")
    recall_metric = evaluate.load("recall")

    # Compute standard classification metrics
    results = {}

    # Basic metrics
    results.update(accuracy_metric.compute(predictions=predictions, references=labels))
    results.update(f1_metric.compute(predictions=predictions, references=labels, average='macro'))
    results.update(f1_metric.compute(predictions=predictions, references=labels, average='weighted'))
    results.update(precision_metric.compute(predictions=predictions, references=labels, average='macro'))
    results.update(recall_metric.compute(predictions=predictions, references=labels, average='macro'))

    # Per-class F1 scores
    f1_per_class = f1_score(labels, predictions, average=None)
    for i, class_name in enumerate(['extremely_negative', 'negative', 'neutral', 'positive', 'extremely_positive']):
        results[f'f1_{class_name}'] = f1_per_class[i]

        # Per-class precision and recall
        precision_per_class = precision_score(labels, predictions, average=None, zero_division=0)
        recall_per_class = recall_score(labels, predictions, average=None, zero_division=0)
        results[f'precision_{class_name}'] = precision_per_class[i]
        results[f'recall_{class_name}'] = recall_per_class[i]

        # Per-class accuracy
        class_mask = (labels == i)
        if class_mask.sum() > 0:
            results[f'accuracy_{class_name}'] = accuracy_score(labels[class_mask], predictions[class_mask])
        else:
            results[f'accuracy_{class_name}'] = 0.0

    # Custom ordinal metrics
    results['mae'] = np.mean(np.abs(predictions - labels))
    results['adjacent_accuracy'] = np.sum(np.abs(predictions - labels) <= 1) / len(labels)

    # Quadratic Weighted Kappa
    try:
        qwk = cohen_kappa_score(labels, predictions, weights='quadratic')
        results['quadratic_weighted_kappa'] = qwk
    except:
        results['quadratic_weighted_kappa'] = 0.0

    return results


## Knowledge Distillation Components


### Distillation Hyperparameters


In [85]:
# Knowledge distillation hyperparameters
TEMPERATURE = 2.0  # Temperature for soft targets
ALPHA = 0.5  # Weight for hard target loss (1-alpha for soft target loss)

print(f"Distillation Temperature: {TEMPERATURE}")
print(f"Alpha (hard loss weight): {ALPHA}")
print(f"Soft loss weight: {1-ALPHA}")


Distillation Temperature: 2.0
Alpha (hard loss weight): 0.5
Soft loss weight: 0.5


### Custom Distillation Trainer


In [86]:
class DistillationTrainer(Trainer):
    """Custom trainer for knowledge distillation"""
    
    def __init__(self, teacher_model, temperature, alpha, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute combined knowledge distillation loss
        """
        # Check if we're in distillation mode (training) or regular mode (evaluation)
        if 'student_input_ids' in inputs and inputs['student_input_ids'] is not None:
            # Distillation mode - use both teacher and student
            # Get student predictions
            student_outputs = model(
                input_ids=inputs.get('student_input_ids'),
                attention_mask=inputs.get('student_attention_mask')
            )
            student_logits = student_outputs.logits
            
            # Get teacher predictions (no gradient)
            with torch.no_grad():
                teacher_outputs = self.teacher_model(
                    input_ids=inputs.get('teacher_input_ids'),
                    attention_mask=inputs.get('teacher_attention_mask')
                )
                teacher_logits = teacher_outputs.logits
            
            # Get labels
            labels = inputs.get('labels')
            
            # Hard target loss (CrossEntropy)
            loss_fct = nn.CrossEntropyLoss()
            hard_loss = loss_fct(student_logits, labels)
            
            # Soft target loss (KL Divergence)
            student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
            teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
            
            # KL Divergence loss
            soft_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
            soft_loss = soft_loss * (self.temperature ** 2)
            
            # Combined loss
            loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
            
            return (loss, student_outputs) if return_outputs else loss
        
        else:
            # Regular mode - use standard loss computation for evaluation
            student_outputs = model(
                input_ids=inputs.get('input_ids'),
                attention_mask=inputs.get('attention_mask'),
                labels=inputs.get('labels')
            )
            
            return (student_outputs.loss, student_outputs) if return_outputs else student_outputs.loss
        
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        Override prediction step to use student inputs only
        """
        # Check if we have the expected distillation inputs
        if 'student_input_ids' in inputs and inputs['student_input_ids'] is not None:
            # During training/evaluation with distillation data
            student_inputs = {
                'input_ids': inputs.get('student_input_ids'),
                'attention_mask': inputs.get('student_attention_mask'),
                'labels': inputs.get('labels')
            }
        else:
            # Fallback for regular evaluation (when using SimpleDataset)
            student_inputs = {
                'input_ids': inputs.get('input_ids'),
                'attention_mask': inputs.get('attention_mask'),
                'labels': inputs.get('labels')
            }
        
        # Call parent's prediction_step with student inputs
        return super().prediction_step(model, student_inputs, prediction_loss_only, ignore_keys)

## Training Setup


### Metrics Logger Callback


In [87]:
class DistillationMetricsLogger(TrainerCallback):
    """Callback to log detailed metrics during distillation"""
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None or not wandb.run:
            return
        
        # Only log when we have evaluation metrics
        if 'eval_loss' in logs:
            current_epoch = int(state.epoch) if state.epoch else 0
            
            # Get training loss from state history
            train_loss = 0
            if state.log_history:
                for log_entry in reversed(state.log_history):
                    if 'loss' in log_entry:
                        train_loss = log_entry['loss']
                        break
            
            # Detailed metrics
            detailed_metrics = {
                "Epoch": current_epoch,
                "Train Loss": train_loss,
                "Validation Loss": logs.get('eval_loss', 0),
                "Validation Accuracy": logs.get('eval_accuracy', 0),
                "Validation F1": logs.get('eval_f1', 0),
                "Validation MAE": logs.get('eval_mae', 0),
                "Validation QWK": logs.get('eval_quadratic_weighted_kappa', 0),
                "Learning_Rate": logs.get('learning_rate', args.learning_rate),
            }
            
            # Log to WandB
            wandb.log(detailed_metrics)
            
            # Print progress
            print(f"Epoch {current_epoch}: "
                  f"Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {logs.get('eval_loss', 0):.4f}, "
                  f"Val F1: {logs.get('eval_f1', 0):.4f}, "
                  f"QWK: {logs.get('eval_quadratic_weighted_kappa', 0):.4f}")


### Training Configuration


In [88]:
# Initialize W&B for tracking
wandb.init(
    project="covid-tweet-sentiment-distillation",
    name=f"{MODEL_TYPE}-to-distilroberta",
    config={
        "model_type": MODEL_TYPE,
        "teacher_model": teacher_model_name,
        "student_model": student_model_name,
        "temperature": TEMPERATURE,
        "alpha": ALPHA,
        "num_epochs": 10,
        "batch_size": 32,
        "learning_rate": 3e-5,
    }
)

In [89]:
# Training arguments with GPU optimizations
training_args = TrainingArguments(
    output_dir="./distilled_model",
    num_train_epochs=10,
    per_device_train_batch_size=16 if device.type == "cpu" else 32,  # Smaller batch for CPU
    per_device_eval_batch_size=32 if device.type == "cpu" else 64,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_quadratic_weighted_kappa",
    greater_is_better=True,
    save_total_limit=2,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="wandb",
    fp16=device.type == "cuda",  # Enable FP16 only for CUDA
    dataloader_num_workers=0,
    dataloader_pin_memory=device.type != "cpu",
)

### Create Distillation Trainer


In [90]:
# Create custom collate function
collate_fn = collate_fn = create_distillation_collate_fn(student_tokenizer, teacher_tokenizer, device)

# Create distillation trainer
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    temperature=TEMPERATURE,
    alpha=ALPHA,
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_detailed_metrics,
    data_collator=collate_fn,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2),
        DistillationMetricsLogger(),
    ],
)

# GPU memory optimization
if device.type in ["cuda", "mps"]:
    # Enable gradient checkpointing for memory efficiency
    student_model.gradient_checkpointing_enable()
    
    if device.type == "cuda":
        # CUDA-specific optimizations
        torch.backends.cudnn.benchmark = True
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print("Distillation trainer created successfully!")

Distillation trainer created successfully!


## Training with Knowledge Distillation


In [91]:
# Train the student model with knowledge distillation
print("Starting knowledge distillation training...")
trainer.train()

Starting knowledge distillation training...


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,F1 Extremely Negative,Precision Extremely Negative,Recall Extremely Negative,Accuracy Extremely Negative,F1 Negative,Precision Negative,Recall Negative,Accuracy Negative,F1 Neutral,Precision Neutral,Recall Neutral,Accuracy Neutral,F1 Positive,Precision Positive,Recall Positive,Accuracy Positive,F1 Extremely Positive,Precision Extremely Positive,Recall Extremely Positive,Accuracy Extremely Positive,Mae,Adjacent Accuracy,Quadratic Weighted Kappa
1,0.8347,0.753833,0.718614,0.718342,0.735603,0.714441,0.683215,0.807263,0.592213,0.592213,0.675609,0.619308,0.743169,0.743169,0.783738,0.786848,0.780652,0.780652,0.678035,0.702862,0.654902,0.654902,0.780999,0.761733,0.801266,0.801266,0.356667,0.934129,0.834216
2,0.5804,0.775524,0.721597,0.714726,0.732524,0.757951,0.769091,0.691176,0.866803,0.866803,0.674832,0.767409,0.602186,0.602186,0.854866,0.865207,0.844769,0.844769,0.612339,0.74136,0.521569,0.521569,0.734893,0.597464,0.95443,0.95443,0.339454,0.948359,0.866675
3,0.4613,0.541632,0.817305,0.818108,0.82878,0.812407,0.807091,0.821656,0.793033,0.793033,0.771163,0.770742,0.771585,0.771585,0.862927,0.888095,0.839145,0.839145,0.803412,0.762139,0.849412,0.849412,0.852568,0.901269,0.808861,0.808861,0.231122,0.956392,0.893931
4,0.3605,0.607764,0.803076,0.801692,0.794178,0.820251,0.772294,0.668666,0.913934,0.913934,0.70762,0.769923,0.654645,0.654645,0.873497,0.889277,0.858268,0.858268,0.795399,0.835203,0.759216,0.759216,0.85816,0.807821,0.91519,0.91519,0.238467,0.96213,0.904807
5,0.3118,0.662436,0.784026,0.7818,0.781652,0.805671,0.791434,0.725256,0.870902,0.870902,0.73425,0.77136,0.700546,0.700546,0.863507,0.87703,0.850394,0.850394,0.739394,0.825121,0.669804,0.669804,0.80742,0.709492,0.936709,0.936709,0.258894,0.962359,0.897559
6,0.2591,0.627535,0.808584,0.80774,0.808417,0.820407,0.813861,0.787356,0.842213,0.842213,0.771733,0.792627,0.751913,0.751913,0.861127,0.890625,0.833521,0.833521,0.776866,0.814273,0.742745,0.742745,0.835414,0.757202,0.931646,0.931646,0.228827,0.96672,0.907647
7,0.2286,0.718361,0.792288,0.789662,0.791133,0.814804,0.808271,0.746528,0.881148,0.881148,0.757417,0.809701,0.711475,0.711475,0.855543,0.851728,0.859393,0.859393,0.745969,0.839216,0.671373,0.671373,0.811892,0.708491,0.950633,0.950633,0.243287,0.969015,0.90651
8,0.2063,0.690931,0.804682,0.802871,0.80416,0.819135,0.820614,0.794626,0.848361,0.848361,0.769231,0.816953,0.726776,0.726776,0.855721,0.841304,0.870641,0.870641,0.768395,0.822739,0.720784,0.720784,0.827042,0.745178,0.929114,0.929114,0.232958,0.967409,0.905317
9,0.1472,0.710835,0.8056,0.80381,0.798935,0.82488,0.802525,0.716586,0.911885,0.911885,0.744431,0.802781,0.693989,0.693989,0.862302,0.865232,0.859393,0.859393,0.785534,0.846782,0.732549,0.732549,0.83705,0.763295,0.926582,0.926582,0.228139,0.970163,0.911929
10,0.1388,0.708193,0.80583,0.803935,0.799346,0.822445,0.806361,0.741824,0.883197,0.883197,0.75,0.797048,0.708197,0.708197,0.848318,0.832251,0.865017,0.865017,0.785684,0.848182,0.731765,0.731765,0.844419,0.777423,0.924051,0.924051,0.226532,0.971311,0.911993


Epoch 1: Train Loss: 0.8347, Val Loss: 0.7538, Val F1: 0.7183, QWK: 0.8342
Epoch 2: Train Loss: 0.5804, Val Loss: 0.7755, Val F1: 0.7147, QWK: 0.8667
Epoch 3: Train Loss: 0.4613, Val Loss: 0.5416, Val F1: 0.8181, QWK: 0.8939
Epoch 4: Train Loss: 0.3605, Val Loss: 0.6078, Val F1: 0.8017, QWK: 0.9048
Epoch 5: Train Loss: 0.3118, Val Loss: 0.6624, Val F1: 0.7818, QWK: 0.8976
Epoch 6: Train Loss: 0.2591, Val Loss: 0.6275, Val F1: 0.8077, QWK: 0.9076
Epoch 7: Train Loss: 0.2286, Val Loss: 0.7184, Val F1: 0.7897, QWK: 0.9065
Epoch 8: Train Loss: 0.2063, Val Loss: 0.6909, Val F1: 0.8029, QWK: 0.9053
Epoch 9: Train Loss: 0.1472, Val Loss: 0.7108, Val F1: 0.8038, QWK: 0.9119
Epoch 10: Train Loss: 0.1388, Val Loss: 0.7082, Val F1: 0.8039, QWK: 0.9120


TrainOutput(global_step=7820, training_loss=0.3834667370447417, metrics={'train_runtime': 3895.0233, 'train_samples_per_second': 64.184, 'train_steps_per_second': 2.008, 'total_flos': 0.0, 'train_loss': 0.3834667370447417, 'epoch': 10.0})

## Evaluation and Comparison


### Model Size Comparison


In [92]:
def get_model_size(model):
    """Calculate model size in MB"""
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_mb = (param_size + buffer_size) / (1024 * 1024)
    return size_mb

# Compare model sizes
teacher_size = get_model_size(teacher_model)
student_size = get_model_size(student_model)
compression_ratio = teacher_size / student_size

print(f"\nModel Size Comparison:")
print(f"Teacher Model Size: {teacher_size:.2f} MB")
print(f"Student Model Size: {student_size:.2f} MB")
print(f"Compression Ratio: {compression_ratio:.2f}x")
print(f"Size Reduction: {(1 - student_size/teacher_size) * 100:.1f}%")



Model Size Comparison:
Teacher Model Size: 514.62 MB
Student Model Size: 313.28 MB
Compression Ratio: 1.64x
Size Reduction: 39.1%


### Inference Speed Comparison


In [93]:
def measure_inference_time(model, tokenizer, sample_text, num_runs=50):
    """Measure average inference time for a model"""
    model.eval()
    times = []
    
    # Warm-up runs
    for _ in range(5):
        inputs = tokenizer(sample_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            _ = model(**inputs)
    
    # Actual timing runs
    for _ in range(num_runs):
        inputs = tokenizer(sample_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        start_time = time.time()
        with torch.no_grad():
            _ = model(**inputs)
        end_time = time.time()
        
        times.append(end_time - start_time)
    
    return np.mean(times), np.std(times)

# Test inference speed
sample_tweet = "COVID-19 vaccines have been crucial in reducing hospitalizations and saving lives worldwide."

print("Measuring inference speed...")
teacher_time, teacher_std = measure_inference_time(teacher_model, teacher_tokenizer, sample_tweet)
student_time, student_std = measure_inference_time(student_model, student_tokenizer, sample_tweet)

speedup = teacher_time / student_time

print(f"\nInference Speed Comparison:")
print(f"Teacher Model: {teacher_time*1000:.2f} ± {teacher_std*1000:.2f} ms")
print(f"Student Model: {student_time*1000:.2f} ± {student_std*1000:.2f} ms")
print(f"Speedup: {speedup:.2f}x")


Measuring inference speed...

Inference Speed Comparison:
Teacher Model: 15.73 ± 0.91 ms
Student Model: 8.59 ± 0.07 ms
Speedup: 1.83x


### Performance Evaluation


In [94]:
def evaluate_model(model, tokenizer, dataset, model_name):
    """Evaluate a model on a given dataset"""
    model.eval()
    
    # Create regular dataset for single-model evaluation
    eval_encodings = tokenizer(
        [dataset[i]['text'] for i in range(len(dataset))],
        truncation=True,
        padding=False,
        max_length=128,
        add_special_tokens=True,
        return_attention_mask=True,
        return_token_type_ids=False
    )
    
    class SimpleDataset(Dataset):
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels = labels
        
        def __getitem__(self, idx):
            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
            item['labels'] = torch.tensor(self.labels[idx])
            return item
        
        def __len__(self):
            return len(self.labels)
    
    eval_dataset = SimpleDataset(eval_encodings, [dataset[i]['labels'].item() for i in range(len(dataset))])
    
    # Create trainer for evaluation
    eval_trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="./temp",
            per_device_eval_batch_size=64,
            remove_unused_columns=False,
            dataloader_num_workers=0,
        ),
        tokenizer=tokenizer,
        compute_metrics=compute_detailed_metrics,
    )
    
    # Evaluate
    results = eval_trainer.evaluate(eval_dataset)
    
    # Format results with model name prefix
    formatted_results = {}
    for key, value in results.items():
        new_key = f"{model_name}_{key}"
        formatted_results[new_key] = value
    
    return results, formatted_results

# Evaluate both models on validation set
print("Evaluating Teacher Model on Validation Set...")
val_results_teacher, val_formatted_teacher = evaluate_model(teacher_model, teacher_tokenizer, val_dataset, "teacher")

print("\nEvaluating Student Model on Validation Set...")
val_results_student, val_formatted_student = evaluate_model(student_model, student_tokenizer, val_dataset, "student")

# Evaluate both models on test set
print("\nEvaluating Teacher Model on Test Set...")
test_results_teacher, test_formatted_teacher = evaluate_model(teacher_model, teacher_tokenizer, test_dataset, "teacher")

print("\nEvaluating Student Model on Test Set...")
test_results_student, test_formatted_student = evaluate_model(student_model, student_tokenizer, test_dataset, "student")

# Log to wandb
wandb.log({
    **val_formatted_teacher,
    **val_formatted_student,
    **test_formatted_teacher,
    **test_formatted_student
})


Evaluating Teacher Model on Validation Set...


  eval_trainer = Trainer(



Evaluating Student Model on Validation Set...


  eval_trainer = Trainer(



Evaluating Teacher Model on Test Set...


  eval_trainer = Trainer(



Evaluating Student Model on Test Set...


  eval_trainer = Trainer(


## Comprehensive Results Summary


In [95]:
print("="*80)
print("KNOWLEDGE DISTILLATION RESULTS SUMMARY")
print("="*80)

print(f"\n📊 MODEL COMPRESSION:")
print(f"Teacher Model Size: {teacher_size:.2f} MB")
print(f"Student Model Size: {student_size:.2f} MB")
print(f"Compression Ratio: {compression_ratio:.2f}x")
print(f"Size Reduction: {(1 - student_size/teacher_size) * 100:.1f}%")

print(f"\n⚡ INFERENCE SPEED:")
print(f"Teacher Model: {teacher_time*1000:.2f} ms")
print(f"Student Model: {student_time*1000:.2f} ms")
print(f"Speedup: {speedup:.2f}x")

print(f"\n🎯 VALIDATION SET PERFORMANCE:")
print(f"                      Teacher     Student     Difference")
print(f"Accuracy:             {val_results_teacher['eval_accuracy']:.4f}      {val_results_student['eval_accuracy']:.4f}      {val_results_student['eval_accuracy'] - val_results_teacher['eval_accuracy']:+.4f}")
print(f"F1-Score:             {val_results_teacher['eval_f1']:.4f}      {val_results_student['eval_f1']:.4f}      {val_results_student['eval_f1'] - val_results_teacher['eval_f1']:+.4f}")
print(f"QWK:                  {val_results_teacher['eval_quadratic_weighted_kappa']:.4f}      {val_results_student['eval_quadratic_weighted_kappa']:.4f}      {val_results_student['eval_quadratic_weighted_kappa'] - val_results_teacher['eval_quadratic_weighted_kappa']:+.4f}")
print(f"MAE:                  {val_results_teacher['eval_mae']:.4f}      {val_results_student['eval_mae']:.4f}      {val_results_student['eval_mae'] - val_results_teacher['eval_mae']:+.4f}")

print(f"\n🧪 TEST SET PERFORMANCE:")
print(f"                      Teacher     Student     Difference")
print(f"Accuracy:             {test_results_teacher['eval_accuracy']:.4f}      {test_results_student['eval_accuracy']:.4f}      {test_results_student['eval_accuracy'] - test_results_teacher['eval_accuracy']:+.4f}")
print(f"F1-Score:             {test_results_teacher['eval_f1']:.4f}      {test_results_student['eval_f1']:.4f}      {test_results_student['eval_f1'] - test_results_teacher['eval_f1']:+.4f}")
print(f"QWK:                  {test_results_teacher['eval_quadratic_weighted_kappa']:.4f}      {test_results_student['eval_quadratic_weighted_kappa']:.4f}      {test_results_student['eval_quadratic_weighted_kappa'] - test_results_teacher['eval_quadratic_weighted_kappa']:+.4f}")
print(f"MAE:                  {test_results_teacher['eval_mae']:.4f}      {test_results_student['eval_mae']:.4f}      {test_results_student['eval_mae'] - test_results_teacher['eval_mae']:+.4f}")

# Calculate performance retention
acc_retention = (test_results_student['eval_accuracy'] / test_results_teacher['eval_accuracy']) * 100
f1_retention = (test_results_student['eval_f1'] / test_results_teacher['eval_f1']) * 100
qwk_retention = (test_results_student['eval_quadratic_weighted_kappa'] / test_results_teacher['eval_quadratic_weighted_kappa']) * 100

print(f"\n🎯 PERFORMANCE RETENTION:")
print(f"Accuracy Retention: {acc_retention:.1f}%")
print(f"F1-Score Retention: {f1_retention:.1f}%")
print(f"QWK Retention: {qwk_retention:.1f}%")

print(f"\n💡 DISTILLATION SUMMARY:")
print(f"• Temperature: {TEMPERATURE}")
print(f"• Alpha (hard loss weight): {ALPHA}")
print(f"• Achieved {compression_ratio:.1f}x model compression with {speedup:.1f}x inference speedup")
print(f"• Performance drop: {abs(test_results_student['eval_f1'] - test_results_teacher['eval_f1']):.4f} F1-score points")
print(f"• Knowledge retention: {f1_retention:.1f}% of teacher's performance")

# Final wandb log with summary metrics
wandb.log({
    "final_compression_ratio": compression_ratio,
    "final_speedup": speedup,
    "final_f1_retention": f1_retention,
    "final_accuracy_retention": acc_retention,
    "final_qwk_retention": qwk_retention,
    "temperature": TEMPERATURE,
    "alpha": ALPHA
})

print(f"\n✅ Knowledge distillation complete! Results logged to W&B.")

# Save the distilled model
trainer.save_model("./distilled_model_final")
print(f"✅ Distilled model saved to ./distilled_model_final")

# Finish wandb run
wandb.finish()

KNOWLEDGE DISTILLATION RESULTS SUMMARY

📊 MODEL COMPRESSION:
Teacher Model Size: 514.62 MB
Student Model Size: 313.28 MB
Compression Ratio: 1.64x
Size Reduction: 39.1%

⚡ INFERENCE SPEED:
Teacher Model: 15.73 ms
Student Model: 8.59 ms
Speedup: 1.83x

🎯 VALIDATION SET PERFORMANCE:
                      Teacher     Student     Difference
Accuracy:             0.8442      0.8058      -0.0383
F1-Score:             0.8433      0.8039      -0.0394
QWK:                  0.9374      0.9120      -0.0254
MAE:                  0.1714      0.2265      +0.0551

🧪 TEST SET PERFORMANCE:
                      Teacher     Student     Difference
Accuracy:             0.8423      0.8067      -0.0356
F1-Score:             0.8414      0.8045      -0.0368
QWK:                  0.9330      0.9127      -0.0203
MAE:                  0.1790      0.2281      +0.0491

🎯 PERFORMANCE RETENTION:
Accuracy Retention: 95.8%
F1-Score Retention: 95.6%
QWK Retention: 97.8%

💡 DISTILLATION SUMMARY:
• Temperature: 2.0
• Alp

0,1
Epoch,▁▂▃▃▄▅▆▆▇█
Learning_Rate,▁▁▁▁▁▁▁▁▁▁
Train Loss,█▅▄▃▃▂▂▂▁▁
Validation Accuracy,▁▁█▇▆▇▆▇▇▇
Validation F1,▁▁█▇▆▇▆▇▇▇
Validation Loss,▇█▁▃▅▄▆▅▆▆
Validation MAE,█▇▁▂▃▁▂▁▁▁
Validation QWK,▁▄▆▇▇██▇██
alpha,▁
eval/accuracy,▁▁▇▆▅▆▅▆▆▆█▆█▆

0,1
Epoch,10.0
Learning_Rate,5e-05
Train Loss,0.1388
Validation Accuracy,0.80583
Validation F1,0.80394
Validation Loss,0.70819
Validation MAE,0.22653
Validation QWK,0.91199
alpha,0.5
eval/accuracy,0.80666


print()