In [1]:
%load_ext autoreload
%autoreload 2

%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets
import pickle

In [3]:
# load in the pretrained vgg code
import shutil

if not os.path.exists('cifar10_models'):
    !git clone https://github.com/huyvnphan/PyTorch_CIFAR10
    
    # copy cifar10_models folder to current directory
    shutil.copytree(os.path.join("PyTorch_CIFAR10", "cifar10_models"), "cifar10_models")
    
    # delete the cloned repo
    shutil.rmtree("PyTorch_CIFAR10")
    

from cifar10_models import *

In [4]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [5]:
from EarlyExitModel import EarlyExitModel

def create_model_with_exits(model_type, num_classes):
    if model_type == "resnet":
        resnet = models.resnet50(pretrained=True)
        
        # set requires_grad to False to freeze the parameters
        for param in resnet.parameters():
            param.requires_grad = False
        
        resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
        
        model = EarlyExitModel(resnet, num_classes, device)
        model.clear_exits()
        exit_layers = [model.add_exit(layer, model_type) for layer in ('layer1', 'layer2', 'layer3')]
        
    elif model_type == "vgg":
        # model loaded from https://github.com/huyvnphan/PyTorch_CIFAR10
        # use this code to generate a pretrained model and pickle/save it
        
        model_path = os.path.join('models', 'vgg', 'vgg11_bn.pkl')
        vgg = pickle.load(open(model_path, 'rb'))
        # set requires_grad to False to freeze the parameters
        for param in vgg.parameters():
            param.requires_grad = False
        vgg.classifier[-1] = nn.Linear(vgg.classifier[-1].in_features, num_classes)
        
        model = EarlyExitModel(vgg, num_classes, device)
        model.clear_exits()
        exit_layers = [model.add_exit(layer, model_type) for layer in ('features.8', 'features.15', 'features.22', 'avgpool')]
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    model.to(device)
    return model
        
        

In [6]:
# choose the model
model_type = "vgg" # either "resnet" or "vgg"

model_path = os.path.join("models", model_type)

model = create_model_with_exits(model_type, num_classes=10)

model

