In [1]:
from early_exit.get_early_exit import create_networks
from models import vgg19_bn
import torch.nn as nn
from copy import deepcopy

In [2]:
class SplitModel(nn.Module):
    def __init__(self,pretrained_model):
        super(SplitModel, self).__init__() 
        self.device_side_net = deepcopy(pretrained_model.features[:14])
        self.server_side_net = deepcopy(pretrained_model.features[14:27])
        
        self.cloud_net = nn.Sequential(
            deepcopy(pretrained_model.features[27:]),
            deepcopy(pretrained_model.avgpool),
            nn.Flatten(),
            deepcopy(pretrained_model.classifier)
        )
        self.net = nn.Sequential(
            self.device_side_net,
            self.server_side_net,
            self.cloud_net
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [4]:
import torchvision
from torchvision import  transforms
import torch
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

testset = torchvision.datasets.CIFAR10(root='data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)




Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 22418811.48it/s]


Extracting data/cifar-10-python.tar.gz to data


In [5]:
model = vgg19_bn()
model = SplitModel(model)


In [6]:
# Define the input parameters
input_shape = (1, 3, 32, 32)
thresholds = [0.8, 0.9, 0.9]
neurons_in_exit_layers = [[1024,1025], [1024,1024]]
epochs = 5
train_dataloader = testloader
test_dataloader = testloader
optimizer = torch.optim.SGD
optimizer_parameters = {'lr': 0.001}
criterion = torch.nn.CrossEntropyLoss()
training_method = 'whole_network'

# Call the create_networks function
networks, train_losses, test_accuracies = create_networks(
    model,
    input_shape,
    thresholds,
    neurons_in_exit_layers,
    epochs,
    train_dataloader,
    test_dataloader,
    optimizer,
    optimizer_parameters,
    criterion,
    training_method
)

0 loss:  2.2984901827812196
1 loss:  2.275150856018066
2 loss:  2.2249294378757476
3 loss:  2.1505331762313844
4 loss:  2.08802056684494
Accuracy of exit 0: 31.16%
Accuracy of exit 1: 26.61%
Accuracy of exit 2: 32.65%


In [8]:
test_accuracies

{'exit_0': 31.16, 'exit_1': 26.61, 'exit_2': 32.65}

In [7]:
train_losses

[2.2984901827812196,
 2.275150856018066,
 2.2249294378757476,
 2.1505331762313844,
 2.08802056684494]