In [1]:
# %load_ext autoreload
# %autoreload 2

# %pip install -r requirements.txt

In [2]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets
import pickle

In [3]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0+cu118
Is MPS (Metal Performance Shader) built? False
Is MPS available? False
Is CUDA available? True
Using device: cuda


In [4]:
model_type = "densenet_cifar100" # either "resnet" or "vgg_cifar10" or "vgg_cifar100"

model_path = os.path.join("models", model_type)

num_classes = 100 if "cifar100" in model_type else 10

In [5]:
from DataLoader import CustomDataset
import numpy as np

if model_type == "resnet":
    # use the imagenette dataset
    hf_dataset = load_dataset("frgfm/imagenette", '320px')
    hf_dataset = concatenate_datasets(hf_dataset.values())
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

elif model_type == "vgg_cifar10":
    # use the cifar10 dataset
    hf_dataset = load_dataset("cifar10")
    hf_dataset = concatenate_datasets(hf_dataset.values())
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.4865, 0.4409],
                             std=[0.2673, 0.2564, 0.2761])
    ])
elif model_type == "vgg_cifar100" or model_type == "densenet_cifar100":
    # use the cifar100 dataset
    hf_dataset = load_dataset("cifar100")
    hf_dataset = concatenate_datasets(hf_dataset.values())
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
    
torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32 if model_type == "resnet" else 64

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

In [6]:
import ModelLoader

loader = ModelLoader.ModelLoader(model_type, device)

model = loader.load_model(num_outputs=num_classes)

model

Loading EarlyExit DenseNet121 model architecture...
Adding exits...


