In [None]:
%load_ext autoreload
%autoreload 2
%cd ../../

In [None]:

import numpy as np
import json
from sklearn.metrics import confusion_matrix
from main.utils.get_data import get_dataset

from main.utils.analysis_utils import plot_macs_vs_acc,entropy

import matplotlib.pyplot as plt

plt.style.use('seaborn-colorblind')

height = 16
width = height*1.6
plt.style.use('ggplot')
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams["figure.figsize"] = (width,height)
plt.rcParams['xtick.color'] = 'black'
plt.rcParams['xtick.major.width'] = 1.6
plt.rcParams['ytick.color'] = 'black'
plt.rcParams['axes.labelcolor'] = "black"
plt.rcParams['axes.linewidth'] = 1.6
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['font.size'] = 25

def get_model_params(directory):
    model_param_file = open(directory + '/model_params.json')
    model_params = json.load(model_param_file)
    if 'loss' not in model_params.keys():
        model_params['loss'] = 'cross-entropy'
    return model_params
def get_label(labels,label):
    if label in labels:
        label = label+'_'+str(2)
        for i in range(3,10):
            if label in labels:
                label = label+'_'+str(i)
            else:
                break
    labels.append(label)
    return labels,label

def make_data(id_set,ood_set):
    y_1,y_2 = np.zeros(len(id_set)),np.ones(len(ood_set))
    return(np.concatenate([id_set,ood_set]),np.concatenate([y_1,y_2]))

n_bins = 100

In [None]:
#Taking kth value

green = np.array([83, 250, 0])/255
red = np.array([255, 40, 0])/255
n = 256  # number of values in the colormap
color_list = [green + (red - green) * i / (n - 1) for i in range(n)]
cmap = plt.cm.colors.ListedColormap(color_list[::-1])

model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/'

dataset = 'CIFAR10'

adversarial = np.load(model_directory + '/adversarial_analysis/knn_distances.npy')
epsilons = np.load(model_directory + '/adversarial_analysis/epsilons.npy')

clean_values = np.load(model_directory+'/knn_ood/train_'+dataset+'_train/test_'+dataset+'_test/all_k_distances.npy')[:,:,-1]

n_tests = adversarial.shape[0]
n_branches = adversarial.shape[2]

fig, axes = plt.subplots(2, int(n_branches/2))

for branch,ax in enumerate(axes.flatten()):
    clean_branch = clean_values[:,branch]
    ax.set_title('Branch: '+ str(branch+1))
    ax.hist(clean_branch,bins=n_bins,label='in distribution',alpha=0.4,color=green,density=True)
    # n_high = np.percentile(clean_branch,99)
    # ax.vlines(np.array([n_high]),ymin=0,ymax=50,colors=['black'])
    for ood_test in range(n_tests):
        knn_values = adversarial[ood_test,:,:,:]
        adversarial_branch = knn_values[:,branch,-1]
        ax.hist(adversarial_branch,bins=n_bins,alpha=0.4,density=True,color=cmap(1-(ood_test/n_tests)))

plt.tight_layout()
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_array([])
cbar = fig.colorbar(sm, ax=axes, location='right')
cbar.set_alpha(0.7)
cbar.draw_all()
cbar.set_ticks([0,0.33,0.66,1])
cbar.set_ticklabels([0.3,0.2,0.1,0])

cbar.set_label('$\epsilon$',rotation=0,fontsize=34)

# fig.text(0.5, 0.95, ('Embedding separation with perturbation: $\epsilon$'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.0, 0.5, 'Counts', ha='center', va='center', rotation='vertical',fontsize=36)
fig.text(0.5, 0.0, 'k$^{th}$ embedding separation', ha='center', va='center', rotation='horizontal',fontsize=36)
plt.savefig('main/figures/adv_dists.pdf',bbox_inches='tight')
plt.show()

