In [91]:
# %load_ext autoreload
# %autoreload 2

# %pip install -r requirements.txt

In [92]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

In [93]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [94]:
# pick which model to load
model_name = "vgg_cifar100" # either "resnet" or "vgg_cifar10" or "vgg_cifar100"
num_classes = 100 if model_name == "vgg_cifar100" else 10
model_path = os.path.join("models", model_name)

In [95]:
from DataLoader import CustomDataset

if model_name == "resnet":
    # use the imagenette dataset
    hf_dataset = load_dataset("frgfm/imagenette", '320px')
    hf_dataset = concatenate_datasets(hf_dataset.values())
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

elif model_name == "vgg_cifar10":
    # use the cifar10 dataset
    hf_dataset = load_dataset("cifar10")
    hf_dataset = concatenate_datasets(hf_dataset.values())
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.4865, 0.4409],
                             std=[0.2673, 0.2564, 0.2761])
    ])
elif model_name == "vgg_cifar100":
    # use the cifar100 dataset
    hf_dataset = load_dataset("cifar100")
    hf_dataset = concatenate_datasets(hf_dataset.values())
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
    
torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32 if model_name == "resnet" else 64

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

In [96]:
import ModelLoader

loader = ModelLoader.ModelLoader(model_name, device, alpha=0.6, dataloader=train_dataloader)

# preview the model architecture
model = loader.load_model(num_outputs=num_classes, pretrained=True)
model

Loading EarlyExit VGG11 model architecture...
Adding exits...


Using cache found in /Users/dylanmace/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


Setting model weights...


