In [1]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from compressai.entropy_models import EntropyBottleneck
from mnist_cnn import MNIST
from model_splitting import split_model
from lat_acc_test_funcs import eval_accuracy
from HECS import hecs_bottleneck, training_stage1, training_stage2
from itertools import product

  @amp.autocast(enabled=False)


In [2]:
trainset, validset, testset = MNIST()
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
validloader = DataLoader(validset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=1, shuffle=False)

In [3]:
model = torch.load('models/MNIST_CNN_Complex.pt', weights_only=False)
head, tail = split_model(model, 10)

In [4]:
encoder = nn.Sequential(
    nn.Conv2d(1, 32, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2), 
    nn.Conv2d(64, 4, 3, padding=1),
)

decoder = nn.Sequential(
    nn.ConvTranspose2d(4, 32, 3, padding=1),
    nn.ReLU(),
    nn.ConvTranspose2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
)

bottleneck = EntropyBottleneck(4)

student = hecs_bottleneck(encoder, bottleneck, decoder, copy.deepcopy(tail))

In [5]:
a_values = [0.1, 0.3, 0.5, 0.7, 0.9]
T_values = [1, 2, 4]
beta_values = [0.001, 0.01, 0.1]

results = []
for a, T, beta in product(a_values, T_values, beta_values):
    print(f"\n=== Running a={a}, T={T}, beta={beta} ===")
    training_stage1(model, student, trainloader, epochs=5, a=a, T=T, beta=beta, quiet=True)

    acc = eval_accuracy(student, validloader, quiet=False)
    print(f"Latency: {student.get_avg_inference_time():.4}")
    print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")

    results.append((a, T, beta, acc, student.get_avg_inference_time(), student.get_avg_transmission_sizes()))
    student.clear_stats()



=== Running a=0.1, T=1, beta=0.001 ===
Accuracy: 98.74%
Latency: 0.2396
Transmission Size: (392000.0, 5045.6) bytes

=== Running a=0.1, T=1, beta=0.01 ===
Accuracy: 95.18%
Latency: 0.229
Transmission Size: (392000.0, 2535.2) bytes

=== Running a=0.1, T=1, beta=0.1 ===
Accuracy: 58.84%
Latency: 0.2365
Transmission Size: (392000.0, 2000.0) bytes

=== Running a=0.1, T=2, beta=0.001 ===
Accuracy: 98.86%
Latency: 0.237
Transmission Size: (392000.0, 6007.0) bytes

=== Running a=0.1, T=2, beta=0.01 ===
Accuracy: 98.60%
Latency: 0.2329
Transmission Size: (392000.0, 3017.9) bytes

=== Running a=0.1, T=2, beta=0.1 ===
Accuracy: 94.78%
Latency: 0.1825
Transmission Size: (392000.0, 2202.8) bytes

=== Running a=0.1, T=4, beta=0.001 ===
Accuracy: 99.22%
Latency: 0.1721
Transmission Size: (392000.0, 8003.6) bytes

=== Running a=0.1, T=4, beta=0.01 ===
Accuracy: 99.02%
Latency: 0.1927
Transmission Size: (392000.0, 3879.0) bytes

=== Running a=0.1, T=4, beta=0.1 ===
Accuracy: 98.14%
Latency: 0.1924
Tr

In [None]:
results.filter()

