# Knowledge Distillation with Custom ResNet Models

This notebook explores the implementation of knowledge distillation between custom ResNet models on the CIFAR-10 dataset. Knowledge distillation is a technique where a smaller model (student) learns from a larger model (teacher) to achieve better performance than it would when trained from scratch.

We'll cover the following components:
1. Model Architecture - Custom ResNet implementation
2. Data Loading and Preprocessing
3. Generalist Model Training
4. Specialist Model Training with Knowledge Distillation
5. Evaluation and Visualization

Let's start by importing the necessary libraries.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Model Architecture

We'll implement a custom ResNet architecture with configurable depth and width. The architecture consists of:

- Basic residual blocks with skip connections
- A custom ResNet class that allows for different widths and depths

### 1.1 Basic Residual Block

In [None]:
class BasicResidualBlock(nn.Module):
    """Basic residual block for our custom ResNet"""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                             stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                             stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection to match dimensions
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                        stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out += self.shortcut(residual)
        out = F.relu(out)
        
        return out

### 1.2 Custom ResNet Architecture

In [None]:
class CustomResNet(nn.Module):
    """Custom ResNet with configurable width and depth"""
    
    def __init__(self, num_blocks, width_factors, num_classes=10):
        super(CustomResNet, self).__init__()
        
        # Initial convolution
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Create residual blocks with different widths
        self.layer1 = self._make_layer(width_factors[0], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(width_factors[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(width_factors[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(width_factors[3], num_blocks[3], stride=2)
        
        # Final classification layer
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(width_factors[3], num_classes)
        
        # Initialize weights
        self._initialize_weights()
        
    def _make_layer(self, width, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for stride in strides:
            layers.append(BasicResidualBlock(self.in_channels, width, stride))
            self.in_channels = width
            
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
    
    def count_parameters(self):
        """Count the number of trainable parameters in the model"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

### 1.3 Creating Generalist and Specialist Models

We'll define two functions to create our models:
1. A larger "generalist" model that handles all 10 CIFAR-10 classes
2. A smaller "specialist" model designed to handle only a subset of classes

In [None]:
# Create generalist model - deeper and wider
def create_generalist_model(num_classes=10):
    # More blocks and higher width factors for more parameters
    num_blocks = [3, 4, 6, 3]  # Similar to ResNet-34
    width_factors = [64, 128, 256, 512]  # Standard widths
    
    return CustomResNet(num_blocks, width_factors, num_classes)

# Create specialist model - smaller than generalist
def create_specialist_model(num_classes):
    # Fewer blocks and smaller width factors for fewer parameters
    num_blocks = [2, 2, 2, 2]  # Similar to ResNet-18 but smaller
    width_factors = [32, 64, 128, 256]  # Half the standard widths
    
    return CustomResNet(num_blocks, width_factors, num_classes)

## 2. Data Loading and Preprocessing

Now we'll set up data loading for the CIFAR-10 dataset with appropriate augmentations.

In [None]:
# Data loading and preprocessing for CIFAR-10
def load_cifar10():
    # Data normalization for CIFAR-10
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                           shuffle=True, num_workers=4)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                          shuffle=False, num_workers=4)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

## 3. Training and Evaluation Functions

### 3.1 Visualization Functions

In [None]:
# Create plotting functions for training visualization
def plot_training_progress(train_losses, train_accs, val_losses=None, val_accs=None, title="Training Progress"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = range(1, len(train_losses) + 1)
    
    # Plot losses
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    if val_losses:
        ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracies
    ax2.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    if val_accs:
        ax2.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.suptitle(title)
    plt.tight_layout()
    
    return fig

### 3.2 Generalist Model Training

In [None]:
# Train function for models with learning rate scheduling
def train_model(model, trainloader, testloader, epochs=90, lr=0.1, model_name="Generalist"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Track metrics
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    best_acc = 0.0
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}')
        for i, data in enumerate(progress_bar):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': running_loss/(i+1), 
                'acc': 100.*correct/total,
                'lr': scheduler.get_last_lr()[0]
            })
        
        train_loss = running_loss/len(trainloader)
        train_acc = 100.*correct/total
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Evaluate on validation set
        val_loss, val_acc = evaluate_model(model, testloader, return_metrics=True)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f'{model_name}_best.pth')
        
        # Print progress
        elapsed = time.time() - start_time
        print(f'Epoch: {epoch+1}/{epochs} | '
              f'Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.2f}% | '
              f'Time: {elapsed/60:.2f}m | LR: {scheduler.get_last_lr()[0]:.6f}')
        
        # Update learning rate
        scheduler.step()
        
        # Plot every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            fig = plot_training_progress(train_losses, train_accs, val_losses, val_accs, 
                                      title=f'{model_name} Training Progress')
            plt.savefig(f'{model_name}_progress_epoch_{epoch+1}.png')
            plt.close(fig)
    
    # Load best model
    model.load_state_dict(torch.load(f'{model_name}_best.pth'))
    
    # Final plot
    fig = plot_training_progress(train_losses, train_accs, val_losses, val_accs, 
                              title=f'{model_name} Training Progress (Final)')
    plt.savefig(f'{model_name}_progress_final.png')
    
    # Print training summary
    print(f'\n{model_name} Training completed in {elapsed/60:.2f} minutes')
    print(f'Best validation accuracy: {best_acc:.2f}%')
    
    return model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_acc': best_acc
    }

### 3.3 Model Evaluation

In [None]:
# Evaluate model with option to return metrics
def evaluate_model(model, testloader, return_metrics=False):
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    running_loss = 0.0
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    test_loss = running_loss / len(testloader)
    
    if not return_metrics:
        print(f'Accuracy on test set: {accuracy:.2f}%')
        return accuracy
    else:
        return test_loss, accuracy

## 4. Knowledge Distillation Implementation

### 4.1 Distillation Loss Function
This function implements the knowledge distillation loss, which combines:
1. A hard loss on the actual labels
2. A soft loss based on the teacher model's logits

In [None]:
# Knowledge distillation loss function
def distillation_loss(outputs, labels, teacher_outputs, class_mapping, T=2.0, alpha=0.5):
    """
    Modified distillation loss to handle different output dimensions
    
    Args:
        outputs: specialist model outputs (e.g., 4 classes)
        labels: remapped labels for specialist classes
        teacher_outputs: generalist model outputs (10 classes)
        class_mapping: maps original classes to specialist classes
        T: temperature for distillation
        alpha: weight for hard vs soft loss
    """
    # Hard loss with actual labels
    hard_loss = nn.CrossEntropyLoss()(outputs, labels)
    
    # Extract only the relevant logits from teacher outputs
    # Create reverse mapping
    reverse_mapping = {v: k for k, v in class_mapping.items()}
    specialist_class_indices = list(class_mapping.keys())
    
    # Extract only relevant teacher logits
    teacher_relevant_logits = torch.zeros_like(outputs).to(outputs.device)
    batch_size = outputs.size(0)
    
    for i in range(len(reverse_mapping)):
        original_idx = reverse_mapping[i]
        teacher_relevant_logits[:, i] = teacher_outputs[:, original_idx]
    
    # Compute soft targets KL divergence
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(outputs/T, dim=1),
        nn.functional.softmax(teacher_relevant_logits/T, dim=1)
    ) * (T * T)
    
    return alpha * hard_loss + (1 - alpha) * soft_loss

### 4.2 Specialist Model Training with Knowledge Distillation

In [None]:
# Train specialist models with knowledge distillation
def train_specialist_with_distillation(specialist_model, generalist_model, trainloader, testloader,
                                     class_mapping, epochs=90, lr=0.1, model_name="Specialist"):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(specialist_model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    specialist_model.to(device)
    generalist_model.to(device)
    generalist_model.eval()  # Set teacher model to evaluation mode
    
    # Track metrics
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    best_acc = 0.0
    start_time = time.time()
    
    for epoch in range(epochs):
        specialist_model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}')
        for i, data in enumerate(progress_bar):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Remap labels for specialist classes
            specialist_labels = torch.tensor([class_mapping[l.item()] for l in labels], 
                                          device=device)
            
            optimizer.zero_grad()
            
            # Get student outputs
            student_outputs = specialist_model(inputs)
            
            # Get teacher outputs
            with torch.no_grad():
                teacher_outputs = generalist_model(inputs)
            
            # Calculate distillation loss
            loss = distillation_loss(student_outputs, specialist_labels, teacher_outputs, class_mapping)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(specialist_labels).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': running_loss/(i+1), 
                'acc': 100.*correct/total,
                'lr': scheduler.get_last_lr()[0]
            })
        
        train_loss = running_loss/len(trainloader)
        train_acc = 100.*correct/total
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Evaluate on validation set
        val_loss, val_acc = evaluate_specialist(
            specialist_model, testloader, class_mapping, return_metrics=True
        )
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(specialist_model.state_dict(), f'{model_name}_best.pth')
        
        # Print progress
        elapsed = time.time() - start_time
        print(f'Epoch: {epoch+1}/{epochs} | '
              f'Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.2f}% | '
              f'Time: {elapsed/60:.2f}m | LR: {scheduler.get_last_lr()[0]:.6f}')
        
        # Update learning rate
        scheduler.step()
        
        # Plot every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            fig = plot_training_progress(train_losses, train_accs, val_losses, val_accs, 
                                       title=f'{model_name} Training Progress')
            plt.savefig(f'{model_name}_progress_epoch_{epoch+1}.png')
            plt.close(fig)
    
    # Load best model
    specialist_model.load_state_dict(torch.load(f'{model_name}_best.pth'))
    
    # Final plot
    fig = plot_training_progress(train_losses, train_accs, val_losses, val_accs, 
                               title=f'{model_name} Training Progress (Final)')
    plt.savefig(f'{model_name}_progress_final.png')
    
    # Print training summary
    print(f'\n{model_name} Training completed in {elapsed/60:.2f} minutes')
    print(f'Best validation accuracy: {best_acc:.2f}%')
    
    return specialist_model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_acc': best_acc
    }

### 4.3 Specialist Model Evaluation

In [None]:
# Evaluate specialist model
def evaluate_specialist(specialist_model, testloader, class_mapping, return_metrics=False):
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    running_loss = 0.0
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    specialist_model.to(device)
    specialist_model.eval()
    
    # Create reverse mapping
    reverse_mapping = {v: k for k, v in class_mapping.items()}
    class_subset = list(class_mapping.keys())
    
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Filter only samples belonging to specialist classes
            mask = torch.tensor([l.item() in class_subset for l in labels], device=device)
            if not mask.any():
                continue
            
            specialist_images = images[mask]
            original_labels = labels[mask]
            
            # Remap labels for specialist 
            specialist_labels = torch.tensor([class_mapping[l.item()] for l in original_labels], 
                                          device=device)
            
            outputs = specialist_model(specialist_images)
            loss = criterion(outputs, specialist_labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += specialist_labels.size(0)
            correct += (predicted == specialist_labels).sum().item()
    
    if total == 0:  # Avoid division by zero
        accuracy = 0
        test_loss = 0
    else:
        accuracy = 100 * correct / total
        test_loss = running_loss / (len(testloader) * len(class_subset) / 10)
    
    if not return_metrics:
        print(f'Specialist accuracy on relevant test set: {accuracy:.2f}%')
        return accuracy
    else:
        return test_loss, accuracy

## 5. Main Training Process

Finally, let's put everything together in the main function that runs the entire knowledge distillation process.

In [None]:
# Main function to run the distillation process
def main():
    # Create output directory
    os.makedirs('models', exist_ok=True)
    os.makedirs('plots', exist_ok=True)
    
    # Load data
    trainloader, testloader, classes = load_cifar10()
    
    # Define class groups for specialists
    class_groups = [
        [2, 3, 4, 5],  # Bird, Cat, Deer, Dog (animals)
        [0, 1, 8, 9]   # Plane, Car, Ship, Truck (vehicles)
    ]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Step 1: Create and train generalist model
    print("Creating generalist model...")
    generalist_model = create_generalist_model(num_classes=10)
    print(f"Generalist model parameters: {generalist_model.count_parameters():,}")
    
    print("Training generalist model...")
    generalist_model, gen_metrics = train_model(
        generalist_model, trainloader, testloader, epochs=10, lr=0.1, model_name="Generalist"
    )
    gen_accuracy = gen_metrics['best_acc']
    
    # Step 2: Train specialist models with knowledge distillation
    specialist_models = []
    class_mappings = []
    specialist_metrics = []
    
    for i, group in enumerate(class_groups):
        print(f"Training specialist model {i+1} for classes {group}...")
        
        # Create class mapping for this specialist
        class_mapping = {original_class: idx for idx, original_class in enumerate(group)}
        class_mappings.append(class_mapping)
        
        # Create dataset subset
        indices = []
        for j, (_, target) in enumerate(trainloader.dataset):
            if target in group:
                indices.append(j)
        
        subset = torch.utils.data.Subset(trainloader.dataset, indices)
        specialist_loader = torch.utils.data.DataLoader(
            subset, batch_size=128, shuffle=True, num_workers=4
        )
        
        # Create smaller specialist model
        specialist_model = create_specialist_model(num_classes=len(group))
        print(f"Specialist model {i+1} parameters: {specialist_model.count_parameters():,}")
        
        # Train specialist model
        specialist_model, spec_metrics = train_specialist_with_distillation(
            specialist_model, generalist_model, specialist_loader, testloader, 
            class_mapping, epochs=10, lr=0.1, model_name=f"Specialist_{i+1}"
        )
        specialist_models.append(specialist_model)
        specialist_metrics.append(spec_metrics)
    
    # Save the trained models
    print("\nSaving all models...")
    
    # Save generalist model
    torch.save({
        'model_state_dict': generalist_model.state_dict(),
        'accuracy': gen_accuracy
    }, 'models/generalist_model_complete.pth')
    
    # Save specialist models
    for i, model in enumerate(specialist_models):
        torch.save({
            'model_state_dict': model.state_dict(),
            'accuracy': specialist_metrics[i]['best_acc'],
            'class_mapping': class_mappings[i],
            'class_group': class_groups[i]
        }, f'models/specialist_model_{i+1}_complete.pth')
    
    print("Training complete!")

## 6. Run the Training Pipeline

Now let's execute our main function to train the models.

In [None]:
if __name__ == "__main__":
    main()

## 7. Summary

In this notebook, we've implemented knowledge distillation with custom ResNet models on the CIFAR-10 dataset. Here's a recap of what we've done:

1. **Architecture**: We created a custom ResNet architecture with configurable depth and width.
   
2. **Models**:
   - **Generalist Model**: A larger model trained on all 10 CIFAR-10 classes
   - **Specialist Models**: Smaller models trained on specific subsets of classes
   
3. **Knowledge Distillation**:
   - Used the generalist model as a teacher to train specialist models
   - Implemented a modified distillation loss that handles different output dimensions
   - The specialist models learn from both ground truth labels and the generalist model's soft targets
   
4. **Evaluation**:
   - The specialist models are evaluated only on their specific subset of classes
   - Training progress is visualized with plots of loss and accuracy

Knowledge distillation is a powerful technique for model compression and specialization. By using a larger teacher model to guide the training of smaller student models, we can create more efficient models for specific tasks while maintaining good performance.