In [None]:
%load_ext autoreload
%autoreload 2

%pip install -r requirements.txt

In [8]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Note: you may need to restart the kernel to use updated packages.


In [9]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [10]:
resnet50 = models.resnet50(pretrained=True)
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, 10)

# # freeze all layers
# for param in resnet50.parameters():
#     param.requires_grad = False

In [11]:
from EarlyExitModel import EarlyExitModel

model = EarlyExitModel(resnet50, 10, device)
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
     

In [12]:
model.clear_exits()
exit_layers = [model.add_exit(layer) for layer in ('layer1', 'layer2', 'layer3')]
model.to(device)  # Move the model to the selected device
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): OptionalExitModule(
      (module): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(i

In [13]:
from DataLoader import CustomDataset

hf_dataset = load_dataset("frgfm/imagenette", '320px')
hf_dataset = concatenate_datasets(hf_dataset.values())

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

## Early Exit Model Training

In [14]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device)

trainer.train(train_dataloader, epoch_count=1, validation_loader=test_dataloader)


Training early exit layer 1
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 27.511475364485783
Epoch 0 Accuracy 0.1990405117270789




Epoch 0 Validation Loss 34.58641808373587
Epoch 0 Validation Accuracy 0.27702245670995673
Training early exit layer 2
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 12.458656603898575
Epoch 0 Accuracy 0.36239339019189765




Epoch 0 Validation Loss 12.599169223081498
Epoch 0 Validation Accuracy 0.3824742965367965
Training early exit layer 3
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 4.80787136305624
Epoch 0 Accuracy 0.4085554371002132




Epoch 0 Validation Loss 8.705436788854145
Epoch 0 Validation Accuracy 0.41352137445887444
Beginning epoch 0 on final classifier head


                                                                                                    

Epoch 0 Loss 1.681870292193854
Epoch 0 Accuracy 0.43788646055437097




Epoch 0 Validation Loss 0.6867687815711612
Epoch 0 Validation Accuracy 0.5429518398268398
Beginning epoch 0 with no forced exits


                                                                                                    

Epoch 0 Loss 600.6875343038075
Epoch 0 Accuracy 0.11631130063965885




Epoch 0 Validation Loss 1012.1820882161459
Epoch 0 Validation Accuracy 0.10545183982683982


In [15]:
import time

# TODO: replace this with a transfer learned resnet model

original_model = models.resnet50(pretrained=True)

nonEarlyExitModel = EarlyExitModel(original_model, 1000, device)
nonEarlyExitModel.to(device)
nonModelTrainer = ModelTrainer(nonEarlyExitModel, device)

# validate the model
print("Validating original Resnet model")
start = time.time()
loss, acc, exits = nonModelTrainer.validate(test_dataloader)
end = time.time()
print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
print(f"Validation time: {end - start}")
print("=====================================================")

# validate the new early exit model
print("Validating new ResnetEE model")
start = time.time()
loss, acc = trainer.validate(test_dataloader)
end = time.time()
print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
print(f"Average exit index: {exits}")
print(f"Validation time: {end - start}")

Validating original Resnet model


KeyboardInterrupt: 