# Putting Example Difficulty in Practice with Early Exit

In the last notebook we explored the concept of example difficulty and prediction depth. How can we use this knowledge to our advantage? What kind of benefits can we see by leveraging our understanding of example difficulty. 

In many growing domains of machine learning, we often will have systems constraints like inference latency, energy usage, etc. 

As we noticed in the last homework, some examples actually don't need to be passed through the entire network. This concept was formalized as prediction depth.

So why do we need to pass the inputs through the entire network if they can be predicted correctly by an earlier layer?

Are there tradeoffs associated with not going through the entire network? 

Much of this homework was inspired by the following paper:

https://arxiv.org/abs/1709.01686

## Connecting to Example Difficulty

In the last homework, we used KNN classifiers to determine prediction depth and visualize patterns that related image size and shape to the prediction difficulty. However, we don't actually need these KNN classifiers. What if we just replaced the KNN classifers with output heads and sent gradients through? This is the main idea behind BranchyNet.

We seek to improve inference speed by simply exiting the network when a prediction is made with reasonably high confidence.

## Concepts of BranchyNet

We will now have N exits, as we did with the KNN classifiers. However, now each exit will contribute to the loss in the following manner

$L_{\text{early exit}}(\hat{y}, y; \theta) = \sum_{i=1}^N w_i L(\hat{y_{\text{exit}}}, y; \theta)$

We will set the total loss of the network be a weighted sum of the standard cross entropy losses at each exit


# Baseline ResNet-18

To properly see the effects of Early Exit, let's set up a resnet without early exit as a baseline for computation and inference speed.

In [5]:
import torch
import torch.nn as nn
import copy
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
import sklearn
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import entropy
import torchprofile
from tqdm import tqdm
from architectures import EarlyExitResNet18
from copy import deepcopy

In [2]:
batch_size = 256

## MACS

MACS stands for multiply and accumulate - In the hardware, this corresponds to multiplying, then adding a number to an accumulator. We use MACS as a measurement of the amount of computation used.

In [3]:
def get_model_macs(model, inputs) -> int:
    return torchprofile.profile_macs(model, inputs)

Let's set up our dataloader in the standard fashion. Similar to last homework, please download the data and put it in the same folder as this

In [13]:
data = np.load('data.npy', allow_pickle=True).item()
x_tensor = torch.FloatTensor(data['x'])
y_tensor = torch.LongTensor(data['y'])
dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor)

test_data = np.load('test_data.npy', allow_pickle=True).item()
x_tensor = torch.FloatTensor(test_data['x'])
y_tensor = torch.LongTensor(test_data['y'])
test_dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


transform = transforms.Compose([
    transforms.ToPILImage(),                       # Convert arrays to PIL images
    transforms.Grayscale(num_output_channels=3),   # Convert grayscale to RGB
    transforms.Resize((224, 224)),                 # Resize all images to 224x224
    transforms.ToTensor(),                      # Convert the images to PyTorch tensors
])


resnet_dataset = deepcopy(dataset)
resnet_test_dataset = deepcopy(test_dataset)

resnet_dataset.transform = transform
resnet_trainloader = torch.utils.data.DataLoader(resnet_dataset, batch_size=batch_size, num_workers=2, shuffle=True)

resnet_test_dataset.transform = transform
resnet_testloader = torch.utils.data.DataLoader(resnet_test_dataset, batch_size=128, num_workers=2, shuffle=False)

In [9]:
resnet = EarlyExitResNet18()
resnet = resnet.to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
resnet_optimizer = optim.Adam(resnet.parameters(), lr=0.0001)

# Training

Let's train a standard ResNet. Don't forget to pass in an entropy tolerance and set early_exit to False