In [None]:
def get_AUROC(directory,epsilons,n_thresh=100,dataset='CIFAR10'):
    clean_values = np.load(directory+'/knn_ood/train_'+dataset+'_train/test_'+dataset+'_test/all_k_distances.npy')[:,:,-1]
    adversarial = np.load(directory + '/adversarial_analysis/knn_distances.npy')[:,:,:,-1]
    n_tests = adversarial.shape[0]
    n_branches = adversarial.shape[2]
    areas = np.zeros((len(epsilons),n_branches))
    thresholds = np.linspace(0,100,n_thresh)
    for test_idx,epsilon in enumerate(epsilons):
        perturbed = adversarial[test_idx,:,:]
        for branch in range(n_branches):
            clean_branch = clean_values[:,branch]
            perturbed_branch = perturbed[:,branch]
            TPR = np.zeros(n_thresh)
            FPR = np.zeros(n_thresh)
            for thresh_idx,thresh in enumerate(thresholds[::-1]):
                threshold = np.percentile(clean_branch,thresh)
                x,y = make_data(clean_branch,perturbed_branch)
                preds = np.array(x>threshold,dtype=int)
                conf_mat = confusion_matrix(y,preds)
                tp = conf_mat[1,1]
                fp = conf_mat[0,1]
                tn = conf_mat[0,0]
                fn = conf_mat[1,0]

                TPR[thresh_idx] = tp/(fn+tp)
                FPR[thresh_idx] = fp/(tn+fp)
            
            areas[test_idx,branch] = np.trapz(TPR,FPR)
    return areas

