# Hinton Version - "Vanilla" Model

## Architecture - always execute this

In [None]:
################## 1 ########################

%%writefile models.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class TeacherNet(nn.Module):
    """Large teacher network: 2 hidden layers of 1200 ReLU units"""
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


class StudentNet(nn.Module):
    """Small student network: 2 hidden layers of 800 ReLU units"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 800)
        self.fc2 = nn.Linear(800, 800)
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


def distillation_loss(student_logits, teacher_soft_targets, hard_targets,
                      temperature, alpha=0.5):
    """
    Combined loss for distillation.

    Args:
        student_logits: Raw outputs from student model
        teacher_soft_targets: Soft probabilities from teacher (at temperature T)
        hard_targets: Ground truth labels
        temperature: Temperature for distillation
        alpha: Weight for hard target loss (1-alpha is weight for soft targets)

    Returns:
        Combined loss
    """
    # Soft target loss: KL divergence between student and teacher (both at temperature T)
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    soft_loss = F.kl_div(student_soft, teacher_soft_targets, reduction='batchmean')

    # Scale by T^2 as per paper (gradients scale as 1/T^2)
    soft_loss = soft_loss * (temperature ** 2)

    # Hard target loss: Standard cross-entropy (at temperature 1)
    hard_loss = F.cross_entropy(student_logits, hard_targets)

    # Weighted combination
    return alpha * hard_loss + (1 - alpha) * soft_loss

Writing models.py


In [None]:
################## 2 ########################

import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import TeacherNet, StudentNet, distillation_loss


def train_teacher(model, train_loader, epochs=10, lr=0.001):
    """Train the large teacher model with dropout regularization"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_normal(model, train_loader, epochs=10, lr=0.001):
    """Train student model normally (baseline)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (normal) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_distilled(student, teacher, train_loader, temperature=20,
                           alpha=0.1, epochs=10, lr=0.001):
    """Train student model using knowledge distillation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()  # Teacher is frozen

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Get teacher's soft targets at temperature T
            with torch.no_grad():
                teacher_soft_targets, _ = teacher(data, temperature=temperature)

            # Get student outputs
            _, student_logits = student(data)

            # Compute distillation loss
            loss = distillation_loss(student_logits, teacher_soft_targets,
                                    target, temperature, alpha)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (distilled T={temperature}) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return student


def evaluate(model, test_loader):
    """Evaluate model accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, _ = model(data, temperature=1.0)
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors


## Section 3 of paper replicated - NO NEED TO EXECUTE

### Version 1: T = 20, Dropout = 0.1, Epochs = 20, lr = 0.001, batchsize = 512

In [None]:
def main():

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Rest of your code...

    dropout = 0.2
    epochs = 20
    temp = 20
    lr = 0.001



    # Data loading with jittering (up to 2 pixels in any direction)
    transform_train = transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                         num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                        num_workers=2, pin_memory=True)

    print("=" * 60)
    print("MNIST Knowledge Distillation Experiment (Section 3)")
    print("=" * 60)

    # 1. Train large teacher network (1200-1200 units with dropout)
    print("\n[1/4] Training Teacher Network (2x1200 ReLU units + dropout)...")
    teacher = TeacherNet(dropout_rate=dropout)
    teacher = train_teacher(teacher, train_loader, epochs=epochs, lr=lr)
    teacher_acc, teacher_err = evaluate(teacher, test_loader)
    print(f"✓ Teacher: {teacher_acc:.2f}% accuracy ({teacher_err} errors)")

    # 2. Train small student network normally (800-800 units, no regularization)
    print("\n[2/4] Training Student Network Normally (2x800 ReLU units, no regularization)...")
    student_normal = StudentNet()
    student_normal = train_student_normal(student_normal, train_loader, epochs=epochs, lr=lr)
    student_normal_acc, student_normal_err = evaluate(student_normal, test_loader)
    print(f"✓ Student (normal): {student_normal_acc:.2f}% accuracy ({student_normal_err} errors)")

    # 3. Train small student network with distillation (T=20)
    print("\n[3/4] Training Student Network with Distillation (T=20, alpha=0.1)...")
    student_distilled = StudentNet()
    student_distilled = train_student_distilled(student_distilled, teacher, train_loader,
                                                temperature=temp, alpha=0.1, epochs=epochs, lr=lr)
    student_distilled_acc, student_distilled_err = evaluate(student_distilled, test_loader)
    print(f"✓ Student (distilled): {student_distilled_acc:.2f}% accuracy ({student_distilled_err} errors)")

    # 4. Results summary
    print("\n" + "=" * 60)
    print("RESULTS SUMMARY")
    print("=" * 60)
    print(f"Teacher (2x1200 + dropout):        {teacher_err:3d} test errors ({teacher_acc:.2f}%)")
    print(f"Student normal (2x800):            {student_normal_err:3d} test errors ({student_normal_acc:.2f}%)")
    print(f"Student distilled (2x800, T=20):   {student_distilled_err:3d} test errors ({student_distilled_acc:.2f}%)")
    print("\nPaper reported:")
    print("Teacher:           67 test errors")
    print("Student normal:   146 test errors")
    print("Student distilled: 74 test errors")
    print("=" * 60)

    # 5. Quick inference speed test
    print("\n" + "=" * 60)
    print("INFERENCE SPEED TEST")
    print("=" * 60)

    import time
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sample = next(iter(test_loader))[0][:1].to(device)  # Single sample

    def count_params(model):
        return sum(p.numel() for p in model.parameters())

    def benchmark(model, name):
        model.eval()
        params = count_params(model)
        with torch.no_grad():
            for _ in range(10): model(sample)  # Warm-up
            if torch.cuda.is_available(): torch.cuda.synchronize()
            start = time.time()
            for _ in range(1000): model(sample)
            if torch.cuda.is_available(): torch.cuda.synchronize()
            ms = (time.time() - start) / 1000 * 1000
        print(f"{name:25s}: {ms:.3f} ms/sample  ({params:,} params)")
        return ms

    t_time = benchmark(teacher, "Teacher (2x1200)")
    s_time = benchmark(student_normal, "Student (2x800)")
    d_time = benchmark(student_distilled, "Distilled (2x800)")

    print(f"\nSpeedup: {t_time/s_time:.2f}x faster with {count_params(student_normal)/count_params(teacher)*100:.1f}% params")
    print("=" * 60)


    # 6. save models architecture and weights

    # Save models
    print("\n" + "=" * 60)
    print("SAVING MODELS")
    print("=" * 60)

    import os
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)

    torch.save({
        'model_state_dict': teacher.state_dict(),
        'dropout_rate': dropout
    }, f'{save_dir}/teacher_model.pth')

    torch.save({
        'model_state_dict': student_normal.state_dict(),
    }, f'{save_dir}/student_normal_model.pth')

    torch.save({
        'model_state_dict': student_distilled.state_dict(),
    }, f'{save_dir}/student_distilled_model.pth')

    print(f"✓ Models saved to {save_dir}/")

    # Download in Colab
    try:
        from google.colab import files
        import shutil

        # Create zip file
        shutil.make_archive('mnist_distillation_models', 'zip', save_dir)
        files.download('mnist_distillation_models.zip')
        print("✓ Models downloaded as mnist_distillation_models.zip!")
    except:
        print("Not running in Colab - models saved locally only")

    print("=" * 60)

if __name__ == "__main__":
    main()

Using device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB
MNIST Knowledge Distillation Experiment (Section 3)

[1/4] Training Teacher Network (2x1200 ReLU units + dropout)...
Teacher Epoch 1/20, Loss: 0.4231
Teacher Epoch 2/20, Loss: 0.1569
Teacher Epoch 3/20, Loss: 0.1245
Teacher Epoch 4/20, Loss: 0.1110
Teacher Epoch 5/20, Loss: 0.0983
Teacher Epoch 6/20, Loss: 0.0941
Teacher Epoch 7/20, Loss: 0.0893
Teacher Epoch 8/20, Loss: 0.0833
Teacher Epoch 9/20, Loss: 0.0791
Teacher Epoch 10/20, Loss: 0.0774
Teacher Epoch 11/20, Loss: 0.0757
Teacher Epoch 12/20, Loss: 0.0712
Teacher Epoch 13/20, Loss: 0.0743
Teacher Epoch 14/20, Loss: 0.0718
Teacher Epoch 15/20, Loss: 0.0706
Teacher Epoch 16/20, Loss: 0.0677
Teacher Epoch 17/20, Loss: 0.0647
Teacher Epoch 18/20, Loss: 0.0678
Teacher Epoch 19/20, Loss: 0.0680
Teacher Epoch 20/20, Loss: 0.0640
✓ Teacher: 98.95% accuracy (105 errors)

[2/4] Training Student Network Normally (2x800 ReLU units, no regularization)...
Student (normal) Epoch 1/20, Loss: 0

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✓ Models downloaded as mnist_distillation_models.zip!


# Example of loading trained models' params and feeding samples through them

In [None]:
import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import TeacherNet, StudentNet, distillation_loss
from google.colab import drive
drive.mount('/content/drive')


# assuming new session so loading MNIST dataset from scratch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


# data loading with jittering (up to 2 pixels in any direction)
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                      num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                    num_workers=2, pin_memory=True)


# accessing the trained models

teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
student_distilled_checkpoint_path = '/content/drive/MyDrive/Distillations/student_distilled_model.pth'

checkpoint = torch.load(teacher_checkpoint_path)
teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
teacher.load_state_dict(checkpoint['model_state_dict'])


checkpoint = torch.load(student_distilled_checkpoint_path)
student_distilled = StudentNet().to(device)
student_distilled.load_state_dict(checkpoint['model_state_dict'])

# feeding each model one sample to see that it works

sample_data, sample_label = next(iter(test_loader))
single_image = sample_data[0:1]  # Take first image, keep batch dimension
true_label = sample_label[0].item()
print(f"True label: {true_label}")


# Move to device
single_image = single_image.to(device)


#Feed through teacher
teacher.eval()
with torch.no_grad():
    teacher_probs, teacher_logits = teacher(single_image)
    teacher_pred = teacher_probs.argmax(dim=1).item()
    teacher_confidence = teacher_probs[0, teacher_pred].item()



print(f"\nTeacher prediction: {teacher_pred} (confidence: {teacher_confidence:.4f})")
print(f"Teacher probabilities: {teacher_probs[0].cpu().numpy()}")

student_distilled.eval()
with torch.no_grad():
    distilled_probs, distilled_logits = student_distilled(single_image)
    distilled_pred = distilled_probs.argmax(dim=1).item()
    distilled_confidence = distilled_probs[0, distilled_pred].item()

print(f"\nStudent (distilled) prediction: {distilled_pred} (confidence: {distilled_confidence:.4f})")
print(f"Distilled probabilities: {distilled_probs[0].cpu().numpy()}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 513kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.70MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.9MB/s]


True label: 7

Teacher prediction: 7 (confidence: 1.0000)
Teacher probabilities: [4.0224847e-08 8.3918567e-06 1.2510760e-05 2.7998278e-06 2.1066224e-07
 9.7254912e-08 1.6154750e-10 9.9995697e-01 2.4853202e-07 1.8769684e-05]

Student (distilled) prediction: 7 (confidence: 1.0000)
Distilled probabilities: [5.6548139e-08 8.8660936e-06 1.6182468e-05 2.9288403e-06 1.7997323e-07
 1.5101638e-07 2.1161005e-10 9.9995589e-01 2.0029137e-07 1.5544521e-05]


# FitNets Version

## architecture

In [None]:
################## 3 ########################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class FitNetStudent(nn.Module):
    """Thin and deep student: 4 hidden layers, ~8% of teacher params"""
    def __init__(self):
        super().__init__()
        # 4 hidden layers with fewer units (teacher has 2x1200)
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 300)  # This is the guided layer (middle)
        self.fc3 = nn.Linear(300, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        h = F.relu(self.fc2(x))  # Guided layer activation
        x = F.relu(self.fc3(h))
        x = F.relu(self.fc4(x))
        logits = self.fc5(x)
        return F.softmax(logits / temperature, dim=1), logits

    def forward_with_hint(self, x):
        """Return both output and guided layer activation"""
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        guided = F.relu(self.fc2(x))  # Guided layer
        x = F.relu(self.fc3(guided))
        x = F.relu(self.fc4(x))
        logits = self.fc5(x)
        return logits, guided


class Regressor(nn.Module):
    """Maps student guided layer (300) to teacher hint layer (1200)"""
    def __init__(self, student_dim=300, teacher_dim=1200):
        super().__init__()
        self.fc = nn.Linear(student_dim, teacher_dim)

    def forward(self, x):
        return F.relu(self.fc(x))



def get_teacher_hint(teacher, x):
    """Extract teacher's first hidden layer activation (hint)"""
    x = x.view(-1, 784)
    hint = F.relu(teacher.fc1(x))  # First hidden layer
    return hint


