<a href="https://colab.research.google.com/github/FatimaJahara/Teacher-Student-Knowledge-Distillation/blob/main/vision_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Knowledge Distillation using MNIST (CNN Teacher, Linear Student)


**MNIST Handwritten Digit Classification**

*   **Teacher Model:** A Convolutional Neural Network (CNN) with two convolutional layers, dropout,
    and two fully connected layers.
*   **Student Model:** A simpler, single-layer fully connected network with dropout.
*   **Distillation:** The student is trained using a combination of the standard cross-entropy loss
    (hard targets) and a distillation loss based on the Kullback-Leibler (KL) divergence between the
    student's softened output probabilities and the teacher's softened output probabilities.  The
    "softening" is achieved using a temperature parameter.
* **Baseline Comparison:** The student is also trained *without* distillation to provide a
    baseline for comparison.
* **Dataset**: MNIST


**Key Concepts Illustrated:**

*   **Knowledge Distillation:**  The core concept of transferring knowledge from a teacher to a
    student network.
*   **Soft Targets:**  Using the teacher's output probabilities (softened by a temperature
    parameter) as targets for the student.
*   **Temperature Scaling:**  The use of a temperature parameter to control the "softness" of the
    probability distributions.
*   **KL Divergence:**  Measuring the difference between the teacher's and student's softened
    probability distributions.
*   **Combined Loss:**  Combining the standard cross-entropy loss with the distillation loss.
*   **Model Compression:**  The student model is significantly smaller than the teacher model,
    demonstrating potential for model compression.
*   **PyTorch:**  Implementation using the PyTorch deep learning framework.
*   **Hugging Face Transformers:** Using the transformers library for BERT models.
* **DataLoader:** Using PyTorch's DataLoader to efficiently load and batch data.

**Structure of the Notebook:**

The code is organized into the following sections for both examples:

1.  **Model Definition(s):** Defines the `TeacherNet` and `StudentNet` classes (MNIST) and functions
    to load/create the teacher and student BERT models (Sentiment Analysis).
2.  **Data Loading:**  Includes functions to load and pre-process the MNIST dataset and a class and
    function to prepare the sentiment analysis dataset and DataLoaders.
3.  **Training Function(s):**  Contains `train_teacher`, `train_student_with_distillation`, `test`,
 and `train_student`.
4.  **Main Execution:**  The `main` function orchestrates the entire process: loading data,
    training the teacher, training the student with and without distillation, and reporting results.

**How to Run:**

1.  **Install Dependencies:**
    ```bash
    pip install torch torchvision transformers tqdm
    ```
2.  **Run the Notebook:** Execute the code cells sequentially.  The MNIST dataset will be
    downloaded automatically.  The BERT model will also be downloaded.

**Expected Output:**

The notebook will print training progress (loss values) and final accuracy scores for the
teacher model, the student model trained with distillation, and the student model trained without
distillation.  You should observe that the distilled student performs better than the baseline
student, demonstrating the effectiveness of knowledge distillation. For the BERT example, training
and validation metrics will be printed.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import os

In [3]:
# --- 1. Define Teacher Model ---


class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

In [4]:
# --- 2. Define Student Model ---


class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # Simpler, single layer
        self.dropout = nn.Dropout(0.2) # Added dropout
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1) # Flatten the input
        x = F.relu(self.fc1(x))
        x = self.dropout(x) # Apply dropout
        x = self.fc2(x)
        return x

In [5]:
# --- 3. Data Loading (MNIST) ---

def get_data_loaders(batch_size=64):
    train_loader = DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)

    test_loader = DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
# --- 4. Training Teacher Model ---