In [None]:
directories = ['trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_1/',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_2/',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_3/',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_4/']
               
epsilons = np.load(directories[0] + '/adversarial_analysis/epsilons.npy')

AUROC_curves = list()

for directory in directories:
    AUROC_curves.append(get_AUROC(directory,epsilons,dataset='CIFAR100'))

all_curves = np.stack(AUROC_curves,axis=0)

np.save(directories[0]+'/adversarial_analysis/AUROC_adv.npy',all_curves) 


In [None]:
model_directory = 'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/'
all_curves = np.load(model_directory+'/adversarial_analysis/AUROC_adv.npy') 

n_branches = 4 

mean_area = np.mean(all_curves,axis=0)
std_area = np.std(all_curves,axis=0)

for test_idx,test in enumerate(epsilons):
    print('Data:',test)
    for branch in range(n_branches):
        print('Branch:',branch+1,np.round(mean_area[test_idx,branch],2),'+-',np.round(std_area[test_idx,branch],2))

In [None]:
n_thresh = 100
thresholds = np.linspace(0,100,n_thresh)

green = np.array([83, 250, 0])/255
red = np.array([255, 40, 0])/255
n = 256  # number of values in the colormap
color_list = [green + (red - green) * i / (n - 1) for i in range(n)]
cmap = plt.cm.colors.ListedColormap(color_list[::-1])

model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/'

dataset = 'CIFAR10'

adversarial = np.load(model_directory + '/adversarial_analysis/knn_distances.npy')[:,:,:,-1]
epsilons = np.load(model_directory + '/adversarial_analysis/epsilons.npy')

clean_values = np.load(model_directory+'/knn_ood/train_'+dataset+'_train/test_'+dataset+'_test/all_k_distances.npy')[:,:,-1]


n_tests = adversarial.shape[0]
n_branches = adversarial.shape[2]

fig, axes = plt.subplots(2, int(n_branches/2))
plt.tight_layout()
areas = np.zeros((n_tests,n_branches))

for perturbation in range(n_tests):
    perturbed = adversarial[perturbation,:,:]
    for branch,ax in enumerate(axes.flatten()):
        clean_branch = clean_values[:,branch]
        perturbed_branch = perturbed[:,branch]
        TPR = np.zeros(n_thresh)
        FPR = np.zeros(n_thresh)
        decision_boundaries = np.zeros(n_thresh)
        for thresh_idx,thresh in enumerate(thresholds[::-1]):
            threshold = np.percentile(clean_branch,thresh)
            x,y = make_data(clean_branch,perturbed_branch)
            preds = np.array(x>threshold,dtype=int)
            conf_mat = confusion_matrix(y,preds)
            tp = conf_mat[1,1]
            fp = conf_mat[0,1]
            tn = conf_mat[0,0]
            fn = conf_mat[1,0]

            TPR[thresh_idx] = tp/(fn+tp)
            FPR[thresh_idx] = fp/(tn+fp)
        
        AUROC = np.trapz(TPR,FPR)
        areas[perturbation,branch] = AUROC
        ax.set_title('Branch: '+str(branch+1))
        ax.plot(FPR,TPR,color=cmap(1-(perturbation/n_tests)),label=('$\epsilon$'+str(round(epsilons[perturbation],2))+' : '+str(round(AUROC,3))))
        # ax.legend(title='AUROC')
        # ax.vlines([0.05],0.05,0.95)
        # ax.grid(which='major',color='grey', alpha=0.3,linestyle='--', linewidth=1)

print(np.round(areas[1:],3))    
print(np.round(epsilons[1:],2)) 

# fig.text(0.5, 0.95, ('Adversarial example detection ROC'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.0, 0.5, 'TPR', ha='center', va='center', rotation='vertical',fontsize=36)
fig.text(0.5, 0.0, 'FPR', ha='center', va='center', rotation='horizontal',fontsize=36)

sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_array([])
cbar = fig.colorbar(sm, ax=axes, location='right')
cbar.set_alpha(0.7)
cbar.draw_all()
cbar.set_ticks([0,0.33,0.66,1])
cbar.set_ticklabels([0.3,0.2,0.1,0])

cbar.set_label('$\epsilon$',rotation=0,fontsize=34)
plt.savefig('main/figures/adv_roc.pdf',bbox_inches='tight')
plt.show()

In [None]:
import numpy as np
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx, array[idx]

def log_spacing(value,n_thresh):
    diff = 1-value
    logs = np.logspace(-2,0,n_thresh)
    return logs*diff + value  


def run_adversarial_inference(model_directory,train_dataset,epsilon,n_thresh=100,knn_percentile=0.95,detect_ood=True,adaptive=False):
    func_outputs = dict()
    name = model_directory.split('/')[3] + ' -- ' + model_directory.split('/')[4]

    id_outputs = np.load(model_directory+'/outputs.npy')
    n_branches = id_outputs.shape[1]
    id_labels = np.load(model_directory+'/labels.npy')

    powers = np.load(model_directory+'/power_usage.npy')
    epsilons = np.load(model_directory+'/../adversarial_analysis/epsilons.npy')
    accuracies = np.load(model_directory+'/../adversarial_analysis/accuracies.npy')

    index,_ = find_nearest(epsilons,epsilon)

    id_knn = np.load(model_directory+'/../knn_ood/train_'+train_dataset+'_train/test_'+train_dataset+'_test/all_k_distances.npy')
    adversarial_knn = np.load(model_directory+'/../adversarial_analysis/knn_distances.npy')[index,:,:,:]

    branch_predictions = np.argmax(id_outputs,axis=2)

    n_id_inputs = branch_predictions.shape[0]
    n_ood_inputs = adversarial_knn.shape[0]

    ood_labels = np.full(n_ood_inputs,-1)
    all_labels = np.concatenate([id_labels,ood_labels])

    id_entropies = np.zeros((n_id_inputs,n_branches))
    for input_idx in range(n_id_inputs):
        for branch_idx in range(n_branches):
            id_entropies[input_idx,branch_idx] = entropy(id_outputs[input_idx,branch_idx,:])
    
    id_exits = np.zeros((n_thresh,n_id_inputs,n_branches))
    id_predictions = np.zeros((n_thresh,n_id_inputs))

    ood_exits = np.zeros((n_thresh,n_ood_inputs,n_branches))
    ood_predictions = np.zeros((n_thresh,n_ood_inputs))

    power_usage = np.zeros(n_thresh)

    max_entropy = np.log(id_outputs.shape[2])  
    thresholds = np.linspace(max_entropy,0,n_thresh)

    knn_thresholds = np.zeros((n_branches,n_thresh))
    for branch in range(n_branches):
        id_branch = id_knn[:,branch,-1]
        knn_thresholds[branch,:] = np.percentile(id_branch,knn_percentile*100)
        if adaptive == True: 
            # adaptive_index = int(n_thresh/2)
            # percentiles=np.linspace(knn_percentile,1,n_thresh-(adaptive_index+1))
            percentiles=log_spacing(knn_percentile,n_thresh)
            for p_idx,percentile in enumerate(percentiles):
                knn_thresh = np.percentile(id_branch,percentile*100)
                # knn_thresholds[branch,p_idx+adaptive_index+1] = knn_thresh
                knn_thresholds[branch,p_idx] = knn_thresh
    
    for thresh_idx,threshold in enumerate(thresholds):
        #get ID outputs w/ knn 
        for inp_idx in range(n_id_inputs):
            early_exit = False
            for branch_idx,branch_entropy in enumerate(id_entropies[inp_idx,:]):
                if id_knn[inp_idx,branch_idx,-1] > knn_thresholds[branch_idx,thresh_idx]:
                    id_exits[thresh_idx,inp_idx,branch_idx] = 1
                    id_predictions[thresh_idx,inp_idx] = -1
                    early_exit=True
                    break
                if branch_entropy < threshold:
                    id_exits[thresh_idx,inp_idx,branch_idx] = 1
                    id_predictions[thresh_idx,inp_idx] = branch_predictions[inp_idx,branch_idx] 
                    early_exit=True
                    break
            if early_exit == False:
                id_exits[thresh_idx,inp_idx,(n_branches-1)] = 1
                id_predictions[thresh_idx,inp_idx] = branch_predictions[inp_idx,(n_branches-1)]

        #get OOD knn outputs w/ entropy
        for inp_idx in range(n_ood_inputs):
            if not detect_ood:
                ood_exits[:,:,(n_branches-1)] = 1
                ood_predictions[:,:] = -2
                break
            early_exit = False
            for branch_idx,knn_distance in enumerate(adversarial_knn[inp_idx,:,-1]):
                if knn_distance > knn_thresholds[branch_idx,thresh_idx]:
                    ood_exits[thresh_idx,inp_idx,branch_idx] = 1
                    ood_predictions[thresh_idx,inp_idx] = -1
                    early_exit=True
                    break
            if early_exit == False:
                ood_exits[thresh_idx,inp_idx,(n_branches-1)] = 1
                if accuracies[inp_idx,-1,index] == 1:
                    ood_predictions[thresh_idx,inp_idx] = -1
                else:
                    ood_predictions[thresh_idx,inp_idx] = -2

        all_exits = np.concatenate([id_exits,ood_exits],axis=1) 
        power_usage[thresh_idx] = np.dot(np.sum(all_exits[thresh_idx,:],axis=0),powers)
        all_predictions = np.concatenate([id_predictions,ood_predictions],axis=1) 

    ood_accuracy = np.zeros(n_thresh)
    id_accuracy = np.zeros(n_thresh)
    all_accuracy = np.zeros(n_thresh)
    for thresh in range(n_thresh):
        id_accuracy[thresh] = np.mean(id_predictions[thresh,:] == id_labels) 
        ood_accuracy[thresh] = np.mean(ood_predictions[thresh,:] == ood_labels) 
        all_accuracy[thresh] = np.mean(all_predictions[thresh,:] == all_labels) 
    
    func_outputs['all_predictions'] = all_predictions
    func_outputs['power_usage'] = power_usage/len(all_labels)
    func_outputs['all_labels'] = all_labels
    func_outputs['all_exits'] = all_exits
    func_outputs['id_predictions'] = id_predictions
    func_outputs['id_labels'] = id_labels
    func_outputs['ood_labels'] = ood_labels
    func_outputs['id_accuracy'] = id_accuracy
    func_outputs['ood_accuracy'] = ood_accuracy
    func_outputs['all_accuracy'] = all_accuracy
    func_outputs['n_id_samples'] = n_id_inputs
    func_outputs['n_ood_samples'] = n_ood_inputs


    return func_outputs


def plot_ood_power(ax,model_directory,data,epsilon,n_thresh,percentile,detect_ood=True,adaptive=False):
    if detect_ood==False:
        label = 'None'
    else:
        label = str(percentile)
    func_out = run_adversarial_inference(model_directory,data,epsilon,n_thresh=n_thresh,knn_percentile=percentile,detect_ood=detect_ood,adaptive=adaptive)
    ax.plot(func_out['power_usage'],func_out['id_accuracy'],label=label)
    standard_error = 1/np.sqrt(func_out['n_id_samples'])
    ax.fill_between(func_out['power_usage'],func_out['id_accuracy']-standard_error,func_out['id_accuracy']+standard_error,alpha=0.3)

def plot_improvement_power_adv(ax,model_directory,data,epsilon,n_thresh,percentile,detect_ood=True,adaptive=False):
    base_value_dict = run_adversarial_inference(model_directory,data,epsilon,n_thresh=n_thresh,knn_percentile=1.0,detect_ood=False,adaptive=False)
    base_power,base_acc=base_value_dict['power_usage'],base_value_dict['id_accuracy']

    func_out = run_adversarial_inference(model_directory,data,epsilon,n_thresh=n_thresh,knn_percentile=percentile,detect_ood=detect_ood,adaptive=adaptive)
    power,acc=func_out['power_usage'],func_out['id_accuracy']
    
    acc_range = np.linspace(base_acc[0],np.max(acc),500)
    max_acc = np.max(base_acc)
    interp_power = np.interp(acc_range,base_acc,base_power)

    if detect_ood==False:
        label = 'None'
    else:
        label = str(percentile)

    interp_power_test = np.interp(acc_range,acc,power)
    power_diff = ((interp_power-interp_power_test)/interp_power)*100
    cutoff_idx = np.argmin(power_diff>0)
    if cutoff_idx == 0:
        cutoff_idx=len(interp_power_test-1)

    acc_range=(max_acc-acc_range)*100
    
    ax.plot(acc_range[:cutoff_idx],power_diff[:cutoff_idx],label=label)

    sigma=3
    interval = sigma*np.sqrt(power_diff/func_out['n_id_samples'])

    ax.fill_between(acc_range[:cutoff_idx],power_diff[:cutoff_idx]-interval[:cutoff_idx],power_diff[:cutoff_idx]+interval[:cutoff_idx],alpha=0.3)

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis'
n_tests = 4
fig, axes = plt.subplots(2, 2,figsize=(width,height))
knn_ood_tests = np.linspace(0,0.3,n_tests)

for test_idx,epsilon in enumerate(knn_ood_tests):
    _,epsilon = find_nearest(epsilons,epsilon)
    print('Using closest value for epsilon:',epsilon)
    ax = axes.flatten()[test_idx]
    ax.set_title('$\epsilon = $'+ str(round(epsilon,2)))
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=25,percentile=1.0,detect_ood=False)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=25,percentile=1.0,detect_ood=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=25,percentile=0.999,detect_ood=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=25,percentile=0.995,detect_ood=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=25,percentile=0.99,detect_ood=True)
    