In [12]:
step = 0
resnet_losses = []
for epoch in tqdm(range(10)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(resnet_trainloader, 0):
        step += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = inputs, labels = data[0].to(device), data[1].to(device)
        
        inputs = inputs.unsqueeze(1)
        inputs = inputs.repeat(1, 3, 1, 1)
        inputs = inputs.to(device)

        # zero the parameter gradients
        resnet_optimizer.zero_grad()
        

        # forward + backward + optimize
        # TODO
        outputs = ... 
        loss = ...
        loss.backward()
        resnet_optimizer.step()
        resnet_losses.append(loss.item())
        # print statistics
        running_loss += loss.item()
        if i % 50 == 49:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 50:.3f}')
            running_loss = 0.0
        
        

print('Finished Training')

  0%|          | 0/10 [00:00<?, ?it/s]

[1, 50] loss: 1.128
[1, 100] loss: 0.377


 10%|█         | 1/10 [00:08<01:15,  8.39s/it]

[2, 50] loss: 0.153
[2, 100] loss: 0.126


 20%|██        | 2/10 [00:16<01:08,  8.50s/it]

[3, 50] loss: 0.071
[3, 100] loss: 0.070


 30%|███       | 3/10 [00:25<00:59,  8.54s/it]

[4, 50] loss: 0.036
[4, 100] loss: 0.032


 40%|████      | 4/10 [00:34<00:51,  8.55s/it]

[5, 50] loss: 0.023
[5, 100] loss: 0.021


 50%|█████     | 5/10 [00:42<00:42,  8.54s/it]

[6, 50] loss: 0.014
[6, 100] loss: 0.014


 60%|██████    | 6/10 [00:51<00:34,  8.52s/it]

[7, 50] loss: 0.022
[7, 100] loss: 0.015


 70%|███████   | 7/10 [00:59<00:25,  8.50s/it]

[8, 50] loss: 0.007
[8, 100] loss: 0.009


 80%|████████  | 8/10 [01:08<00:16,  8.49s/it]

[9, 50] loss: 0.009
[9, 100] loss: 0.006


 90%|█████████ | 9/10 [01:16<00:08,  8.46s/it]

[10, 50] loss: 0.007
[10, 100] loss: 0.011


100%|██████████| 10/10 [01:24<00:00,  8.49s/it]

Finished Training





### Fill in the code below to evaluate the ResNet

In [15]:
resnet.eval()
total_macs = 0

for epoch in range(1):  # loop over the dataset multiple times

    total_correct = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(resnet_testloader, 0)):
            # get the inputs; data is a list of [inputs, labels]
            
            inputs, labels = inputs, labels = data[0].to(device), data[1].to(device)
            
            inputs = inputs.unsqueeze(1)
            inputs = inputs.repeat(1, 3, 1, 1)
            inputs = inputs.to(device)
            # forward + backward + optimize
            # TODO
            outputs = ...
            total_macs += get_model_macs(resnet, (inputs, 0, False))
            
            
            indices = torch.argmax(outputs, dim=1)
            
            total_correct += torch.sum(labels == indices)


print(f'Accuracy: {total_correct/6000} %')
print('Total MACS: ', total_macs)

47it [00:02, 19.61it/s]

Accuracy: 0.9078333377838135 %
Total MACS:  3336213504000





What was the accuracy with a regular ResNet-18?

Inference Speed?

Total MACS?

In [16]:
resnet_early = EarlyExitResNet18()
resnet_early = resnet.to(device)
criterion = nn.CrossEntropyLoss()
resnet_optimizer = optim.Adam(resnet.parameters(), lr=0.0001)

### Fill in the code below to train the early exit network

In [17]:
step = 0
resnet_losses = []
for epoch in tqdm(range(10)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(resnet_trainloader, 0):
        step += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = inputs, labels = data[0].to(device), data[1].to(device)
        
        inputs = inputs.unsqueeze(1)
        inputs = inputs.repeat(1, 3, 1, 1)
        inputs = inputs.to(device)

        # zero the parameter gradients
        resnet_optimizer.zero_grad()
        

        # forward + backward + optimize
        # TODO Use w_0 = 1 and w_i = 0.3 for i > 0
        outputs = ...
        loss = ...
        loss.backward()
        resnet_optimizer.step()
        resnet_losses.append(loss.item())
        # print statistics
        running_loss += loss.item()
        if i % 50 == 49:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 50:.3f}')
            running_loss = 0.0
        
        

