<a href="https://colab.research.google.com/github/Saoudyahya/Distillation-ml/blob/main/Distillation_ml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define Teacher Model (larger)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        return self.model(x)

# Define Student Model (smaller)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.model(x)

# Define Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, temperature):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, true_labels, alpha, task_loss_fn):
        teacher_probs = torch.softmax(teacher_logits / self.temperature, dim=1)
        student_probs = torch.log_softmax(student_logits / self.temperature, dim=1)
        distillation_loss = self.kl_div_loss(student_probs, teacher_probs) * (self.temperature ** 2)
        task_loss = task_loss_fn(student_logits, true_labels)
        return alpha * task_loss + (1 - alpha) * distillation_loss

# Load Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Train Teacher Model
teacher = TeacherModel()
optimizer = optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        optimizer.zero_grad()
        outputs = teacher(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
print("Teacher training complete.")

# Train Student Model Using Distillation
student = StudentModel()
distillation_loss_fn = DistillationLoss(temperature=3.0)
task_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=0.001)

for epoch in range(5):
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        with torch.no_grad():
            teacher_outputs = teacher(images)
        student_outputs = student(images)
        loss = distillation_loss_fn(student_outputs, teacher_outputs, labels, alpha=0.5, task_loss_fn=task_loss_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
print("Student training complete.")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 39.6MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.22MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.8MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.41MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Teacher training complete.
Student training complete.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define Teacher Model (larger)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        return self.model(x)

# Define Student Model (smaller)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.model(x)

# Define Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, temperature):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, true_labels, alpha, task_loss_fn):
        teacher_probs = torch.softmax(teacher_logits / self.temperature, dim=1)
        student_probs = torch.log_softmax(student_logits / self.temperature, dim=1)
        distillation_loss = self.kl_div_loss(student_probs, teacher_probs) * (self.temperature ** 2)
        task_loss = task_loss_fn(student_logits, true_labels)
        return alpha * task_loss + (1 - alpha) * distillation_loss

# Load Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Train Teacher Model
teacher = TeacherModel()
optimizer = optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        optimizer.zero_grad()
        outputs = teacher(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
print("Teacher training complete.")

# Train Student Model Using Distillation
student = StudentModel()
distillation_loss_fn = DistillationLoss(temperature=3.0)
task_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=0.001)

for epoch in range(5):
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        with torch.no_grad():
            teacher_outputs = teacher(images)
        student_outputs = student(images)
        loss = distillation_loss_fn(student_outputs, teacher_outputs, labels, alpha=0.5, task_loss_fn=task_loss_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
print("Student training complete.")

# Evaluate a model
def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.view(-1, 28*28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Evaluate Teacher and Student Models
teacher_accuracy = evaluate(teacher, test_loader)
student_accuracy = evaluate(student, test_loader)

print(f"Teacher Model Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Accuracy: {student_accuracy:.2f}%")


Teacher training complete.
Student training complete.
Teacher Model Accuracy: 97.11%
Student Model Accuracy: 96.52%
