In [15]:
# %load_ext autoreload
# %autoreload 2

# %pip install -r requirements.txt
import warnings

# ignore FutureWarning
warnings.resetwarnings()
warnings.simplefilter(action='ignore', category=FutureWarning)

In [16]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

In [17]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [18]:
# pick which model to load
model_name = "vgg_cifar10" # either "resnet" or "vgg_cifar10" or "vgg_cifar100"
num_classes = 100 if "cifar100" in model_name else 10
model_path = os.path.join("models", model_name)

In [19]:
from DataLoader import *

def create_dataloader(test_batch_size=None):
    dataset_name = "cifar100" if "cifar100" in model_name else "cifar10" if "cifar10" in model_name else "imagenette"
    return CustomDataLoader().get_dataset(dataset_name, batch_size=None, test_batch_size=test_batch_size)

## Early Exit Model Visualization

In [20]:
import time
import numpy as np
from OptionalExitModule import TrainingState
import datetime

def collect_val_data(num_batches=50, num_runs=4, should_print=True):
    model.eval()
    
    # throw away the first run
    if should_print:
        print("Throwing away the first run...")
        collect_val_data(num_batches=num_batches, num_runs=1, should_print=False)
    
    print(f"Collecting data for {num_batches if num_batches is not None else 'all'} batches {num_runs} times...")
    average_times = []
    average_accs = []
    for _ in range(num_runs):
        # iterate through the entire dataloader and save the exit idxs
        exit_idx_counts = {}
        times = []
        adj_times = []
        accs = []
        for i, (x, y) in enumerate(test_dataloader):
            x = x.to(device)
            start = time.time()
            y_hat = model(x)
            
            accs.append((torch.argmax(y_hat, dim=1) == y.to(device)).sum().item() / len(y))
            times.append(time.time() - start)
            adj_times.append(times[-1] - sum([exit_module.gate_time for exit_module in model.exit_modules]))
            
            y_hat_classes = y_hat.argmax(dim=1)
            exit_idxs = torch.tensor(model.num_exits_per_module, device=device)
            for (j, exit_idx) in enumerate(exit_idxs):
                exit_idx_counts[j+1] = exit_idx_counts.get(j+1, 0) + exit_idx.item()
                
            if i+1 == num_batches: break
            
        average_times.append(np.array(times).mean())
        average_accs.append(np.array(accs).mean())
        
        if not should_print: continue
        print(datetime.datetime.now())
        print("TIME", average_times[-1])
        print("ACCURACY", average_accs[-1])
        if adj_times[-1] > 0:
            print(adj_times[-1])
    
    if not should_print: return
    print("=== FINAL RESULTS ===")
    print(exit_idx_counts)
    print(f"Average time: {np.array(average_times).mean()}")
    print(f"Average accuracy: {np.array(average_accs).mean()}")
    return exit_idx_counts, average_times, average_accs

In [21]:
import ModelLoader

train_dataloader, test_dataloader, _ = create_dataloader(test_batch_size=None)
loader = ModelLoader.ModelLoader(model_name, device, alpha=0.5, dataloader=train_dataloader)

# preview the model architecture
model = loader.load_model(num_outputs=num_classes, pretrained=True)
print("\n=====================================================\n")
with torch.no_grad():
    ee_idxs, ee_times, ee_accs = collect_val_data(num_runs=5, num_batches=None)
# model

Loading EarlyExit VGG11 model architecture...
Adding exits...


  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)


Setting model weights...
EarlyExitModel(
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): OptionalExitModule(
        (module): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (exit_gate): Linear(in_features=8192, out_features=1, bias=True)
        (classifier): Linear(in_features=8192, out_features=10, bias=True)
      )
      (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_r

  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)


In [None]:
# now, reload model with no exits and run testing again

train_dataloader, test_dataloader, _ = create_dataloader(test_batch_size=None)
loader = ModelLoader.ModelLoader(model_name, device, alpha=0.75, dataloader=train_dataloader)

# preview the model architecture
model = loader.load_model(num_outputs=num_classes, pretrained=True)
model.clear_exits()
print("\n=====================================================\n")
with torch.no_grad():
    orig_idxs, orig_times, orig_accs = collect_val_data(num_runs=5, num_batches=None)
# model

In [None]:
print("\n=====================================================\n")
print(f"Accuracy Drop: {np.array(orig_accs).mean() - np.array(ee_accs).mean()}")
print(f"Speedup Factor: {np.array(orig_times).mean() / np.array(ee_times).mean()}x")
print(f"Exit Index Distribution: {ee_idxs}")

In [None]:
raise Exception("Done")

In [None]:
import pandas
import numpy as np
  
