<a href="https://colab.research.google.com/github/SanyaGandhi/IIITH_SSD/blob/master/KD_MNIST_RB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
import torchvision.models as models
import torch.optim as optim
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models

In [8]:
teacher_model = models.resnet18(pretrained=True)



In [22]:
# Modify the first convolutional layer to take a single-channel input
teacher_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Freeze all layers except the last fully connected layer
for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model.fc.requires_grad = True

# Modify the last fully connected layer to output 10 classes
teacher_model.fc = nn.Linear(512, 10)

In [13]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Define transformations to be applied to the input images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the training set
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split the training set into training and validation sets
trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

# Load the testing set
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training, validation, and testing sets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


In [4]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(64 * 3 * 3, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = x.view(-1, 64 * 3 * 3)
        x = self.fc(x)
        return x

student_model = StudentModel()


In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9)


In [6]:
temperature = 5
alpha = 0.5
epochs=5

In [24]:
teacher_model.eval()
for epoch in range(epochs):
    student_model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        # forward pass with teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)

        # forward pass with student model
        student_outputs = student_model(inputs)

        # calculate soft targets
        soft_targets = F.softmax(teacher_outputs / temperature, dim=1)

        # calculate loss using both hard and soft targets
        hard_loss = criterion(student_outputs, labels)
        soft_loss = nn.KLDivLoss()(F.log_softmax(student_outputs / temperature, dim=1),
                                   soft_targets) * temperature * temperature * alpha
        loss = hard_loss + soft_loss

        # backward pass and optimization
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(trainloader)
    # print(model_time)
    print('Model: Epoch [%d/%d], Loss: %.4f' % (epoch+1, 5, epoch_loss))
        # # print statistics
        # running_loss += loss.item()
        # if i % 2000 == 1999:
        #     print('[%d, %5d] loss: %.3f' %
        #           (epoch + 1, i + 1, running_loss / 2000))
        # running_loss = 0.0


Model: Epoch [1/5], Loss: 0.2259
Model: Epoch [2/5], Loss: 0.2175
Model: Epoch [3/5], Loss: 0.2111
Model: Epoch [4/5], Loss: 0.2061
Model: Epoch [5/5], Loss: 0.2019


In [25]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = student_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the student model on the 10000 test images: %d %%' % (
    100 * correct / total))


Accuracy of the student model on the 10000 test images: 98 %
