In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional, Any

if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use Apple GPU via Metal
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


In [2]:
class TextDataset(Dataset):
    """Custom dataset class for text classification with tokenization."""
    
    def __init__(self, data: Any, tokenizer: Any, max_length: int = 150):
        """
        Initialize the dataset.
        
        Args:
            data: Dataset containing 'label' and 'text' fields
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length for tokenization
        """
        self.targets = torch.tensor(data['label'])
        texts = data['text']
        
        tokens = tokenizer(
            texts, 
            return_tensors='pt', 
            truncation=True, 
            padding='max_length', 
            max_length=max_length
        )
        
        self.input_ids = tokens['input_ids']
        self.attention_mask = tokens['attention_mask']
        self.length = len(texts)
    
    def __len__(self) -> int:
        return self.length
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.input_ids[index], self.attention_mask[index], self.targets[index]

## Prepeare Torch Data Loaders

In [3]:
class DataManager:
    """Manages data loading and preprocessing for knowledge distillation."""
    
    def __init__(self, dataset_name: str, tokenizer: Any, test_size: float = 0.2, 
                 max_length: int = 150, batch_size: int = 32):
        """
        Initialize data manager.
        
        Args:
            dataset_name: Name of the dataset to load
            tokenizer: Tokenizer for text processing
            test_size: Fraction of data to use for validation
            max_length: Maximum sequence length
            batch_size: Batch size for data loaders
        """
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.test_size = test_size
        self.max_length = max_length
        self.batch_size = batch_size
        
        self.train_loader = None
        self.valid_loader = None
        self.test_loader = None
    
    def prepare_data(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """
        Load and prepare data loaders.
        
        Returns:
            Tuple of (train_loader, valid_loader, test_loader)
        """
        # Load dataset
        data = load_dataset(self.dataset_name)
        
        # Split data
        train_test = data['train'].train_test_split(test_size=self.test_size, shuffle=True)
        train_data = train_test['train']
        valid_data = train_test['test']
        test_data = data['test']
        
        # Print dataset statistics
        self._print_dataset_stats(train_data, valid_data, test_data)
        
        # Create custom datasets
        train_dataset = TextDataset(train_data, self.tokenizer, self.max_length)
        valid_dataset = TextDataset(valid_data, self.tokenizer, self.max_length)
        test_dataset = TextDataset(test_data, self.tokenizer, self.max_length)
        
        # Create data loaders
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size)
        self.valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size)
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size)
        
        return self.train_loader, self.valid_loader, self.test_loader
    
    def _print_dataset_stats(self, train_data: Any, valid_data: Any, test_data: Any):
        """Print dataset statistics."""
        print(f'Train set has {train_data.num_rows} samples')
        print(f'Validation set has {valid_data.num_rows} samples')
        print(f'Test set has {test_data.num_rows} samples')

## Model Management

In [4]:
class ModelManager:
    """Manages teacher and student models for knowledge distillation."""
    
    def __init__(self, teacher_model_name: str, student_model_name: str, 
                 num_labels: int, device: torch.device, dropout_rate: float = 0.2):
        """
        Initialize model manager.
        
        Args:
            teacher_model_name: Name/path of teacher model
            student_model_name: Name/path of student model  
            num_labels: Number of classification labels
            device: Device to load models on
            dropout_rate: Dropout rate for student model
        """
        self.teacher_model_name = teacher_model_name
        self.student_model_name = student_model_name
        self.num_labels = num_labels
        self.device = device
        self.dropout_rate = dropout_rate
        
        self.teacher_model = None
        self.student_model = None
    
    def load_models(self) -> Tuple[nn.Module, nn.Module]:
        """
        Load teacher and student models.
        
        Returns:
            Tuple of (student_model, teacher_model)
        """
        # Load student model
        self.student_model = AutoModelForSequenceClassification.from_pretrained(
            self.student_model_name, 
            num_labels=self.num_labels
        ).to(self.device)
        self.student_model.dropout = nn.Dropout(self.dropout_rate)
        
        # Load teacher model
        self.teacher_model = AutoModelForSequenceClassification.from_pretrained(
            self.teacher_model_name,
            num_labels=self.num_labels  
        ).to(self.device)
        
        return self.student_model, self.teacher_model