if "cifar100" in model_name:
    class_names = [
        "apple",
        "aquarium_fish",
        "baby",
        "bear",
        "beaver",
        "bed",
        "bee",
        "beetle",
        "bicycle",
        "bottle",
        "bowl",
        "boy",
        "bridge",
        "bus",
        "butterfly",
        "camel",
        "can",
        "castle",
        "caterpillar",
        "cattle",
        "chair",
        "chimpanzee",
        "clock",
        "cloud",
        "cockroach",
        "couch",
        "crab",
        "crocodile",
        "cup",
        "dinosaur",
        "dolphin",
        "elephant",
        "flatfish",
        "forest",
        "fox",
        "girl",
        "hamster",
        "house",
        "kangaroo",
        "keyboard",
        "lamp",
        "lawn_mower",
        "leopard",
        "lion",
        "lizard",
        "lobster",
        "man",
        "maple_tree",
        "motorcycle",
        "mountain",
        "mouse",
        "mushroom",
        "oak_tree",
        "orange",
        "orchid",
        "otter",
        "palm_tree",
        "pear",
        "pickup_truck",
        "pine_tree",
        "plain",
        "plate",
        "poppy",
        "porcupine",
        "possum",
        "rabbit",
        "raccoon",
        "ray",
        "road",
        "rocket",
        "rose",
        "sea",
        "seal",
        "shark",
        "shrew",
        "skunk",
        "skyscraper",
        "snail",
        "snake",
        "spider",
        "squirrel",
        "streetcar",
        "sunflower",
        "sweet_pepper",
        "table",
        "tank",
        "telephone",
        "television",
        "tiger",
        "tractor",
        "train",
        "trout",
        "tulip",
        "turtle",
        "wardrobe",
        "whale",
        "willow_tree",
        "wolf",
        "woman",
        "worm"
    ]
else:
    class_names = []
    
if model_name == "vgg_cifar10":
    class_idx = 5
elif model_name == "vgg_cifar100" or model_name == "densenet_cifar100":
    class_idx = class_names.index("apple")
else:
    class_idx = 0
  
    

In [None]:
# get all the images of the class_idx
_, _, torch_dataset = create_dataloader()

class_images = []
for i in range(len(torch_dataset)):
    img, label = torch_dataset[i]
    if label == class_idx:
        class_images.append((img, label))



In [None]:
print(f"Found {len(class_images)} images of class {class_names[class_idx]}")
print(f"Shape of first image: {class_images[0][0].shape}")

In [None]:
# load each image throughout the model
num_exits_taken = {}
misclassified_images = {}
correct_images = {}



for i in range(len(class_images)):
    image, label = class_images[i]
    image = image.reshape(1, *image.shape).to(device)
    
    output_class = model(image)
    exit_idx_taken = torch.tensor(model.num_exits_per_module, device=device).argmax() + 1
    num_exits_taken[exit_idx_taken.item()] = num_exits_taken.get(exit_idx_taken.item(), 0) + 1
    
    # get class name for prediction from model
    class_idx_prediction = torch.argmax(output_class, dim=1).item()
    class_name = class_names[class_idx_prediction]
    
    if output_class.argmax() != class_idx:
        misclassified_images[i] = (class_name, label, exit_idx_taken.item())
    else:
        correct_images[i] = exit_idx_taken.item()
    
    if i % 100 == 0:
        print(f"Processed {i} images:", num_exits_taken)
        print(f"Misclassified images: {len(misclassified_images)}")
        print(f"Correct images: {len(correct_images)}")



In [None]:
# iterate through correct and incorrect images to find where they exited
exit_idxs_incorrect = {}
for (_, _, exit_idx_taken) in misclassified_images.values():
    exit_idxs_incorrect[exit_idx_taken] = exit_idxs_incorrect.get(exit_idx_taken, 0) + 1

exit_idxs_correct = {}
for idx in correct_images.values():
    exit_idxs_correct[idx] = exit_idxs_correct.get(idx, 0) + 1


print("Correct Indices", exit_idxs_correct)
print("Incorrect Indices", exit_idxs_incorrect)

In [None]:
if not os.path.exists("data_vis"):
    os.mkdir("data_vis")
    
if not os.path.exists(os.path.join("data_vis", "misclassified")):
    os.mkdir(os.path.join("data_vis", "misclassified"))
    
if not os.path.exists(os.path.join("data_vis", "correct")):
    os.mkdir(os.path.join("data_vis", "correct"))
    
# clear the folders
for filename in os.listdir(os.path.join("data_vis", "misclassified")):
    os.remove(os.path.join("data_vis", "misclassified", filename))
    
for filename in os.listdir(os.path.join("data_vis", "correct")):
    os.remove(os.path.join("data_vis", "correct", filename))
    
# save the misclassified images
for i, (label, y, exit_idx) in misclassified_images.items():
    image = class_images[i][0]
    torchvision.utils.save_image(image, os.path.join("data_vis", "misclassified", f"{i}-{label}-idx{exit_idx}.png"))
    
# save the correct images
for (image_idx, exit_idx) in correct_images.items():
    image = class_images[image_idx][0]
    torchvision.utils.save_image(image, os.path.join("data_vis", "correct", f"{image_idx}-idx{exit_idx}.png"))