# plt.ylabel('OOD-Aware Accuracy')
# plt.xlabel('Power usage (MACs)')
h, l = axes.flatten()[-1].get_legend_handles_labels()
legend = fig.legend(h,l,title='Adversarial detection percentile',loc='center',bbox_to_anchor=(0.5,1.02),ncol=len(knn_ood_tests))
plt.setp(legend.get_title(), multialignment='center')
plt.tight_layout()

# fig.text(0.5, 0.95, ('OOD detection ROC'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.0, 0.5, 'Perturbation-Aware Accuracy', ha='center', va='center', rotation='vertical')
fig.text(0.5, 0, 'Power usage (MACs)', ha='center', va='center', rotation='horizontal')
plt.savefig('main/figures/adv_powers.pdf',bbox_inches='tight')
plt.show()

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis'
n_tests = 4
fig, axes = plt.subplots(2, 2, figsize= (width,height))
knn_ood_tests = np.linspace(0,0.3,n_tests)

for test_idx,epsilon in enumerate(knn_ood_tests):
    _,epsilon = find_nearest(epsilons,epsilon)
    print('Using closest value for epsilon:',epsilon)
    ax = axes.flatten()[test_idx]
    ax.set_title('$\epsilon = $'+ str(round(epsilon,2)))
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=1.0,detect_ood=False)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=1.0,detect_ood=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=0.999,detect_ood=True,adaptive=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=0.995,detect_ood=True,adaptive=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=0.99,detect_ood=True,adaptive=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=0.95,detect_ood=True,adaptive=True)
    plot_ood_power(ax,model_directory,'CIFAR10',epsilon,n_thresh=250,percentile=0.9,detect_ood=True,adaptive=True)
    
