In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

In [None]:
times_20n_AM = np.load('data/20n_times_AM.npy')
times_20n_MU = np.load('data/20n_times_MU.npy')

times_20n = np.concatenate([times_20n_AM.reshape((1,) + times_20n_AM.shape),
                            times_20n_MU.reshape((1,) + times_20n_MU.shape)])

times_hd_AM = np.load('data/hd_times_AM.npy')
times_hd_MU = np.load('data/hd_times_MU.npy')

times_hd = np.concatenate([times_hd_AM.reshape((1,) + times_hd_AM.shape),
                           times_hd_MU.reshape((1,) + times_hd_MU.shape)])

In [None]:
times_20n_avg = np.median(times_20n,axis=2)
times_hd_avg  = np.median(times_hd,axis=2)

In [None]:
times_20n_low = np.quantile(times_20n, 0.25, axis=2)
times_hd_low  = np.quantile(times_hd, 0.25, axis=2)

times_20n_high = np.quantile(times_20n, 0.75, axis=2)
times_hd_high  = np.quantile(times_hd, 0.75, axis=2)

In [None]:
names = ["AM", "MU"]
formats = [{'marker': 'x', 'linestyle': '-',  'color': '#000000'},
           {'marker': '*', 'linestyle': '--', 'color': '#0072B2'}]

In [None]:
def filter_format(d):
    return {k: v for k, v in d.items() if k not in ['marker']}

In [None]:
# Plot formatting
plt.rcParams.update({'font.size': 15})

# Plot the graphs
for i, f, n in zip(range(len(formats)), formats, names):
    plt.plot(np.arange(times_20n.shape[1])+1, times_20n_avg[i,:], **f, label=n, linewidth='1.75')
    plt.fill_between(np.arange(times_20n.shape[1])+1,
                     times_20n_low[i,:],
                     times_20n_high[i,:],
                     **filter_format(f), alpha=0.2)

plt.xlabel("Rank")
plt.ylabel("Time (s)")
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend(fontsize=20)
    
f_loss_ylim = plt.gca().get_ylim()

plt.savefig(f'plots/20n_times.pdf', bbox_inches='tight', dpi=600)

print("20News Times")
plt.show()

In [None]:
# Plot formatting
plt.rcParams.update({'font.size': 15})

# Plot the graphs
for i, f, n in zip(range(len(formats)), formats, names):
    plt.plot(np.arange(times_hd.shape[1])+1, times_hd_avg[i,:], **f, label=n, linewidth='1.75')
    plt.fill_between(np.arange(times_hd.shape[1])+1,
                     times_hd_low[i,:],
                     times_hd_high[i,:],
                     **filter_format(f), alpha=0.2)

plt.xlabel("Rank")
plt.ylabel("Time (s)")
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend(fontsize=20)
    
f_loss_ylim = plt.gca().get_ylim()

plt.savefig(f'plots/hd_times.pdf', bbox_inches='tight', dpi=600)

print("Heart Disease Times")
plt.show()