## Metrics Tracking

In [5]:
class MetricsTracker:
    """Tracks and manages training metrics."""
    
    def __init__(self):
        """Initialize metrics tracker."""
        self.training_loss_list = []
        self.training_kd_loss_list = []
        self.training_ce_loss_list = []
        self.training_accuracy_list = []
        self.valid_loss_list = []
        self.valid_accuracy_list = []
    
    def update_metrics(self, train_loss: float, train_kd_loss: float, train_ce_loss: float,
                      train_accuracy: float, valid_loss: float, valid_accuracy: float):
        """Update all metrics with current epoch values."""
        self.training_loss_list.append(train_loss)
        self.training_kd_loss_list.append(train_kd_loss)
        self.training_ce_loss_list.append(train_ce_loss)
        self.training_accuracy_list.append(train_accuracy)
        self.valid_loss_list.append(valid_loss)
        self.valid_accuracy_list.append(valid_accuracy)
    
    def print_detailed_metrics(self, epoch: int, total_epochs: int, train_loss: float, 
                             train_kd_loss: float, train_ce_loss: float, train_accuracy: float,
                             valid_loss: float, valid_accuracy: float):
        """Print detailed metrics for current epoch."""
        print(f"""
        Epoch {epoch + 1}/{total_epochs}:
        ├─ Combined Loss: {train_loss:.4f}
        ├─ KD Loss: {train_kd_loss:.4f} 
        ├─ CE Loss: {train_ce_loss:.4f}
        ├─ KD/CE Ratio: {train_kd_loss/train_ce_loss:.2f}
        ├─ Train Accuracy: {train_accuracy:.4f}
        ├─ Valid Loss: {valid_loss:.4f}
        └─ Valid Accuracy: {valid_accuracy:.4f}
        """)

## Knowldge Distillation Training