# plt.ylabel('OOD-Aware Accuracy')
# plt.xlabel('Power usage (MACs)')
h, l = axes.flatten()[-1].get_legend_handles_labels()
legend = fig.legend(h,l,title='Adversarial detection percentile: $\delta$',loc='center',bbox_to_anchor=(0.5,1.04),ncol=7,fontsize=34)
plt.setp(legend.get_title(), multialignment='center',fontsize=34)
plt.tight_layout()

# fig.text(0.5, 0.95, ('OOD detection ROC'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.0, 0.5, 'Perturbation-Aware Accuracy', ha='center', va='center', rotation='vertical',fontsize=36)
fig.text(0.5, 0, 'Power usage (MACs)', ha='center', va='center', rotation='horizontal',fontsize=36)
plt.savefig('main/figures/adv_powers_adaptive.pdf',bbox_inches='tight')
plt.show()

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis'
n_tests = 6
fig, axes = plt.subplots(3, int(n_tests/3),figsize=(25, 25))
knn_ood_tests = np.linspace(0,0.3,n_tests)

for test_idx,epsilon in enumerate(knn_ood_tests):
    _,epsilon = find_nearest(epsilons,epsilon)
    print('Using closest value for epsilon:',epsilon)
    ax = axes.flatten()[test_idx]
    ax.set_title('$\epsilon = $'+ str(round(epsilon,2))) 
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=1.0,detect_ood=True)
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.999,adaptive=True)
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.995,adaptive=True)
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.99,adaptive=True)
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.95,adaptive=True)
    plot_improvement_power_adv(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.9,adaptive=True)
    # ax.invert_xaxis()
    