def train_stage1_hints(student, teacher, regressor, train_loader, epochs=10, lr=0.001):
    """Stage 1: Train student (up to guided layer) + regressor to match teacher hint"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    regressor = regressor.to(device)
    teacher.eval()

    # Only optimize student layers up to guided + regressor
    optimizer = optim.Adam(list(student.parameters()) + list(regressor.parameters()), lr=lr)

    student.train()
    regressor.train()

    for epoch in range(epochs):
        total_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            # Get teacher hint (first hidden layer)
            with torch.no_grad():
                teacher_hint = get_teacher_hint(teacher, data)

            # Get student guided layer and pass through regressor
            _, student_guided = student.forward_with_hint(data)
            student_prediction = regressor(student_guided)

            # L2 loss between regressor output and teacher hint
            loss = F.mse_loss(student_prediction, teacher_hint)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 1 Epoch {epoch+1}/{epochs}, Hint Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_stage2_kd(student, teacher, train_loader, temperature=20, alpha=0.1, epochs=10, lr=0.001):
    """Stage 2: Standard KD training (reuse from original code)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Teacher soft targets
            with torch.no_grad():
                teacher_soft, _ = teacher(data, temperature=temperature)

            # Student outputs
            _, student_logits = student(data)

            # Distillation loss
            student_soft = F.log_softmax(student_logits / temperature, dim=1)
            soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
            hard_loss = F.cross_entropy(student_logits, target)
            loss = alpha * hard_loss + (1 - alpha) * soft_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 2 Epoch {epoch+1}/{epochs}, KD Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_fitnet(teacher, train_loader, epochs_stage1=10, epochs_stage2=20, lr=0.001, temp=20):
    """Complete FitNet training: Stage 1 (hints) → Stage 2 (KD)"""
    print("\n[FitNet Stage 1] Training with hints from teacher...")
    student = FitNetStudent()
    regressor = Regressor(student_dim=300, teacher_dim=1200)

    student = train_stage1_hints(student, teacher, regressor, train_loader,
                                 epochs=epochs_stage1, lr=lr)

    print("\n[FitNet Stage 2] Knowledge distillation...")
    student = train_stage2_kd(student, teacher, train_loader,
                             temperature=temp, alpha=0.1, epochs=epochs_stage2, lr=lr)

    return student


In [None]:
# Load trained teacher
teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
checkpoint = torch.load(teacher_checkpoint_path)
teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
teacher.load_state_dict(checkpoint['model_state_dict'])

# Train FitNet
fitnet_student = train_fitnet(teacher, train_loader,
                               epochs_stage1=10,
                               epochs_stage2=20,
                               temp=20)




[FitNet Stage 1] Training with hints from teacher...
Stage 1 Epoch 1/10, Hint Loss: 0.0151
Stage 1 Epoch 2/10, Hint Loss: 0.0070
Stage 1 Epoch 3/10, Hint Loss: 0.0064
Stage 1 Epoch 4/10, Hint Loss: 0.0062
Stage 1 Epoch 5/10, Hint Loss: 0.0060
Stage 1 Epoch 6/10, Hint Loss: 0.0060
Stage 1 Epoch 7/10, Hint Loss: 0.0059
Stage 1 Epoch 8/10, Hint Loss: 0.0059
Stage 1 Epoch 9/10, Hint Loss: 0.0059
Stage 1 Epoch 10/10, Hint Loss: 0.0058

[FitNet Stage 2] Knowledge distillation...
Stage 2 Epoch 1/20, KD Loss: 4.4499
Stage 2 Epoch 2/20, KD Loss: 1.0118
Stage 2 Epoch 3/20, KD Loss: 0.5305
Stage 2 Epoch 4/20, KD Loss: 0.3713
Stage 2 Epoch 5/20, KD Loss: 0.2996
Stage 2 Epoch 6/20, KD Loss: 0.2565
Stage 2 Epoch 7/20, KD Loss: 0.2246
Stage 2 Epoch 8/20, KD Loss: 0.2018
Stage 2 Epoch 9/20, KD Loss: 0.1849
Stage 2 Epoch 10/20, KD Loss: 0.1715
Stage 2 Epoch 11/20, KD Loss: 0.1619
Stage 2 Epoch 12/20, KD Loss: 0.1551
Stage 2 Epoch 13/20, KD Loss: 0.1462
Stage 2 Epoch 14/20, KD Loss: 0.1427
Stage 2 Epoc

In [None]:
# Evaluate
fitnet_acc, fitnet_err = evaluate(fitnet_student, test_loader)
print(f"FitNet: {fitnet_acc:.2f}% accuracy ({fitnet_err} errors)")

import os
save_dir = './saved_models'
os.makedirs(save_dir, exist_ok=True)

# Save FitNet student model
torch.save({
    'model_state_dict': fitnet_student.state_dict(),
}, f'{save_dir}/fitnet_student_model.pth')

print(f"✓ FitNet model saved to {save_dir}/")

FitNet: 98.83% accuracy (117 errors)
✓ FitNet model saved to ./saved_models/


# Relational KD


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from models import TeacherNet, StudentNet

# --- 1. Define Helper Functions Locally to Avoid Import Errors ---

def evaluate(model, test_loader):
    """
    Evaluate model accuracy.
    (Defined locally to ensure it works even if not in models.py)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Support both models that return (output, feature) and just (output)
            out = model(data, temperature=1.0)
            if isinstance(out, tuple):
                outputs = out[0]
            else:
                outputs = out

            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors

def rkd_distance_loss(student_emb, teacher_emb):
    """
    Relational Knowledge Distillation - Distance Loss
    Forces student to mimic the pairwise distances found in the teacher's embedding space.
    """
    # Compute pairwise distance matrices (batch_size x batch_size)
    # p=2 means Euclidean distance
    t_dist = torch.cdist(teacher_emb, teacher_emb, p=2)
    s_dist = torch.cdist(student_emb, student_emb, p=2)

    # Normalize distances by the mean of the non-zero elements
    # (This makes the loss scale-invariant)
    t_mean = t_dist[t_dist > 0].mean()
    s_mean = s_dist[s_dist > 0].mean()

    t_dist_norm = t_dist / t_mean
    s_dist_norm = s_dist / s_mean

    # The loss is the Huber loss (smooth L1) between the normalized distance matrices
    loss = F.smooth_l1_loss(s_dist_norm, t_dist_norm)
    return loss


def get_features(model, x, is_teacher=False):
    """
    Manually run forward pass up to the penultimate layer to get embeddings.
    """
    x = x.view(-1, 784)

    if isinstance(model, TeacherNet):
        # Teacher: fc1 -> relu -> dropout -> fc2 -> relu -> dropout -> [EMBEDDING] -> fc3
        x = F.relu(model.fc1(x))
        x = model.dropout(x)
        x = F.relu(model.fc2(x))
        # We capture the features here (after 2nd ReLU, before final dropout/classifier)
        return x

    elif isinstance(model, StudentNet):
        # Student: fc1 -> relu -> fc2 -> relu -> [EMBEDDING] -> fc3
        x = F.relu(model.fc1(x))
        x = F.relu(model.fc2(x))
        return x

    return x


def train_student_rkd(student, teacher, train_loader, epochs=10, lr=0.001, beta=1.0):
    """
    Train student using RKD (Distance) + Cross Entropy.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval() # Freeze teacher

    optimizer = optim.Adam(student.parameters(), lr=lr)

    print(f"Training RKD Student (Beta={beta})...")

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        total_rkd_loss = 0
        total_task_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # 1. Get Teacher Embeddings
            with torch.no_grad():
                teacher_emb = get_features(teacher, data, is_teacher=True)

            # 2. Get Student Embeddings and Logits
            student_emb = get_features(student, data, is_teacher=False)
            student_logits = student.fc3(student_emb)

            # 3. Calculate Losses
            task_loss = F.cross_entropy(student_logits, target)
            rkd_loss_val = rkd_distance_loss(student_emb, teacher_emb)

            # Combined Loss
            loss = task_loss + (beta * rkd_loss_val)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_rkd_loss += rkd_loss_val.item()
            total_task_loss += task_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_rkd = total_rkd_loss / len(train_loader)
        print(f"RKD Epoch {epoch+1}/{epochs} | Total: {avg_loss:.4f} | RKD: {avg_rkd:.4f} | Task: {total_task_loss/len(train_loader):.4f}")

    return student


