In [1]:
import torch
import torch.nn as nn
from compressai.entropy_models import EntropyBottleneck
from model_splitting import split_model
from mnist_cnn import MNIST
from torch.utils.data import DataLoader
from Tx_Rx import transmit, receive
import threading
import time
from bottlefit_injection import stage2_training
import copy

  @amp.autocast(enabled=False)


In [7]:
from HECS import hecs_bottleneck, training_stage1, training_stage2

In [2]:
model = torch.load('models/MNIST_Bottlefit.pt', weights_only=False).to('mps')
# nn.Sequential(*[layer for seq in model.children() for layer in seq])

In [3]:
nn.Sequential(model.children())

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ConvTranspose2d(4, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=3136, out_features=128, bias=True)
    (10): ReLU()
    (11): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [23]:
len(list(model.children()))

3

In [4]:
list(head.children())

[Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)]

In [5]:
list(tail.children())

[ConvTranspose2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Flatten(start_dim=1, end_dim=-1),
 Linear(in_features=6272, out_features=512, bias=True),
 ReLU(),
 Linear(in_features=512, out_features=256, bias=True),
 ReLU(),
 Linear(in_features=256, out_features=128, bias=True),
 ReLU(),
 Linear(in_features=128, out_features=10, bias=True)]

In [8]:

trainset, _, testset = MNIST()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
    
encoder = nn.Sequential(
    nn.Conv2d(1, 32, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2), 
    nn.Conv2d(64, 8, 3, padding=1),
)

decoder = nn.Sequential(
    nn.ConvTranspose2d(8, 32, 3, padding=1),
    nn.ReLU(),
    nn.ConvTranspose2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
)

bottleneck = EntropyBottleneck(8)

model = torch.load('models/MNIST_CNN_Complex.pt', weights_only=False).to('mps')
head, tail = split_model(model, 10)

student = hecs_bottleneck(encoder, bottleneck, decoder, copy.deepcopy(tail))

In [9]:
list(tail.children())

[Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(),
 Flatten(start_dim=1, end_dim=-1),
 Linear(in_features=6272, out_features=512, bias=True),
 ReLU(),
 Linear(in_features=512, out_features=256, bias=True),
 ReLU(),
 Linear(in_features=256, out_features=128, bias=True),
 ReLU(),
 Linear(in_features=128, out_features=10, bias=True)]

In [4]:
training_stage1(model, student, trainloader, epochs=10, lr=0.01, quiet=False)

Epoch 1/10, Loss: 332.3908
Epoch 2/10, Loss: 14.8213
Epoch 3/10, Loss: 9.1526
Epoch 4/10, Loss: 7.7052
Epoch 5/10, Loss: 6.8733
Epoch 6/10, Loss: 6.6553
Epoch 7/10, Loss: 6.3463
Epoch 8/10, Loss: 6.2874
Epoch 9/10, Loss: 6.2260
Epoch 10/10, Loss: 6.1512


In [5]:
student.eval()

total = 0
correct = 0
times = []
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to('cpu'), labels.to('cpu')

        time.sleep(0.001)  # Add delay between batches to simulate real-world conditions (also prevents crash with small batches)

        torch.mps.synchronize()
        start_time = time.time()
        outputs = student(inputs)
        torch.mps.synchronize()
        end_time = time.time()

        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()

        total += labels.size(0)
        times.append(end_time - start_time)

accuracy = 100 * correct / total
latency =  sum(times) / len(times)

print(f"Accuracy: {accuracy:.2f}%")
print(f"Latency per batch: {latency:.4f} seconds")

# Accuracy: 61.68%
# Latency per batch: 0.1203 seconds

Accuracy: 11.46%
Latency per batch: 0.1073 seconds


In [10]:
training_stage2(model, student, trainloader, quiet=False)

Epoch 1/10, Loss: 0.9927
Epoch 2/10, Loss: 0.5897
Epoch 3/10, Loss: 0.4424
Epoch 4/10, Loss: 0.3378
Epoch 5/10, Loss: 0.2703
Epoch 6/10, Loss: 0.2178
Epoch 7/10, Loss: 0.1771
Epoch 8/10, Loss: 0.1495
Epoch 9/10, Loss: 0.1268
Epoch 10/10, Loss: 0.1040


In [None]:
total = 0
correct = 0
times = []
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to('cpu'), labels.to('cpu')

        time.sleep(0.001)  # Add delay between batches to simulate real-world conditions (also prevents crash with small batches)

        torch.mps.synchronize()
        start_time = time.time()
        outputs = student(inputs)
        torch.mps.synchronize()
        end_time = time.time()

        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()

        total += labels.size(0)
        times.append(end_time - start_time)



accuracy = 100 * correct / total
latency =  sum(times) / len(times)

print(f"Accuracy: {accuracy:.2f}%")
print(f"Latency per batch: {latency:.4f} seconds")

# Accuracy: 97.06%
# Latency per batch: 0.1240 seconds

Accuracy: 97.06%
Latency per batch: 0.1240 seconds