EarlyExitModel(
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): OptionalExitModule(
        (module): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (exit_gate): Linear(in_features=8192, out_features=1, bias=True)
        (classifier): Linear(in_features=8192, out_features=100, bias=True)
      )
      (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

## Early Exit Model Visualization

In [113]:
import pandas
import numpy as np
  
if model_name == "vgg_cifar100":
    class_names = [
        "apple",
        "aquarium_fish",
        "baby",
        "bear",
        "beaver",
        "bed",
        "bee",
        "beetle",
        "bicycle",
        "bottle",
        "bowl",
        "boy",
        "bridge",
        "bus",
        "butterfly",
        "camel",
        "can",
        "castle",
        "caterpillar",
        "cattle",
        "chair",
        "chimpanzee",
        "clock",
        "cloud",
        "cockroach",
        "couch",
        "crab",
        "crocodile",
        "cup",
        "dinosaur",
        "dolphin",
        "elephant",
        "flatfish",
        "forest",
        "fox",
        "girl",
        "hamster",
        "house",
        "kangaroo",
        "keyboard",
        "lamp",
        "lawn_mower",
        "leopard",
        "lion",
        "lizard",
        "lobster",
        "man",
        "maple_tree",
        "motorcycle",
        "mountain",
        "mouse",
        "mushroom",
        "oak_tree",
        "orange",
        "orchid",
        "otter",
        "palm_tree",
        "pear",
        "pickup_truck",
        "pine_tree",
        "plain",
        "plate",
        "poppy",
        "porcupine",
        "possum",
        "rabbit",
        "raccoon",
        "ray",
        "road",
        "rocket",
        "rose",
        "sea",
        "seal",
        "shark",
        "shrew",
        "skunk",
        "skyscraper",
        "snail",
        "snake",
        "spider",
        "squirrel",
        "streetcar",
        "sunflower",
        "sweet_pepper",
        "table",
        "tank",
        "telephone",
        "television",
        "tiger",
        "tractor",
        "train",
        "trout",
        "tulip",
        "turtle",
        "wardrobe",
        "whale",
        "willow_tree",
        "wolf",
        "woman",
        "worm"
    ]
else:
    class_names = []
    
if model_name == "vgg_cifar10":
    class_idx = 5
elif model_name == "vgg_cifar100":
    class_idx = class_names.index("wolf")
else:
    class_idx = 0
  
    

In [114]:
# get all the images of the class_idx
class_images = []
for i in range(len(torch_dataset)):
    img, label = torch_dataset[i]
    if label == class_idx:
        class_images.append((img, label))



In [115]:
print(f"Found {len(class_images)} images of class {class_names[class_idx]}")
print(f"Shape of first image: {class_images[0][0].shape}")

Found 600 images of class wolf
Shape of first image: torch.Size([3, 32, 32])


In [116]:
# load each image throughout the model
num_exits_taken = {}
misclassified_images = {}
correct_images = {}



for i in range(len(class_images)):
    image, label = class_images[i]
    image = image.reshape(1, *image.shape).to(device)
    
    output_class = model(image)
    exit_idx_taken = torch.tensor(model.num_exits_per_module, device=device).argmax() + 1
    num_exits_taken[exit_idx_taken.item()] = num_exits_taken.get(exit_idx_taken.item(), 0) + 1
    
    # get class name for prediction from model
    class_name = class_names[output_class.argmax().item()]
    
    if output_class.argmax() != class_idx:
        misclassified_images[i] = (output_class.argmax().item(), label, exit_idx_taken.item())
    else:
        correct_images[i] = exit_idx_taken.item()
    
    if i % 100 == 0:
        print(f"Processed {i} images:", num_exits_taken)
        print(f"Misclassified images: {len(misclassified_images)}")
        print(f"Correct images: {len(correct_images)}")

Processed 0 images: {3: 1}
Misclassified images: 1
Correct images: 0
Processed 100 images: {3: 28, 2: 73}
Misclassified images: 72
Correct images: 29
Processed 200 images: {3: 55, 2: 146}
Misclassified images: 142
Correct images: 59
Processed 300 images: {3: 88, 2: 213}
Misclassified images: 215
Correct images: 86
Processed 400 images: {3: 117, 2: 284}
Misclassified images: 291
Correct images: 110
Processed 500 images: {3: 145, 2: 356}
Misclassified images: 371
Correct images: 130


In [117]:
if not os.path.exists("data_vis"):
    os.mkdir("data_vis")
    
if not os.path.exists(os.path.join("data_vis", "misclassified")):
    os.mkdir(os.path.join("data_vis", "misclassified"))
    
if not os.path.exists(os.path.join("data_vis", "correct")):
    os.mkdir(os.path.join("data_vis", "correct"))
    
# clear the folders
for filename in os.listdir(os.path.join("data_vis", "misclassified")):
    os.remove(os.path.join("data_vis", "misclassified", filename))
    
for filename in os.listdir(os.path.join("data_vis", "correct")):
    os.remove(os.path.join("data_vis", "correct", filename))
    
# save the misclassified images
for i, (yhat, y, exit_idx) in misclassified_images.items():
    image = class_images[i][0]
    misclass_label = class_names[yhat]
    torchvision.utils.save_image(image, os.path.join("data_vis", "misclassified", f"{i}-{misclass_label}-idx{exit_idx}.png"))
    
# save the correct images
for (image_idx, exit_idx) in correct_images.items():
    image = class_images[image_idx][0]
    torchvision.utils.save_image(image, os.path.join("data_vis", "correct", f"{image_idx}-idx{exit_idx}.png"))


In [136]:
# iterate through the entire dataloader and save the exit idxs
exit_idx_counts = {}

for i, (x, y) in enumerate(train_dataloader):
    y_hat = model(x.to(device))
    y_hat_classes = y_hat.argmax(dim=1)
    exit_idxs = torch.tensor(model.num_exits_per_module, device=device)
    # concat the exit idxs with batch size - sum of exit idxs
    exit_idxs = torch.cat((exit_idxs, torch.tensor([batch_size]).to(device) - exit_idxs.sum()))
    for (i, exit_idx) in enumerate(exit_idxs):
        exit_idx_counts[i] = exit_idx_counts.get(i, 0) + exit_idx.item()
print(exit_idx_counts)

{0: 0, 1: 31365, 2: 15860, 3: 0, 4: 775, 5: 0}