In [None]:

# --- 2. Main Execution ---

def main_rkd():
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import os

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True)

    # Load Pre-trained Teacher
    teacher_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'

    if os.path.exists(teacher_path):
        print(f"Loading teacher from {teacher_path}")
        checkpoint = torch.load(teacher_path)
        teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
        teacher.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("Teacher model not found! Please run the first section to train/save the teacher.")
        return

    # Initialize RKD Student (Same architecture as Normal Student: 2x800)
    student_rkd = StudentNet().to(device)

    # Train with RKD
    student_rkd = train_student_rkd(student_rkd, teacher, train_loader, epochs=20, lr=0.001, beta=100)

    # Evaluate
    rkd_acc, rkd_err = evaluate(student_rkd, test_loader)
    print("\n" + "=" * 60)
    print(f"RKD Student Results: {rkd_acc:.2f}% accuracy ({rkd_err} errors)")
    print("=" * 60)

    # Save
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)
    torch.save({
        'model_state_dict': student_rkd.state_dict(),
    }, f'{save_dir}/student_rkd_model.pth')
    print(f"✓ RKD model saved to {save_dir}/")

if __name__ == "__main__":
    main_rkd()

Using device: cuda
Loading teacher from /content/drive/MyDrive/Distillations/teacher_model.pth
Training RKD Student (Beta=100)...
RKD Epoch 1/20 | Total: 0.9601 | RKD: 0.0042 | Task: 0.5376
RKD Epoch 2/20 | Total: 0.2273 | RKD: 0.0013 | Task: 0.1019
RKD Epoch 3/20 | Total: 0.1632 | RKD: 0.0009 | Task: 0.0712
RKD Epoch 4/20 | Total: 0.1397 | RKD: 0.0008 | Task: 0.0616
RKD Epoch 5/20 | Total: 0.1213 | RKD: 0.0007 | Task: 0.0517
RKD Epoch 6/20 | Total: 0.1146 | RKD: 0.0007 | Task: 0.0486
RKD Epoch 7/20 | Total: 0.1066 | RKD: 0.0006 | Task: 0.0446
RKD Epoch 8/20 | Total: 0.0998 | RKD: 0.0006 | Task: 0.0407
RKD Epoch 9/20 | Total: 0.0959 | RKD: 0.0006 | Task: 0.0390
RKD Epoch 10/20 | Total: 0.0924 | RKD: 0.0005 | Task: 0.0382
RKD Epoch 11/20 | Total: 0.0874 | RKD: 0.0005 | Task: 0.0357
RKD Epoch 12/20 | Total: 0.0838 | RKD: 0.0005 | Task: 0.0344
RKD Epoch 13/20 | Total: 0.0797 | RKD: 0.0005 | Task: 0.0318
RKD Epoch 14/20 | Total: 0.0777 | RKD: 0.0005 | Task: 0.0303
RKD Epoch 15/20 | Total: 

# Algorithm Change


## Vanilla

In [None]:
%%writefile models.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class TeacherNet(nn.Module):
    """Large teacher network: 2 hidden layers of 1200 ReLU units"""
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


class StudentNet(nn.Module):
    """Small student network: 2 hidden layers of 800 ReLU units"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 800)
        self.bn1 = nn.BatchNorm1d(800)      # NEW: BatchNorm after first linear
        self.fc2 = nn.Linear(800, 800)
        self.bn2 = nn.BatchNorm1d(800)      # NEW: BatchNorm after second linear
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))   # fc1 → bn1 → relu
        identity = x                         # NEW: Save for residual
        x = F.relu(self.bn2(self.fc2(x))) + identity  # fc2 → bn2 → relu → ADD identity
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


def distillation_loss(student_logits, teacher_soft_targets, hard_targets,
                      temperature, alpha=0.5):
    """
    Combined loss for distillation.

    Args:
        student_logits: Raw outputs from student model
        teacher_soft_targets: Soft probabilities from teacher (at temperature T)
        hard_targets: Ground truth labels
        temperature: Temperature for distillation
        alpha: Weight for hard target loss (1-alpha is weight for soft targets)

    Returns:
        Combined loss
    """
    # Soft target loss: KL divergence between student and teacher (both at temperature T)
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    soft_loss = F.kl_div(student_soft, teacher_soft_targets, reduction='batchmean')

    # Scale by T^2 as per paper (gradients scale as 1/T^2)
    soft_loss = soft_loss * (temperature ** 2)

    # Hard target loss: Standard cross-entropy (at temperature 1)
    hard_loss = F.cross_entropy(student_logits, hard_targets)

    # Weighted combination
    return alpha * hard_loss + (1 - alpha) * soft_loss

Writing models.py


In [None]:
################## 2 ########################

import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import TeacherNet, StudentNet, distillation_loss


def train_teacher(model, train_loader, epochs=10, lr=0.001):
    """Train the large teacher model with dropout regularization"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_normal(model, train_loader, epochs=10, lr=0.001):
    """Train student model normally (baseline)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (normal) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_distilled(student, teacher, train_loader, temperature=20,
                           alpha=0.1, epochs=10, lr=0.001):
    """Train student model using knowledge distillation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()  # Teacher is frozen

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Get teacher's soft targets at temperature T
            with torch.no_grad():
                teacher_soft_targets, _ = teacher(data, temperature=temperature)

            # Get student outputs
            _, student_logits = student(data)

            # Compute distillation loss
            loss = distillation_loss(student_logits, teacher_soft_targets,
                                    target, temperature, alpha)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (distilled T={temperature}) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return student


def evaluate(model, test_loader):
    """Evaluate model accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, _ = model(data, temperature=1.0)
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors


In [None]:
def main():

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Rest of your code...

    dropout = 0.2
    epochs = 20
    temp = 20
    lr = 0.001
    batchsize = 512


    # Data loading with jittering (up to 2 pixels in any direction)
    transform_train = transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                         num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                        num_workers=2, pin_memory=True)

    print("=" * 60)
    print("MNIST Knowledge Distillation Experiment (Section 3)")
    print("=" * 60)


    # -------------------------------------------------------------------------------- #

    # load the pretrained teacher model intead of retraining the teacher
    teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
    checkpoint = torch.load(teacher_checkpoint_path)
    teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
    teacher.load_state_dict(checkpoint['model_state_dict'])

    # -------------------------------------------------------------------------------- #

    # 2. Train small student network normally (800-800 units, no regularization)
    print("\n[2/4] Training Student Network Normally (2x800 ReLU units, no regularization)...")
    student_normal = StudentNet()
    student_normal = train_student_normal(student_normal, train_loader, epochs=epochs, lr=lr)
    student_normal_acc, student_normal_err = evaluate(student_normal, test_loader)
    print(f"✓ Student (normal): {student_normal_acc:.2f}% accuracy ({student_normal_err} errors)")

    # 3. Train small student network with distillation (T=20)
    print("\n[3/4] Training Student Network with Distillation (T=20, alpha=0.1)...")
    student_distilled = StudentNet()
    student_distilled = train_student_distilled(student_distilled, teacher, train_loader,
                                                temperature=temp, alpha=0.1, epochs=epochs, lr=lr)
    student_distilled_acc, student_distilled_err = evaluate(student_distilled, test_loader)
    print(f"✓ Student (distilled): {student_distilled_acc:.2f}% accuracy ({student_distilled_err} errors)")


    # 6. save models architecture and weights

    # Save models
    print("\n" + "=" * 60)
    print("SAVING MODELS")
    print("=" * 60)

    import os
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)

    torch.save({
        'model_state_dict': student_normal.state_dict(),
    }, f'{save_dir}/student_normal_model_after_algo_change.pth')

    torch.save({
        'model_state_dict': student_distilled.state_dict(),
    }, f'{save_dir}/student_distilled_model_after_algo_change.pth')

    print(f"✓ Models saved to {save_dir}/")

    # Download in Colab
    try:
        from google.colab import files
        import shutil

        # Create zip file
        shutil.make_archive('mnist_distillation_models', 'zip', save_dir)
        files.download('mnist_distillation_models.zip')
        print("✓ Models downloaded as mnist_distillation_models.zip!")
    except:
        print("Not running in Colab - models saved locally only")

    print("=" * 60)

if __name__ == "__main__":
    main()

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
GPU Memory: 42.47 GB
MNIST Knowledge Distillation Experiment (Section 3)

[2/4] Training Student Network Normally (2x800 ReLU units, no regularization)...
Student (normal) Epoch 1/20, Loss: 0.2686
Student (normal) Epoch 2/20, Loss: 0.1262
Student (normal) Epoch 3/20, Loss: 0.0982
Student (normal) Epoch 4/20, Loss: 0.0867
Student (normal) Epoch 5/20, Loss: 0.0739
Student (normal) Epoch 6/20, Loss: 0.0673
Student (normal) Epoch 7/20, Loss: 0.0624
Student (normal) Epoch 8/20, Loss: 0.0589
Student (normal) Epoch 9/20, Loss: 0.0549
Student (normal) Epoch 10/20, Loss: 0.0507
Student (normal) Epoch 11/20, Loss: 0.0485
Student (normal) Epoch 12/20, Loss: 0.0444
Student (normal) Epoch 13/20, Loss: 0.0449
Student (normal) Epoch 14/20, Loss: 0.0406
Student (normal) Epoch 15/20, Loss: 0.0414
Student (normal) Epoch 16/20, Loss: 0.0389
Student (normal) Epoch 17/20, Loss: 0.0375
Student (normal) Epoch 18/20, Loss: 0.0341
Student (normal) Epoch 19/20, Loss

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✓ Models downloaded as mnist_distillation_models.zip!


## FitNets

In [None]:
################## 3 ########################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class FitNetStudent(nn.Module):
    """Thin and deep student: 4 hidden layers with ResNet + BatchNorm"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 300)
        self.bn1 = nn.BatchNorm1d(300)
        self.fc2 = nn.Linear(300, 300)
        self.bn2 = nn.BatchNorm1d(300)
        self.fc3 = nn.Linear(300, 300)
        self.bn3 = nn.BatchNorm1d(300)
        self.fc4 = nn.Linear(300, 300)
        self.bn4 = nn.BatchNorm1d(300)
        self.fc5 = nn.Linear(300, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))

        identity = x
        x = F.relu(self.bn2(self.fc2(x))) + identity

        identity = x
        x = F.relu(self.bn3(self.fc3(x))) + identity

        identity = x
        x = F.relu(self.bn4(self.fc4(x))) + identity

        logits = self.fc5(x)
        return F.softmax(logits / temperature, dim=1), logits

    def forward_with_hint(self, x):
        """Return both output and guided layer activation"""
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))

        identity = x
        guided = F.relu(self.bn2(self.fc2(x))) + identity

        identity = guided
        x = F.relu(self.bn3(self.fc3(guided))) + identity

        identity = x
        x = F.relu(self.bn4(self.fc4(x))) + identity

        logits = self.fc5(x)
        return logits, guided