# plt.ylabel('OOD-Aware Accuracy')
# plt.xlabel('Power usage (MACs)')
h, l = axes.flatten()[-1].get_legend_handles_labels()
legend = fig.legend(h,l,title='Adversarial detection percentile',loc='center',bbox_to_anchor=(0.5,1.02),ncol=len(knn_ood_tests))
plt.setp(legend.get_title(), multialignment='center')
plt.tight_layout()

# fig.text(0.5, 0.95, ('OOD detection ROC'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.5, 0.0,  'Accuracy Drop (%)',ha='center', va='center', rotation='horizontal')
fig.text(0.0, 0.5, 'Power Improvement (%)', ha='center', va='center', rotation='vertical')
plt.savefig('main/figures/power_improvement_adv.pdf',bbox_inches='tight')
plt.show()

In [None]:
def plot_adv_accuracy(ax,model_directory,data,ood_data,n_thresh,percentile,detect_ood=True,adaptive=False):
    base_value_dict = run_adversarial_inference(model_directory,data,ood_data,n_thresh=n_thresh,knn_percentile=1.0,detect_ood=False,adaptive=False)
    base_acc=base_value_dict['id_accuracy']

    func_out = run_adversarial_inference(model_directory,data,ood_data,n_thresh=n_thresh,knn_percentile=percentile,detect_ood=detect_ood,adaptive=adaptive)
    id_acc,ood_acc=func_out['id_accuracy'],func_out['ood_accuracy']
    
    acc_range = np.linspace(min(base_acc),max(base_acc),500)

    max_acc = max(base_acc)

    label = str(percentile)
    if adaptive:
        label = label + ' adaptive'

    interp_ood_acc_test = np.interp(acc_range,id_acc,ood_acc)*100

    acc_range = (acc_range/max_acc)*100
    
    ax.plot(acc_range,interp_ood_acc_test,label=label)

    sigma=3
    interval = sigma*np.sqrt(interp_ood_acc_test/func_out['n_id_samples'])

    ax.fill_between(acc_range,interp_ood_acc_test-interval,interp_ood_acc_test+interval,alpha=0.3)

    targets = np.array([100,99,95,90])
    key_indices = list()

    for target in targets:
        key_indices.append(np.abs(acc_range - target).argmin())

    
    return np.round(acc_range[key_indices],0),np.round(interp_ood_acc_test[key_indices],3)

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis'
n_tests = 10
fig, axes = plt.subplots(2, 2,figsize=(25, 25))
knn_ood_tests = np.linspace(0,0.3,4)

for test_idx,epsilon in enumerate(knn_ood_tests):
    _,epsilon = find_nearest(epsilons,epsilon)
    print('Using closest value for epsilon:',epsilon)
    ax = axes.flatten()[test_idx-1]
    ax.clear()
    ax.set_title('$\epsilon = $'+ str(round(epsilon,2))) 
    print(np.round(epsilon,2),0.999,plot_adv_accuracy(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.999,adaptive=True))
    print(np.round(epsilon,2),0.995,plot_adv_accuracy(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.995,adaptive=True))
    print(np.round(epsilon,2),0.99,plot_adv_accuracy(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.99,adaptive=True))
    print(np.round(epsilon,2),0.99,plot_adv_accuracy(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.95,adaptive=True))
    print(np.round(epsilon,2),0.99,plot_adv_accuracy(ax,model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=0.9,adaptive=True))
    ax.invert_xaxis()
    
# plt.ylabel('OOD-Aware Accuracy')
# plt.xlabel('Power usage (MACs)')
h, l = axes.flatten()[-1].get_legend_handles_labels()
legend = fig.legend(h,l,title='Adversarial detection percentile',loc='center',bbox_to_anchor=(0.5,1.02),ncol=len(knn_ood_tests))
plt.setp(legend.get_title(), multialignment='center')
plt.tight_layout()