[(0.1, 1, 0.001, 98.74, 0.2395565390586853, (392000.0, 5045.6)),
 (0.1, 1, 0.01, 95.18, 0.2290206789970398, (392000.0, 2535.2)),
 (0.1, 1, 0.1, 58.84, 0.2364691376686096, (392000.0, 2000.0)),
 (0.1, 2, 0.001, 98.86, 0.23698900938034057, (392000.0, 6007.0)),
 (0.1, 2, 0.01, 98.6, 0.2328951895236969, (392000.0, 3017.9)),
 (0.1, 2, 0.1, 94.78, 0.18251324892044068, (392000.0, 2202.8)),
 (0.1, 4, 0.001, 99.22, 0.17210156321525574, (392000.0, 8003.6)),
 (0.1, 4, 0.01, 99.02, 0.1926710546016693, (392000.0, 3879.0)),
 (0.1, 4, 0.1, 98.14, 0.19238696098327637, (392000.0, 2429.7)),
 (0.3, 1, 0.001, 98.38, 0.17650046348571777, (392000.0, 2856.6)),
 (0.3, 1, 0.01, 95.86, 0.19117312431335448, (392000.0, 2214.0)),
 (0.3, 1, 0.1, 80.12, 0.19246523976325988, (392000.0, 2000.0)),
 (0.3, 2, 0.001, 99.04, 0.2567841649055481, (392000.0, 3445.9)),
 (0.3, 2, 0.01, 97.14, 0.19354044198989867, (392000.0, 2375.2)),
 (0.3, 2, 0.1, 92.3, 0.22992007732391356, (392000.0, 2038.4)),
 (0.3, 4, 0.001, 99.14, 0.1927240

In [10]:
list(filter(lambda r: r[3] > 90 and r[4] < 0.2 and r[5][1] < 2500, results))

[(0.1, 2, 0.1, 94.78, 0.18251324892044068, (392000.0, 2202.8)),
 (0.1, 4, 0.1, 98.14, 0.19238696098327637, (392000.0, 2429.7)),
 (0.3, 1, 0.01, 95.86, 0.19117312431335448, (392000.0, 2214.0)),
 (0.3, 2, 0.01, 97.14, 0.19354044198989867, (392000.0, 2375.2)),
 (0.3, 4, 0.1, 96.08, 0.19130685329437255, (392000.0, 2135.8)),
 (0.5, 1, 0.01, 94.74, 0.1917216360569, (392000.0, 2075.3)),
 (0.5, 2, 0.01, 95.62, 0.1923802375793457, (392000.0, 2186.6)),
 (0.5, 4, 0.1, 94.0, 0.1893610417842865, (392000.0, 2054.0)),
 (0.7, 1, 0.001, 96.5, 0.19287680387496947, (392000.0, 2366.8)),
 (0.7, 1, 0.01, 94.54, 0.19160792827606202, (392000.0, 2050.7)),
 (0.7, 2, 0.01, 95.08, 0.18250105977058412, (392000.0, 2104.6)),
 (0.9, 1, 0.001, 96.28, 0.19173746705055236, (392000.0, 2344.0)),
 (0.9, 1, 0.01, 93.06, 0.19347615242004396, (392000.0, 2023.7))]

In [6]:
eval_accuracy(student, testloader, quiet=False)
print(f"Latency: {student.get_avg_inference_time():.4}")
print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")
student.clear_stats()

Accuracy: 88.96%
Latency: 0.03368
Transmission Size: (3136.0, 12.28) bytes


In [8]:
training_stage2(student, trainloader, quiet=False)

Epoch 1/10, Loss: 0.8937
Epoch 2/10, Loss: 0.8107
Epoch 3/10, Loss: 0.7671
Epoch 4/10, Loss: 0.7392
Epoch 5/10, Loss: 0.7226
Epoch 6/10, Loss: 0.7041
Epoch 7/10, Loss: 0.6974
Epoch 8/10, Loss: 0.6831
Epoch 9/10, Loss: 0.6712
Epoch 10/10, Loss: 0.6638


In [9]:
eval_accuracy(student, testloader, quiet=False)
print(f"Latency: {student.get_avg_inference_time():.4}")
print(f"Transmission Size: {student.get_avg_transmission_sizes()} bytes")
student.clear_stats()

Accuracy: 83.78%
Latency: 0.03239
Transmission Size: (3136.0, 8.7584) bytes


In [7]:
torch.save(student, './models/MNIST_Complex_HECS.pt')