class Regressor(nn.Module):
    """Maps student guided layer (300) to teacher hint layer (1200)"""
    def __init__(self, student_dim=300, teacher_dim=1200):
        super().__init__()
        self.fc = nn.Linear(student_dim, teacher_dim)

    def forward(self, x):
        return F.relu(self.fc(x))



def get_teacher_hint(teacher, x):
    """Extract teacher's first hidden layer activation (hint)"""
    x = x.view(-1, 784)
    hint = F.relu(teacher.fc1(x))  # First hidden layer
    return hint


def train_stage1_hints(student, teacher, regressor, train_loader, epochs=10, lr=0.001):
    """Stage 1: Train student (up to guided layer) + regressor to match teacher hint"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    regressor = regressor.to(device)
    teacher.eval()

    # Only optimize student layers up to guided + regressor
    optimizer = optim.Adam(list(student.parameters()) + list(regressor.parameters()), lr=lr)

    student.train()
    regressor.train()

    for epoch in range(epochs):
        total_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            # Get teacher hint (first hidden layer)
            with torch.no_grad():
                teacher_hint = get_teacher_hint(teacher, data)

            # Get student guided layer and pass through regressor
            _, student_guided = student.forward_with_hint(data)
            student_prediction = regressor(student_guided)

            # L2 loss between regressor output and teacher hint
            loss = F.mse_loss(student_prediction, teacher_hint)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 1 Epoch {epoch+1}/{epochs}, Hint Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_stage2_kd(student, teacher, train_loader, temperature=20, alpha=0.1, epochs=10, lr=0.001):
    """Stage 2: Standard KD training (reuse from original code)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Teacher soft targets
            with torch.no_grad():
                teacher_soft, _ = teacher(data, temperature=temperature)

            # Student outputs
            _, student_logits = student(data)

            # Distillation loss
            student_soft = F.log_softmax(student_logits / temperature, dim=1)
            soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
            hard_loss = F.cross_entropy(student_logits, target)
            loss = alpha * hard_loss + (1 - alpha) * soft_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 2 Epoch {epoch+1}/{epochs}, KD Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_fitnet(teacher, train_loader, epochs_stage1=10, epochs_stage2=20, lr=0.001, temp=20):
    """Complete FitNet training: Stage 1 (hints) → Stage 2 (KD)"""
    print("\n[FitNet Stage 1] Training with hints from teacher...")
    student = FitNetStudent()
    regressor = Regressor(student_dim=300, teacher_dim=1200)

    student = train_stage1_hints(student, teacher, regressor, train_loader,
                                 epochs=epochs_stage1, lr=lr)

    print("\n[FitNet Stage 2] Knowledge distillation...")
    student = train_stage2_kd(student, teacher, train_loader,
                             temperature=temp, alpha=0.1, epochs=epochs_stage2, lr=lr)

    return student


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Data loading with jittering (up to 2 pixels in any direction)
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                      num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                    num_workers=2, pin_memory=True)

# Load trained teacher
teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
checkpoint = torch.load(teacher_checkpoint_path)
teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
teacher.load_state_dict(checkpoint['model_state_dict'])

# Train FitNet
fitnet_student = train_fitnet(teacher, train_loader,
                               epochs_stage1=10,
                               epochs_stage2=20,
                               temp=20)




[FitNet Stage 1] Training with hints from teacher...
Stage 1 Epoch 1/10, Hint Loss: 0.0353
Stage 1 Epoch 2/10, Hint Loss: 0.0297
Stage 1 Epoch 3/10, Hint Loss: 0.0289
Stage 1 Epoch 4/10, Hint Loss: 0.0284
Stage 1 Epoch 5/10, Hint Loss: 0.0280
Stage 1 Epoch 6/10, Hint Loss: 0.0275
Stage 1 Epoch 7/10, Hint Loss: 0.0270
Stage 1 Epoch 8/10, Hint Loss: 0.0267
Stage 1 Epoch 9/10, Hint Loss: 0.0264
Stage 1 Epoch 10/10, Hint Loss: 0.0262

[FitNet Stage 2] Knowledge distillation...
Stage 2 Epoch 1/20, KD Loss: 1.4957
Stage 2 Epoch 2/20, KD Loss: 0.4002
Stage 2 Epoch 3/20, KD Loss: 0.3248
Stage 2 Epoch 4/20, KD Loss: 0.2756
Stage 2 Epoch 5/20, KD Loss: 0.2670
Stage 2 Epoch 6/20, KD Loss: 0.2626
Stage 2 Epoch 7/20, KD Loss: 0.2448
Stage 2 Epoch 8/20, KD Loss: 0.2228
Stage 2 Epoch 9/20, KD Loss: 0.2269
Stage 2 Epoch 10/20, KD Loss: 0.2136
Stage 2 Epoch 11/20, KD Loss: 0.2256
Stage 2 Epoch 12/20, KD Loss: 0.2053
Stage 2 Epoch 13/20, KD Loss: 0.2008
Stage 2 Epoch 14/20, KD Loss: 0.1945
Stage 2 Epoc

In [None]:
# save
import os
save_dir = './saved_models'
os.makedirs(save_dir, exist_ok=True)

# Save FitNet student model
torch.save({
    'model_state_dict': fitnet_student.state_dict(),
}, f'{save_dir}/fitnet_student_model_after_algo_change.pth')

print(f"✓ FitNet model saved to {save_dir}/")

✓ FitNet model saved to ./saved_models/


## Relational KD

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from models import TeacherNet, StudentNet

# --- 1. Define Helper Functions Locally to Avoid Import Errors ---

def evaluate(model, test_loader):
    """
    Evaluate model accuracy.
    (Defined locally to ensure it works even if not in models.py)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Support both models that return (output, feature) and just (output)
            out = model(data, temperature=1.0)
            if isinstance(out, tuple):
                outputs = out[0]
            else:
                outputs = out

            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors

def rkd_distance_loss(student_emb, teacher_emb):
    """
    Relational Knowledge Distillation - Distance Loss
    Forces student to mimic the pairwise distances found in the teacher's embedding space.
    """
    # Compute pairwise distance matrices (batch_size x batch_size)
    # p=2 means Euclidean distance
    t_dist = torch.cdist(teacher_emb, teacher_emb, p=2)
    s_dist = torch.cdist(student_emb, student_emb, p=2)

    # Normalize distances by the mean of the non-zero elements
    # (This makes the loss scale-invariant)
    t_mean = t_dist[t_dist > 0].mean()
    s_mean = s_dist[s_dist > 0].mean()

    t_dist_norm = t_dist / t_mean
    s_dist_norm = s_dist / s_mean

    # The loss is the Huber loss (smooth L1) between the normalized distance matrices
    loss = F.smooth_l1_loss(s_dist_norm, t_dist_norm)
    return loss


def get_features(model, x, is_teacher=False):
    x = x.view(-1, 784)

    if isinstance(model, TeacherNet):
        x = F.relu(model.fc1(x))
        x = model.dropout(x)
        x = F.relu(model.fc2(x))
        return x                        # Teacher unchanged

    elif isinstance(model, StudentNet):
        x = F.relu(model.bn1(model.fc1(x)))      # fc1 → bn1 → relu
        identity = x                              # Save for residual
        x = F.relu(model.bn2(model.fc2(x))) + identity  # fc2 → bn2 → relu + skip
        return x                                  # Return pre-fc3 embeddings

    return x


def train_student_rkd(student, teacher, train_loader, epochs=10, lr=0.001, beta=1.0):
    """
    Train student using RKD (Distance) + Cross Entropy.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval() # Freeze teacher

    optimizer = optim.Adam(student.parameters(), lr=lr)

    print(f"Training RKD Student (Beta={beta})...")

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        total_rkd_loss = 0
        total_task_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # 1. Get Teacher Embeddings
            with torch.no_grad():
                teacher_emb = get_features(teacher, data, is_teacher=True)

            # 2. Get Student Embeddings and Logits
            student_emb = get_features(student, data, is_teacher=False)
            student_logits = student.fc3(student_emb)

            # 3. Calculate Losses
            task_loss = F.cross_entropy(student_logits, target)
            rkd_loss_val = rkd_distance_loss(student_emb, teacher_emb)

            # Combined Loss
            loss = task_loss + (beta * rkd_loss_val)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_rkd_loss += rkd_loss_val.item()
            total_task_loss += task_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_rkd = total_rkd_loss / len(train_loader)
        print(f"RKD Epoch {epoch+1}/{epochs} | Total: {avg_loss:.4f} | RKD: {avg_rkd:.4f} | Task: {total_task_loss/len(train_loader):.4f}")

    return student


In [None]:

# --- 2. Main Execution ---

def main_rkd():
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import os

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True)

    # Load Pre-trained Teacher
    teacher_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'

    if os.path.exists(teacher_path):
        print(f"Loading teacher from {teacher_path}")
        checkpoint = torch.load(teacher_path)
        teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
        teacher.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("Teacher model not found! Please run the first section to train/save the teacher.")
        return

    # Initialize RKD Student (Same architecture as Normal Student: 2x800)
    student_rkd = StudentNet().to(device)

    # Train with RKD
    student_rkd = train_student_rkd(student_rkd, teacher, train_loader, epochs=20, lr=0.001, beta=100)

    # Evaluate
    rkd_acc, rkd_err = evaluate(student_rkd, test_loader)
    print("\n" + "=" * 60)
    print(f"RKD Student Results: {rkd_acc:.2f}% accuracy ({rkd_err} errors)")
    print("=" * 60)

    # Save
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)
    torch.save({
        'model_state_dict': student_rkd.state_dict(),
    }, f'{save_dir}/student_rkd_model_after_algo_change.pth')
    print(f"✓ RKD model saved to {save_dir}/")

