In [1]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mnist_cnn import MNIST
from model_splitting import split_model
from lat_acc_test_funcs import eval_accuracy
from bottlefit_injection import bottlefit_bottleneck, stage1_training, stage2_training

In [2]:
trainset, _, testset = MNIST()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=1, shuffle=False)

In [3]:
model = torch.load('models/MNIST_CNN_Complex.pt', weights_only=False)
t_head, t_tail = split_model(model, 10)
s_tail = copy.deepcopy(t_tail)

In [4]:
list(t_head.children()), list(t_tail.children())

([Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
  Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)],
 [Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  ReLU(),
  Flatten(start_dim=1, end_dim=-1),
  Linear(in_features=6272, out_features=512, bias=True),
  ReLU(),
  Linear(in_features=512, out_features=256, bias=True),
  ReLU(),
  Linear(in_features=256, out_features=128, bias=True),
  ReLU(),
  Linear(in_features=128, out_features=10, bias=True)])

In [5]:
s_encoder = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(64, 4, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
)

s_decoder = nn.Sequential(
    nn.ConvTranspose2d(4, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2)
)

In [6]:
student = bottlefit_bottleneck(s_encoder, s_decoder, s_tail)

In [7]:
stage1_training(t_head, t_tail,student, trainloader, quiet=False)
stage2_training(model, student, trainloader, epochs=5, freeze_encoder=True, quiet=False)

Epoch 1 - Loss: 76.2666
Epoch 2 - Loss: 10.2871
Epoch 3 - Loss: 8.1554
Epoch 4 - Loss: 7.0989
Epoch 5 - Loss: 6.3692
Epoch 6 - Loss: 5.8201
Epoch 7 - Loss: 5.6496
Epoch 8 - Loss: 5.2469
Epoch 9 - Loss: 5.0460
Epoch 10 - Loss: 4.8709
Epoch 1 - Loss: 0.0345
Epoch 2 - Loss: 0.0238
Epoch 3 - Loss: 0.0189
Epoch 4 - Loss: 0.0229
Epoch 5 - Loss: 0.0182


In [8]:
eval_accuracy(student, testloader, quiet=False)
print(f"Latency: {student.get_avg_inference_time():.4}")
print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")
student.clear_stats()

Accuracy: 99.50%
Latency: 0.006618
Transmission Size: 3136.0 bytes


In [9]:
torch.save(student, './models/MNIST_Complex_Bottlefit.pt')