<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Efficient_GAN_Inference_with_Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define a simple teacher and student model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.fc(x)

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.fc(x)

# Initialize models
teacher_model = TeacherModel()
student_model = StudentModel()

# Define optimizer and loss function
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
criterion = nn.MSELoss()  # Using Mean Squared Error loss for distillation

# Transform and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Knowledge distillation function
def distill_knowledge(teacher_model, student_model, data_loader, optimizer, criterion, num_epochs=10):
    teacher_model.eval()  # Set teacher model to evaluation mode
    for epoch in range(num_epochs):
        for images, _ in data_loader:
            images = images.to(device)

            # Teacher forward pass
            with torch.no_grad():
                teacher_output = teacher_model(images)

            # Student forward pass
            student_output = student_model(images)

            # Compute distillation loss
            loss = criterion(student_output, teacher_output)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss.item()}")

# Define device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

# Perform knowledge distillation
distill_knowledge(teacher_model, student_model, train_loader, optimizer, criterion, num_epochs=10)

print("Knowledge distillation completed!")