EarlyExitModel(
  (model): DenseNet(
    (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dense1): Sequential(
      (0): Bottleneck(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): Bottleneck(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): Bottleneck(
        (bn1): BatchNorm2d(48, eps

## Early Exit Model Training

If you want tensorboard support, you need to run the following commands on **macOS**:

This clears your logs:
```
rm -rf runs/*
```

This shows your public IP address:
```
dig -4 TXT +short o-o.myaddr.l.google.com @ns1.google.com
```

This starts tensorboard at this address:
```
tensorboard --host 0.0.0.0 --logdir={model_path}/runs
```
```



If you want tensorboard support, you need to run the following commands on **Windows**:

This clears your logs:
```
rmdir /s /q runs
```

This shows your IP address:
```
powershell -command "$ipAddress = (Invoke-WebRequest -Uri 'http://ipinfo.io/ip').Content.Trim(); Write-Host 'Your IP address is: ' $ipAddress"  
```

This starts tensorboard at this address:
```
tensorboard --host 0.0.0.0 --logdir={model_path}\runs
```

In [7]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device, model_dir=model_path)

In [8]:
# train the classifiers
trainer.train_classifiers(train_dataloader, epoch_count=30, validation_loader=test_dataloader)


Training classifier for exit 1


                                                                                                    

Epoch 0 Loss 3.626804178873698
Epoch 0 Accuracy 0.22164583333333332
Epoch 0 Validation Loss 3.1221434448627714
Epoch 0 Validation Accuracy 0.29596077127659576


                                                                                                    

Epoch 1 Loss 2.8041664861043296
Epoch 1 Accuracy 0.3617708333333333
Epoch 1 Validation Loss 2.858804517603935
Epoch 1 Validation Accuracy 0.3346908244680851


                                                                                                    

Epoch 2 Loss 2.537349071184794
Epoch 2 Accuracy 0.41222916666666665
Epoch 2 Validation Loss 2.7860132151461663
Epoch 2 Validation Accuracy 0.34507978723404253


                                                                                                    

Epoch 3 Loss 2.3903164149920144
Epoch 3 Accuracy 0.4448125
Epoch 3 Validation Loss 2.7754577550482242
Epoch 3 Validation Accuracy 0.3507313829787234


                                                                                                    

Epoch 4 Loss 2.2939477372169494
Epoch 4 Accuracy 0.4671875
Epoch 4 Validation Loss 2.7917573807087352
Epoch 4 Validation Accuracy 0.3507313829787234


                                                                                                    

Epoch 5 Loss 2.2244452366828917
Epoch 5 Accuracy 0.4839791666666667
Epoch 5 Validation Loss 2.8224568138731287
Epoch 5 Validation Accuracy 0.3503989361702128


                                                                                                    

Epoch 6 Loss 2.171733702818553
Epoch 6 Accuracy 0.5000625
Epoch 6 Validation Loss 2.8618742899691805
Epoch 6 Validation Accuracy 0.3480718085106383
Validation accuracies are decreasing, stopping training early
Training classifier for exit 2


                                                                                                    

Epoch 0 Loss 3.641574485460917
Epoch 0 Accuracy 0.29129166666666667
Epoch 0 Validation Loss 2.9361559063830276
Epoch 0 Validation Accuracy 0.4486369680851064


                                                                                                    

Epoch 1 Loss 2.5877591670354207
Epoch 1 Accuracy 0.5170416666666666
Epoch 1 Validation Loss 2.3850791581133577
Epoch 1 Validation Accuracy 0.5301695478723404


                                                                                                    

Epoch 2 Loss 2.167040139834086
Epoch 2 Accuracy 0.58725
Epoch 2 Validation Loss 2.1351362092697874
Epoch 2 Validation Accuracy 0.5662400265957447


                                                                                                    

Epoch 3 Loss 1.9385779767036437
Epoch 3 Accuracy 0.6223333333333333
Epoch 3 Validation Loss 1.9961611324168267
Epoch 3 Validation Accuracy 0.5802027925531915


                                                                                                    

Epoch 4 Loss 1.7912328456242879
Epoch 4 Accuracy 0.6455625
Epoch 4 Validation Loss 1.9091938372622146
Epoch 4 Validation Accuracy 0.5878490691489362


                                                                                                    

Epoch 5 Loss 1.6859785148302715
Epoch 5 Accuracy 0.6616666666666666
Epoch 5 Validation Loss 1.8502489024020257
Epoch 5 Validation Accuracy 0.5920877659574468


                                                                                                    

Epoch 6 Loss 1.605444438457489
Epoch 6 Accuracy 0.6758541666666666
Epoch 6 Validation Loss 1.8083551031477907
Epoch 6 Validation Accuracy 0.5929188829787234


                                                                                                    

Epoch 7 Loss 1.5410148215293884
Epoch 7 Accuracy 0.6872291666666667
Epoch 7 Validation Loss 1.7777915539893698
Epoch 7 Validation Accuracy 0.5933344414893617


                                                                                                    

Epoch 8 Loss 1.48753169465065
Epoch 8 Accuracy 0.6960833333333334
Epoch 8 Validation Loss 1.7548197843927018
Epoch 8 Validation Accuracy 0.593999335106383


                                                                                                    

Epoch 9 Loss 1.4420182983080545
Epoch 9 Accuracy 0.7043125
Epoch 9 Validation Loss 1.737344485014043
Epoch 9 Validation Accuracy 0.5949135638297872


                                                                                                    

Epoch 10 Loss 1.4024857448736827
Epoch 10 Accuracy 0.7107916666666667
Epoch 10 Validation Loss 1.7240431790656232
Epoch 10 Validation Accuracy 0.5949966755319149


                                                                                                    

Epoch 11 Loss 1.3676384650071463
Epoch 11 Accuracy 0.7165625
Epoch 11 Validation Loss 1.7140243877755834
Epoch 11 Validation Accuracy 0.5940824468085106


                                                                                                    

Epoch 12 Loss 1.3365441302458445
Epoch 12 Accuracy 0.7216875
Epoch 12 Validation Loss 1.7065529271643212
Epoch 12 Validation Accuracy 0.593500664893617
Validation accuracies are decreasing, stopping training early
Training classifier for exit 3


                                                                                                    

Epoch 0 Loss 4.05200745010376
Epoch 0 Accuracy 0.47979166666666667
Epoch 0 Validation Loss 3.5129033377830017
Epoch 0 Validation Accuracy 0.6708776595744681


                                                                                                    

Epoch 1 Loss 3.1731879800160727
Epoch 1 Accuracy 0.6887916666666667
Epoch 1 Validation Loss 2.7959421277046204
Epoch 1 Validation Accuracy 0.7164228723404256


                                                                                                    

Epoch 2 Loss 2.5948354187011717
Epoch 2 Accuracy 0.7371041666666667
Epoch 2 Validation Loss 2.3324113589652042
Epoch 2 Validation Accuracy 0.7498337765957447


                                                                                                    

Epoch 3 Loss 2.2103522667884827
Epoch 3 Accuracy 0.7684166666666666
Epoch 3 Validation Loss 2.0227562550534595
Epoch 3 Validation Accuracy 0.7725232712765957


                                                                                                    

Epoch 4 Loss 1.946616992632548
Epoch 4 Accuracy 0.7863333333333333
Epoch 4 Validation Loss 1.808221308474845
Epoch 4 Validation Accuracy 0.7842420212765957


                                                                                                    

Epoch 5 Loss 1.759454741160075
Epoch 5 Accuracy 0.7978958333333334
Epoch 5 Validation Loss 1.6546130960292005
Epoch 5 Validation Accuracy 0.7911402925531915


                                                                                                    

Epoch 6 Loss 1.6221971033414204
Epoch 6 Accuracy 0.8058125
Epoch 6 Validation Loss 1.5413852796909657
Epoch 6 Validation Accuracy 0.7977892287234043


                                                                                                    

Epoch 7 Loss 1.5183940420150757
Epoch 7 Accuracy 0.8113541666666667
Epoch 7 Validation Loss 1.4555389120223674
Epoch 7 Validation Accuracy 0.800282579787234


                                                                                                    

Epoch 8 Loss 1.4375369044939676
Epoch 8 Accuracy 0.81525
Epoch 8 Validation Loss 1.3887425781564509
Epoch 8 Validation Accuracy 0.8026097074468085


                                                                                                    

Epoch 9 Loss 1.372887390613556
Epoch 9 Accuracy 0.8179791666666667
Epoch 9 Validation Loss 1.3354376643262011
Epoch 9 Validation Accuracy 0.8029421542553191


                                                                                                    

Epoch 10 Loss 1.319995778163274
Epoch 10 Accuracy 0.8199791666666667
Epoch 10 Validation Loss 1.2920682068834914
Epoch 10 Validation Accuracy 0.8043550531914894


                                                                                                    

Epoch 11 Loss 1.2758773959477743
Epoch 11 Accuracy 0.822
Epoch 11 Validation Loss 1.2561842203140259
Epoch 11 Validation Accuracy 0.8052692819148937


                                                                                                    

Epoch 12 Loss 1.2384796843528747
Epoch 12 Accuracy 0.8237708333333333
Epoch 12 Validation Loss 1.2260332037793829
Epoch 12 Validation Accuracy 0.8057679521276596


                                                                                                    

Epoch 13 Loss 1.2063152598539988
Epoch 13 Accuracy 0.8251875
Epoch 13 Validation Loss 1.20037079871969
Epoch 13 Validation Accuracy 0.8062666223404256


                                                                                                    

Epoch 14 Loss 1.1783359168370564
Epoch 14 Accuracy 0.8270625
Epoch 14 Validation Loss 1.178329793379662
Epoch 14 Validation Accuracy 0.8067652925531915


                                                                                                    

Epoch 15 Loss 1.1537036230564117
Epoch 15 Accuracy 0.8281875
Epoch 15 Validation Loss 1.159135476705876
Epoch 15 Validation Accuracy 0.8067652925531915


                                                                                                    

Epoch 16 Loss 1.1317985257307688
Epoch 16 Accuracy 0.8293125
Epoch 16 Validation Loss 1.1423179256789229
Epoch 16 Validation Accuracy 0.8071808510638298


                                                                                                    

Epoch 17 Loss 1.1121978398164114
Epoch 17 Accuracy 0.831125
Epoch 17 Validation Loss 1.1274809295192678
Epoch 17 Validation Accuracy 0.8067652925531915


                                                                                                    

Epoch 18 Loss 1.0945178136825562
Epoch 18 Accuracy 0.8322708333333333
Epoch 18 Validation Loss 1.1143148579496018
Epoch 18 Validation Accuracy 0.8066821808510638
Validation accuracies are decreasing, stopping training early
Training final classifier


                                                                                                    

Epoch 0 Loss 4.071361401240031
Epoch 0 Accuracy 0.227875
Epoch 0 Validation Loss 3.46170814113414
Epoch 0 Validation Accuracy 0.547623005319149


                                                                                                    

Epoch 1 Loss 3.006620939254761
Epoch 1 Accuracy 0.6860625
Epoch 1 Validation Loss 2.4684126377105713
Epoch 1 Validation Accuracy 0.796126994680851


                                                                                                    

Epoch 2 Loss 2.1616322231292724
Epoch 2 Accuracy 0.8062708333333334
Epoch 2 Validation Loss 1.719908615376087
Epoch 2 Validation Accuracy 0.8456615691489362


                                                                                                    

Epoch 3 Loss 1.5509364056587218
Epoch 3 Accuracy 0.8376875
Epoch 3 Validation Loss 1.2205448486703507
Epoch 3 Validation Accuracy 0.8617021276595744


                                                                                                    

Epoch 4 Loss 1.1538728663921356
Epoch 4 Accuracy 0.8506458333333333
Epoch 4 Validation Loss 0.919031782353178
Epoch 4 Validation Accuracy 0.8699301861702128


                                                                                                    

Epoch 5 Loss 0.9144474639892578
Epoch 5 Accuracy 0.8565
Epoch 5 Validation Loss 0.7461583652394883
Epoch 5 Validation Accuracy 0.8735039893617021


                                                                                                    

Epoch 6 Loss 0.7735398157835006
Epoch 6 Accuracy 0.8591458333333334
Epoch 6 Validation Loss 0.6468963363069169
Epoch 6 Validation Accuracy 0.8743351063829787


                                                                                                    

Epoch 7 Loss 0.6892132031122843
Epoch 7 Accuracy 0.8608333333333333
Epoch 7 Validation Loss 0.5881621429419264
Epoch 7 Validation Accuracy 0.8750831117021277


                                                                                                    

Epoch 8 Loss 0.6372130912542343
Epoch 8 Accuracy 0.8620208333333333
Epoch 8 Validation Loss 0.5523561612564198
Epoch 8 Validation Accuracy 0.875748005319149


                                                                                                    

Epoch 9 Loss 0.6040735071500143
Epoch 9 Accuracy 0.8627083333333333
Epoch 9 Validation Loss 0.5298923399854214
Epoch 9 Validation Accuracy 0.8771609042553191


                                                                                                    

Epoch 10 Loss 0.5825301775137584
Epoch 10 Accuracy 0.8630416666666667
Epoch 10 Validation Loss 0.5156401345862988
Epoch 10 Validation Accuracy 0.8770777925531915


                                                                                                    

Epoch 11 Loss 0.5683971085349718
Epoch 11 Accuracy 0.8632083333333334
Epoch 11 Validation Loss 0.5067133265448377
Epoch 11 Validation Accuracy 0.8778257978723404


                                                                                                    

Epoch 12 Loss 0.5592639628052711
Epoch 12 Accuracy 0.8637291666666667
Epoch 12 Validation Loss 0.5013603398457487
Epoch 12 Validation Accuracy 0.8780751329787234


                                                                                                    

Epoch 13 Loss 0.5536067534883817
Epoch 13 Accuracy 0.8638541666666667
Epoch 13 Validation Loss 0.49858006066147315
Epoch 13 Validation Accuracy 0.8782413563829787


                                                                                                    

Epoch 14 Loss 0.5504524027705192
Epoch 14 Accuracy 0.8640416666666667
Epoch 14 Validation Loss 0.49758800317315344
Epoch 14 Validation Accuracy 0.8785738031914894


                                                                                                    

Epoch 15 Loss 0.5491201723416647
Epoch 15 Accuracy 0.8642291666666667
Epoch 15 Validation Loss 0.4979155827709969
Epoch 15 Validation Accuracy 0.878656914893617


                                                                                                    

Epoch 16 Loss 0.549165763159593
Epoch 16 Accuracy 0.8643125
Epoch 16 Validation Loss 0.4992472084754325
Epoch 16 Validation Accuracy 0.8787400265957447


                                                                                                    

Epoch 17 Loss 0.5502656107743581
Epoch 17 Accuracy 0.86425
Epoch 17 Validation Loss 0.5013238047348693
Epoch 17 Validation Accuracy 0.8787400265957447


                                                                                                    

Epoch 18 Loss 0.5521711937785149
Epoch 18 Accuracy 0.8642916666666667
Epoch 18 Validation Loss 0.503999254567192
Epoch 18 Validation Accuracy 0.87890625


                                                                                                    

Epoch 19 Loss 0.5547117105722428
Epoch 19 Accuracy 0.8643125
Epoch 19 Validation Loss 0.5071509074657521
Epoch 19 Validation Accuracy 0.879404920212766


                                                                                                    

Epoch 20 Loss 0.5577881837288539
Epoch 20 Accuracy 0.8646458333333333
Epoch 20 Validation Loss 0.5106689823751754
Epoch 20 Validation Accuracy 0.8795711436170213


                                                                                                    

Epoch 21 Loss 0.5612681451042493
Epoch 21 Accuracy 0.8645833333333334
Epoch 21 Validation Loss 0.5145007645354626
Epoch 21 Validation Accuracy 0.8798204787234043


                                                                                                    

Epoch 22 Loss 0.5650871676405271
Epoch 22 Accuracy 0.8647083333333333
Epoch 22 Validation Loss 0.5185690947035526
Epoch 22 Validation Accuracy 0.879654255319149


                                                                                                    

Epoch 23 Loss 0.5691711524228255
Epoch 23 Accuracy 0.8647708333333334
Epoch 23 Validation Loss 0.5228287837606795
Epoch 23 Validation Accuracy 0.879654255319149


                                                                                                    

Epoch 24 Loss 0.5734692712128162
Epoch 24 Accuracy 0.8648333333333333
Epoch 24 Validation Loss 0.5272278325988892
Epoch 24 Validation Accuracy 0.8797373670212766


                                                                                                    

Epoch 25 Loss 0.5779502451817194
Epoch 25 Accuracy 0.8648333333333333
Epoch 25 Validation Loss 0.5317449408801312
Epoch 25 Validation Accuracy 0.8797373670212766


                                                                                                    

Epoch 26 Loss 0.5825937119026979
Epoch 26 Accuracy 0.8648333333333333
Epoch 26 Validation Loss 0.5363723434349323
Epoch 26 Validation Accuracy 0.8797373670212766


                                                                                                    

Epoch 27 Loss 0.5873580257892609
Epoch 27 Accuracy 0.8647916666666666
Epoch 27 Validation Loss 0.5410728501354126
Epoch 27 Validation Accuracy 0.8794880319148937


                                                                                                    

Epoch 28 Loss 0.5922221521536509
Epoch 28 Accuracy 0.8646458333333333
Epoch 28 Validation Loss 0.5458329455491077
Epoch 28 Validation Accuracy 0.8794880319148937


                                                                                                    

Epoch 29 Loss 0.5971557800869147
Epoch 29 Accuracy 0.8648541666666667
Epoch 29 Validation Loss 0.5506236114876067
Epoch 29 Validation Accuracy 0.8795711436170213
