In [None]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mnist_cnn import MNIST
from model_splitting import split_model
from lat_acc_test_funcs import eval_accuracy
from bottlefit_quantizer import bottlefit_quantizer, stage1_training, stage2_training

In [2]:
trainset, _, testset = MNIST()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=1, shuffle=False)

In [3]:
model = torch.load('models/MNIST_CNN.pt', weights_only=False)
t_head, t_tail = split_model(model, 6)
s_head, s_tail = copy.deepcopy(t_head), copy.deepcopy(t_tail)

In [4]:
list(t_tail.children())

[Flatten(start_dim=1, end_dim=-1),
 Linear(in_features=3136, out_features=128, bias=True),
 ReLU(),
 Linear(in_features=128, out_features=10, bias=True)]

In [5]:
s_encoder = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 4, kernel_size=3, stride=2, padding=1),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2)
)

s_decoder = nn.Sequential(
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.Conv2d(4, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2)
)

In [None]:
student = bottlefit_quantizer(s_encoder, s_decoder, s_tail)

In [7]:
stage1_training(t_head, t_tail, student, trainloader, quiet=False)
stage2_training(model, student, trainloader, epochs=5, freeze_encoder=True, quiet=False)

Epoch 1 - Loss: 48.2563
Epoch 2 - Loss: 14.4717
Epoch 3 - Loss: 10.7037
Epoch 4 - Loss: 8.9590
Epoch 5 - Loss: 7.9502
Epoch 6 - Loss: 7.2553
Epoch 7 - Loss: 6.7862
Epoch 8 - Loss: 6.4352
Epoch 9 - Loss: 6.1721
Epoch 10 - Loss: 5.9632
Epoch 1 - Loss: 0.0563
Epoch 2 - Loss: 0.0387
Epoch 3 - Loss: 0.0309
Epoch 4 - Loss: 0.0278
Epoch 5 - Loss: 0.0245


In [8]:
eval_accuracy(student, testloader, quiet=False)
print(f"Latency: {student.get_avg_inference_time():.4}")
print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")
student.clear_stats()

Accuracy: 99.22%
Latency: 0.006367
Transmission Size: 784.0 bytes


In [9]:
from bottlefit_quantizer import Quantizer
quantizer = Quantizer(4)

c = torch.full((5,3), 5)

y = quantizer.quantize(c)
print(y)
print(y.shape)

y = torch.tensor(y)

y.element_size() * y.nelement()

tensor([1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1], dtype=torch.uint8)
torch.Size([60])


  y = torch.tensor(y)


60

In [10]:
y.element_size()

1

In [11]:
y.nelement()

60

In [12]:
len(y)

60