<a href="https://colab.research.google.com/github/alexandrumeterez/ai_notebooks/blob/master/dark_knowledge_and_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import SGD

In [0]:
class Teacher(nn.Module):
    def __init__(self, T):
        super().__init__()
        self.T = T
        self.fc1 = nn.Linear(784, 1800)
        self.fc2 = nn.Linear(1800, 1800)
        self.fc3 = nn.Linear(1800, 10)

    def forward(self, x):
        if self.training is False:
            self.T = 1
        x = self.fc1(x)
        x = F.dropout(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.dropout(x)
        x = F.relu(x)

        x = self.fc3(x)
        # x = F.softmax(x / self.T, dim=0)
        return x / self.T

In [0]:
class Student(nn.Module):
    def __init__(self, T):
        super().__init__()
        self.T = T
        self.fc1 = nn.Linear(784, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 10)

    def forward(self, x):
        if self.training is False:
            self.T = 1

        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.fc3(x)
        # x = F.softmax(x / self.T, dim=0)
        return x / self.T

In [0]:
preprocess = transforms.Compose([
    transforms.ToTensor()
])
train_ds = MNIST(root='.', train=True, download=True, transform=preprocess)
test_ds = MNIST(root='.', train=False, download=True, transform=preprocess)

In [0]:
BATCH_SIZE = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, drop_last=True)

