In [1]:
%load_ext autoreload
%autoreload 2

%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [4]:
resnet50 = models.resnet50(pretrained=True)
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, 10)

# # freeze all layers
# for param in resnet50.parameters():
#     param.requires_grad = False



In [5]:
from EarlyExitModel import EarlyExitModel

model = EarlyExitModel(resnet50, 10, device)
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
     

In [6]:
model.clear_exits()
exit_layers = [model.add_exit(layer) for layer in ('layer1', 'layer2', 'layer3')]
model.to(device)  # Move the model to the selected device
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): OptionalExitModule(
      (module): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(i

In [7]:
from DataLoader import CustomDataset

hf_dataset = load_dataset("frgfm/imagenette", '320px')
hf_dataset = concatenate_datasets(hf_dataset.values())

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

## Early Exit Model Training

In [8]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device)

trainer.train(train_dataloader, epoch_count=1000, validation_loader=test_dataloader)


Training early exit layer 1
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 5.221151706353942
Epoch 0 Accuracy 0.22297441364605544




Epoch 0 Validation Loss 5.11471977971849
Epoch 0 Validation Accuracy 0.19970914502164505
Beginning epoch 1


                                                                                                    

Epoch 1 Loss 4.337191618022634
Epoch 1 Accuracy 0.3115138592750533




Epoch 1 Validation Loss 4.671873507045564
Epoch 1 Validation Accuracy 0.26663961038961037
Beginning epoch 2


                                                                                                    

Epoch 2 Loss 3.970699915245398
Epoch 2 Accuracy 0.3574893390191898




Epoch 2 Validation Loss 3.91069191410428
Epoch 2 Validation Accuracy 0.386228354978355
Beginning epoch 3


                                                                                                    

Epoch 3 Loss 3.731722704688115
Epoch 3 Accuracy 0.3967217484008529




Epoch 3 Validation Loss 3.796754096235548
Epoch 3 Validation Accuracy 0.3960362554112554
Beginning epoch 4


                                                                                                    

Epoch 4 Loss 3.565681627971023
Epoch 4 Accuracy 0.42324093816631125




Epoch 4 Validation Loss 4.07407511983599
Epoch 4 Validation Accuracy 0.3625541125541125
Beginning epoch 5


                                                                                                    

Epoch 5 Loss 3.4928495777187063
Epoch 5 Accuracy 0.440658315565032




Epoch 5 Validation Loss 4.461827332065219
Epoch 5 Validation Accuracy 0.4371956168831169
Beginning epoch 6


                                                                                                    

Epoch 6 Loss 3.36140046368784
Epoch 6 Accuracy 0.46500533049040516




Epoch 6 Validation Loss 4.45181842928841
Epoch 6 Validation Accuracy 0.38704004329004327
Beginning epoch 7


                                                                                                    

Epoch 7 Loss 3.305938783688332
Epoch 7 Accuracy 0.4773587420042644




Epoch 7 Validation Loss 3.4546586062226976
Epoch 7 Validation Accuracy 0.5078463203463203
Beginning epoch 8


                                                                                                    

Epoch 8 Loss 3.2155497426417337
Epoch 8 Accuracy 0.49661513859275047




Epoch 8 Validation Loss 4.491165615263439
Epoch 8 Validation Accuracy 0.43215638528138534
Beginning epoch 9


                                                                                                    

Epoch 9 Loss 3.171648998758686
Epoch 9 Accuracy 0.5109408315565033




Epoch 9 Validation Loss 4.333704210463024
Epoch 9 Validation Accuracy 0.4552218614718615
Beginning epoch 10


                                                                                                    

Epoch 10 Loss 3.1563932376121406
Epoch 10 Accuracy 0.5220149253731343




Epoch 10 Validation Loss 4.866016495795477
Epoch 10 Validation Accuracy 0.4044575216450217
Beginning epoch 11


                                                                                                    

Epoch 11 Loss 3.125861639051295
Epoch 11 Accuracy 0.5271455223880597




Epoch 11 Validation Loss 5.554895980017526
Epoch 11 Validation Accuracy 0.4005681818181818
Model is overfitting, stopping early
Training early exit layer 2
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 3.5997576635275315
Epoch 0 Accuracy 0.5137393390191898




Epoch 0 Validation Loss 4.248706677130291
Epoch 0 Validation Accuracy 0.4938785173160173
Beginning epoch 1


                                                                                                    

Epoch 1 Loss 3.2666477238954004
Epoch 1 Accuracy 0.5368736673773987




Epoch 1 Validation Loss 5.368245641390483
Epoch 1 Validation Accuracy 0.44341856060606066
Beginning epoch 2


                                                                                                    

Epoch 2 Loss 3.106726313704875
Epoch 2 Accuracy 0.5547841151385928




Epoch 2 Validation Loss 5.002213804494767
Epoch 2 Validation Accuracy 0.46171536796536794
Beginning epoch 3


                                                                                                    

Epoch 3 Loss 2.996122082667564
Epoch 3 Accuracy 0.5696828358208955




Epoch 3 Validation Loss 5.304633918262663
Epoch 3 Validation Accuracy 0.44439935064935066
Beginning epoch 4


                                                                                                    

Epoch 4 Loss 2.8924140654393096
Epoch 4 Accuracy 0.5853944562899787




Epoch 4 Validation Loss 6.120252140930721
Epoch 4 Validation Accuracy 0.4225175865800866
Beginning epoch 5


                                                                                                    

Epoch 5 Loss 2.897205625719099
Epoch 5 Accuracy 0.5943230277185501




Epoch 5 Validation Loss 5.182737001350948
Epoch 5 Validation Accuracy 0.4670589826839827
Beginning epoch 6


                                                                                                    

Epoch 6 Loss 2.8420655825244845
Epoch 6 Accuracy 0.6038512793176972




Epoch 6 Validation Loss 5.984544288544428
Epoch 6 Validation Accuracy 0.4310403138528139
Beginning epoch 7


                                                                                                    

Epoch 7 Loss 2.8458057768309293
Epoch 7 Accuracy 0.6094482942430705




Epoch 7 Validation Loss 5.053296821457999
Epoch 7 Validation Accuracy 0.4869791666666667
Beginning epoch 8


                                                                                                    

Epoch 8 Loss 2.8401395771040847
Epoch 8 Accuracy 0.6097414712153518




Epoch 8 Validation Loss 5.94535261676425
Epoch 8 Validation Accuracy 0.453125
Beginning epoch 9


                                                                                                    

Epoch 9 Loss 2.8168668202499845
Epoch 9 Accuracy 0.6234408315565032




Epoch 9 Validation Loss 5.309238408293043
Epoch 9 Validation Accuracy 0.4672957251082251
Beginning epoch 10


                                                                                                    

Epoch 10 Loss 2.7798048869887393
Epoch 10 Accuracy 0.6211087420042645




Epoch 10 Validation Loss 4.676582001504444
Epoch 10 Validation Accuracy 0.5104166666666666
Beginning epoch 11


                                                                                                    

Epoch 11 Loss 2.7307468547749876
Epoch 11 Accuracy 0.6305303837953092




Epoch 11 Validation Loss 4.9877222435815
Epoch 11 Validation Accuracy 0.479876893939394
Beginning epoch 12


                                                                                                    

Epoch 12 Loss 2.7370634575388326
Epoch 12 Accuracy 0.6303171641791044




Epoch 12 Validation Loss 4.891923343851452
Epoch 12 Validation Accuracy 0.5048363095238095
Beginning epoch 13


                                                                                                    

Epoch 13 Loss 2.712978790589233
Epoch 13 Accuracy 0.6394856076759062




Epoch 13 Validation Loss 5.534655500025976
Epoch 13 Validation Accuracy 0.4720982142857143
Beginning epoch 14


                                                                                                    

Epoch 14 Loss 2.700586729441116
Epoch 14 Accuracy 0.6435900852878466




Epoch 14 Validation Loss 5.856686137971424
Epoch 14 Validation Accuracy 0.46486066017316013
Model is overfitting, stopping early
Training early exit layer 3
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 2.7541374215439185
Epoch 0 Accuracy 0.6327691897654585




Epoch 0 Validation Loss 6.1579809699739725
Epoch 0 Validation Accuracy 0.46820887445887444
Beginning epoch 1


                                                                                                    

Epoch 1 Loss 2.73915738347751
Epoch 1 Accuracy 0.6363805970149253




Epoch 1 Validation Loss 6.583019253753481
Epoch 1 Validation Accuracy 0.43841314935064934
Beginning epoch 2


                                                                                                    

Epoch 2 Loss 2.681108968204527
Epoch 2 Accuracy 0.6468283582089552




Epoch 2 Validation Loss 7.356714663051424
Epoch 2 Validation Accuracy 0.4211647727272727
Beginning epoch 3


                                                                                                    

Epoch 3 Loss 2.712539380699841
Epoch 3 Accuracy 0.654117803837953




In [None]:
# import time

# # TODO: replace this with a transfer learned resnet model

# original_model = models.resnet50(pretrained=True)

# nonEarlyExitModel = EarlyExitModel(original_model, 1000, device)
# nonEarlyExitModel.to(device)
# nonModelTrainer = ModelTrainer(nonEarlyExitModel, device)

# # validate the model
# print("Validating original Resnet model")
# start = time.time()
# loss, acc, exits = nonModelTrainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Validation time: {end - start}")
# print("=====================================================")

# # validate the new early exit model
# print("Validating new ResnetEE model")
# start = time.time()
# loss, acc = trainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Average exit index: {exits}")
# print(f"Validation time: {end - start}")

Validating original Resnet model


KeyboardInterrupt: 