EarlyExitModel(
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): OptionalExitModule(
        (module): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU(inplace=True)
      (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): BatchNorm2d(256, eps=1e-05, mom

In [7]:
from DataLoader import CustomDataset
import numpy as np

if model_type == "resnet":
    # use the imagenette dataset
    hf_dataset = load_dataset("frgfm/imagenette", '320px')
    hf_dataset = concatenate_datasets(hf_dataset.values())
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

elif model_type == "vgg":
    # use the cifar10 dataset
    hf_dataset = load_dataset("cifar10")
    hf_dataset = concatenate_datasets(hf_dataset.values())
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
# get first image
img = hf_dataset[0]['img']
label = hf_dataset[0]['label']

print(f"Image shape: {np.array(img).shape}")
print(f"Label: {label}")


torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32 if model_type == "resnet" else 64

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

Image shape: (32, 32, 3)
Label: 0


## Early Exit Model Training

If you want tensorboard support, you need to run the following commands on **macOS**:

This clears your logs:
```
rm -rf runs/*
```

This shows your public IP address:
```
dig -4 TXT +short o-o.myaddr.l.google.com @ns1.google.com
```

This starts tensorboard at this address:
```
tensorboard --host 0.0.0.0 --logdir={model_path}/runs
```
```



If you want tensorboard support, you need to run the following commands on **Windows**:

This clears your logs:
```
rmdir /s /q runs
```

This shows your IP address:
```
powershell -command "$ipAddress = (Invoke-WebRequest -Uri 'http://ipinfo.io/ip').Content.Trim(); Write-Host 'Your IP address is: ' $ipAddress"  
```

This starts tensorboard at this address:
```
tensorboard --host 0.0.0.0 --logdir={model_path}\runs
```

In [8]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device, model_dir=model_path)

In [9]:
# train the classifiers
trainer.train_classifiers(train_dataloader, epoch_count=30, validation_loader=test_dataloader)


Training classifier for exit 1


                                                                                                    

Epoch 0 Loss 1.9695088030497232
Epoch 0 Accuracy 0.4980833333333333




Epoch 0 Validation Loss 1.7297682210485985
Epoch 0 Validation Accuracy 0.5691489361702128


                                                                                                    

Epoch 1 Loss 1.5724122323989869
Epoch 1 Accuracy 0.5925416666666666




Epoch 1 Validation Loss 1.4693488225023796
Epoch 1 Validation Accuracy 0.6018949468085106


                                                                                                    

Epoch 2 Loss 1.3724816891352336
Epoch 2 Accuracy 0.6195




Epoch 2 Validation Loss 1.32303135636005
Epoch 2 Validation Accuracy 0.6193484042553191


                                                                                                    

Epoch 3 Loss 1.2514568719863892
Epoch 3 Accuracy 0.6380833333333333




Epoch 3 Validation Loss 1.2283820135796324
Epoch 3 Validation Accuracy 0.6363863031914894


                                                                                                    

Epoch 4 Loss 1.1690987172921499
Epoch 4 Accuracy 0.6524791666666667




Epoch 4 Validation Loss 1.1614058718402336
Epoch 4 Validation Accuracy 0.648188164893617


                                                                                                    

Epoch 5 Loss 1.1086213686466218
Epoch 5 Accuracy 0.6631458333333333




Epoch 5 Validation Loss 1.1110648287737623
Epoch 5 Validation Accuracy 0.6565824468085106


                                                                                                    

Epoch 6 Loss 1.0618245236078898
Epoch 6 Accuracy 0.672375




Epoch 6 Validation Loss 1.0715562048110556
Epoch 6 Validation Accuracy 0.6644780585106383


                                                                                                    

Epoch 7 Loss 1.0242428472836813
Epoch 7 Accuracy 0.6810416666666667




Epoch 7 Validation Loss 1.0395877513479679
Epoch 7 Validation Accuracy 0.6719581117021277


                                                                                                    

Epoch 8 Loss 0.99322549478213
Epoch 8 Accuracy 0.68875




Epoch 8 Validation Loss 1.0131157424221648
Epoch 8 Validation Accuracy 0.6761968085106383


                                                                                                    

Epoch 9 Loss 0.9670765421390534
Epoch 9 Accuracy 0.6943958333333333




Epoch 9 Validation Loss 0.9907969670726898
Epoch 9 Validation Accuracy 0.6818484042553191


                                                                                                    

Epoch 10 Loss 0.9446423842112224
Epoch 10 Accuracy 0.6999166666666666




Epoch 10 Validation Loss 0.9716902140607225
Epoch 10 Validation Accuracy 0.6848404255319149


                                                                                                    

Epoch 11 Loss 0.9251346615155538
Epoch 11 Accuracy 0.7044583333333333




Epoch 11 Validation Loss 0.9551507513573829
Epoch 11 Validation Accuracy 0.6883311170212766


                                                                                                    

Epoch 12 Loss 0.9079930868148803
Epoch 12 Accuracy 0.7081875




Epoch 12 Validation Loss 0.9406906223677575
Epoch 12 Validation Accuracy 0.6899102393617021


                                                                                                    

Epoch 13 Loss 0.8927633100350698
Epoch 13 Accuracy 0.711875




Epoch 13 Validation Loss 0.927928089461428
Epoch 13 Validation Accuracy 0.6925698138297872


                                                                                                    

Epoch 14 Loss 0.8791285354296366
Epoch 14 Accuracy 0.715




Epoch 14 Validation Loss 0.916557325961742
Epoch 14 Validation Accuracy 0.6948969414893617


                                                                                                    

Epoch 15 Loss 0.8668293616771698
Epoch 15 Accuracy 0.7180625




Epoch 15 Validation Loss 0.90637310419945
Epoch 15 Validation Accuracy 0.6979720744680851


                                                                                                    

Epoch 16 Loss 0.8556658920447031
Epoch 16 Accuracy 0.7206666666666667




Epoch 16 Validation Loss 0.8972097705019281
Epoch 16 Validation Accuracy 0.7008809840425532


                                                                                                    

Epoch 17 Loss 0.8454723443984985
Epoch 17 Accuracy 0.7231666666666666




Epoch 17 Validation Loss 0.8889168402616013
Epoch 17 Validation Accuracy 0.702626329787234


                                                                                                    

Epoch 18 Loss 0.8361206556161245
Epoch 18 Accuracy 0.726125




Epoch 18 Validation Loss 0.881368486488119
Epoch 18 Validation Accuracy 0.7039561170212766


                                                                                                    

Epoch 19 Loss 0.8274975117047628
Epoch 19 Accuracy 0.7281666666666666




Epoch 19 Validation Loss 0.8744787107756797
Epoch 19 Validation Accuracy 0.7059507978723404


                                                                                                    

Epoch 20 Loss 0.8195094281435013
Epoch 20 Accuracy 0.7302083333333333




Epoch 20 Validation Loss 0.8681648040705539
Epoch 20 Validation Accuracy 0.7076961436170213


                                                                                                    

Epoch 21 Loss 0.8120908803939819
Epoch 21 Accuracy 0.7324166666666667




Epoch 21 Validation Loss 0.8623533432788038
Epoch 21 Validation Accuracy 0.7092752659574468


                                                                                                    

Epoch 22 Loss 0.8051838702360788
Epoch 22 Accuracy 0.7350625




Epoch 22 Validation Loss 0.8570115074832388
Epoch 22 Validation Accuracy 0.7108543882978723


                                                                                                    

Epoch 23 Loss 0.7987292038202286
Epoch 23 Accuracy 0.7368541666666667




Epoch 23 Validation Loss 0.8520746370579334
Epoch 23 Validation Accuracy 0.7112699468085106


                                                                                                    

Epoch 24 Loss 0.792672700603803
Epoch 24 Accuracy 0.7386458333333333




Epoch 24 Validation Loss 0.8474956278471236
Epoch 24 Validation Accuracy 0.7128490691489362


                                                                                                    

Epoch 25 Loss 0.7869799691836039
Epoch 25 Accuracy 0.7404166666666666




Epoch 25 Validation Loss 0.8432446334590303
Epoch 25 Validation Accuracy 0.7134308510638298


                                                                                                    

Epoch 26 Loss 0.7816111831267675
Epoch 26 Accuracy 0.7421458333333333




Epoch 26 Validation Loss 0.8392916680016416
Epoch 26 Validation Accuracy 0.7138464095744681


                                                                                                    

Epoch 27 Loss 0.7765401279926301
Epoch 27 Accuracy 0.7439166666666667




Epoch 27 Validation Loss 0.8356024210757398
Epoch 27 Validation Accuracy 0.7149268617021277


                                                                                                    

Epoch 28 Loss 0.7717383006413777
Epoch 28 Accuracy 0.7454166666666666




Epoch 28 Validation Loss 0.8321581321193817
Epoch 28 Validation Accuracy 0.7165059840425532


                                                                                                    

Epoch 29 Loss 0.767185189406077
Epoch 29 Accuracy 0.747125




Epoch 29 Validation Loss 0.8289405357964496
Epoch 29 Validation Accuracy 0.7173371010638298
Training classifier for exit 2


                                                                                                    

Epoch 0 Loss 2.109144184748332
Epoch 0 Accuracy 0.6539166666666667




Epoch 0 Validation Loss 1.9272087311491053
Epoch 0 Validation Accuracy 0.737782579787234


                                                                                                    

Epoch 1 Loss 1.7775250811576844
Epoch 1 Accuracy 0.7565208333333333




Epoch 1 Validation Loss 1.6317128748335736
Epoch 1 Validation Accuracy 0.7688663563829787


                                                                                                    

Epoch 2 Loss 1.516120148817698
Epoch 2 Accuracy 0.7808333333333334




Epoch 2 Validation Loss 1.4023914413249239
Epoch 2 Validation Accuracy 0.781748670212766


                                                                                                    

Epoch 3 Loss 1.314148777961731
Epoch 3 Accuracy 0.7925416666666667




Epoch 3 Validation Loss 1.2269517570099933
Epoch 3 Validation Accuracy 0.7913896276595744


                                                                                                    

Epoch 4 Loss 1.1593159280618033
Epoch 4 Accuracy 0.8010416666666667




Epoch 4 Validation Loss 1.0927684005904705
Epoch 4 Validation Accuracy 0.7978723404255319


                                                                                                    

Epoch 5 Loss 1.0399975946744282
Epoch 5 Accuracy 0.8079791666666667




Epoch 5 Validation Loss 0.9890794072379457
Epoch 5 Validation Accuracy 0.8032746010638298


                                                                                                    

Epoch 6 Loss 0.9467825048764547
Epoch 6 Accuracy 0.8143333333333334




Epoch 6 Validation Loss 0.9076278143106623
Epoch 6 Validation Accuracy 0.8068484042553191


                                                                                                    

Epoch 7 Loss 0.8726701966921488
Epoch 7 Accuracy 0.8191875




Epoch 7 Validation Loss 0.8424445701406357
Epoch 7 Validation Accuracy 0.8119182180851063


                                                                                                    

Epoch 8 Loss 0.8126735224723816
Epoch 8 Accuracy 0.8234375




Epoch 8 Validation Loss 0.789350209401009
Epoch 8 Validation Accuracy 0.816156914893617


                                                                                                    

Epoch 9 Loss 0.7632506414254506
Epoch 9 Accuracy 0.82775




Epoch 9 Validation Loss 0.745359646196061
Epoch 9 Validation Accuracy 0.8193982712765957


                                                                                                    

Epoch 10 Loss 0.7218988355795543
Epoch 10 Accuracy 0.8315




Epoch 10 Validation Loss 0.7083912818355763
Epoch 10 Validation Accuracy 0.8233876329787234


                                                                                                    

Epoch 11 Loss 0.6867986324628195
Epoch 11 Accuracy 0.8349583333333334




Epoch 11 Validation Loss 0.6768692389447638
Epoch 11 Validation Accuracy 0.8271276595744681


                                                                                                    

Epoch 12 Loss 0.6566227316061656
Epoch 12 Accuracy 0.8381666666666666




Epoch 12 Validation Loss 0.6496692884792673
Epoch 12 Validation Accuracy 0.8307014627659575


                                                                                                    

Epoch 13 Loss 0.630385461807251
Epoch 13 Accuracy 0.841125




Epoch 13 Validation Loss 0.6259599126716877
Epoch 13 Validation Accuracy 0.8342752659574468


                                                                                                    

Epoch 14 Loss 0.6073592750628789
Epoch 14 Accuracy 0.8436458333333333




Epoch 14 Validation Loss 0.6051074856139244
Epoch 14 Validation Accuracy 0.8361037234042553


                                                                                                    

Epoch 15 Loss 0.586974630355835
Epoch 15 Accuracy 0.8465416666666666




Epoch 15 Validation Loss 0.5866017091147443
Epoch 15 Validation Accuracy 0.8380152925531915


                                                                                                    

Epoch 16 Loss 0.5687825384934744
Epoch 16 Accuracy 0.8487708333333334




Epoch 16 Validation Loss 0.5700683493880515
Epoch 16 Validation Accuracy 0.8400099734042553


                                                                                                    

Epoch 17 Loss 0.5524407925605774
Epoch 17 Accuracy 0.8508333333333333




Epoch 17 Validation Loss 0.5552085213204647
Epoch 17 Validation Accuracy 0.8419215425531915


                                                                                                    

Epoch 18 Loss 0.5376805284818014
Epoch 18 Accuracy 0.8528333333333333




Epoch 18 Validation Loss 0.5417860808207634
Epoch 18 Validation Accuracy 0.8433344414893617


                                                                                                    

Epoch 19 Loss 0.5242775911092759
Epoch 19 Accuracy 0.8549583333333334




Epoch 19 Validation Loss 0.5295979624733012
Epoch 19 Validation Accuracy 0.8450797872340425


                                                                                                    

Epoch 20 Loss 0.512054606239001
Epoch 20 Accuracy 0.8567291666666667




Epoch 20 Validation Loss 0.5184851891182839
Epoch 20 Validation Accuracy 0.8470744680851063


                                                                                                    

Epoch 21 Loss 0.5008574284315109
Epoch 21 Accuracy 0.85825




Epoch 21 Validation Loss 0.5083091262173145
Epoch 21 Validation Accuracy 0.8492353723404256


                                                                                                    

Epoch 22 Loss 0.49055897919336955
Epoch 22 Accuracy 0.8595625




Epoch 22 Validation Loss 0.49895732929097847
Epoch 22 Validation Accuracy 0.8501496010638298


                                                                                                    

Epoch 23 Loss 0.4810562068223953
Epoch 23 Accuracy 0.8612708333333333




Epoch 23 Validation Loss 0.49033354048399214
Epoch 23 Validation Accuracy 0.8515625


                                                                                                    

Epoch 24 Loss 0.4722587054570516
Epoch 24 Accuracy 0.8626041666666666




Epoch 24 Validation Loss 0.4823589957457908
Epoch 24 Validation Accuracy 0.8524767287234043


                                                                                                    

Epoch 25 Loss 0.464090389808019
Epoch 25 Accuracy 0.8640833333333333




Epoch 25 Validation Loss 0.4749587041900513
Epoch 25 Validation Accuracy 0.854720744680851


                                                                                                    

Epoch 26 Loss 0.4564857354958852
Epoch 26 Accuracy 0.8651875




Epoch 26 Validation Loss 0.4680790587308559
Epoch 26 Validation Accuracy 0.8561336436170213


                                                                                                    

Epoch 27 Loss 0.4493886033097903
Epoch 27 Accuracy 0.8664791666666667




Epoch 27 Validation Loss 0.46166650158293704
Epoch 27 Validation Accuracy 0.8573803191489362


                                                                                                    

Epoch 28 Loss 0.4427487685084343
Epoch 28 Accuracy 0.8677291666666667




Epoch 28 Validation Loss 0.4556769362472473
Epoch 28 Validation Accuracy 0.8578789893617021


                                                                                                    

Epoch 29 Loss 0.43652374523878096
Epoch 29 Accuracy 0.8692291666666667




Epoch 29 Validation Loss 0.4500698513807134
Epoch 29 Validation Accuracy 0.8592087765957447
Training classifier for exit 3


                                                                                                    

Epoch 0 Loss 1.9723586581548056
Epoch 0 Accuracy 0.9158333333333334




Epoch 0 Validation Loss 1.6404699816348705
Epoch 0 Validation Accuracy 0.9679188829787234


                                                                                                    

Epoch 1 Loss 1.4020823510487874
Epoch 1 Accuracy 0.9667291666666666




Epoch 1 Validation Loss 1.1344434894779896
Epoch 1 Validation Accuracy 0.9715757978723404


                                                                                                    

Epoch 2 Loss 0.9866690288384755
Epoch 2 Accuracy 0.9695625




Epoch 2 Validation Loss 0.7939502168843087
Epoch 2 Validation Accuracy 0.973154920212766


                                                                                                    

Epoch 3 Loss 0.7117592747608821
Epoch 3 Accuracy 0.9713125




Epoch 3 Validation Loss 0.5782765333956861
Epoch 3 Validation Accuracy 0.9733211436170213


                                                                                                    

Epoch 4 Loss 0.5353982052803039
Epoch 4 Accuracy 0.9715416666666666




Epoch 4 Validation Loss 0.4415395083896657
Epoch 4 Validation Accuracy 0.9736535904255319


                                                                                                    

Epoch 5 Loss 0.42072833915551505
Epoch 5 Accuracy 0.9716875




Epoch 5 Validation Loss 0.3522066152159204
Epoch 5 Validation Accuracy 0.9738198138297872


                                                                                                    

Epoch 6 Loss 0.3436890744169553
Epoch 6 Accuracy 0.9721458333333334




Epoch 6 Validation Loss 0.2915331428830928
Epoch 6 Validation Accuracy 0.9740691489361702


                                                                                                    

Epoch 7 Loss 0.28996623289585116
Epoch 7 Accuracy 0.9724375




Epoch 7 Validation Loss 0.24871182853871204
Epoch 7 Validation Accuracy 0.9742353723404256


                                                                                                    

Epoch 8 Loss 0.2511425569852193
Epoch 8 Accuracy 0.9727083333333333




Epoch 8 Validation Loss 0.21741823190228737
Epoch 8 Validation Accuracy 0.9745678191489362


                                                                                                    

Epoch 9 Loss 0.2221895851790905
Epoch 9 Accuracy 0.972875




Epoch 9 Validation Loss 0.1938647095193254
Epoch 9 Validation Accuracy 0.9747340425531915


                                                                                                    

Epoch 10 Loss 0.2000020337899526
Epoch 10 Accuracy 0.973125




Epoch 10 Validation Loss 0.17566663518230966
Epoch 10 Validation Accuracy 0.9748171542553191


                                                                                                    

Epoch 11 Loss 0.18260319767395655
Epoch 11 Accuracy 0.9734583333333333




Epoch 11 Validation Loss 0.16130302121189047
Epoch 11 Validation Accuracy 0.9748171542553191


                                                                                                    

Epoch 12 Loss 0.16868840725223225
Epoch 12 Accuracy 0.9736875




Epoch 12 Validation Loss 0.1497563888972744
Epoch 12 Validation Accuracy 0.9751496010638298


                                                                                                    

Epoch 13 Loss 0.15736605495711167
Epoch 13 Accuracy 0.9739166666666667




Epoch 13 Validation Loss 0.14032551138959032
Epoch 13 Validation Accuracy 0.9753989361702128


                                                                                                    

Epoch 14 Loss 0.14802449347575505
Epoch 14 Accuracy 0.9740625




Epoch 14 Validation Loss 0.13252176829871345
Epoch 14 Validation Accuracy 0.9753158244680851


                                                                                                    

Epoch 15 Loss 0.1402177489300569
Epoch 15 Accuracy 0.9741041666666667




Epoch 15 Validation Loss 0.12598509555484386
Epoch 15 Validation Accuracy 0.9754820478723404


                                                                                                    

Epoch 16 Loss 0.1336230592429638
Epoch 16 Accuracy 0.9741666666666666




Epoch 16 Validation Loss 0.12045238048155257
Epoch 16 Validation Accuracy 0.9753989361702128


                                                                                                    

Epoch 17 Loss 0.1279999733865261
Epoch 17 Accuracy 0.9742916666666667




Epoch 17 Validation Loss 0.11572843891112729
Epoch 17 Validation Accuracy 0.9754820478723404


                                                                                                    

Epoch 18 Loss 0.1231664575835069
Epoch 18 Accuracy 0.9743958333333333




Epoch 18 Validation Loss 0.11166672190611666
Epoch 18 Validation Accuracy 0.9757313829787234


                                                                                                    

Epoch 19 Loss 0.11898114201923211
Epoch 19 Accuracy 0.9744583333333333




Epoch 19 Validation Loss 0.10815071294757914
Epoch 19 Validation Accuracy 0.9756482712765957


                                                                                                    

Epoch 20 Loss 0.11533521266529957
Epoch 20 Accuracy 0.9744166666666667




Epoch 20 Validation Loss 0.10508716454174608
Epoch 20 Validation Accuracy 0.9755651595744681


                                                                                                    

Epoch 21 Loss 0.11214128395169973
Epoch 21 Accuracy 0.9745833333333334




Epoch 21 Validation Loss 0.10240675964729583
Epoch 21 Validation Accuracy 0.9756482712765957


                                                                                                    

Epoch 22 Loss 0.10933112942924102
Epoch 22 Accuracy 0.9746041666666667




Epoch 22 Validation Loss 0.10005116364621419
Epoch 22 Validation Accuracy 0.9754820478723404


                                                                                                    

Epoch 23 Loss 0.1068478057856361
Epoch 23 Accuracy 0.9746458333333333




Epoch 23 Validation Loss 0.09797152953143133
Epoch 23 Validation Accuracy 0.9755651595744681


                                                                                                    

Epoch 24 Loss 0.10464558496077855
Epoch 24 Accuracy 0.9746041666666667




Epoch 24 Validation Loss 0.09613121011631286
Epoch 24 Validation Accuracy 0.9757313829787234


                                                                                                    

Epoch 25 Loss 0.10268690517793098
Epoch 25 Accuracy 0.974625




Epoch 25 Validation Loss 0.09449900700611637
Epoch 25 Validation Accuracy 0.975814494680851


                                                                                                    

Epoch 26 Loss 0.10093917045742273
Epoch 26 Accuracy 0.9747083333333333




Epoch 26 Validation Loss 0.09304462194918318
Epoch 26 Validation Accuracy 0.9756482712765957


                                                                                                    

Epoch 27 Loss 0.09937575236707925
Epoch 27 Accuracy 0.9747291666666666




Epoch 27 Validation Loss 0.09174839679666973
Epoch 27 Validation Accuracy 0.9756482712765957


                                                                                                    

Epoch 28 Loss 0.09797377347573638
Epoch 28 Accuracy 0.9747916666666666




Epoch 28 Validation Loss 0.09058878138819908
Epoch 28 Validation Accuracy 0.9756482712765957


                                                                                                    

Epoch 29 Loss 0.0967143261184295
Epoch 29 Accuracy 0.9748125




Epoch 29 Validation Loss 0.08955240652876649
Epoch 29 Validation Accuracy 0.9757313829787234
Training classifier for exit 4


                                                                                                    

Epoch 0 Loss 0.6858307979106903
Epoch 0 Accuracy 0.9475833333333333




Epoch 0 Validation Loss 0.11509134073523765
Epoch 0 Validation Accuracy 0.9791389627659575


                                                                                                    

Epoch 1 Loss 0.09853727541988094
Epoch 1 Accuracy 0.9758958333333333




Epoch 1 Validation Loss 0.07586528703649627
Epoch 1 Validation Accuracy 0.9790558510638298


                                                                                                    

Epoch 2 Loss 0.08590072102937847
Epoch 2 Accuracy 0.9759583333333334




Epoch 2 Validation Loss 0.07774949907523362
Epoch 2 Validation Accuracy 0.9792220744680851


                                                                                                    

Epoch 3 Loss 0.08921093175187707
Epoch 3 Accuracy 0.9760625




Epoch 3 Validation Loss 0.08283359062439326
Epoch 3 Validation Accuracy 0.9788065159574468


                                                                                                    

Epoch 4 Loss 0.09440996365916605
Epoch 4 Accuracy 0.9760416666666667




Epoch 4 Validation Loss 0.08831178609135021
Epoch 4 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 5 Loss 0.09987924634758383
Epoch 5 Accuracy 0.9761041666666667




Epoch 5 Validation Loss 0.09369561836687586
Epoch 5 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 6 Loss 0.10526919008952487
Epoch 6 Accuracy 0.9761458333333334




Epoch 6 Validation Loss 0.0988671285323359
Epoch 6 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 7 Loss 0.1104759377706893
Epoch 7 Accuracy 0.9760625




Epoch 7 Validation Loss 0.10377987757650037
Epoch 7 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 8 Loss 0.11544496283287299
Epoch 8 Accuracy 0.9761041666666667




Epoch 8 Validation Loss 0.10842297962158517
Epoch 8 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 9 Loss 0.12014537084186547
Epoch 9 Accuracy 0.9760833333333333




Epoch 9 Validation Loss 0.11278034546803804
Epoch 9 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 10 Loss 0.12456792786793085
Epoch 10 Accuracy 0.9761041666666667




Epoch 10 Validation Loss 0.1168691259446008
Epoch 10 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 11 Loss 0.12871263198022642
Epoch 11 Accuracy 0.9761041666666667




Epoch 11 Validation Loss 0.12069158374188578
Epoch 11 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 12 Loss 0.13259530021552685
Epoch 12 Accuracy 0.9761458333333334




Epoch 12 Validation Loss 0.12428164893285708
Epoch 12 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 13 Loss 0.13623704175310802
Epoch 13 Accuracy 0.9761458333333334




Epoch 13 Validation Loss 0.1276464822450383
Epoch 13 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 14 Loss 0.13965481346604428
Epoch 14 Accuracy 0.9762083333333333




Epoch 14 Validation Loss 0.13081787699840372
Epoch 14 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 15 Loss 0.1428716711451246
Epoch 15 Accuracy 0.9762291666666667




Epoch 15 Validation Loss 0.13381218131323
Epoch 15 Validation Accuracy 0.9783909574468085


                                                                                                    

Epoch 16 Loss 0.1459102198343256
Epoch 16 Accuracy 0.97625




Epoch 16 Validation Loss 0.13665647018965718
Epoch 16 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 17 Loss 0.14879261112622347
Epoch 17 Accuracy 0.9763125




Epoch 17 Validation Loss 0.13936266937058012
Epoch 17 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 18 Loss 0.15153240833279596
Epoch 18 Accuracy 0.9763125




Epoch 18 Validation Loss 0.1419400011181734
Epoch 18 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 19 Loss 0.15414560515363124
Epoch 19 Accuracy 0.9763333333333334




Epoch 19 Validation Loss 0.14440686799963426
Epoch 19 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 20 Loss 0.1566469933811968
Epoch 20 Accuracy 0.9762708333333333




Epoch 20 Validation Loss 0.14677202713590562
Epoch 20 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 21 Loss 0.159044077420097
Epoch 21 Accuracy 0.97625




Epoch 21 Validation Loss 0.1490398450063469
Epoch 21 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 22 Loss 0.16134556184040502
Epoch 22 Accuracy 0.97625




Epoch 22 Validation Loss 0.1512211263847937
Epoch 22 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 23 Loss 0.16355581347891016
Epoch 23 Accuracy 0.9762708333333333




Epoch 23 Validation Loss 0.15331372983676242
Epoch 23 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 24 Loss 0.16567906680359706
Epoch 24 Accuracy 0.9762916666666667




Epoch 24 Validation Loss 0.15533257531520706
Epoch 24 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 25 Loss 0.16772359856226832
Epoch 25 Accuracy 0.97625




Epoch 25 Validation Loss 0.1572770863997581
Epoch 25 Validation Accuracy 0.9784740691489362


                                                                                                    

Epoch 26 Loss 0.169695673850255
Epoch 26 Accuracy 0.9762708333333333




Epoch 26 Validation Loss 0.159158671359072
Epoch 26 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 27 Loss 0.17160176172998481
Epoch 27 Accuracy 0.9762291666666667




Epoch 27 Validation Loss 0.16097440746413552
Epoch 27 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 28 Loss 0.17344415002932675
Epoch 28 Accuracy 0.9762083333333333




Epoch 28 Validation Loss 0.16272984960214373
Epoch 28 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 29 Loss 0.17522523855180408
Epoch 29 Accuracy 0.9762291666666667




Epoch 29 Validation Loss 0.16443065780917462
Epoch 29 Validation Accuracy 0.9786402925531915
Training final classifier


                                                                                                    

Epoch 0 Loss 0.27702689278715603
Epoch 0 Accuracy 0.967625




Epoch 0 Validation Loss 0.08051054309149511
Epoch 0 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 1 Loss 0.10432342614745721
Epoch 1 Accuracy 0.9757708333333334




Epoch 1 Validation Loss 0.10349952988036286
Epoch 1 Validation Accuracy 0.9786402925531915


                                                                                                    

Epoch 2 Loss 0.12700780745300774
Epoch 2 Accuracy 0.975125




Epoch 2 Validation Loss 0.12127358233955748
Epoch 2 Validation Accuracy 0.9785571808510638


                                                                                                    

Epoch 3 Loss 0.14143466924691286
Epoch 3 Accuracy 0.9755416666666666




Epoch 3 Validation Loss 0.13341736067700416
Epoch 3 Validation Accuracy 0.9788065159574468


                                                                                                    

Epoch 4 Loss 0.15345141256775605
Epoch 4 Accuracy 0.9753541666666666




Epoch 4 Validation Loss 0.14233391551491384
Epoch 4 Validation Accuracy 0.9788065159574468


                                                                                                    

Epoch 5 Loss 0.16201893004776835
Epoch 5 Accuracy 0.9756041666666667




Epoch 5 Validation Loss 0.14958009263354652
Epoch 5 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 6 Loss 0.1684098675408759
Epoch 6 Accuracy 0.9756875




Epoch 6 Validation Loss 0.15560424435621836
Epoch 6 Validation Accuracy 0.9790558510638298


                                                                                                    

Epoch 7 Loss 0.1726826741698697
Epoch 7 Accuracy 0.9757708333333334




Epoch 7 Validation Loss 0.16064280821704027
Epoch 7 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 8 Loss 0.18072620217768334
Epoch 8 Accuracy 0.9756875




Epoch 8 Validation Loss 0.1651795990991355
Epoch 8 Validation Accuracy 0.9787234042553191


                                                                                                    

Epoch 9 Loss 0.1823356085816978
Epoch 9 Accuracy 0.97575




Epoch 9 Validation Loss 0.16925020467933619
Epoch 9 Validation Accuracy 0.9788065159574468


                                                                                                    

Epoch 10 Loss 0.1877325348559084
Epoch 10 Accuracy 0.975375




Epoch 10 Validation Loss 0.1729978552245313
Epoch 10 Validation Accuracy 0.9788065159574468


                                                                                                    

Epoch 11 Loss 0.1923598044483463
Epoch 11 Accuracy 0.9755




Epoch 11 Validation Loss 0.17585361473326291
Epoch 11 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 12 Loss 0.19399009919040183
Epoch 12 Accuracy 0.9753541666666666




Epoch 12 Validation Loss 0.17886593394538605
Epoch 12 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 13 Loss 0.19726935756908165
Epoch 13 Accuracy 0.9752708333333333




Epoch 13 Validation Loss 0.18171616599848198
Epoch 13 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 14 Loss 0.20106115859871226
Epoch 14 Accuracy 0.9757083333333333




Epoch 14 Validation Loss 0.18455182670510717
Epoch 14 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 15 Loss 0.20281502123289252
Epoch 15 Accuracy 0.9761875




Epoch 15 Validation Loss 0.1871519542385707
Epoch 15 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 16 Loss 0.20461169142979307
Epoch 16 Accuracy 0.9761041666666667




Epoch 16 Validation Loss 0.18916283839708417
Epoch 16 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 17 Loss 0.2087664592222784
Epoch 17 Accuracy 0.9760208333333333




Epoch 17 Validation Loss 0.19168800988332202
Epoch 17 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 18 Loss 0.21095655167140404
Epoch 18 Accuracy 0.975625




Epoch 18 Validation Loss 0.19353933023499187
Epoch 18 Validation Accuracy 0.9790558510638298


                                                                                                    

Epoch 19 Loss 0.2116240621013648
Epoch 19 Accuracy 0.9755833333333334




Epoch 19 Validation Loss 0.19552893550818742
Epoch 19 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 20 Loss 0.2150060554859781
Epoch 20 Accuracy 0.97575




Epoch 20 Validation Loss 0.19750768262438054
Epoch 20 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 21 Loss 0.2140483017793586
Epoch 21 Accuracy 0.9759583333333334




Epoch 21 Validation Loss 0.19923000086237155
Epoch 21 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 22 Loss 0.21804904460468463
Epoch 22 Accuracy 0.9760208333333333




Epoch 22 Validation Loss 0.20085996248264706
Epoch 22 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 23 Loss 0.21931195044122243
Epoch 23 Accuracy 0.9759166666666667




Epoch 23 Validation Loss 0.20234874734410724
Epoch 23 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 24 Loss 0.22028198322250014
Epoch 24 Accuracy 0.9756875




Epoch 24 Validation Loss 0.20371184110142243
Epoch 24 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 25 Loss 0.2230765962978066
Epoch 25 Accuracy 0.9756041666666667




Epoch 25 Validation Loss 0.20528249698982057
Epoch 25 Validation Accuracy 0.9788896276595744


                                                                                                    

Epoch 26 Loss 0.22379022026938886
Epoch 26 Accuracy 0.9759375




Epoch 26 Validation Loss 0.2066712325389107
Epoch 26 Validation Accuracy 0.9789727393617021


                                                                                                    

Epoch 27 Loss 0.22783797233343844
Epoch 27 Accuracy 0.9759583333333334




Epoch 27 Validation Loss 0.20809844856245985
Epoch 27 Validation Accuracy 0.9790558510638298


                                                                                                    

Epoch 28 Loss 0.22592648675628293
Epoch 28 Accuracy 0.9762291666666667




Epoch 28 Validation Loss 0.20956317325193696
Epoch 28 Validation Accuracy 0.9791389627659575


                                                                                                    

Epoch 29 Loss 0.23104527500456587
Epoch 29 Accuracy 0.9756041666666667




Epoch 29 Validation Loss 0.21083713732575526
Epoch 29 Validation Accuracy 0.9788896276595744


In [10]:
# train the exits

#trainer.train_exit_layers(train_dataloader, epoch_count=30, validation_loader=test_dataloader)