if __name__ == "__main__":
    main_rkd()

Using device: cuda
Loading teacher from /content/drive/MyDrive/Distillations/teacher_model.pth
Training RKD Student (Beta=100)...
RKD Epoch 1/20 | Total: 0.7007 | RKD: 0.0044 | Task: 0.2559
RKD Epoch 2/20 | Total: 0.2843 | RKD: 0.0019 | Task: 0.0986
RKD Epoch 3/20 | Total: 0.2317 | RKD: 0.0015 | Task: 0.0791
RKD Epoch 4/20 | Total: 0.2008 | RKD: 0.0013 | Task: 0.0669
RKD Epoch 5/20 | Total: 0.1789 | RKD: 0.0012 | Task: 0.0594
RKD Epoch 6/20 | Total: 0.1674 | RKD: 0.0011 | Task: 0.0549
RKD Epoch 7/20 | Total: 0.1563 | RKD: 0.0011 | Task: 0.0505
RKD Epoch 8/20 | Total: 0.1522 | RKD: 0.0010 | Task: 0.0490
RKD Epoch 9/20 | Total: 0.1380 | RKD: 0.0009 | Task: 0.0436
RKD Epoch 10/20 | Total: 0.1378 | RKD: 0.0009 | Task: 0.0438
RKD Epoch 11/20 | Total: 0.1299 | RKD: 0.0009 | Task: 0.0402
RKD Epoch 12/20 | Total: 0.1233 | RKD: 0.0008 | Task: 0.0394
RKD Epoch 13/20 | Total: 0.1208 | RKD: 0.0008 | Task: 0.0381
RKD Epoch 14/20 | Total: 0.1203 | RKD: 0.0008 | Task: 0.0370
RKD Epoch 15/20 | Total: 

# McNemar Test - before and after algo change to check if significant

## load architectures before

In [None]:
################## 1 ########################

%%writefile modelsbefore.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class TeacherNet(nn.Module):
    """Large teacher network: 2 hidden layers of 1200 ReLU units"""
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


class StudentNetbefore(nn.Module):
    """Small student network: 2 hidden layers of 800 ReLU units"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 800)
        self.fc2 = nn.Linear(800, 800)
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


def distillation_loss(student_logits, teacher_soft_targets, hard_targets,
                      temperature, alpha=0.5):
    """
    Combined loss for distillation.

    Args:
        student_logits: Raw outputs from student model
        teacher_soft_targets: Soft probabilities from teacher (at temperature T)
        hard_targets: Ground truth labels
        temperature: Temperature for distillation
        alpha: Weight for hard target loss (1-alpha is weight for soft targets)

    Returns:
        Combined loss
    """
    # Soft target loss: KL divergence between student and teacher (both at temperature T)
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    soft_loss = F.kl_div(student_soft, teacher_soft_targets, reduction='batchmean')

    # Scale by T^2 as per paper (gradients scale as 1/T^2)
    soft_loss = soft_loss * (temperature ** 2)

    # Hard target loss: Standard cross-entropy (at temperature 1)
    hard_loss = F.cross_entropy(student_logits, hard_targets)

    # Weighted combination
    return alpha * hard_loss + (1 - alpha) * soft_loss


Writing modelsbefore.py


In [None]:
################## 2 ########################

import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from modelsbefore import TeacherNet, StudentNetbefore, distillation_loss


def train_teacher(model, train_loader, epochs=10, lr=0.001):
    """Train the large teacher model with dropout regularization"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_normal(model, train_loader, epochs=10, lr=0.001):
    """Train student model normally (baseline)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (normal) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_distilled(student, teacher, train_loader, temperature=20,
                           alpha=0.1, epochs=10, lr=0.001):
    """Train student model using knowledge distillation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()  # Teacher is frozen

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Get teacher's soft targets at temperature T
            with torch.no_grad():
                teacher_soft_targets, _ = teacher(data, temperature=temperature)

            # Get student outputs
            _, student_logits = student(data)

            # Compute distillation loss
            loss = distillation_loss(student_logits, teacher_soft_targets,
                                    target, temperature, alpha)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (distilled T={temperature}) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return student


def evaluate(model, test_loader):
    """Evaluate model accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, _ = model(data, temperature=1.0)
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors


In [None]:
################## 3 ########################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class FitNetStudentbefore(nn.Module):
    """Thin and deep student: 4 hidden layers, ~8% of teacher params"""
    def __init__(self):
        super().__init__()
        # 4 hidden layers with fewer units (teacher has 2x1200)
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 300)  # This is the guided layer (middle)
        self.fc3 = nn.Linear(300, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        h = F.relu(self.fc2(x))  # Guided layer activation
        x = F.relu(self.fc3(h))
        x = F.relu(self.fc4(x))
        logits = self.fc5(x)
        return F.softmax(logits / temperature, dim=1), logits

    def forward_with_hint(self, x):
        """Return both output and guided layer activation"""
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        guided = F.relu(self.fc2(x))  # Guided layer
        x = F.relu(self.fc3(guided))
        x = F.relu(self.fc4(x))
        logits = self.fc5(x)
        return logits, guided


class Regressor(nn.Module):
    """Maps student guided layer (300) to teacher hint layer (1200)"""
    def __init__(self, student_dim=300, teacher_dim=1200):
        super().__init__()
        self.fc = nn.Linear(student_dim, teacher_dim)

    def forward(self, x):
        return F.relu(self.fc(x))



def get_teacher_hint(teacher, x):
    """Extract teacher's first hidden layer activation (hint)"""
    x = x.view(-1, 784)
    hint = F.relu(teacher.fc1(x))  # First hidden layer
    return hint


def train_stage1_hints(student, teacher, regressor, train_loader, epochs=10, lr=0.001):
    """Stage 1: Train student (up to guided layer) + regressor to match teacher hint"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    regressor = regressor.to(device)
    teacher.eval()

    # Only optimize student layers up to guided + regressor
    optimizer = optim.Adam(list(student.parameters()) + list(regressor.parameters()), lr=lr)

    student.train()
    regressor.train()

    for epoch in range(epochs):
        total_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            # Get teacher hint (first hidden layer)
            with torch.no_grad():
                teacher_hint = get_teacher_hint(teacher, data)

            # Get student guided layer and pass through regressor
            _, student_guided = student.forward_with_hint(data)
            student_prediction = regressor(student_guided)

            # L2 loss between regressor output and teacher hint
            loss = F.mse_loss(student_prediction, teacher_hint)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 1 Epoch {epoch+1}/{epochs}, Hint Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_stage2_kd(student, teacher, train_loader, temperature=20, alpha=0.1, epochs=10, lr=0.001):
    """Stage 2: Standard KD training (reuse from original code)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Teacher soft targets
            with torch.no_grad():
                teacher_soft, _ = teacher(data, temperature=temperature)

            # Student outputs
            _, student_logits = student(data)

            # Distillation loss
            student_soft = F.log_softmax(student_logits / temperature, dim=1)
            soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
            hard_loss = F.cross_entropy(student_logits, target)
            loss = alpha * hard_loss + (1 - alpha) * soft_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 2 Epoch {epoch+1}/{epochs}, KD Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_fitnet(teacher, train_loader, epochs_stage1=10, epochs_stage2=20, lr=0.001, temp=20):
    """Complete FitNet training: Stage 1 (hints) → Stage 2 (KD)"""
    print("\n[FitNet Stage 1] Training with hints from teacher...")
    student = FitNetStudentbefore()
    regressor = Regressor(student_dim=300, teacher_dim=1200)

    student = train_stage1_hints(student, teacher, regressor, train_loader,
                                 epochs=epochs_stage1, lr=lr)

    print("\n[FitNet Stage 2] Knowledge distillation...")
    student = train_stage2_kd(student, teacher, train_loader,
                             temperature=temp, alpha=0.1, epochs=epochs_stage2, lr=lr)

    return student


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from modelsbefore import TeacherNet, StudentNetbefore

# --- 1. Define Helper Functions Locally to Avoid Import Errors ---

def evaluate(model, test_loader):
    """
    Evaluate model accuracy.
    (Defined locally to ensure it works even if not in models.py)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Support both models that return (output, feature) and just (output)
            out = model(data, temperature=1.0)
            if isinstance(out, tuple):
                outputs = out[0]
            else:
                outputs = out

            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors

def rkd_distance_loss(student_emb, teacher_emb):
    """
    Relational Knowledge Distillation - Distance Loss
    Forces student to mimic the pairwise distances found in the teacher's embedding space.
    """
    # Compute pairwise distance matrices (batch_size x batch_size)
    # p=2 means Euclidean distance
    t_dist = torch.cdist(teacher_emb, teacher_emb, p=2)
    s_dist = torch.cdist(student_emb, student_emb, p=2)

    # Normalize distances by the mean of the non-zero elements
    # (This makes the loss scale-invariant)
    t_mean = t_dist[t_dist > 0].mean()
    s_mean = s_dist[s_dist > 0].mean()

    t_dist_norm = t_dist / t_mean
    s_dist_norm = s_dist / s_mean

    # The loss is the Huber loss (smooth L1) between the normalized distance matrices
    loss = F.smooth_l1_loss(s_dist_norm, t_dist_norm)
    return loss


def get_featuresbefore(model, x, is_teacher=False):
    """
    Manually run forward pass up to the penultimate layer to get embeddings.
    """
    x = x.view(-1, 784)

    if isinstance(model, TeacherNet):
        # Teacher: fc1 -> relu -> dropout -> fc2 -> relu -> dropout -> [EMBEDDING] -> fc3
        x = F.relu(model.fc1(x))
        x = model.dropout(x)
        x = F.relu(model.fc2(x))
        # We capture the features here (after 2nd ReLU, before final dropout/classifier)
        return x

    elif isinstance(model, StudentNetbefore):
        # Student: fc1 -> relu -> fc2 -> relu -> [EMBEDDING] -> fc3
        x = F.relu(model.fc1(x))
        x = F.relu(model.fc2(x))
        return x

    return x


def train_student_rkd(student, teacher, train_loader, epochs=10, lr=0.001, beta=1.0):
    """
    Train student using RKD (Distance) + Cross Entropy.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval() # Freeze teacher

    optimizer = optim.Adam(student.parameters(), lr=lr)

    print(f"Training RKD Student (Beta={beta})...")

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        total_rkd_loss = 0
        total_task_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # 1. Get Teacher Embeddings
            with torch.no_grad():
                teacher_emb = get_featuresbefore(teacher, data, is_teacher=True)

            # 2. Get Student Embeddings and Logits
            student_emb = get_featuresbefore(student, data, is_teacher=False)
            student_logits = student.fc3(student_emb)

            # 3. Calculate Losses
            task_loss = F.cross_entropy(student_logits, target)
            rkd_loss_val = rkd_distance_loss(student_emb, teacher_emb)

            # Combined Loss
            loss = task_loss + (beta * rkd_loss_val)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_rkd_loss += rkd_loss_val.item()
            total_task_loss += task_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_rkd = total_rkd_loss / len(train_loader)
        print(f"RKD Epoch {epoch+1}/{epochs} | Total: {avg_loss:.4f} | RKD: {avg_rkd:.4f} | Task: {total_task_loss/len(train_loader):.4f}")

    return student


