In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
from sklearn.metrics import confusion_matrix
from tqdm.notebook import tqdm

# Threshold based model cascades

#### Greedy heuristic to jump between consecutive rep. sizes in MRL models is determined using model's prediction confidence. For each rep. size we first determine threshold on prediction confidence, such that, if for any input image if the confidence is below threshold, we jump to the next model. We perform a grid search on interval (0, 1) using a held out set of size 10000, and report result on remaining 40000 images. Lastly, in the paper we report results averaged over 30 random seeds. 

 #### Note that to have a reliable estimation of thresholds, we use test time augmentation, therefore, inference scripts MUST be run using `--tta` flag. 

In [7]:
N_seeds=30
import collections

In [None]:
PATH_TO_GROUND_TRUTH = "gt_dataset=V1.pth" # we need to pass ground truth as a torch tensor. 
PATH_TO_SOFTMAX_PROBABILITIES = "mrl=1_efficient=0_dataset=V1_tta=1_softmax.pth" 

In [None]:
input_idxes = [];thrsh_all_seeds=[]
for l in range(N_seeds):
    idx = np.random.choice(50000, 10000, replace=False) # Selecting the subset for grid search.
    input_idxes.append(idx) # appending it so that we can get the held out set. 
    greedy_thrsh=[]
    gt = torch.load(PATH_TO_GROUND_TRUTH)[idx].cpu()
    softmax=torch.load(PATH_TO_SOFTMAX_PROBABILITIES)[:, idx, :].cpu()
    confidence_, predictions_ = torch.max(softmax, dim=-1)
    n=len(gt)

    thrsh = np.linspace(0.1,1, 100) # Grid search
    for d1 in range(8):
        d2=d1+1
        print("Searching thresholds between models of rep. size", 2**(d1+3), 2**(d2+3))
        confidence_d1, predictions_d1 = confidence_[d1], predictions_[d1]
        confidence_d2, predictions_d2 = confidence_[d2], predictions_[d2]

        acc=[]; lower_=[]
        for t in thrsh:
            preds= torch.zeros(n); preds=preds.long()
            idx_d1 = (confidence_d1>t) # indices where smaller dimension is confident than threshold
            idx_d2 =(confidence_d1<=t) 
            n1, n2= (idx_d1.sum()).item(), (idx_d2.sum()).item() # number of such examples.
            preds[idx_d1] = predictions_d1[idx_d1] # Using predictions from smaller dimension
            preds[idx_d2] = predictions_d2[idx_d2] # Using predictions from higher dimension
            acc.append(100*(((preds==gt).sum())/n).cpu().numpy()) # Computing accuracy 

        acc = np.asarray(acc)
        max_acc = -np.asarray(sorted(-acc))[0]
        best_thrhs = thrsh[np.nonzero(acc==max_acc)[0]].min() # Choosing minimum such threshold which will result in best accuracy
        max_idx =(np.asarray(acc)).argmax()
        print(f"Cascade Performance between dimension {2**(d1+3)} and {2**(d2+3)} is {acc[max_idx]} with threshold {best_thrhs}")
        greedy_thrsh.append(best_thrhs) # Saving the policy  
    
    thrsh_all_seeds.append(greedy_thrsh) # For each random seed.. 

In [9]:
sel = [] # Getting the held out set for testing greedy policy.
for i in input_idxes:
    u = np.zeros(50000)
    u[i]=1
    sel.append(np.nonzero(1-u)[0])

### Evaluating Greedy Policy 

#### In the previous snippet we determined the greedy threshold to switch between the consecutive rep. sizes. We can now use this to naviagate between the cascades. While we can do this for all the models (that is rep. sizes from 8 to 2048), but we will also study how does early stopping affects the performance. This means, we will set a cap on the maximum rep. size we will use in this cascade system. 

In [None]:
all_expected_dimensions=[[] for l in range(8)] # For every max rep. size for cascading, we store the expected dimension used for prediction to later average across random seeds.
all_accuracies=[[] for l in range(8)]

for early_stopping_dim in range(8): # Early stopping dimension
    print(early_stopping_dim+1)
    for seed in range(N_seeds):
        gt = torch.load(PATH_TO_GROUND_TRUTH)[sel[seed]].cpu() #choosing the held out set
        softmax=torch.load(PATH_TO_SOFTMAX_PROBABILITIES)[:, sel[seed], :].cpu() # corresponding softmaxes
        confidence_, predictions_ = torch.max(softmax, dim=-1)
        greedy_thrsh = thrsh_all_seeds[seed] # Greedy policy for that seed

        max_cascading = early_stopping_dim+1
        acc=0 #number of correct predictions
        model=[] 
        for i in tqdm(range(predictions_.shape[-1])): #iterating over all the testing examples. 
            flag=True
            for j in range(max_cascading):
                if confidence_[j, i] > greedy_thrsh[j]: # if we are confident at smaller dimension model, then break
                    j_ = j; flag=False # j_ denotes the dim we used to make prediction. 
                    break
            if flag:
                j_=max_cascading # This means that we will use the maximum possible dimension for predictions. 
                
            model.append(j_) # dimension to be used for prediction
            acc+=(predictions_[j_, i]==gt[i]).sum()
            
        counter=collections.Counter(model) # A counter over the models used for predicting    
        probs = {2**(j+3): counter[j]/len(gt) for j in counter.keys()} # probability distribution, used to compute expected representation size for prediction 

        expected_dim=0
        for k in probs.keys():
            expected_dim+= k*probs[k]

        all_expected_dimensions[early_stopping_dim].append([expected_dim]) # Saving expected dimensionality for every seed and maximum cascade
        all_accuracies[early_stopping_dim].append([(acc/len(gt)).item()])  # Saving accuracy for every seed and maximum cascade

all_expected_dimensions = (np.asarray(all_expected_dimensions)).squeeze()
all_accuracies = (np.asarray(all_accuracies)).squeeze()


### Expected dimension statistics for different maximum cascade rep. size

In [14]:
print(all_expected_dimensions.mean(axis=-1)) # Mean expected dimension for every maximum cascade rep. size
print(all_expected_dimensions.std(axis=-1)) # standard deviation in expected dimension for every maximum cascade rep. size

[ 13.18502     18.02794     25.73146     35.28671333  45.66122
  62.05503333  87.92042    121.67231333]
[ 0.72321347  1.60812216  2.80969956  5.28050878  7.60811513 10.94328704
 20.18233146 32.94076334]


### Cascade accuracy statistics for different maximum cascade rep. size

In [15]:
print(all_accuracies.mean(axis=-1))
print(all_accuracies.std(axis=-1))

[0.73790917 0.75242666 0.76048249 0.76273834 0.76418666 0.76514667
 0.76533666 0.76541667]
[0.00103216 0.00091551 0.00139763 0.00170652 0.00188398 0.0019996
 0.002036   0.0020418 ]


In [16]:
thrsh_all_seeds = np.asarray(thrsh_all_seeds) # threshold for different random seeds

### Greedy threshold statistics for different maximum cascade rep. size

In [19]:
print(thrsh_all_seeds.mean(axis=0))
print(thrsh_all_seeds.std(axis=0))

[0.78636364 0.70939394 0.7930303  0.65151515 0.55333333 0.55636364
 0.52212121 0.44484848]
[0.07305796 0.11164888 0.07465931 0.1491451  0.05973926 0.06952812
 0.13982837 0.11729113]
