In [9]:
#imports
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from models.cnn import SimpleCNN
from models.mlp import MLP
from distillation_utils import Distiller
from invariances_utils import shift_preserving_shape
import numpy as np


In [10]:
in_channels = 1
num_classes = 10
num_conv_layers = 2
temperature = 1
num_epochs = 10
batch_size = 64
lr = 0.001
TRAIN = False
device = 'cuda'
#np.random.seed(42)

In [11]:
# MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
#Obtaining CNN
cnn_path = "saved_models/model"
cnn = SimpleCNN(in_channels=in_channels, num_classes=num_classes, num_conv_layers=num_conv_layers, temperature=temperature).to('cuda:0')
if TRAIN:
    criterion_cnn = torch.nn.CrossEntropyLoss()
    optimizer_cnn = torch.optim.Adam(cnn.parameters(), lr=lr)
    # model training
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            outputs = cnn(images.to('cuda'))
            loss = criterion_cnn(outputs, labels.to('cuda'))

            optimizer_cnn.zero_grad()
            loss.backward()
            optimizer_cnn.step()

            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
    # Save the trained model
    torch.save(cnn.state_dict(), cnn_path)
    print(f"Model saved as {cnn_path}!")
if not TRAIN:
    state_dict = torch.load(cnn_path)
    cnn.load_state_dict(state_dict=state_dict)

# Testing the model
cnn.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = cnn(images.to('cuda'))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to('cuda')).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')

Test Accuracy: 0.9914


In [13]:
#Loading undistilled MLP
if TRAIN:
    mlp = MLP(input_dim = 784, output_dim= num_classes, hidden_size= 2048,
        hidden_layers= 4, device='cuda')
    criterion_mlp = torch.nn.CrossEntropyLoss()
    optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=lr)
    mlp.train(train_loader=train_loader, optimizer=optimizer_mlp, criterion=criterion_mlp, 
              num_epochs=5)
if not TRAIN:
    mlp = MLP(input_dim = 784, output_dim= num_classes, hidden_size= 2048,
            hidden_layers= 4, device='cuda', from_saved_state_dict="saved_models/mlp")
mlp.eval(test_loader=test_loader)

Not using softmax
Test Accuracy: 0.9704


In [14]:
#loading distilled MLP
mlp_student = MLP(input_dim = 784, output_dim= num_classes, hidden_size= 2048,
          hidden_layers= 4, device='cuda')
if TRAIN:
    distiller = Distiller(student=mlp_student, teacher=cnn, device='cuda', lr=0.001)
    distiller.distill(train_loader, 5, "saved_models/")
if not TRAIN:
    distiller = Distiller(student=mlp_student, teacher=cnn, device='cuda', lr=0.001,
                        load_student_from_path = 'saved_models/distiller')
distiller.test_step(test_loader=test_loader)

Not using softmax
Test Accuracy: 0.9688


In [15]:
from invariances_utils import test_IM
test_IM(test_loader, mlp)

[tensor([50.8540], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([55.7828], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([53.8485], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([49.8548], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([48.7668], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([56.7308], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([48.2267], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([53.7515], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([55.2343], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([57.1900], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([50.1332], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([53.3430], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([51.0499], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([62.9042], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([49.4475], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([5

tensor(8554.4023, device='cuda:0', grad_fn=<SumBackward0>)

In [16]:
test_IM(test_loader, distiller.get_student())

[tensor([39.4146], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([39.3120], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([40.5080], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([41.2471], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([39.8635], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([35.3821], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([39.0805], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([41.0281], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([41.0057], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([41.6737], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([41.6637], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([40.4811], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([40.5587], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([45.6170], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([37.3044], device='cuda:0', grad_fn=<UnsqueezeBackward0>), tensor([3

tensor(6376.8823, device='cuda:0', grad_fn=<SumBackward0>)