In [6]:
class KnowledgeDistillationTrainer:
    """Main trainer class for knowledge distillation."""
    
    def __init__(self, student_model: nn.Module, teacher_model: nn.Module, 
                 device: torch.device, temperature: float = 3.0, alpha: float = 0.7,
                 learning_rate: float = 2e-5, epochs: int = 3):
        """
        Initialize knowledge distillation trainer.
        
        Args:
            student_model: Student model to train
            teacher_model: Teacher model for distillation
            device: Device for training
            temperature: Temperature for knowledge distillation
            alpha: Weight for combining CE and KD losses
            learning_rate: Learning rate for optimizer
            epochs: Number of training epochs
        """
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.device = device
        self.temperature = temperature
        self.alpha = alpha
        self.learning_rate = learning_rate
        self.epochs = epochs
        
        # Loss functions
        self.entropy_loss = nn.CrossEntropyLoss()
        self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
        
        # Metrics tracker
        self.metrics_tracker = MetricsTracker()
        
        # Initialize optimizer and scheduler (will be set during training)
        self.optimizer = None
        self.scheduler = None
    
    def accuracy_score(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
                      model: nn.Module) -> float:
        """Calculate accuracy for a batch."""
        with torch.no_grad():
            outputs = model(batch[0].to(self.device), batch[1].to(self.device))
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)
            class_predictions = torch.argmax(probabilities, dim=1)
            acc = torch.mean((class_predictions == batch[2].to(self.device)).to(torch.float)).data.item()
            return acc
    
    def setup_optimizer_scheduler(self, train_loader: DataLoader):
        """Setup optimizer and learning rate scheduler."""
        self.optimizer = optim.AdamW(self.student_model.parameters(), lr=self.learning_rate)
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=self.learning_rate,
            total_steps=len(train_loader) * self.epochs,
            pct_start=0.1,
            anneal_strategy='cos'
        )
    
    def train_epoch(self, train_loader: DataLoader) -> Tuple[float, float, float, float]:
        """Train for one epoch."""
        self.student_model.train()
        train_loss = 0.0
        train_kd_loss = 0.0
        train_ce_loss = 0.0
        train_accuracy = 0.0
        
        for batch in train_loader:
            self.optimizer.zero_grad()
            input_ids = batch[0].to(self.device)
            attention_mask = batch[1].to(self.device)
            target_tensors = batch[2].to(self.device)
            
            # Student model predictions
            student_logits = self.student_model(input_ids=input_ids, attention_mask=attention_mask).logits
            
            # Cross-entropy loss
            ce_loss = self.entropy_loss(student_logits, target_tensors)
            
            # Teacher model logits
            with torch.no_grad():
                teacher_outputs = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits
            
            # Knowledge distillation loss
            kd_loss = self.temperature ** 2 * self.kd_loss_fn(
                F.log_softmax(student_logits / self.temperature, dim=-1),
                F.softmax(teacher_logits / self.temperature, dim=-1)
            )
            
            # Combined loss
            loss = self.alpha * ce_loss + (1. - self.alpha) * kd_loss
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), 1.0)
            
            self.optimizer.step()
            self.scheduler.step()
            
            # Update training metrics
            train_kd_loss += kd_loss.item()
            train_ce_loss += ce_loss.item()
            train_loss += loss.item()
            accuracy = self.accuracy_score(batch, self.student_model)
            train_accuracy += accuracy
        
        # Calculate averages
        train_accuracy /= len(train_loader)
        train_loss /= len(train_loader)
        train_kd_loss /= len(train_loader)
        train_ce_loss /= len(train_loader)
        
        return train_loss, train_kd_loss, train_ce_loss, train_accuracy
    
    def validate_epoch(self, valid_loader: DataLoader) -> Tuple[float, float]:
        """Validate for one epoch."""
        self.student_model.eval()
        valid_loss = 0.0
        valid_accuracy = 0.0
        
        for batch in valid_loader:
            input_ids = batch[0].to(self.device)
            attention_mask = batch[1].to(self.device)
            target_tensors = batch[2].to(self.device)
            
            output = self.student_model(input_ids=input_ids, attention_mask=attention_mask)
            val_loss = self.entropy_loss(output.logits, target_tensors)
            valid_loss += val_loss.item()
            accuracy = self.accuracy_score(batch, self.student_model)
            valid_accuracy += accuracy
        
        # Calculate averages
        valid_accuracy /= len(valid_loader)
        valid_loss /= len(valid_loader)
        
        return valid_loss, valid_accuracy
    
    def train(self, train_loader: DataLoader, valid_loader: DataLoader) -> Dict[str, List[float]]:
        """
        Main training loop.
        
        Args:
            train_loader: Training data loader
            valid_loader: Validation data loader
            
        Returns:
            Dictionary containing all training metrics
        """
        # Setup optimizer and scheduler
        self.setup_optimizer_scheduler(train_loader)
        
        for epoch in tqdm(range(self.epochs), total=self.epochs):
            # Training
            train_loss, train_kd_loss, train_ce_loss, train_accuracy = self.train_epoch(train_loader)
            
            # Validation
            valid_loss, valid_accuracy = self.validate_epoch(valid_loader)
            
            # Print metrics
            self.metrics_tracker.print_detailed_metrics(
                epoch, self.epochs, train_loss, train_kd_loss, train_ce_loss,
                train_accuracy, valid_loss, valid_accuracy
            )
            
            # Update metrics
            self.metrics_tracker.update_metrics(
                train_loss, train_kd_loss, train_ce_loss,
                train_accuracy, valid_loss, valid_accuracy
            )
        
        return {
            'training_loss': self.metrics_tracker.training_loss_list,
            'training_kd_loss': self.metrics_tracker.training_kd_loss_list,
            'training_ce_loss': self.metrics_tracker.training_ce_loss_list,
            'training_accuracy': self.metrics_tracker.training_accuracy_list,
            'valid_loss': self.metrics_tracker.valid_loss_list,
            'valid_accuracy': self.metrics_tracker.valid_accuracy_list
        }
    
    def save_model(self, save_path: str, tokenizer: Any):
        """Save the trained student model and tokenizer."""
        self.student_model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)

## Pipeline 