In [None]:
import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from modelsbefore import TeacherNet, StudentNetbefore, distillation_loss
from google.colab import drive
drive.mount('/content/drive')


# ------------------------------------------------------------------------
# ------------------------ LOADING THE DATA ------------------------------
# ------------------------------------------------------------------------


# assuming new session so loading MNIST dataset from scratch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


# data loading with jittering (up to 2 pixels in any direction)
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                      num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                    num_workers=2, pin_memory=True)


# ------------------------------------------------------------------------
# ---------------- LOADING THE TRAINED MODELS ----------------------------
# ------------------------------------------------------------------------


teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
student_distilled_checkpoint_path = '/content/drive/MyDrive/Distillations/student_distilled_model.pth'
student_normal_checkpoint_path = '/content/drive/MyDrive/Distillations/student_normal_model.pth'
fitnet_student_checkpoint_path = '/content/drive/MyDrive/Distillations/fitnet_student_model.pth'
student_rkd_checkpoint_path = '/content/drive/MyDrive/Distillations/student_rkd_model.pth'



# teacher
checkpoint = torch.load(teacher_checkpoint_path)
teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
teacher.load_state_dict(checkpoint['model_state_dict'])


# student KD
checkpoint = torch.load(student_distilled_checkpoint_path)
student_distilled_before = StudentNetbefore().to(device)
student_distilled_before.load_state_dict(checkpoint['model_state_dict'])

# student regular
checkpoint = torch.load(student_normal_checkpoint_path)
student_normal_before = StudentNetbefore().to(device)
student_normal_before.load_state_dict(checkpoint['model_state_dict'])

# student fitnet
checkpoint = torch.load(fitnet_student_checkpoint_path)
fitnet_student_before = FitNetStudentbefore().to(device)
fitnet_student_before.load_state_dict(checkpoint['model_state_dict'])

# rkd
checkpoint = torch.load(student_rkd_checkpoint_path)
student_rkd_before = StudentNetbefore().to(device)
student_rkd_before.load_state_dict(checkpoint['model_state_dict'])



# ------------------------------------------------------------------------
# ---------------- Models Errors Evaluations ----------------------------
# ------------------------------------------------------------------------



# Evaluate all models
print("\n" + "=" * 60)
print("EVALUATING ALL MODELS")
print("=" * 60)

teacher_acc, teacher_err = evaluate(teacher, test_loader)
print(f"✓ Teacher evaluated: {teacher_acc:.2f}% accuracy ({teacher_err} errors)")

student_normal_acc, student_normal_err = evaluate(student_normal_before, test_loader)
print(f"✓ Student (normal) evaluated: {student_normal_acc:.2f}% accuracy ({student_normal_err} errors)")

student_distilled_acc, student_distilled_err = evaluate(student_distilled_before, test_loader)
print(f"✓ Student (distilled) evaluated: {student_distilled_acc:.2f}% accuracy ({student_distilled_err} errors)")

fitnet_acc, fitnet_err = evaluate(fitnet_student_before, test_loader)
print(f"✓ FitNet evaluated: {fitnet_acc:.2f}% accuracy ({fitnet_err} errors)")

student_rkd_acc, student_rkd_err = evaluate(student_rkd_before, test_loader)
print(f"✓ RKD Student evaluated: {student_rkd_acc:.2f}% accuracy ({student_rkd_err} errors)")


# Results summary
print("\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
print(f"Teacher (2x1200 + dropout):        {teacher_err:3d} test errors ({teacher_acc:.2f}%)")
print(f"Student normal (2x800):            {student_normal_err:3d} test errors ({student_normal_acc:.2f}%)")
print(f"Student distilled (2x800, T=20):   {student_distilled_err:3d} test errors ({student_distilled_acc:.2f}%)")
print(f"FitNet (4x300, T=20):              {fitnet_err:3d} test errors ({fitnet_acc:.2f}%)")
#add for rkd
print(f"RKD (2x800):                       {student_rkd_err:3d} test errors ({student_rkd_acc:.2f}%)")
print("\nPaper reported (Hinton et al.):")
print("Teacher:           67 test errors")
print("Student normal:   146 test errors")
print("Student distilled: 74 test errors")
print("\nPaper reported (FitNets):")
print("FitNet:            51 test errors")
print("=" * 60)




# ------------------------------------------------------------------------
# ---------------- Models Inference Time ----------------------------
# ------------------------------------------------------------------------




# Inference speed test
print("\n" + "=" * 60)
print("INFERENCE SPEED TEST")
print("=" * 60)

images, labels = next(iter(test_loader))
sample = images[:1].to(device)  # Single sample image
label = labels[:1]  # Corresponding label
print(f'label is {label}')

# get sample label

def count_params(model):
    return sum(p.numel() for p in model.parameters())

def benchmark(model, name, num_samples=10000, num_trials=5):
    model.eval()
    params = count_params(model)
    times = []

    with torch.no_grad():
        # Warmup
        for _ in range(100):
            model(sample)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        # Multiple trials
        for trial in range(num_trials):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()
            for _ in range(num_samples):
                model(sample)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            times.append((time.time() - start) * 1000 / num_samples)

    mean_ms = sum(times) / len(times)
    std_ms = (sum((t - mean_ms)**2 for t in times) / len(times))**0.5
    print(f"{name:30s}: {mean_ms:.3f} ± {std_ms:.3f} ms/sample  ({params:,} params)")
    return mean_ms, params

t_time, t_params = benchmark(teacher, "Teacher (2x1200)")
s_time, s_params = benchmark(student_normal_before, "Student Normal (2x800)")
d_time, d_params = benchmark(student_distilled_before, "Student Distilled (2x800)")
f_time, f_params = benchmark(fitnet_student_before, "FitNet (4x300)")
r_time, r_params = benchmark(student_rkd_before, "RKD (2x800)")

print("\n" + "-" * 60)
print("COMPRESSION & SPEED COMPARISON")
print("-" * 60)
print(f"Student Normal vs Teacher:")
print(f"  Speedup: {t_time/s_time:.2f}x faster")
print(f"  Compression: {s_params/t_params*100:.1f}% params")

print(f"\nStudent Distilled vs Teacher:")
print(f"  Speedup: {t_time/d_time:.2f}x faster")
print(f"  Compression: {d_params/t_params*100:.1f}% params")

print(f"\nFitNet vs Teacher:")
print(f"  Speedup: {t_time/f_time:.2f}x faster")
print(f"  Compression: {f_params/t_params*100:.1f}% params ({t_params/f_params:.1f}x smaller)")

print(f"\nFitNet vs Student Distilled:")
print(f"  Speedup: {d_time/f_time:.2f}x")
print(f"  Compression: {f_params/d_params*100:.1f}% params")
print("=" * 60)

Mounted at /content/drive
Using device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.61MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.74MB/s]



EVALUATING ALL MODELS
✓ Teacher evaluated: 98.95% accuracy (105 errors)
✓ Student (normal) evaluated: 98.77% accuracy (123 errors)
✓ Student (distilled) evaluated: 98.85% accuracy (115 errors)
✓ FitNet evaluated: 98.83% accuracy (117 errors)
✓ RKD Student evaluated: 98.82% accuracy (118 errors)

RESULTS SUMMARY
Teacher (2x1200 + dropout):        105 test errors (98.95%)
Student normal (2x800):            123 test errors (98.77%)
Student distilled (2x800, T=20):   115 test errors (98.85%)
FitNet (4x300, T=20):              117 test errors (98.83%)
RKD (2x800):                       118 test errors (98.82%)

Paper reported (Hinton et al.):
Teacher:           67 test errors
Student normal:   146 test errors
Student distilled: 74 test errors

Paper reported (FitNets):
FitNet:            51 test errors