In [0]:
def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        inputs = inputs.view(BATCH_SIZE, -1)
        predictions = model(inputs)
        loss = criterion(predictions, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        _, predicted = torch.max(predictions.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return epoch_loss / len(train_loader), correct / total

def train_distill(model, teacher, train_loader, criterion, optimizer):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        inputs = inputs.view(BATCH_SIZE, -1)
        predictions = model(inputs)
        teacher_labels = teacher(inputs)

        loss = criterion(predictions, teacher_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        _, predicted = torch.max(predictions.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return epoch_loss / len(train_loader), correct / total


def test(model, test_loader, criterion):
    model.eval()
    with torch.no_grad():
        epoch_loss = 0
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            inputs = inputs.view(BATCH_SIZE, -1)
            predictions = model(inputs)
            loss = criterion(predictions, labels)
            epoch_loss += loss.item()

            _, predicted = torch.max(predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return epoch_loss / len(test_loader), correct / total

In [78]:
TEMPERATURE = 20
N_EPOCHS = 10


# Train teacher
teacher = Teacher(TEMPERATURE)
teacher.to(DEVICE)

optimizer = SGD(teacher.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch + 1}')
    train_loss, train_acc = train(teacher, train_loader, criterion, optimizer)
    print(f'\tTrain loss: {train_loss:.4f}, Train acc: {train_acc * 100:.2f}%')
    test_loss, test_acc = test(teacher, test_loader, criterion)
    print(f'\tTest loss: {test_loss:.4f}, Test acc: {test_acc * 100:.2f}%')


Epoch: 1
	Train loss: 2.3023, Train acc: 11.50%
	Test loss: 2.2941, Test acc: 13.00%
Epoch: 2
	Train loss: 2.2114, Train acc: 37.69%
	Test loss: 2.1019, Test acc: 57.13%
Epoch: 3
	Train loss: 1.9041, Train acc: 63.98%
	Test loss: 1.6212, Test acc: 69.17%
Epoch: 4
	Train loss: 1.3241, Train acc: 72.31%
	Test loss: 1.0364, Test acc: 77.19%
Epoch: 5
	Train loss: 0.8993, Train acc: 77.96%
	Test loss: 0.7579, Test acc: 80.58%
Epoch: 6
	Train loss: 0.7065, Train acc: 80.94%
	Test loss: 0.6305, Test acc: 83.07%
Epoch: 7
	Train loss: 0.6087, Train acc: 83.10%
	Test loss: 0.5591, Test acc: 84.32%
Epoch: 8
	Train loss: 0.5485, Train acc: 84.60%
	Test loss: 0.5015, Test acc: 85.66%
Epoch: 9
	Train loss: 0.5068, Train acc: 85.48%
	Test loss: 0.4754, Test acc: 86.70%
Epoch: 10
	Train loss: 0.4772, Train acc: 86.30%
	Test loss: 0.4464, Test acc: 87.02%


In [93]:
# Train student without distilation
student = Student(TEMPERATURE)
student.to(DEVICE)

optimizer = SGD(student.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch + 1}')
    train_loss, train_acc = train(student, train_loader, criterion, optimizer)
    print(f'\tTrain loss: {train_loss:.4f}, Train acc: {train_acc * 100:.2f}%')
    test_loss, test_acc = test(student, test_loader, criterion)
    print(f'\tTest loss: {test_loss:.4f}, Test acc: {test_acc * 100:.2f}%')


Epoch: 1
	Train loss: 2.3024, Train acc: 12.06%
	Test loss: 2.3247, Test acc: 11.96%
Epoch: 2
	Train loss: 2.3012, Train acc: 11.84%
	Test loss: 2.2771, Test acc: 13.90%
Epoch: 3
	Train loss: 2.2505, Train acc: 15.71%
	Test loss: 2.2144, Test acc: 17.09%
Epoch: 4
	Train loss: 2.1762, Train acc: 17.44%
	Test loss: 2.1163, Test acc: 18.59%
Epoch: 5
	Train loss: 2.0693, Train acc: 18.65%
	Test loss: 1.9909, Test acc: 19.36%
Epoch: 6
	Train loss: 1.9410, Train acc: 23.20%
	Test loss: 1.8519, Test acc: 27.70%
Epoch: 7
	Train loss: 1.8019, Train acc: 34.57%
	Test loss: 1.7072, Test acc: 42.78%
Epoch: 8
	Train loss: 1.6481, Train acc: 46.26%
	Test loss: 1.5370, Test acc: 51.26%
Epoch: 9
	Train loss: 1.4640, Train acc: 53.31%
	Test loss: 1.3334, Test acc: 58.48%
Epoch: 10
	Train loss: 1.2544, Train acc: 61.47%
	Test loss: 1.1167, Test acc: 67.92%


In [94]:
# Train student with distilation
student = Student(TEMPERATURE)
student.to(DEVICE)

optimizer = SGD(student.parameters(), lr=0.001)
criterion = nn.MSELoss()
criterion_test = nn.CrossEntropyLoss()

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch + 1}')
    train_loss, train_acc = train_distill(student, teacher, train_loader, criterion, optimizer)
    print(f'\tTrain loss: {train_loss:.4f}, Train acc: {train_acc * 100:.2f}%')
    test_loss, test_acc = test(student, test_loader, criterion_test)
    print(f'\tTest loss: {test_loss:.4f}, Test acc: {test_acc * 100:.2f}%')

Epoch: 1
	Train loss: 11.5638, Train acc: 9.86%
	Test loss: 2.3076, Test acc: 9.59%
Epoch: 2
	Train loss: 8.1813, Train acc: 23.03%
	Test loss: 1.4821, Test acc: 45.26%
Epoch: 3
	Train loss: 2.5671, Train acc: 61.33%
	Test loss: 0.8330, Test acc: 73.50%
Epoch: 4
	Train loss: 1.5226, Train acc: 75.77%
	Test loss: 0.7259, Test acc: 79.11%
Epoch: 5
	Train loss: 1.1257, Train acc: 80.68%
	Test loss: 0.6193, Test acc: 82.36%
Epoch: 6
	Train loss: 0.9079, Train acc: 82.43%
	Test loss: 0.5882, Test acc: 83.18%
Epoch: 7
	Train loss: 0.8350, Train acc: 83.28%
	Test loss: 0.5680, Test acc: 83.93%
Epoch: 8
	Train loss: 0.7895, Train acc: 83.94%
	Test loss: 0.5534, Test acc: 84.15%
Epoch: 9
	Train loss: 0.7507, Train acc: 84.45%
	Test loss: 0.5366, Test acc: 85.09%
Epoch: 10
	Train loss: 0.7245, Train acc: 84.85%
	Test loss: 0.5244, Test acc: 85.56%