# fig.text(0.5, 0.95, ('OOD detection ROC'), ha='center', va='center', rotation='horizontal',fontsize=30)
fig.text(0.5, 0.0,  'ID Accuracy (%)',ha='center', va='center', rotation='horizontal')
fig.text(0.0, 0.5, 'OOD accuracy (%)', ha='center', va='center', rotation='vertical')
plt.savefig('main/figures/adv_detection_acc.pdf',bbox_inches='tight')
plt.show()

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis'

fig,ax = plt.subplots(figsize = (15,10))

epsilon = 0.3
n_thresh = 35

base = run_adversarial_inference(model_directory,'CIFAR10',epsilon,n_thresh=n_thresh,knn_percentile=1.0,detect_ood=False,adaptive=False)
test = run_adversarial_inference(model_directory,'CIFAR10',epsilon,n_thresh=n_thresh,knn_percentile=0.99,detect_ood=True,adaptive=True)
base_acc = base['id_accuracy']
test_acc = test['id_accuracy']
base_powers=base['power_usage']
test_powers=test['power_usage']

acc_range = np.linspace(min(base_acc),max(base_acc),500)

interp_power_base = np.interp(acc_range,base_acc,base_powers)
interp_power_test = np.interp(acc_range,base_acc,test_powers)

arrowprops = dict(arrowstyle="<->", linewidth=2, color='dimgrey',connectionstyle="arc3")
ax.annotate('$\Delta$Power', xy=(0, 0), xycoords='figure fraction',xytext=(0.225,0.185), textcoords='figure fraction',fontsize=24,color='dimgrey')
ax.annotate("",xy=(0.11, 0.17), xycoords='figure fraction',xytext=(0.42, 0.17), textcoords='figure fraction',arrowprops=arrowprops)

ax.annotate('$\Delta$Acc', xy=(0, 0), xycoords='figure fraction',xytext=(0.445, 0.745), textcoords='figure fraction',fontsize=24,color='dimgrey')
ax.annotate("",xy=(0.51, 0.87), xycoords='figure fraction',xytext=(0.51, 0.62), textcoords='figure fraction',arrowprops=arrowprops)


plt.tight_layout()

ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1.5)

ax.set_xticks([])
ax.set_yticks([])


plt.plot(interp_power_base,acc_range,label='Conventional',linewidth=4,color ='sandybrown')
plt.plot(interp_power_test,acc_range,label='Distribution Aware',linewidth=4,color='lightseagreen')
plt.fill_betweenx(acc_range,interp_power_base,interp_power_test,alpha=0.2,color =['greenyellow'])
    
plt.ylabel('Accuracy')
plt.xlabel('Power usage')
legend = plt.legend(title='Exiting Algorithm')
plt.setp(legend.get_title(), multialignment='center')
plt.tight_layout()
plt.savefig('main/figures/power_improvement_diagram.pdf',bbox_inches='tight')
plt.show()


In [None]:
def get_metrics_accuracy(model_directory,data,epsilon,n_thresh,percentile,detect_ood=True,adaptive=False,verbose=True):
    base_value_dict = run_adversarial_inference(model_directory,data,1.0,n_thresh=n_thresh,knn_percentile=1.0,detect_ood=False,adaptive=False)
    base_acc=base_value_dict['id_accuracy']
    base_powers=base_value_dict['power_usage']

    min_power = base_powers[0]
    min_power_acc = base_acc[0]

    func_out = run_adversarial_inference(model_directory,data,epsilon,n_thresh=n_thresh,knn_percentile=percentile,detect_ood=detect_ood,adaptive=adaptive)
    ood_powers,ood_acc=func_out['power_usage'],func_out['id_accuracy']
    
    #accuracy increase
    equiv_idx = np.argmin(np.abs(ood_powers-min_power))
    peak_accuracy_gain = (ood_acc[equiv_idx]-min_power_acc)*100

    #power increase      
    max_power_idx = np.argmax(base_powers - ood_powers)
    peak_power_gain = ((base_powers[max_power_idx] - ood_powers[max_power_idx])/base_powers[max_power_idx])*100

    if verbose:
        print('peak accuracy increase: ',np.round(peak_accuracy_gain,4))
        print('peak power increase: ',np.round(peak_power_gain,4))
    
    return peak_accuracy_gain,peak_power_gain
    

In [None]:
model_directory = 'trained-models/CIFAR10/resnet_4_branch/w1.0_d18/300-epoch/analysis/'
epsilons = np.load(model_directory + '../adversarial_analysis/epsilons.npy')
percentiles = [1.0,0.999,0.995,0.99,0.95,0.9]