INFERENCE SPEED TEST
label is tensor([7])
Teacher (2x1200)              : 0.183 ± 0.026 ms/sample  (2,395,210 params)
Student Normal (2x800)        : 0.157 ± 0.025 ms/sample  (1,276,810 par

## load architectures after

In [None]:
%%writefile modelsafter.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class TeacherNet(nn.Module):
    """Large teacher network: 2 hidden layers of 1200 ReLU units"""
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


class StudentNetafter(nn.Module):
    """Small student network: 2 hidden layers of 800 ReLU units"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 800)
        self.bn1 = nn.BatchNorm1d(800)      # NEW: BatchNorm after first linear
        self.fc2 = nn.Linear(800, 800)
        self.bn2 = nn.BatchNorm1d(800)      # NEW: BatchNorm after second linear
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))   # fc1 → bn1 → relu
        identity = x                         # NEW: Save for residual
        x = F.relu(self.bn2(self.fc2(x))) + identity  # fc2 → bn2 → relu → ADD identity
        logits = self.fc3(x)
        return F.softmax(logits / temperature, dim=1), logits


def distillation_loss(student_logits, teacher_soft_targets, hard_targets,
                      temperature, alpha=0.5):
    """
    Combined loss for distillation.

    Args:
        student_logits: Raw outputs from student model
        teacher_soft_targets: Soft probabilities from teacher (at temperature T)
        hard_targets: Ground truth labels
        temperature: Temperature for distillation
        alpha: Weight for hard target loss (1-alpha is weight for soft targets)

    Returns:
        Combined loss
    """
    # Soft target loss: KL divergence between student and teacher (both at temperature T)
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    soft_loss = F.kl_div(student_soft, teacher_soft_targets, reduction='batchmean')

    # Scale by T^2 as per paper (gradients scale as 1/T^2)
    soft_loss = soft_loss * (temperature ** 2)

    # Hard target loss: Standard cross-entropy (at temperature 1)
    hard_loss = F.cross_entropy(student_logits, hard_targets)

    # Weighted combination
    return alpha * hard_loss + (1 - alpha) * soft_loss

Writing modelsafter.py


In [None]:
################## 2 ########################

import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from modelsafter import TeacherNet, StudentNetafter, distillation_loss


def train_teacher(model, train_loader, epochs=10, lr=0.001):
    """Train the large teacher model with dropout regularization"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_normal(model, train_loader, epochs=10, lr=0.001):
    """Train student model normally (baseline)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            _, logits = model(data)
            loss = torch.nn.functional.cross_entropy(logits, target)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (normal) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return model


def train_student_distilled(student, teacher, train_loader, temperature=20,
                           alpha=0.1, epochs=10, lr=0.001):
    """Train student model using knowledge distillation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()  # Teacher is frozen

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Get teacher's soft targets at temperature T
            with torch.no_grad():
                teacher_soft_targets, _ = teacher(data, temperature=temperature)

            # Get student outputs
            _, student_logits = student(data)

            # Compute distillation loss
            loss = distillation_loss(student_logits, teacher_soft_targets,
                                    target, temperature, alpha)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Student (distilled T={temperature}) Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

    return student


def evaluate(model, test_loader):
    """Evaluate model accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, _ = model(data, temperature=1.0)
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors


In [None]:
################## 3 ########################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class FitNetStudentafter(nn.Module):
    """Thin and deep student: 4 hidden layers with ResNet + BatchNorm"""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 300)
        self.bn1 = nn.BatchNorm1d(300)
        self.fc2 = nn.Linear(300, 300)
        self.bn2 = nn.BatchNorm1d(300)
        self.fc3 = nn.Linear(300, 300)
        self.bn3 = nn.BatchNorm1d(300)
        self.fc4 = nn.Linear(300, 300)
        self.bn4 = nn.BatchNorm1d(300)
        self.fc5 = nn.Linear(300, 10)

    def forward(self, x, temperature=1.0):
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))

        identity = x
        x = F.relu(self.bn2(self.fc2(x))) + identity

        identity = x
        x = F.relu(self.bn3(self.fc3(x))) + identity

        identity = x
        x = F.relu(self.bn4(self.fc4(x))) + identity

        logits = self.fc5(x)
        return F.softmax(logits / temperature, dim=1), logits

    def forward_with_hint(self, x):
        """Return both output and guided layer activation"""
        x = x.view(-1, 784)
        x = F.relu(self.bn1(self.fc1(x)))

        identity = x
        guided = F.relu(self.bn2(self.fc2(x))) + identity

        identity = guided
        x = F.relu(self.bn3(self.fc3(guided))) + identity

        identity = x
        x = F.relu(self.bn4(self.fc4(x))) + identity

        logits = self.fc5(x)
        return logits, guided


class Regressor(nn.Module):
    """Maps student guided layer (300) to teacher hint layer (1200)"""
    def __init__(self, student_dim=300, teacher_dim=1200):
        super().__init__()
        self.fc = nn.Linear(student_dim, teacher_dim)

    def forward(self, x):
        return F.relu(self.fc(x))



def get_teacher_hint(teacher, x):
    """Extract teacher's first hidden layer activation (hint)"""
    x = x.view(-1, 784)
    hint = F.relu(teacher.fc1(x))  # First hidden layer
    return hint


def train_stage1_hints(student, teacher, regressor, train_loader, epochs=10, lr=0.001):
    """Stage 1: Train student (up to guided layer) + regressor to match teacher hint"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    regressor = regressor.to(device)
    teacher.eval()

    # Only optimize student layers up to guided + regressor
    optimizer = optim.Adam(list(student.parameters()) + list(regressor.parameters()), lr=lr)

    student.train()
    regressor.train()

    for epoch in range(epochs):
        total_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            # Get teacher hint (first hidden layer)
            with torch.no_grad():
                teacher_hint = get_teacher_hint(teacher, data)

            # Get student guided layer and pass through regressor
            _, student_guided = student.forward_with_hint(data)
            student_prediction = regressor(student_guided)

            # L2 loss between regressor output and teacher hint
            loss = F.mse_loss(student_prediction, teacher_hint)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 1 Epoch {epoch+1}/{epochs}, Hint Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_stage2_kd(student, teacher, train_loader, temperature=20, alpha=0.1, epochs=10, lr=0.001):
    """Stage 2: Standard KD training (reuse from original code)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()

    optimizer = optim.Adam(student.parameters(), lr=lr)

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Teacher soft targets
            with torch.no_grad():
                teacher_soft, _ = teacher(data, temperature=temperature)

            # Student outputs
            _, student_logits = student(data)

            # Distillation loss
            student_soft = F.log_softmax(student_logits / temperature, dim=1)
            soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
            hard_loss = F.cross_entropy(student_logits, target)
            loss = alpha * hard_loss + (1 - alpha) * soft_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Stage 2 Epoch {epoch+1}/{epochs}, KD Loss: {total_loss/len(train_loader):.4f}")

    return student


def train_fitnet(teacher, train_loader, epochs_stage1=10, epochs_stage2=20, lr=0.001, temp=20):
    """Complete FitNet training: Stage 1 (hints) → Stage 2 (KD)"""
    print("\n[FitNet Stage 1] Training with hints from teacher...")
    student = FitNetStudentafter()
    regressor = Regressor(student_dim=300, teacher_dim=1200)

    student = train_stage1_hints(student, teacher, regressor, train_loader,
                                 epochs=epochs_stage1, lr=lr)

    print("\n[FitNet Stage 2] Knowledge distillation...")
    student = train_stage2_kd(student, teacher, train_loader,
                             temperature=temp, alpha=0.1, epochs=epochs_stage2, lr=lr)

    return student


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from modelsafter import TeacherNet, StudentNetafter

# --- 1. Define Helper Functions Locally to Avoid Import Errors ---

def evaluate(model, test_loader):
    """
    Evaluate model accuracy.
    (Defined locally to ensure it works even if not in models.py)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # Support both models that return (output, feature) and just (output)
            out = model(data, temperature=1.0)
            if isinstance(out, tuple):
                outputs = out[0]
            else:
                outputs = out

            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    errors = total - correct
    return accuracy, errors

def rkd_distance_loss(student_emb, teacher_emb):
    """
    Relational Knowledge Distillation - Distance Loss
    Forces student to mimic the pairwise distances found in the teacher's embedding space.
    """
    # Compute pairwise distance matrices (batch_size x batch_size)
    # p=2 means Euclidean distance
    t_dist = torch.cdist(teacher_emb, teacher_emb, p=2)
    s_dist = torch.cdist(student_emb, student_emb, p=2)

    # Normalize distances by the mean of the non-zero elements
    # (This makes the loss scale-invariant)
    t_mean = t_dist[t_dist > 0].mean()
    s_mean = s_dist[s_dist > 0].mean()

    t_dist_norm = t_dist / t_mean
    s_dist_norm = s_dist / s_mean

    # The loss is the Huber loss (smooth L1) between the normalized distance matrices
    loss = F.smooth_l1_loss(s_dist_norm, t_dist_norm)
    return loss


def get_featuresafter(model, x, is_teacher=False):
    x = x.view(-1, 784)

    if isinstance(model, TeacherNet):
        x = F.relu(model.fc1(x))
        x = model.dropout(x)
        x = F.relu(model.fc2(x))
        return x                        # Teacher unchanged

    elif isinstance(model, StudentNetafter):
        x = F.relu(model.bn1(model.fc1(x)))      # fc1 → bn1 → relu
        identity = x                              # Save for residual
        x = F.relu(model.bn2(model.fc2(x))) + identity  # fc2 → bn2 → relu + skip
        return x                                  # Return pre-fc3 embeddings

    return x


def train_student_rkd(student, teacher, train_loader, epochs=10, lr=0.001, beta=1.0):
    """
    Train student using RKD (Distance) + Cross Entropy.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval() # Freeze teacher

    optimizer = optim.Adam(student.parameters(), lr=lr)

    print(f"Training RKD Student (Beta={beta})...")

    student.train()
    for epoch in range(epochs):
        total_loss = 0
        total_rkd_loss = 0
        total_task_loss = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # 1. Get Teacher Embeddings
            with torch.no_grad():
                teacher_emb = get_featuresafter(teacher, data, is_teacher=True)

            # 2. Get Student Embeddings and Logits
            student_emb = get_featuresafter(student, data, is_teacher=False)
            student_logits = student.fc3(student_emb)

            # 3. Calculate Losses
            task_loss = F.cross_entropy(student_logits, target)
            rkd_loss_val = rkd_distance_loss(student_emb, teacher_emb)

            # Combined Loss
            loss = task_loss + (beta * rkd_loss_val)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_rkd_loss += rkd_loss_val.item()
            total_task_loss += task_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_rkd = total_rkd_loss / len(train_loader)
        print(f"RKD Epoch {epoch+1}/{epochs} | Total: {avg_loss:.4f} | RKD: {avg_rkd:.4f} | Task: {total_task_loss/len(train_loader):.4f}")

    return student


In [None]:
import torch
import time
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from modelsafter import TeacherNet, StudentNetafter, distillation_loss
from google.colab import drive
drive.mount('/content/drive')


# ------------------------------------------------------------------------
# ------------------------ LOADING THE DATA ------------------------------
# ------------------------------------------------------------------------


# assuming new session so loading MNIST dataset from scratch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


# data loading with jittering (up to 2 pixels in any direction)
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(2/28, 2/28)),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST('./data', train=False, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,
                      num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False,
                    num_workers=2, pin_memory=True)


# ------------------------------------------------------------------------
# ---------------- LOADING THE TRAINED MODELS ----------------------------
# ------------------------------------------------------------------------


teacher_checkpoint_path = '/content/drive/MyDrive/Distillations/teacher_model.pth'
student_distilled_checkpoint_path = '/content/drive/MyDrive/Distillations/student_distilled_model_after_algo_change.pth'
student_normal_checkpoint_path = '/content/drive/MyDrive/Distillations/student_normal_model_after_algo_change.pth'
fitnet_student_checkpoint_path = '/content/drive/MyDrive/Distillations/fitnet_student_model_after_algo_change.pth'
student_rkd_checkpoint_path = '/content/drive/MyDrive/Distillations/student_rkd_model_after_algo_change.pth'



# teacher
checkpoint = torch.load(teacher_checkpoint_path)
teacher = TeacherNet(dropout_rate=checkpoint['dropout_rate']).to(device)
teacher.load_state_dict(checkpoint['model_state_dict'])


