In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm

sns.set(font_scale=2.0, style='whitegrid')

In [None]:
raw_mem_data = pd.DataFrame([
    { 'dataset': 'Elevators', 'method': 'Exact GP', 'peak_mem': 3.6 },
    { 'dataset': 'Elevators', 'method': 'SGPR', 'peak_mem': 1.5 },
    { 'dataset': 'Elevators', 'method': 'SKIP', 'peak_mem': 3 },
    { 'dataset': 'Elevators', 'method': 'Simplex-GP', 'peak_mem': 1 },
    { 'dataset': 'Houseelectric', 'method': 'SGPR', 'peak_mem': 16.5 },
    { 'dataset': 'Houseelectric', 'method': 'Simplex-GP', 'peak_mem': 2.5 },
    { 'dataset': 'Keggdirected', 'method': 'Exact GP', 'peak_mem': 23 },
    { 'dataset': 'Keggdirected', 'method': 'SGPR', 'peak_mem': 2.5 },
    { 'dataset': 'Keggdirected', 'method': 'SKIP', 'peak_mem': 11.5 },
    { 'dataset': 'Keggdirected', 'method': 'Simplex-GP', 'peak_mem': 1.5 },
    { 'dataset': 'Precipitation', 'method': 'SGPR', 'peak_mem': 19.2 },
    { 'dataset': 'Precipitation', 'method': 'SKIP', 'peak_mem': 12 },
    { 'dataset': 'Precipitation', 'method': 'Simplex-GP', 'peak_mem': 1.5 },
    { 'dataset': 'Protein', 'method': 'Exact GP', 'peak_mem': 20 },
    { 'dataset': 'Protein', 'method': 'SGPR', 'peak_mem': 2.4 },
    { 'dataset': 'Protein', 'method': 'SKIP', 'peak_mem': 4.5 },
    { 'dataset': 'Protein', 'method': 'Simplex-GP', 'peak_mem': 1.5 },
])
raw_mem_data

In [None]:
g = sns.catplot(data=raw_mem_data, x='peak_mem', y='dataset', hue='method', kind='bar',
                palette=sns.color_palette('husl', 4),
                order=['Houseelectric', 'Precipitation', 'Protein', 'Keggdirected', 'Elevators'])

g.ax.set_title('Peak GPU Memory Usage (GB)')
g.ax.set_xlabel('')
g.ax.set_ylabel('')
# g.ax.set_xticklabels(g.ax.get_xticklabels(), rotation=20)

g._legend.set_visible(False)
# g.fig.subplots_adjust(right=0.5)
handles, labels = g.ax.get_legend_handles_labels()
g.fig.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, .4, .25, 0.),
             loc='lower center', ncol=1, borderaxespad=-0.25, frameon=True, title='Method')

g.fig.tight_layout()
# g.fig.savefig('mem_usage.pdf', bbox_inches='tight')

In [None]:
raw_speed_data = pd.DataFrame([
    { 'dataset': 'Elevators', 'd': 17, 'n': 16599, 'exact_mvm_t': 0.008, 'simplex_mvm_t': 0.083 },
    { 'dataset': 'Houseelectric', 'd': 11, 'n': 2049280, 'exact_mvm_t': 17.1, 'simplex_mvm_t': 1.756 },
    { 'dataset': 'Keggdirected', 'd': 20, 'n': 48827, 'exact_mvm_t': 0.033, 'simplex_mvm_t': 0.134 },
    { 'dataset': 'Precipitation', 'd': 3, 'n': 628474, 'exact_mvm_t': 0.549, 'simplex_mvm_t': 0.082 },
    { 'dataset': 'Protein', 'd': 9, 'n': 45730, 'exact_mvm_t': 0.014, 'simplex_mvm_t': 0.034 },
])

# raw_speed_data['ratio'] = np.log(raw_speed_data['exact_mvm_t'] / raw_speed_data['simplex_mvm_t'])
raw_speed_data['ratio'] = raw_speed_data['exact_mvm_t'] / raw_speed_data['simplex_mvm_t']
raw_speed_data = raw_speed_data.sort_values(by='n', ascending=False)

raw_speed_data

In [None]:
g = sns.relplot(data=raw_speed_data, x='n', y='ratio', hue='dataset', kind='scatter',
                s=400, edgecolor='black', palette=sns.color_palette('husl', 5))
g.ax.set_xlabel(r'Dataset Size ($n$)')
g.ax.set_ylabel('MVM Speedup')
g.ax.set_xscale('log')
g.ax.set_yscale('log')
g.ax.set_xlim([10**4, 2.5 * 10**6])

tx = np.arange(10**4, 2.5 * 10**6)
g.ax.plot(tx, np.ones_like(tx), '--', color='gray', linewidth=4)
g.ax.text(5*  10**5, 1.2, 'KeOps Exact GP', fontsize=15, color='black')

g.ax.set_xticks([10**4, 10**5, 10**6])
g.ax.set_yticklabels([f'{t:.0f}x' for t in g.ax.get_yticks()])

g._legend.set_visible(False)
handles, labels = g.ax.get_legend_handles_labels()
for h in handles:
    h._sizes = [400]
    h.set(edgecolor='black')
g.fig.legend(handles=handles, labels=labels, bbox_to_anchor=(1., 0.5, .25, 0.),
             loc='lower center', ncol=1, borderaxespad=-2, frameon=True, title='Dataset')
g.fig.tight_layout()

g.fig.savefig('mvm_speedup.pdf', bbox_inches='tight')

In [None]:
# fig, ax = plt.subplots(figsize=(10,5))

# sns.barplot(data=raw_speed_data, x='ratio', y='dataset',
#             palette=sns.color_palette('husl', 5),
#             order=['Houseelectric', 'Precipitation', 'Protein', 'Keggdirected', 'Elevators'])

# ax.set_xticklabels([f'{t:.0f}x' for t in ax.get_xticks()])
# ax.set_xlabel('MVM Speedup')
# ax.set_ylabel('Dataset')

# fig.tight_layout()

# fig.savefig('mvm_speedup.pdf', bbox_inches='tight')