In [1]:
%load_ext autoreload
%autoreload 2

%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [4]:
resnet50 = models.resnet50(pretrained=True)
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, 10)

# # freeze all layers
# for param in resnet50.parameters():
#     param.requires_grad = False



In [5]:
from EarlyExitModel import EarlyExitModel

model = EarlyExitModel(resnet50, 10, device)
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
     

In [6]:
model.clear_exits()
exit_layers = [model.add_exit(layer) for layer in ('layer1', 'layer2', 'layer3')]
model.to(device)  # Move the model to the selected device
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): OptionalExitModule(
      (module): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(i

In [7]:
from DataLoader import CustomDataset

hf_dataset = load_dataset("frgfm/imagenette", '320px')
hf_dataset = concatenate_datasets(hf_dataset.values())

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

## Early Exit Model Training

In [8]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device)

trainer.train(train_dataloader, epoch_count=1000, validation_loader=test_dataloader)


Training early exit layer 1
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 5.269671457205246
Epoch 0 Accuracy 0.21360607675906182




Epoch 0 Validation Loss 3.679361085096995
Epoch 0 Validation Accuracy 0.33154085497835495
Beginning epoch 1


                                                                                                    

Epoch 1 Loss 4.3230493844445075
Epoch 1 Accuracy 0.30283848614072495




Epoch 1 Validation Loss 5.549984863826206
Epoch 1 Validation Accuracy 0.3540313852813853
Beginning epoch 2


                                                                                                    

Epoch 2 Loss 4.010286066781229
Epoch 2 Accuracy 0.36082089552238805




Epoch 2 Validation Loss 4.988745953355517
Epoch 2 Validation Accuracy 0.3378652597402597
Beginning epoch 3


                                                                                                    

Epoch 3 Loss 3.797397490401766
Epoch 3 Accuracy 0.39949360341151385




Epoch 3 Validation Loss 3.937353843734378
Epoch 3 Validation Accuracy 0.4096320346320346
Beginning epoch 4


                                                                                                    

Epoch 4 Loss 3.6505854069296992
Epoch 4 Accuracy 0.42833155650319826




Epoch 4 Validation Loss 4.465637241091047
Epoch 4 Validation Accuracy 0.39495400432900435
Beginning epoch 5


                                                                                                    

Epoch 5 Loss 3.5978319182324765
Epoch 5 Accuracy 0.44048507462686565




Epoch 5 Validation Loss 3.845569196201506
Epoch 5 Validation Accuracy 0.42711715367965364
Beginning epoch 6


                                                                                                    

Epoch 6 Loss 3.5481372235426263
Epoch 6 Accuracy 0.4658182302771855




Epoch 6 Validation Loss 4.455735743045807
Epoch 6 Validation Accuracy 0.40983495670995673
Beginning epoch 7


                                                                                                    

Epoch 7 Loss 3.4853193464563854
Epoch 7 Accuracy 0.4697494669509594




Epoch 7 Validation Loss 4.9502956455662135
Epoch 7 Validation Accuracy 0.3961038961038961
Beginning epoch 8


                                                                                                    

Epoch 8 Loss 3.438027733119566
Epoch 8 Accuracy 0.4916311300639659




Epoch 8 Validation Loss 4.203825811545054
Epoch 8 Validation Accuracy 0.4433170995670996
Beginning epoch 9


                                                                                                    

Epoch 9 Loss 3.394413235294285
Epoch 9 Accuracy 0.5037713219616204




Epoch 9 Validation Loss 5.002994767257145
Epoch 9 Validation Accuracy 0.4116950757575758
Beginning epoch 10


                                                                                                    

Epoch 10 Loss 3.36383701420542
Epoch 10 Accuracy 0.5099413646055437




Epoch 10 Validation Loss 5.147018923645928
Epoch 10 Validation Accuracy 0.39870806277056275
Beginning epoch 11


                                                                                                    

Epoch 11 Loss 3.330303413120668
Epoch 11 Accuracy 0.5222681236673774




Epoch 11 Validation Loss 4.323728209450131
Epoch 11 Validation Accuracy 0.4555600649350649
Beginning epoch 12


                                                                                                    

Epoch 12 Loss 3.294378299143777
Epoch 12 Accuracy 0.5384461620469083




Epoch 12 Validation Loss 4.674117497035435
Epoch 12 Validation Accuracy 0.4325284090909091
Beginning epoch 13


                                                                                                    

Epoch 13 Loss 3.261983233779224
Epoch 13 Accuracy 0.5388326226012794




Epoch 13 Validation Loss 3.9208987043017434
Epoch 13 Validation Accuracy 0.48254870129870125
Beginning epoch 14


                                                                                                    

Epoch 14 Loss 3.218698866331755
Epoch 14 Accuracy 0.5519856076759062




Epoch 14 Validation Loss 4.241234606220608
Epoch 14 Validation Accuracy 0.46577380952380953
Beginning epoch 15


                                                                                                    

Epoch 15 Loss 3.173738477835015
Epoch 15 Accuracy 0.5546775053304904




Epoch 15 Validation Loss 5.019679429985228
Epoch 15 Validation Accuracy 0.41710633116883117
Beginning epoch 16


                                                                                                    

Epoch 16 Loss 3.1295132693959706
Epoch 16 Accuracy 0.5642857142857143




Epoch 16 Validation Loss 4.624407881782169
Epoch 16 Validation Accuracy 0.4541057900432901
Beginning epoch 17


                                                                                                    

Epoch 17 Loss 3.086501487867156
Epoch 17 Accuracy 0.5705490405117272




Epoch 17 Validation Loss 4.8151290984380815
Epoch 17 Validation Accuracy 0.45356466450216454
Beginning epoch 18


                                                                                                    

Epoch 18 Loss 3.0096102287520226
Epoch 18 Accuracy 0.5743203624733475




Epoch 18 Validation Loss 4.081027646859487
Epoch 18 Validation Accuracy 0.48867018398268397
Beginning epoch 19


                                                                                                    

Epoch 19 Loss 2.9602113695286993
Epoch 19 Accuracy 0.5849280383795309




Epoch 19 Validation Loss 3.9793059584640322
Epoch 19 Validation Accuracy 0.5046672077922078
Beginning epoch 20


                                                                                                    

Epoch 20 Loss 2.9066692964354557
Epoch 20 Accuracy 0.5904317697228145




Epoch 20 Validation Loss 4.03321107228597
Epoch 20 Validation Accuracy 0.5139339826839827
Beginning epoch 21


                                                                                                    

Epoch 21 Loss 2.8521133230693305
Epoch 21 Accuracy 0.5909781449893391




Epoch 21 Validation Loss 4.185628110454196
Epoch 21 Validation Accuracy 0.5056141774891775
Beginning epoch 22


                                                                                                    

Epoch 22 Loss 2.8638228565899295
Epoch 22 Accuracy 0.5998667377398721




Epoch 22 Validation Loss 4.350318429015932
Epoch 22 Validation Accuracy 0.5009469696969697
Beginning epoch 23


                                                                                                    

Epoch 23 Loss 2.815572405928996
Epoch 23 Accuracy 0.6091284648187634




Epoch 23 Validation Loss 4.428059120972951
Epoch 23 Validation Accuracy 0.5022659632034632
Beginning epoch 24


                                                                                                    

Epoch 24 Loss 2.8630457986646625
Epoch 24 Accuracy 0.6098480810234541




Epoch 24 Validation Loss 4.655508589176905
Epoch 24 Validation Accuracy 0.47270698051948057
Beginning epoch 25


                                                                                                    

Epoch 25 Loss 2.8230205512758513
Epoch 25 Accuracy 0.615271855010661




Epoch 25 Validation Loss 4.243175496657689
Epoch 25 Validation Accuracy 0.5297280844155845
Beginning epoch 26


                                                                                                    

Epoch 26 Loss 2.7916953414233765
Epoch 26 Accuracy 0.6250799573560768




Epoch 26 Validation Loss 4.392370825722104
Epoch 26 Validation Accuracy 0.5183982683982684
Beginning epoch 27


                                                                                                    

Epoch 27 Loss 2.8073734176692677
Epoch 27 Accuracy 0.6259061833688699




Epoch 27 Validation Loss 4.637024152846563
Epoch 27 Validation Accuracy 0.5013189935064934
Beginning epoch 28


                                                                                                    

Epoch 28 Loss 2.819385631048857
Epoch 28 Accuracy 0.625453091684435




Epoch 28 Validation Loss 4.761799273036775
Epoch 28 Validation Accuracy 0.5071022727272727
Beginning epoch 29


                                                                                                    

Epoch 29 Loss 2.7841823444437623
Epoch 29 Accuracy 0.6299173773987207




Epoch 29 Validation Loss 4.605806912694659
Epoch 29 Validation Accuracy 0.53125
Beginning epoch 30


                                                                                                    

Epoch 30 Loss 2.7725836272559947
Epoch 30 Accuracy 0.635047974413646




Epoch 30 Validation Loss 5.547536401521592
Epoch 30 Validation Accuracy 0.48143262987012986
Beginning epoch 31


                                                                                                    

Epoch 31 Loss 2.7399021089966618
Epoch 31 Accuracy 0.6402851812366738




Epoch 31 Validation Loss 5.279683293331237
Epoch 31 Validation Accuracy 0.5002367424242424
Beginning epoch 32


                                                                                                    

Epoch 32 Loss 2.703191670019235
Epoch 32 Accuracy 0.6470815565031982




Epoch 32 Validation Loss 4.776997245493389
Epoch 32 Validation Accuracy 0.5202922077922078
Beginning epoch 33


                                                                                                    

Epoch 33 Loss 2.6982530744218116
Epoch 33 Accuracy 0.65295842217484




Epoch 33 Validation Loss 4.880500972270966
Epoch 33 Validation Accuracy 0.5290516774891775
Beginning epoch 34


                                                                                                    

Epoch 34 Loss 2.689258317093351
Epoch 34 Accuracy 0.6558368869936034




Epoch 34 Validation Loss 4.9644708931446075
Epoch 34 Validation Accuracy 0.5301339285714286
Beginning epoch 35


                                                                                                    

Epoch 35 Loss 2.6712712585036433
Epoch 35 Accuracy 0.658821961620469




Epoch 35 Validation Loss 4.674476260230655
Epoch 35 Validation Accuracy 0.5349702380952381
Beginning epoch 36


                                                                                                    

Epoch 36 Loss 2.618691649009932
Epoch 36 Accuracy 0.6678971215351812




Epoch 36 Validation Loss 4.8136121063005355
Epoch 36 Validation Accuracy 0.5252976190476191
Beginning epoch 37


                                                                                                    

Epoch 37 Loss 2.6121639309534386
Epoch 37 Accuracy 0.665724946695096




Epoch 37 Validation Loss 4.935679977848416
Epoch 37 Validation Accuracy 0.5310808982683982
Beginning epoch 38


                                                                                                    

Epoch 38 Loss 2.62907851151566
Epoch 38 Accuracy 0.6634994669509594




Epoch 38 Validation Loss 5.637198114678974
Epoch 38 Validation Accuracy 0.4975987554112554
Beginning epoch 39


                                                                                                    

Epoch 39 Loss 2.5631051323307092
Epoch 39 Accuracy 0.6717217484008529




Epoch 39 Validation Loss 5.05754345087778
Epoch 39 Validation Accuracy 0.5180262445887446
Beginning epoch 40


                                                                                                    

Epoch 40 Loss 2.5801046952382842
Epoch 40 Accuracy 0.6713619402985075




Epoch 40 Validation Loss 4.9134140199139
Epoch 40 Validation Accuracy 0.5098417207792207
Beginning epoch 41


                                                                                                    

Epoch 41 Loss 2.5561052868170524
Epoch 41 Accuracy 0.6761060767590619




Epoch 41 Validation Loss 5.478901838972455
Epoch 41 Validation Accuracy 0.509469696969697
Beginning epoch 42


                                                                                                    

Epoch 42 Loss 2.535421756488174
Epoch 42 Accuracy 0.681116737739872




Epoch 42 Validation Loss 5.5160343533470515
Epoch 42 Validation Accuracy 0.500202922077922
Beginning epoch 43


                                                                                                    

Epoch 43 Loss 2.5432091710282796
Epoch 43 Accuracy 0.6818097014925373




Epoch 43 Validation Loss 5.630595113549914
Epoch 43 Validation Accuracy 0.5181953463203464
Beginning epoch 44


                                                                                                    

Epoch 44 Loss 2.5249993243324225
Epoch 44 Accuracy 0.6842217484008529




Epoch 44 Validation Loss 5.656682384865625
Epoch 44 Validation Accuracy 0.5089285714285714
Beginning epoch 45


                                                                                                    

Epoch 45 Loss 2.5408529057431575
Epoch 45 Accuracy 0.6846082089552239




Epoch 45 Validation Loss 5.497327188650767
Epoch 45 Validation Accuracy 0.5174851190476191
Beginning epoch 46


                                                                                                    

Epoch 46 Loss 2.5610182739015834
Epoch 46 Accuracy 0.6869269722814499




Epoch 46 Validation Loss 5.331137005771909
Epoch 46 Validation Accuracy 0.5191423160173161
Beginning epoch 47


                                                                                                    

Epoch 47 Loss 2.5078734445038124
Epoch 47 Accuracy 0.6937366737739872




Epoch 47 Validation Loss 5.907565114044008
Epoch 47 Validation Accuracy 0.4992559523809524
Beginning epoch 48


                                                                                                    

Epoch 48 Loss 2.5218736174057668
Epoch 48 Accuracy 0.6900186567164179




Epoch 48 Validation Loss 5.768491797503971
Epoch 48 Validation Accuracy 0.5042613636363636
Beginning epoch 49


                                                                                                    

Epoch 49 Loss 2.4767467854405516
Epoch 49 Accuracy 0.6972014925373134




Epoch 49 Validation Loss 6.447691409360795
Epoch 49 Validation Accuracy 0.4860998376623377
Beginning epoch 50


                                                                                                    

Epoch 50 Loss 2.4600703458732633
Epoch 50 Accuracy 0.700839552238806




Epoch 50 Validation Loss 4.647541494596572
Epoch 50 Validation Accuracy 0.5440679112554113
Beginning epoch 51


                                                                                                    

Epoch 51 Loss 2.4468676543057852
Epoch 51 Accuracy 0.7017857142857142




Epoch 51 Validation Loss 5.267027108442216
Epoch 51 Validation Accuracy 0.5197172619047619
Beginning epoch 52


                                                                                                    

Epoch 52 Loss 2.4670442375674178
Epoch 52 Accuracy 0.7036513859275053




Epoch 52 Validation Loss 5.297160531793322
Epoch 52 Validation Accuracy 0.5238095238095238
Beginning epoch 53


                                                                                                    

Epoch 53 Loss 2.4296076578435613
Epoch 53 Accuracy 0.7050506396588486




Epoch 53 Validation Loss 5.484219280027208
Epoch 53 Validation Accuracy 0.5119385822510822
Beginning epoch 54


                                                                                                    

Epoch 54 Loss 2.403322812559
Epoch 54 Accuracy 0.7053171641791045




Epoch 54 Validation Loss 5.202543278535207
Epoch 54 Validation Accuracy 0.5260416666666666
Beginning epoch 55


                                                                                                    

Epoch 55 Loss 2.3729796438964446
Epoch 55 Accuracy 0.7131529850746269




Epoch 55 Validation Loss 5.2336942326454885
Epoch 55 Validation Accuracy 0.5122767857142857
Beginning epoch 56


                                                                                                    

Epoch 56 Loss 2.3881801852539404
Epoch 56 Accuracy 0.7132462686567164




Epoch 56 Validation Loss 5.4487760705607275
Epoch 56 Validation Accuracy 0.5227272727272727
Beginning epoch 57


                                                                                                    

Epoch 57 Loss 2.3864215507880964
Epoch 57 Accuracy 0.7180970149253731




Epoch 57 Validation Loss 5.307838720934732
Epoch 57 Validation Accuracy 0.5308779761904762
Beginning epoch 58


                                                                                                    

Epoch 58 Loss 2.353934231667376
Epoch 58 Accuracy 0.7188432835820896




Epoch 58 Validation Loss 5.009374079250154
Epoch 58 Validation Accuracy 0.5310808982683982
Beginning epoch 59


                                                                                                    

Epoch 59 Loss 2.371256707453016
Epoch 59 Accuracy 0.7184834754797441




Epoch 59 Validation Loss 5.277942268621354
Epoch 59 Validation Accuracy 0.5340232683982684
Beginning epoch 60


                                                                                                    

Epoch 60 Loss 2.3669310823305327
Epoch 60 Accuracy 0.7191231343283582




Epoch 60 Validation Loss 4.861120690902074
Epoch 60 Validation Accuracy 0.5570549242424242
Beginning epoch 61


                                                                                                    

Epoch 61 Loss 2.40283000478104
Epoch 61 Accuracy 0.719776119402985




Epoch 61 Validation Loss 4.943558079855783
Epoch 61 Validation Accuracy 0.5546536796536797
Beginning epoch 62


                                                                                                    

Epoch 62 Loss 2.3997847378253936
Epoch 62 Accuracy 0.7239872068230276




Epoch 62 Validation Loss 5.342175722122192
Epoch 62 Validation Accuracy 0.5474161255411255
Beginning epoch 63


                                                                                                    

Epoch 63 Loss 2.370499412960081
Epoch 63 Accuracy 0.725453091684435




Epoch 63 Validation Loss 5.097795103277479
Epoch 63 Validation Accuracy 0.5459280303030303
Beginning epoch 64


                                                                                                    

Epoch 64 Loss 2.374289427097164
Epoch 64 Accuracy 0.7264925373134329




Epoch 64 Validation Loss 5.009346101965223
Epoch 64 Validation Accuracy 0.5390286796536797
Model is overfitting, stopping early
Training early exit layer 2
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 5.6759234488900026
Epoch 0 Accuracy 0.3869802771855011




Epoch 0 Validation Loss 4.922900461015248
Epoch 0 Validation Accuracy 0.43787202380952384
Beginning epoch 1


                                                                                                    

Epoch 1 Loss 3.3315434470105525
Epoch 1 Accuracy 0.5862473347547974




Epoch 1 Validation Loss 4.917433744385129
Epoch 1 Validation Accuracy 0.47673160173160173
Beginning epoch 2


                                                                                                    

Epoch 2 Loss 2.6701572878147237
Epoch 2 Accuracy 0.6627132196162048




Epoch 2 Validation Loss 4.243288338184357
Epoch 2 Validation Accuracy 0.5524215367965368
Beginning epoch 3


                                                                                                    

Epoch 3 Loss 2.27149970011035
Epoch 3 Accuracy 0.7176305970149254




Epoch 3 Validation Loss 3.30051957425617
Epoch 3 Validation Accuracy 0.6451907467532467
Beginning epoch 4


                                                                                                    

Epoch 4 Loss 1.988745261122709
Epoch 4 Accuracy 0.7552638592750534




Epoch 4 Validation Loss 3.4752173054785955
Epoch 4 Validation Accuracy 0.6330830627705628
Beginning epoch 5


                                                                                                    

Epoch 5 Loss 1.7813926560971052
Epoch 5 Accuracy 0.7838086353944562




Epoch 5 Validation Loss 3.2800772076561335
Epoch 5 Validation Accuracy 0.6581777597402597
Beginning epoch 6


                                                                                                    

Epoch 6 Loss 1.5740204139265106
Epoch 6 Accuracy 0.8091151385927505




Epoch 6 Validation Loss 4.245664085660662
Epoch 6 Validation Accuracy 0.6136025432900433
Beginning epoch 7


                                                                                                    

Epoch 7 Loss 1.3981255888963229
Epoch 7 Accuracy 0.8345682302771855




Epoch 7 Validation Loss 3.5560733789489385
Epoch 7 Validation Accuracy 0.6513460497835497
Beginning epoch 8


                                                                                                    

Epoch 8 Loss 1.2884859936754913
Epoch 8 Accuracy 0.8453224946695096




Epoch 8 Validation Loss 3.4850734784489585
Epoch 8 Validation Accuracy 0.6650771103896104
Beginning epoch 9


                                                                                                    

Epoch 9 Loss 1.1849253908454773
Epoch 9 Accuracy 0.8625799573560767




Epoch 9 Validation Loss 3.804074662072318
Epoch 9 Validation Accuracy 0.6557765151515151
Beginning epoch 10


                                                                                                    

Epoch 10 Loss 1.09158903235317
Epoch 10 Accuracy 0.8679904051172708




Epoch 10 Validation Loss 3.6884664382253374
Epoch 10 Validation Accuracy 0.6546604437229437
Beginning epoch 11


                                                                                                    

Epoch 11 Loss 1.0654820995555179
Epoch 11 Accuracy 0.8791844349680171




Epoch 11 Validation Loss 3.5864873351085755
Epoch 11 Validation Accuracy 0.6782670454545455
Beginning epoch 12


                                                                                                    

Epoch 12 Loss 1.0432916941557235
Epoch 12 Accuracy 0.8834754797441365




In [None]:
# import time

# # TODO: replace this with a transfer learned resnet model

# original_model = models.resnet50(pretrained=True)

# nonEarlyExitModel = EarlyExitModel(original_model, 1000, device)
# nonEarlyExitModel.to(device)
# nonModelTrainer = ModelTrainer(nonEarlyExitModel, device)

# # validate the model
# print("Validating original Resnet model")
# start = time.time()
# loss, acc, exits = nonModelTrainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Validation time: {end - start}")
# print("=====================================================")

# # validate the new early exit model
# print("Validating new ResnetEE model")
# start = time.time()
# loss, acc = trainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Average exit index: {exits}")
# print(f"Validation time: {end - start}")

Validating original Resnet model


KeyboardInterrupt: 