In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots

COLORS = px.colors.qualitative.T10

In [2]:
analysis = torch.load("data/comparison/v3/analysis.pt")
print(list(analysis.keys()))

['solar_10_minutes', 'traffic_hourly', 'london_smart_meters']


In [3]:
list(analysis["solar_10_minutes"].keys())

['ski-exact', 'ski', 'kernel-matmul', 'vnngp']

In [4]:
len(analysis["solar_10_minutes"]["kernel-matmul"][0]["run_history"])

15245

In [7]:
human_readable_methods = {
    "vnngp": "VNNGP",
    "ski": "SKI",
    "ski-exact": "SKI",
    "kernel-matmul": "KernelMatmul",
}
human_readable_datasets = {
    "traffic_hourly": "Traffic",
    "solar_10_minutes": "Solar",
    "electricity_hourly": "Electricity",
    "london_smart_meters": "London Smart Meters",
}
dataset_ids = ["london_smart_meters", "solar_10_minutes", "traffic_hourly"]
fig = make_subplots(
    cols=len(analysis), column_titles=[human_readable_datasets[x] for x in dataset_ids],
    shared_yaxes=True,
)
for i, dataset in enumerate(dataset_ids):
    for j, method in enumerate(["kernel-matmul", "ski-exact", "vnngp"]):
        fig.add_trace(
            go.Box(
                name=human_readable_methods[method],
                legendgroup=human_readable_methods[method],
                showlegend=(i == 0),
                y=[4 * 60 * 60 / len(x["run_history"]) for x in analysis[dataset][method]],
                marker_color=COLORS[j],
                boxpoints="all"
            ),
            row=1,
            col=i + 1,
        )
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(type="log")
fig.update_yaxes(title="Seconds per trial", row=1, col=1)
fig.show()
fig.update_layout(width=500, height=150, margin=dict(t=12, l=0, r=0, b=0))
font_size = 9
fig.update_layout(font_size=font_size, legend_font_size=font_size)
fig.update_annotations(font_size=font_size)
tickfont_size = int(0.8 * font_size)
fig.update_xaxes(tickfont_size=tickfont_size, title_font_size=font_size)
fig.update_yaxes(tickfont_size=tickfont_size, title_font_size=font_size)
fig.update_coloraxes(colorbar_tickfont_size=font_size)
fig.write_image("rebuttal/trial_time.svg")

In [6]:
for method in ["kernel-matmul", "ski-exact", "vnngp"]:
    values = []
    for dataset in list(analysis.keys()):
        values.extend([4 * 60 * 60 / len(x["run_history"]) for x in analysis[dataset][method]])
    values = np.array(values)
    print(f"{human_readable_methods[method]}: {np.mean(values):.2f} [{np.quantile(values, 0.25):.2f}, {np.quantile(values, 0.75):.2f}]")

KernelMatmul: 2.07 [1.51, 2.63]
SKI: 1.68 [1.37, 2.12]
VNNGP: 17.39 [16.18, 18.51]