print('Finished Training')

  0%|          | 0/10 [00:00<?, ?it/s]

[1, 50] loss: 2.892
[1, 100] loss: 1.327


 10%|█         | 1/10 [00:08<01:14,  8.25s/it]

[2, 50] loss: 0.741
[2, 100] loss: 0.696


 20%|██        | 2/10 [00:16<01:07,  8.42s/it]

[3, 50] loss: 0.506
[3, 100] loss: 0.471


 30%|███       | 3/10 [00:25<00:59,  8.49s/it]

[4, 50] loss: 0.357
[4, 100] loss: 0.366


 40%|████      | 4/10 [00:33<00:51,  8.52s/it]

[5, 50] loss: 0.269
[5, 100] loss: 0.269


 50%|█████     | 5/10 [00:42<00:42,  8.55s/it]

[6, 50] loss: 0.209
[6, 100] loss: 0.212


 60%|██████    | 6/10 [00:51<00:34,  8.56s/it]

[7, 50] loss: 0.155
[7, 100] loss: 0.165


 70%|███████   | 7/10 [00:59<00:25,  8.56s/it]

[8, 50] loss: 0.127
[8, 100] loss: 0.143


 80%|████████  | 8/10 [01:08<00:17,  8.57s/it]

[9, 50] loss: 0.092
[9, 100] loss: 0.087


 90%|█████████ | 9/10 [01:16<00:08,  8.57s/it]

[10, 50] loss: 0.056
[10, 100] loss: 0.062


100%|██████████| 10/10 [01:25<00:00,  8.54s/it]

Finished Training





### Fill in the code below to evaluate the early exit network

In [22]:
resnet_early.eval()
exiting_points = {0:0, 1:0, 2:0, 3:0}
entropies = []
total_early_macs = 0

for epoch in range(1):  # loop over the dataset multiple times

    total_early_correct = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(resnet_testloader, 0)):
            # get the inputs; data is a list of [inputs, labels]
            
            inputs, labels = inputs, labels = data[0].to(device), data[1].to(device)
            
            inputs = inputs.unsqueeze(1)
            inputs = inputs.repeat(1, 3, 1, 1)
            inputs = inputs.to(device)
            # forward + backward + optimize
            # TODO Use an entropy tolerance of 0.05
            outputs, num, curr_entropy = ...
            entropies.append(curr_entropy)
            total_early_macs += get_model_macs(resnet_early, (inputs, 0.05))
            
            
            exiting_points[num] += 1
            
            
            indices = torch.argmax(outputs, dim=1)
            
            total_early_correct += torch.sum(labels == indices)


print(f'Accuracy: {total_correct/6000} %')
print('Num Exiting: ', exiting_points)
print('Total MACS: ', total_early_macs)
entropies = sorted(entropies)
print(len(entropies))
print('Entropies: ', entropies[0], entropies[len(entropies)//4], entropies[len(entropies)//2], entropies[3*len(entropies)//4], entropies[-1])

47it [00:01, 23.53it/s]

Accuracy: 0.9078333377838135 %
Num Exiting:  {0: 17, 1: 16, 2: 8, 3: 6}
Total MACS:  2305700339712
47
Entropies:  0.004162853 0.019386116 0.06340146 0.08171481 0.098951355





What was the accuracy with an early exit ResNet-18?

Inference Speed?

Total MACS?

In [23]:
print(f'Ratio of Standard to Early Exit MACS: {total_macs/total_early_macs}')

1.446941498224665

## Entropy

Play around with the entropy tolerance to see how low you can get the MACS while keeping 90 percent or greater accuracy.

How did early exit do? Compare accuracy and MACS.

## Open Question

No solutions will be provided for this question:

When would we use early exit, versus just using a smaller model? What factors should we consider? 

How does early exit relate to example difficulty?