def train_teacher(model, train_loader, optimizer, epochs=3, device='cpu'): # Reduced epochs for Colab
    model.to(device)
    model.train()
    for epoch in range(1, epochs + 1):
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if batch_idx % 100 == 0: # More frequent updates for quicker feedback
                print(f'Teacher Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
        print(f"Teacher Epoch {epoch} Average Loss: {running_loss / len(train_loader):.6f}")

In [10]:
# --- 5. Training Student Model ---

def train_student_with_distillation(teacher_model, student_model, train_loader, optimizer, epochs=3, temperature=5.0, alpha=0.5, device='cpu'):  # Reduced epochs
    teacher_model.to(device)
    student_model.to(device)
    teacher_model.eval()  # Teacher set to eval mode
    student_model.train()

    for epoch in range(1, epochs + 1):
        running_loss = 0.0
        running_distillation_loss = 0.0
        running_student_loss = 0.0


        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # Get teacher's logits (soft targets) by freezing the teacher
            # No gradients for the teacher; it’s fixed. Also set teacher_model.eval() outside the loop to disable dropout/BN updates.
            with torch.no_grad():  # No gradients needed for teacher
                teacher_logits = teacher_model(data)

            # Student's logits (Student forward)
            student_logits = student_model(data)

            # --- Distillation Loss ---
            """ Temperature T (>1) softens the distributions, revealing “dark knowledge” (class similarities).
            KLDivLoss expects log-probs as input and probs as target—your use is correct.
            Multiply by T²: without it, gradients shrink when you divide logits by T. This rescales them back (Hinton et al., 2015). """
            soft_targets = F.softmax(teacher_logits / temperature, dim=1) # probsfrom teacher
            soft_predictions = F.log_softmax(student_logits / temperature, dim=1) # log-probs from student
            distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean') * (temperature ** 2)

            # --- Student Loss (Cross-Entropy) --- Hard label loss
            student_loss = F.cross_entropy(student_logits, target)

            # --- Combined Loss ---
            loss = alpha * distillation_loss + (1 - alpha) * student_loss

            # Backprop + logging
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            running_distillation_loss += distillation_loss.item()
            running_student_loss += student_loss.item()


            if batch_idx % 100 == 0:
                print(f'Student Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f} '
                      f'Distill Loss: {distillation_loss.item():.6f}, Student Loss: {student_loss.item():.6f}')

        print(f"Student Epoch {epoch} Average Loss: {running_loss / len(train_loader):.6f}, "
              f"Average Distillation Loss: {running_distillation_loss / len(train_loader):.6f}, "
              f"Average Student Loss: {running_student_loss / len(train_loader):.6f}")

In [None]:
def test(model, test_loader, device='cpu'):
    model.to(device)
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    return accuracy

In [11]:
# --- 4. Main Execution ---

def main():
    # Check CUDA availability
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device: {device}")

    # Get data loaders
    train_loader, test_loader = get_data_loaders(batch_size=64)

    # --- Train Teacher ---
    teacher_model = TeacherNet()
    teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001) # Use Adam optimizer
    start_time = time.time()
    train_teacher(teacher_model, train_loader, teacher_optimizer, epochs=3, device=device)
    end_time = time.time()
    print(f"Teacher training time: {end_time - start_time:.2f} seconds")
    teacher_accuracy = test(teacher_model, test_loader, device=device)


    # --- Train Student with Distillation ---
    student_model = StudentNet()
    student_optimizer = optim.Adam(student_model.parameters(), lr=0.001) # Use Adam
    start_time = time.time()
    train_student_with_distillation(teacher_model, student_model, train_loader, student_optimizer, epochs=3, temperature=5.0, alpha=0.7, device=device) #Experiment with T and alpha
    end_time = time.time()
    print(f"Student training time (with distillation): {end_time - start_time:.2f} seconds")
    distilled_student_accuracy = test(student_model, test_loader, device=device)


    # --- Train Student *without* Distillation (Baseline) ---
    student_model_no_distill = StudentNet()
    student_optimizer_no_distill = optim.Adam(student_model_no_distill.parameters(), lr=0.001)
    start_time = time.time()
    train_teacher(student_model_no_distill, train_loader, student_optimizer_no_distill, epochs=3, device=device)  # Reuse train_teacher
    end_time = time.time()
    print(f"Student training time (no distillation): {end_time - start_time:.2f} seconds")
    baseline_student_accuracy = test(student_model_no_distill, test_loader, device=device)

    print("-" * 20)
    print("Results:")
    print(f"Teacher Accuracy: {teacher_accuracy:.2f}%")
    print(f"Distilled Student Accuracy: {distilled_student_accuracy:.2f}%")
    print(f"Baseline Student Accuracy (no distillation): {baseline_student_accuracy:.2f}%")
    print("-" * 20)


if __name__ == '__main__':
    main()

Using device: cpu
Teacher Epoch 1 Average Loss: 0.190221
Teacher Epoch 2 Average Loss: 0.079064
Teacher Epoch 3 Average Loss: 0.058591
Teacher training time: 577.97 seconds

Test set: Accuracy: 9907/10000 (99.07%)

Student Epoch 1 Average Loss: 3.062320, Average Distillation Loss: 4.203429, Average Student Loss: 0.399733
Student Epoch 2 Average Loss: 1.485089, Average Distillation Loss: 2.039994, Average Student Loss: 0.190309
Student Epoch 3 Average Loss: 1.234251, Average Distillation Loss: 1.703532, Average Student Loss: 0.139261
Student training time (with distillation): 214.22 seconds

Test set: Accuracy: 9698/10000 (96.98%)

Teacher Epoch 1 Average Loss: 0.290976
Teacher Epoch 2 Average Loss: 0.142389
Teacher Epoch 3 Average Loss: 0.112699
Student training time (no distillation): 58.65 seconds

Test set: Accuracy: 9751/10000 (97.51%)

--------------------
Results:
Teacher Accuracy: 99.07%
Distilled Student Accuracy: 96.98%
Baseline Student Accuracy (no distillation): 97.51%
-----