In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc
import seaborn as sns
from hdf5storage import loadmat
from scipy.stats import spearmanr

# Load data

In [None]:
dat = loadmat('./data/bhscore_abs_pvalue_20200211.mat')

In [None]:
networks = np.array(dat['networks'])
bhscores = np.array(dat['bhscores'])
roi_dists = np.array(dat['best_roi_distributions'])

# sort by BH scores
sort_ind = np.argsort(bhscores)[::-1]
bhscores = bhscores[sort_ind]
networks = networks[sort_ind]
roi_dists = roi_dists[sort_ind]

# Figure 3A

In [None]:
rc('font',**{'family':'serif','serif':['Avenir']})
plt.rcParams["font.size"] = 8
plt.rcParams["pdf.fonttype"] = 42
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(1, 1, 1)

sns.barplot(x=bhscores, y=networks, orient='h', ax=ax, color='cornflowerblue')
for i_net, net in enumerate(networks):
    ax.text(x=bhscores[i_net]-0.028, y=i_net+0.25, s=str(np.round(bhscores[i_net], 2)), color='white')

# fig settings
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.set_xticks([])


plt.savefig('figure_3A.pdf')

# Figure 3B

In [None]:
# ImageNet top-1 accuracy for each DNN

dnn_acc_map = dict()
dnn_acc_map['AlexNet'] = 57.7 
dnn_acc_map['VGG-16'] = 71.5 
dnn_acc_map['VGG-19'] = 71.1 
dnn_acc_map['VGG-S'] = 63.3 
dnn_acc_map['VGG-F'] = 58.9 
dnn_acc_map['VGG-M'] = 62.7 
dnn_acc_map['CORnet-R'] = 56 
dnn_acc_map['CORnet-S'] = 75 
dnn_acc_map['CORnet-Z'] = 48 
dnn_acc_map['DenseNet-121'] = 74.91 
dnn_acc_map['DenseNet-161'] = 77.64 
dnn_acc_map['DenseNet-169'] = 76.09 
dnn_acc_map['DenseNet-201'] = 77.31 
dnn_acc_map['Inception-ResNet-v2'] = 80.4 
dnn_acc_map['Inception-v1'] = 69.8 
dnn_acc_map['Inception-v2'] = 73.9 
dnn_acc_map['Inception-v3'] = 78.0 
dnn_acc_map['Inception-v4'] = 80.2 
dnn_acc_map['NASNet-Large'] = 82.7 
dnn_acc_map['NASNet-Mobile'] = 74.0 
dnn_acc_map['PNASNet-Large'] = 82.9 
dnn_acc_map['ResNet-50-v2'] = 75.6 
dnn_acc_map['ResNet-101-v2'] = 77.0 
dnn_acc_map['ResNet-152-v2'] = 77.8 
dnn_acc_map['ResNet-18'] = 69.8 
dnn_acc_map['ResNet-34'] = 73.3 
dnn_acc_map['SqueezeNet-1.0'] = 57.5 
dnn_acc_map['SqueezeNet-1.1'] = 57.5 
dnn_acc_map['MobileNet-v2-1.4-224'] = 75.0 


In [None]:
img_accs = []
for n in networks:
    img_accs.append(dnn_acc_map[n])

In [None]:
# scatter

def make_scatter(ax, array1, array2, labels):
    ax.scatter(array1, array2, color=np.array([79, 139, 185])/255, s=15)
    for a1, a2, l in zip(array1, array2, labels):
        ax.text(a1, a2, l)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.text(45, 0.53, r'$\rho$ = %.2f' % np.round(spearmanr(array1, array2)[0], 2))
 
# https://stackoverflow.com/questions/50057591/matplotlib-scale-axis-lengths-to-be-equal
def make_square_axes(ax):
    """Make an axes square in screen units.

    Should be called after plotting.
    """
    ax.set_aspect(1 / ax.get_data_ratio())

In [None]:
plt.rcParams["figure.figsize"] = [210/25.4 * 0.35, 294/25.4 * 0.35]
rc('font',**{'family':'serif','serif':['Avenir']})
plt.rcParams["font.size"] = 7
plt.rcParams["pdf.fonttype"] = 42

fig = plt.figure()

ax1 = fig.add_subplot(111)
make_scatter(ax1, img_accs, bhscores, networks)
ax1.set_xlabel('ImageNet top-1 accuracy (%)')
ax1.set_ylabel('BH score')
ax1.set_xlim([40, 90])
ax1.set_ylim([0, 0.6])
make_square_axes(ax1)

plt.savefig('figure_3B.pdf')