# student KD
checkpoint = torch.load(student_distilled_checkpoint_path)
student_distilled_after = StudentNetafter().to(device)
student_distilled_after.load_state_dict(checkpoint['model_state_dict'])

# student regular
checkpoint = torch.load(student_normal_checkpoint_path)
student_normal_after = StudentNetafter().to(device)
student_normal_after.load_state_dict(checkpoint['model_state_dict'])

# student fitnet
checkpoint = torch.load(fitnet_student_checkpoint_path)
fitnet_student_after = FitNetStudentafter().to(device)
fitnet_student_after.load_state_dict(checkpoint['model_state_dict'])

# rkd
checkpoint = torch.load(student_rkd_checkpoint_path)
student_rkd_after = StudentNetafter().to(device)
student_rkd_after.load_state_dict(checkpoint['model_state_dict'])



# ------------------------------------------------------------------------
# ---------------- Models Errors Evaluations ----------------------------
# ------------------------------------------------------------------------



# Evaluate all models
print("\n" + "=" * 60)
print("EVALUATING ALL MODELS")
print("=" * 60)

teacher_acc, teacher_err = evaluate(teacher, test_loader)
print(f"✓ Teacher evaluated: {teacher_acc:.2f}% accuracy ({teacher_err} errors)")

student_normal_acc, student_normal_err = evaluate(student_normal_after, test_loader)
print(f"✓ Student (normal) evaluated: {student_normal_acc:.2f}% accuracy ({student_normal_err} errors)")

student_distilled_acc, student_distilled_err = evaluate(student_distilled_after, test_loader)
print(f"✓ Student (distilled) evaluated: {student_distilled_acc:.2f}% accuracy ({student_distilled_err} errors)")

fitnet_acc, fitnet_err = evaluate(fitnet_student_after, test_loader)
print(f"✓ FitNet evaluated: {fitnet_acc:.2f}% accuracy ({fitnet_err} errors)")

student_rkd_acc, student_rkd_err = evaluate(student_rkd_after, test_loader)
print(f"✓ RKD Student evaluated: {student_rkd_acc:.2f}% accuracy ({student_rkd_err} errors)")


# Results summary
print("\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
print(f"Teacher (2x1200 + dropout):        {teacher_err:3d} test errors ({teacher_acc:.2f}%)")
print(f"Student normal (2x800):            {student_normal_err:3d} test errors ({student_normal_acc:.2f}%)")
print(f"Student distilled (2x800, T=20):   {student_distilled_err:3d} test errors ({student_distilled_acc:.2f}%)")
print(f"FitNet (4x300, T=20):              {fitnet_err:3d} test errors ({fitnet_acc:.2f}%)")
#add for rkd
print(f"RKD (2x800):                       {student_rkd_err:3d} test errors ({student_rkd_acc:.2f}%)")
print("\nPaper reported (Hinton et al.):")
print("Teacher:           67 test errors")
print("Student normal:   146 test errors")
print("Student distilled: 74 test errors")
print("\nPaper reported (FitNets):")
print("FitNet:            51 test errors")
print("=" * 60)




# ------------------------------------------------------------------------
# ---------------- Models Inference Time ----------------------------
# ------------------------------------------------------------------------




# Inference speed test
print("\n" + "=" * 60)
print("INFERENCE SPEED TEST")
print("=" * 60)

images, labels = next(iter(test_loader))
sample = images[:1].to(device)  # Single sample image
label = labels[:1]  # Corresponding label
print(f'label is {label}')

# get sample label

def count_params(model):
    return sum(p.numel() for p in model.parameters())

def benchmark(model, name, num_samples=10000, num_trials=5):
    model.eval()
    params = count_params(model)
    times = []

    with torch.no_grad():
        # Warmup
        for _ in range(100):
            model(sample)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        # Multiple trials
        for trial in range(num_trials):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()
            for _ in range(num_samples):
                model(sample)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            times.append((time.time() - start) * 1000 / num_samples)

    mean_ms = sum(times) / len(times)
    std_ms = (sum((t - mean_ms)**2 for t in times) / len(times))**0.5
    print(f"{name:30s}: {mean_ms:.3f} ± {std_ms:.3f} ms/sample  ({params:,} params)")
    return mean_ms, params

t_time, t_params = benchmark(teacher, "Teacher (2x1200)")
s_time, s_params = benchmark(student_normal_after, "Student Normal (2x800)")
d_time, d_params = benchmark(student_distilled_after, "Student Distilled (2x800)")
f_time, f_params = benchmark(fitnet_student_after, "FitNet (4x300)")
r_time, r_params = benchmark(student_rkd_after, "RKD (2x800)")

print("\n" + "-" * 60)
print("COMPRESSION & SPEED COMPARISON")
print("-" * 60)
print(f"Student Normal vs Teacher:")
print(f"  Speedup: {t_time/s_time:.2f}x faster")
print(f"  Compression: {s_params/t_params*100:.1f}% params")

print(f"\nStudent Distilled vs Teacher:")
print(f"  Speedup: {t_time/d_time:.2f}x faster")
print(f"  Compression: {d_params/t_params*100:.1f}% params")

print(f"\nFitNet vs Teacher:")
print(f"  Speedup: {t_time/f_time:.2f}x faster")
print(f"  Compression: {f_params/t_params*100:.1f}% params ({t_params/f_params:.1f}x smaller)")

print(f"\nFitNet vs Student Distilled:")
print(f"  Speedup: {d_time/f_time:.2f}x")
print(f"  Compression: {f_params/d_params*100:.1f}% params")
print("=" * 60)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
GPU: Tesla T4
GPU Memory: 15.83 GB

EVALUATING ALL MODELS
✓ Teacher evaluated: 98.95% accuracy (105 errors)
✓ Student (normal) evaluated: 98.73% accuracy (127 errors)
✓ Student (distilled) evaluated: 98.92% accuracy (108 errors)
✓ FitNet evaluated: 98.80% accuracy (120 errors)
✓ RKD Student evaluated: 99.08% accuracy (92 errors)

RESULTS SUMMARY
Teacher (2x1200 + dropout):        105 test errors (98.95%)
Student normal (2x800):            127 test errors (98.73%)
Student distilled (2x800, T=20):   108 test errors (98.92%)
FitNet (4x300, T=20):              120 test errors (98.80%)
RKD (2x800):                        92 test errors (99.08%)

Paper reported (Hinton et al.):
Teacher:           67 test errors
Student normal:   146 test errors
Student distilled: 74 test errors

Paper reported (FitNets):
FitNet:            51 test errors

INFEREN

## test

In [None]:
from statsmodels.stats.contingency_tables import mcnemar
import numpy as np

def get_all_predictions(model, loader, device):
    """Get predictions for entire test set"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)

            # Handle if model returns tuple (output, features)
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # Take only the logits

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    return np.array(all_preds), np.array(all_labels)

def mcnemar_comparison(preds1, preds2, true_labels, name1, name2):
    """Run McNemar's test between two models"""
    correct1 = (preds1 == true_labels)
    correct2 = (preds2 == true_labels)

    # Contingency table: [both wrong, model1 only right]
    #                     [model2 only right, both right]
    b = np.sum(correct1 & ~correct2)  # model1 right, model2 wrong
    c = np.sum(~correct1 & correct2)  # model1 wrong, model2 right

    table = [[0, b], [c, 0]]

    result = mcnemar(table, exact=False, correction=True)

    print(f"\n{name1} vs {name2}:")
    print(f"  {name1} correct, {name2} wrong: {b}")
    print(f"  {name1} wrong, {name2} correct: {c}")
    print(f"  McNemar statistic: {result.statistic:.3f}")
    print(f"  p-value: {result.pvalue:.4f}")
    print(f"  Significant: {'Yes' if result.pvalue < 0.05 else 'No'}")

    return result.pvalue

In [None]:
print("=" * 60)
print("McNEMAR'S TEST: ALGORITHMIC CHANGE IMPACT")
print("=" * 60)

# Compare Baseline before vs after algo change
baseline_before_preds, labels = get_all_predictions(student_normal_before, test_loader, device)
baseline_after_preds, _ = get_all_predictions(student_normal_after, test_loader, device)
mcnemar_comparison(baseline_before_preds, baseline_after_preds, labels,
                   "Baseline (before)", "Baseline (after)")

# Compare Vanilla KD before vs after algo change
vanilla_before_preds, _ = get_all_predictions(student_distilled_before, test_loader, device)
vanilla_after_preds, _ = get_all_predictions(student_distilled_after, test_loader, device)
mcnemar_comparison(vanilla_before_preds, vanilla_after_preds, labels,
                   "Vanilla KD (before)", "Vanilla KD (after)")

# Compare FitNet before vs after
fitnet_before_preds, _ = get_all_predictions(fitnet_student_before, test_loader, device)
fitnet_after_preds, _ = get_all_predictions(fitnet_student_after, test_loader, device)
mcnemar_comparison(fitnet_before_preds, fitnet_after_preds, labels,
                   "FitNet (before)", "FitNet (after)")

# Compare RKD before vs after
rkd_before_preds, _ = get_all_predictions(student_rkd_before, test_loader, device)
rkd_after_preds, _ = get_all_predictions(student_rkd_after, test_loader, device)
mcnemar_comparison(rkd_before_preds, rkd_after_preds, labels,
                   "RKD (before)", "RKD (after)")

McNEMAR'S TEST: ALGORITHMIC CHANGE IMPACT

Baseline (before) vs Baseline (after):
  Baseline (before) correct, Baseline (after) wrong: 77
  Baseline (before) wrong, Baseline (after) correct: 73
  McNemar statistic: 0.060
  p-value: 0.8065
  Significant: No

Vanilla KD (before) vs Vanilla KD (after):
  Vanilla KD (before) correct, Vanilla KD (after) wrong: 16
  Vanilla KD (before) wrong, Vanilla KD (after) correct: 23
  McNemar statistic: 0.923
  p-value: 0.3367
  Significant: No

FitNet (before) vs FitNet (after):
  FitNet (before) correct, FitNet (after) wrong: 25
  FitNet (before) wrong, FitNet (after) correct: 22
  McNemar statistic: 0.085
  p-value: 0.7705
  Significant: No

RKD (before) vs RKD (after):
  RKD (before) correct, RKD (after) wrong: 34
  RKD (before) wrong, RKD (after) correct: 60
  McNemar statistic: 6.649
  p-value: 0.0099
  Significant: Yes


np.float64(0.009921504538268757)