In [7]:
class KnowledgeDistillationPipeline:
    """Complete pipeline for knowledge distillation."""
    
    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the pipeline with configuration.
        
        Args:
            config: Configuration dictionary containing all parameters
        """
        self.config = config
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        print(f"Using device: {self.device}")
        
    def run(self):
        """Run the complete knowledge distillation pipeline."""
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.config['student_model_name'])
        
        # Prepare data
        data_manager = DataManager(
            dataset_name=self.config['dataset_name'],
            tokenizer=tokenizer,
            test_size=self.config.get('test_size', 0.2),
            max_length=self.config.get('max_length', 150),
            batch_size=self.config.get('batch_size', 64)
        )
        train_loader, valid_loader, test_loader = data_manager.prepare_data()
        
        # Load models
        model_manager = ModelManager(
            teacher_model_name=self.config['teacher_model_name'],
            student_model_name=self.config['student_model_name'],
            num_labels=self.config['num_labels'],
            device=self.device,
            dropout_rate=self.config.get('dropout_rate', 0.2)
        )
        student_model, teacher_model = model_manager.load_models()
        
        # Initialize trainer
        trainer = KnowledgeDistillationTrainer(
            student_model=student_model,
            teacher_model=teacher_model,
            device=self.device,
            temperature=self.config.get('temperature', 3.0),
            alpha=self.config.get('alpha', 0.7),
            learning_rate=self.config.get('learning_rate', 2e-5),
            epochs=self.config.get('epochs', 3)
        )
        
        # Train model
        metrics = trainer.train(train_loader, valid_loader)
        
        # Save model
        if 'save_path' in self.config:
            trainer.save_model(self.config['save_path'], tokenizer)
        
        return metrics, trainer, test_loader

## Run Pipeline 

In [7]:
config = {
    'dataset_name': 'ag_news',
    'teacher_model_name': 'bert-base-uncased',  
    'student_model_name': 'distilbert-base-uncased', 
    'num_labels': 4,  # Number of classes in ag_news
    'test_size': 0.2, # train-validation split 
    'max_length': 150,
    'batch_size': 64,  
    'temperature': 3.0,
    'alpha': 0.7,
    'learning_rate': 2e-5,
    'epochs': 3,  
    'dropout_rate': 0.2, #set dropout for student model 
    'save_path': "/Users/arsalsyed/Documents/student_model_distilled"  # Path to save the trained model
}

# Initialize and run pipeline
pipeline = KnowledgeDistillationPipeline(config)
metrics, trainer, test_loader = pipeline.run()


 33%|██████████████▎                            | 1/3 [1:25:49<2:51:39, 5149.58s/it]


    Epoch 1/3:
    ├─ Combined Loss: 0.5981
    ├─ KD Loss: 1.1593 
    ├─ CE Loss: 0.3576
    ├─ KD/CE Ratio: 3.24
    ├─ Train Accuracy: 0.8985
    ├─ Valid Loss: 0.2353
    └─ Valid Accuracy: 0.9369
    


 67%|████████████████████████████▋              | 2/3 [2:29:30<1:12:47, 4367.79s/it]


    Epoch 2/3:
    ├─ Combined Loss: 0.2024
    ├─ KD Loss: 0.2559 
    ├─ CE Loss: 0.1795
    ├─ KD/CE Ratio: 1.43
    ├─ Train Accuracy: 0.9537
    ├─ Valid Loss: 0.1997
    └─ Valid Accuracy: 0.9463
    


100%|█████████████████████████████████████████████| 3/3 [3:28:32<00:00, 4170.93s/it]


    Epoch 3/3:
    ├─ Combined Loss: 0.1415
    ├─ KD Loss: 0.1569 
    ├─ CE Loss: 0.1349
    ├─ KD/CE Ratio: 1.16
    ├─ Train Accuracy: 0.9647
    ├─ Valid Loss: 0.2049
    └─ Valid Accuracy: 0.9480
    





('/Users/arsalsyed/Documents/student_model_distilled/tokenizer_config.json',
 '/Users/arsalsyed/Documents/student_model_distilled/special_tokens_map.json',
 '/Users/arsalsyed/Documents/student_model_distilled/vocab.txt',
 '/Users/arsalsyed/Documents/student_model_distilled/added_tokens.json',
 '/Users/arsalsyed/Documents/student_model_distilled/tokenizer.json')