for test_idx,epsilon in enumerate(epsilons[::3]):
    print('\nEpsilon:',np.round(epsilon,3))
    for percentile in percentiles:
        print('Percentile:',percentile)
        get_metrics_accuracy(model_directory,'CIFAR10',epsilon,n_thresh=50,percentile=percentile,detect_ood=True)

In [None]:
directories = ['trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_1/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_2/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_3/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_4/analysis']

epsilons = np.load(directories[0] + '/../adversarial_analysis/epsilons.npy')[::3]
percentiles = [1.0,0.999,0.995,0.99,0.95,0.9]

all_values = np.zeros((len(directories),len(epsilons),len(percentiles),2))

for dir_index,directory in enumerate(directories):
    print('in directory:',directory)
    for epsilon_idx,epsilon in enumerate(epsilons):
        print('\nepsilon:',epsilon)
        for p_index,percentile in enumerate(percentiles):
            print(percentile)
            acc,pow = get_metrics_accuracy(directory,'CIFAR100',epsilon,n_thresh=50,percentile=percentile,detect_ood=True,verbose=False)

            all_values[dir_index,epsilon_idx,p_index,0] = acc
            all_values[dir_index,epsilon_idx,p_index,1] = pow

np.save(directories[0]+'/peak_values_multirun_adv.npy',all_values) 

In [None]:
directory = 'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/analysis'

all_vals = np.load(directory+'/peak_values_multirun_adv.npy')

epsilons = np.load(directories[0] + '/../adversarial_analysis/epsilons.npy')[::3]
percentiles = [1.0,0.999,0.995,0.99,0.95,0.9]


for epsilon_idx,epsilon in enumerate(epsilons):
    print('\nepsilon:',epsilon)
    for p_index,percentile in enumerate(percentiles):
        mean_acc = np.mean(all_vals[:,epsilon_idx,p_index,0])
        std_acc = np.std(all_vals[:,epsilon_idx,p_index,0])

        mean_power = np.mean(all_vals[:,epsilon_idx,p_index,1])
        std_power = np.std(all_vals[:,epsilon_idx,p_index,1])
        print(percentile,'\tacc:',np.round(mean_acc,2),'+-',np.round(std_acc,2))
        print(percentile,'\tpower:',np.round(mean_power,2),'+-',np.round(std_power,2))

In [None]:
directories = ['trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_1/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_2/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_3/analysis',
               'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch_4/analysis']

n_branches=4
fig, axes = plt.subplots(2, int(n_branches/2))
epsilons = np.load(directories[0] + '/../adversarial_analysis/epsilons.npy')[::3]
percentiles = [1.0,0.999,0.995,0.99,0.95,0.9]
detection_values = np.zeros((len(directories),len(epsilons),len(percentiles),4))

fig, axes = plt.subplots(2, int(n_branches/2))
for dir_idx,directory in enumerate(directories):
    print('Directory:',directory)
    for epsilon_idx,epsilon in enumerate(epsilons):
        print('epsilon:',epsilon)
        ax = axes.flatten()[epsilon_idx]
        ax.set_title(epsilon)
        for p_index,percentile in enumerate(percentiles):
            accs,detect_accs = plot_adv_accuracy(ax,directory,'CIFAR100',epsilon,n_thresh=50,percentile=percentile,detect_ood=True,adaptive=True)
            print(percentile,accs,detect_accs)
            detection_values[dir_idx,epsilon_idx,p_index,:] = detect_accs
            
np.save(directories[0]+'/detection_accuracies_adv.npy',detection_values) 

In [None]:
directory = 'trained-models/CIFAR100/resnet_4_branch/w1.0_d18/300-epoch/analysis'

detection_accuracies_ood = np.load(directory+'/detection_accuracies_adv.npy')

epsilons = np.load(directories[0] + '/../adversarial_analysis/epsilons.npy')[::3]
percentiles = [1.0,0.999,0.995,0.99,0.95,0.9]

targets = np.array([100,99,95,90])

for epsilon_idx,epsilon in enumerate(epsilons):
    print('\n',epsilon)
    for t_index, target in enumerate(targets):
        print('target',target)
        for p_index,percentile in enumerate(percentiles):
            acc_mean = np.mean(detection_accuracies_ood[:,epsilon_idx,p_index,t_index])
            acc_std =  np.std(detection_accuracies_ood[:,epsilon_idx,p_index,t_index])
            print(percentile,'\tacc:',np.round(acc_mean,2),'+-',np.round(acc_std,2))