In [1]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mnist_cnn import MNIST
from model_splitting import split_model
from lat_acc_test_funcs import eval_accuracy
from bottlefit_injection import bottlefit_bottleneck, stage1_training, stage2_training

In [2]:
trainset, _, testset = MNIST()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=1, shuffle=False)

In [3]:
model = torch.load('models/MNIST_CNN.pt', weights_only=False)
t_head, t_tail = split_model(model, 6)
s_head, s_tail = copy.deepcopy(t_head), copy.deepcopy(t_tail)

In [4]:
list(t_tail.children())

[Flatten(start_dim=1, end_dim=-1),
 Linear(in_features=3136, out_features=128, bias=True),
 ReLU(),
 Linear(in_features=128, out_features=10, bias=True)]

In [5]:
s_encoder = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 4, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2)
)

s_decoder = nn.Sequential(
    nn.ConvTranspose2d(4, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2)
)

In [6]:
student = bottlefit_bottleneck(s_encoder, s_decoder, s_tail)

In [7]:
stage1_training(t_head, t_tail, student, trainloader, quiet=False)
stage2_training(model, student, trainloader, epochs=5, freeze_encoder=True, quiet=False)

Epoch 1 - Loss: 27.3761
Epoch 2 - Loss: 7.5373
Epoch 3 - Loss: 6.5751
Epoch 4 - Loss: 6.1018
Epoch 5 - Loss: 5.7762
Epoch 6 - Loss: 5.5314
Epoch 7 - Loss: 5.3486
Epoch 8 - Loss: 5.2069
Epoch 9 - Loss: 5.0952
Epoch 10 - Loss: 4.9943
Epoch 1 - Loss: 0.0550
Epoch 2 - Loss: 0.0338
Epoch 3 - Loss: 0.0287
Epoch 4 - Loss: 0.0250
Epoch 5 - Loss: 0.0246


In [8]:
eval_accuracy(student, testloader, quiet=False)
print(f"Latency: {student.get_avg_inference_time():.4}")
print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")
student.clear_stats()

Accuracy: 99.16%
Latency: 0.005288
Transmission Size: 784.0 bytes


In [9]:
torch.save(student, './models/MNIST_Bottlefit.pt')