<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Model_Distillation_for_Efficient_Deployment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Define the Teacher and Student models
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.model = models.resnet50(pretrained=True)  # Pretrained ResNet-50

    def forward(self, x):
        return self.model(x)

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.model = models.resnet18(pretrained=True)  # Pretrained ResNet-18

    def forward(self, x):
        return self.model(x)

# Distillation loss function
def distillation_loss(student_outputs, teacher_outputs, labels, alpha=0.5, temperature=2):
    kd_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(student_outputs / temperature, dim=1),
        F.softmax(teacher_outputs / temperature, dim=1)
    ) * (alpha * temperature * temperature)
    ce_loss = nn.CrossEntropyLoss()(student_outputs, labels) * (1. - alpha)
    return kd_loss + ce_loss

# Training function
def train_model(teacher_model, student_model, dataloader, optimizer, num_epochs=10, device='cpu'):
    teacher_model.to(device)
    student_model.to(device)
    teacher_model.eval()  # Teacher in evaluation mode
    student_model.train()  # Student in training mode

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass through teacher model
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)

            # Forward pass through student model
            student_outputs = student_model(inputs)
            loss = distillation_loss(student_outputs, teacher_outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}, '
              f'Accuracy: {100. * correct / total:.2f}%')

# Data preparation with augmentations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(224),  # ResNet models expect 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = TeacherModel().to(device)
student_model = StudentModel().to(device)

# Optimizer
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# Training the student model using distillation
train_model(teacher_model, student_model, train_loader, optimizer, num_epochs=10, device=device)

print("Distilled model deployed successfully.")