In [None]:
!pip install torch torchvision matplotlib numpy

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
# Knowledge Distillation Implementation
# A simple and effective implementation for distilling knowledge from a teacher model to a student model

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import time
import copy
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define teacher and student models
class TeacherModel(nn.Module):
    """A larger model to act as the teacher"""
    def __init__(self):
        super(TeacherModel, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        # Pooling and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        # Feature extraction
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        # Flatten
        x = x.view(-1, 128 * 3 * 3)

        # Classification
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

    def get_features(self, x):
        """Get intermediate features for additional distillation"""
        features = []

        # Extract features from each layer
        x = F.relu(self.conv1(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv3(x))
        features.append(x)
        x = self.pool(x)

        x = x.view(-1, 128 * 3 * 3)
        x = F.relu(self.fc1(x))
        features.append(x)

        return features

class StudentModel(nn.Module):
    """A smaller model to be trained via knowledge distillation"""
    def __init__(self):
        super(StudentModel, self).__init__()
        # Convolutional layers (fewer filters)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Pooling
        self.pool = nn.MaxPool2d(2, 2)

        # Fully connected layers (smaller)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Feature extraction
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # Flatten
        x = x.view(-1, 32 * 7 * 7)

        # Classification
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

    def get_features(self, x):
        """Get intermediate features for additional distillation"""
        features = []

        # Extract features from each layer
        x = F.relu(self.conv1(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        features.append(x)
        x = self.pool(x)

        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        features.append(x)

        return features

# Knowledge distillation loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha  # Weight for distillation loss vs standard loss
        self.temperature = temperature  # Temperature for softening probability distributions

    def forward(self, student_logits, teacher_logits, labels):
        # Standard cross-entropy loss
        hard_loss = F.cross_entropy(student_logits, labels)

        # Distillation loss: KL-divergence between soft targets from teacher and student
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_prob = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (self.temperature ** 2)

        # Combine the two losses
        loss = (1 - self.alpha) * hard_loss + self.alpha * soft_loss

        return loss

# Feature distillation loss - optional enhancement
class FeatureDistillationLoss(nn.Module):
    def __init__(self, beta=0.1):
        super(FeatureDistillationLoss, self).__init__()
        self.beta = beta  # Weight for feature distillation

    def forward(self, student_features, teacher_features):
        # We'll implement a simple L2 distance for feature matching
        # For simplicity, we only use the last feature map from each
        loss = 0

        # Adapt student feature dimensions to match teacher's
        student_last_feature = student_features[-1]
        teacher_last_feature = teacher_features[-1]

        # Compute the mean squared error loss
        feat_loss = F.mse_loss(student_last_feature, teacher_last_feature)

        return self.beta * feat_loss

# Load MNIST dataset
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)
    return trainloader, testloader

# Train teacher model (standard training)
def train_teacher(model, trainloader, epochs=3):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("Training teacher model...")
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished training teacher model')
    return model

# Train student model with knowledge distillation
def train_student_with_distillation(student_model, teacher_model, trainloader,
                                   epochs=3, alpha=0.5, temperature=2.0, beta=0.0):
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval()  # Teacher model is fixed

    distill_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    feature_criterion = FeatureDistillationLoss(beta=beta) if beta > 0 else None
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)

    print("Training student model with distillation...")
    student_model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            # Get outputs from both models
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
                if beta > 0:
                    teacher_features = teacher_model.get_features(inputs)

            student_outputs = student_model(inputs)
            if beta > 0:
                student_features = student_model.get_features(inputs)

            # Compute distillation loss
            loss = distill_criterion(student_outputs, teacher_outputs, labels)

            # Add feature matching loss if requested
            if beta > 0:
                feature_loss = feature_criterion(student_features, teacher_features)
                loss += feature_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished training student model')
    return student_model

# Fine-tune student model on task-specific data
def fine_tune_student(model, trainloader, epochs=2):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Lower learning rate for fine-tuning

    print("Fine-tuning student model...")
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Fine-tuning Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished fine-tuning student model')
    return model

# Evaluate model accuracy
def evaluate_model(model, testloader):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Measure inference time
def measure_inference_time(model, testloader, num_batches=10):
    model.to(device)
    model.eval()

    # Warm-up
    for i, (images, _) in enumerate(testloader):
        if i > 5:
            break
        images = images.to(device)
        with torch.no_grad():
            _ = model(images)

    # Measure time
    start_time = time.time()
    batch_count = 0

    with torch.no_grad():
        for i, (images, _) in enumerate(testloader):
            if i >= num_batches:
                break
            images = images.to(device)
            _ = model(images)
            batch_count += 1

    end_time = time.time()
    avg_time = (end_time - start_time) / batch_count

    return avg_time

# Get model size
def get_model_size(model):
    torch.save(model.state_dict(), "temp_model.pt")
    size_mb = os.path.getsize("temp_model.pt") / (1024 * 1024)
    os.remove("temp_model.pt")
    return size_mb

# Save model predictions for further analysis
def save_predictions(model, testloader, filename):
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    np.savez(filename, predictions=np.array(all_preds), labels=np.array(all_labels))

# Visualize predictions
def plot_confusion_matrix(model_name, predictions_file):
    data = np.load(predictions_file)
    preds = data['predictions']
    labels = data['labels']

    from sklearn.metrics import confusion_matrix
    import seaborn as sns

    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(f'{model_name}_confusion_matrix.png')
    plt.close()

# Main function
def main():
    # Load data
    trainloader, testloader = load_data()

    # Create and train teacher model
    teacher_model = TeacherModel()
    teacher_params = count_parameters(teacher_model)
    print(f"Teacher model has {teacher_params:,} parameters")

    # Check if a pretrained model exists to save time
    if os.path.exists('teacher_model.pt'):
        print("Loading pre-trained teacher model...")
        teacher_model.load_state_dict(torch.load('teacher_model.pt'))
    else:
        teacher_model = train_teacher(teacher_model, trainloader, epochs=3)
        torch.save(teacher_model.state_dict(), 'teacher_model.pt')

    # Evaluate teacher model
    teacher_accuracy = evaluate_model(teacher_model, testloader)
    teacher_inference_time = measure_inference_time(teacher_model, testloader)
    teacher_size = get_model_size(teacher_model)

    print("\n--- Teacher Model Metrics ---")
    print(f"Accuracy: {teacher_accuracy:.2f}%")
    print(f"Parameters: {teacher_params:,}")
    print(f"Inference Time: {teacher_inference_time*1000:.2f} ms per batch")
    print(f"Model Size: {teacher_size:.2f} MB")

    # Create student model
    student_model = StudentModel()
    student_params = count_parameters(student_model)
    print(f"\nStudent model has {student_params:,} parameters")
    print(f"Parameter reduction: {(1 - student_params/teacher_params)*100:.1f}%")

    # Train student model without distillation (for comparison)
    standard_student = copy.deepcopy(student_model)
    if os.path.exists('standard_student.pt'):
        print("Loading pre-trained standard student model...")
        standard_student.load_state_dict(torch.load('standard_student.pt'))
    else:
        standard_student = train_teacher(standard_student, trainloader, epochs=3)
        torch.save(standard_student.state_dict(), 'standard_student.pt')

    # Evaluate standard student
    standard_student_accuracy = evaluate_model(standard_student, testloader)
    standard_student_time = measure_inference_time(standard_student, testloader)
    standard_student_size = get_model_size(standard_student)

    print("\n--- Standard Student Model Metrics ---")
    print(f"Accuracy: {standard_student_accuracy:.2f}%")
    print(f"Parameters: {student_params:,}")
    print(f"Inference Time: {standard_student_time*1000:.2f} ms per batch")
    print(f"Model Size: {standard_student_size:.2f} MB")

    # Train student with knowledge distillation
    distilled_student = copy.deepcopy(student_model)

    if os.path.exists('distilled_student.pt'):
        print("Loading pre-trained distilled student model...")
        distilled_student.load_state_dict(torch.load('distilled_student.pt'))
    else:
        distilled_student = train_student_with_distillation(
            distilled_student, teacher_model, trainloader,
            epochs=3, alpha=0.5, temperature=4.0)
        torch.save(distilled_student.state_dict(), 'distilled_student.pt')

    # Evaluate distilled student
    distilled_accuracy = evaluate_model(distilled_student, testloader)
    distilled_time = measure_inference_time(distilled_student, testloader)
    distilled_size = get_model_size(distilled_student)

    print("\n--- Distilled Student Model Metrics ---")
    print(f"Accuracy: {distilled_accuracy:.2f}%")
    print(f"Parameters: {student_params:,}")
    print(f"Inference Time: {distilled_time*1000:.2f} ms per batch")
    print(f"Model Size: {distilled_size:.2f} MB")

    # Fine-tune the distilled student
    fine_tuned_student = copy.deepcopy(distilled_student)

    if os.path.exists('fine_tuned_student.pt'):
        print("Loading pre-trained fine-tuned student model...")
        fine_tuned_student.load_state_dict(torch.load('fine_tuned_student.pt'))
    else:
        fine_tuned_student = fine_tune_student(fine_tuned_student, trainloader, epochs=2)
        torch.save(fine_tuned_student.state_dict(), 'fine_tuned_student.pt')

    # Evaluate fine-tuned student
    fine_tuned_accuracy = evaluate_model(fine_tuned_student, testloader)
    fine_tuned_time = measure_inference_time(fine_tuned_student, testloader)

    print("\n--- Fine-tuned Student Model Metrics ---")
    print(f"Accuracy: {fine_tuned_accuracy:.2f}%")
    print(f"Accuracy Improvement from Distillation: {distilled_accuracy - standard_student_accuracy:.2f}%")
    print(f"Accuracy Improvement from Fine-tuning: {fine_tuned_accuracy - distilled_accuracy:.2f}%")
    print(f"Inference Time: {fine_tuned_time*1000:.2f} ms per batch")

    # Save predictions for analysis
    save_predictions(teacher_model, testloader, 'teacher_preds.npz')
    save_predictions(standard_student, testloader, 'standard_student_preds.npz')
    save_predictions(distilled_student, testloader, 'distilled_student_preds.npz')
    save_predictions(fine_tuned_student, testloader, 'fine_tuned_student_preds.npz')

    # Comparison summary
    print("\n" + "="*50)
    print("KNOWLEDGE DISTILLATION SUMMARY")
    print("="*50)
    print(f"{'Model':<25} {'Accuracy':<10} {'Size (MB)':<12} {'Inference (ms)':<15} {'Parameters':<12}")
    print("-" * 75)
    print(f"{'Teacher':<25} {teacher_accuracy:<10.2f} {teacher_size:<12.2f} {teacher_inference_time*1000:<15.2f} {teacher_params:,}")
    print(f"{'Student (Standard)':<25} {standard_student_accuracy:<10.2f} {standard_student_size:<12.2f} {standard_student_time*1000:<15.2f} {student_params:,}")
    print(f"{'Student (Distilled)':<25} {distilled_accuracy:<10.2f} {distilled_size:<12.2f} {distilled_time*1000:<15.2f} {student_params:,}")
    print(f"{'Student (Fine-tuned)':<25} {fine_tuned_accuracy:<10.2f} {distilled_size:<12.2f} {fine_tuned_time*1000:<15.2f} {student_params:,}")

    # Visualization
    models = ['Teacher', 'Student\nStandard', 'Student\nDistilled', 'Student\nFine-tuned']
    accuracies = [teacher_accuracy, standard_student_accuracy, distilled_accuracy, fine_tuned_accuracy]
    params = [teacher_params, student_params, student_params, student_params]
    inference_times = [teacher_inference_time*1000, standard_student_time*1000,
                       distilled_time*1000, fine_tuned_time*1000]

    # Create bar charts
    plt.figure(figsize=(15, 10))

    # Accuracy comparison
    plt.subplot(2, 2, 1)
    plt.bar(models, accuracies, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Accuracy (%)')
    plt.ylabel('Accuracy')

    # Parameter comparison
    plt.subplot(2, 2, 2)
    plt.bar(models, params, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Parameters')
    plt.ylabel('Parameters')

    # Inference time comparison
    plt.subplot(2, 2, 3)
    plt.bar(models, inference_times, color=['blue', 'orange', 'green', 'red'])
    plt.title('Inference Time (ms)')
    plt.ylabel('Time (ms)')

    # Size comparison
    sizes = [teacher_size, standard_student_size, distilled_size, distilled_size]
    plt.subplot(2, 2, 4)
    plt.bar(models, sizes, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Size (MB)')
    plt.ylabel('Size (MB)')

    plt.tight_layout()
    plt.savefig('knowledge_distillation_comparison.png')
    plt.close()

    # Plot confusion matrices
    plot_confusion_matrix('Teacher', 'teacher_preds.npz')
    plot_confusion_matrix('Standard_Student', 'standard_student_preds.npz')
    plot_confusion_matrix('Distilled_Student', 'distilled_student_preds.npz')
    plot_confusion_matrix('Fine_Tuned_Student', 'fine_tuned_student_preds.npz')

    print("\nVisualization saved as 'knowledge_distillation_comparison.png'")
    print("Confusion matrices saved for each model.")

if __name__ == "__main__":
    main()

Using device: cpu


100%|██████████| 9.91M/9.91M [00:00<00:00, 38.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.46MB/s]

Teacher model has 688,138 parameters
Training teacher model...





Epoch 1, Batch 100, Loss: 0.606
Epoch 1, Batch 200, Loss: 0.158
Epoch 1, Batch 300, Loss: 0.110
Epoch 1, Batch 400, Loss: 0.086
Epoch 1, Batch 500, Loss: 0.072
Epoch 1, Batch 600, Loss: 0.074
Epoch 1, Batch 700, Loss: 0.070
Epoch 1, Batch 800, Loss: 0.067
Epoch 1, Batch 900, Loss: 0.054
Epoch 2, Batch 100, Loss: 0.043
Epoch 2, Batch 200, Loss: 0.030
Epoch 2, Batch 300, Loss: 0.049
Epoch 2, Batch 400, Loss: 0.042
Epoch 2, Batch 500, Loss: 0.041
Epoch 2, Batch 600, Loss: 0.041
Epoch 2, Batch 700, Loss: 0.049
Epoch 2, Batch 800, Loss: 0.037
Epoch 2, Batch 900, Loss: 0.041
Epoch 3, Batch 100, Loss: 0.030
Epoch 3, Batch 200, Loss: 0.030
Epoch 3, Batch 300, Loss: 0.028
Epoch 3, Batch 400, Loss: 0.030
Epoch 3, Batch 500, Loss: 0.035
Epoch 3, Batch 600, Loss: 0.030
Epoch 3, Batch 700, Loss: 0.030
Epoch 3, Batch 800, Loss: 0.029
Epoch 3, Batch 900, Loss: 0.027
Finished training teacher model

--- Teacher Model Metrics ---
Accuracy: 98.80%
Parameters: 688,138
Inference Time: 75.43 ms per batch
M