For paper, probably want to use matplotlib instead of plotly

In [1]:
import os
import pickle
import numpy as np
from numpy.linalg import LinAlgError

# utils for plotting
import plotly.graph_objects as go
import plotly.colors as colors

In [2]:
score_dir = 'output/scores'
fig_dir = 'output/figures'
thin = 'kt'

# # nadaraya watson
# method = 'nw'
# kernel = 'epanechnikov'

# kernel ridge regression
method = 'krr'
kernel = 'gaussian'

ground_truth = 'sum_gauss'
logn_lo = 8
logn_hi = 14
n_trials = 100

In [3]:
results = []
names = ["k(x1,x2) * (1+ y1*y2)", "k((x1,x2), (y1,y2))", "k(x1,x2)"]
names = ["k^2(x1,x2) + k(x1,x2) * y1*y2)", "k((x1,x2), (y1,y2))", "k(x1,x2)"]

for a in range(3):
    ablation_str = f"-ablation{a}" if a > 0 else ""
    score_file = os.path.join(
        score_dir,
        f"{thin}-{method}-k={kernel}-gt={ground_truth}-logn={logn_lo}_{logn_hi}-t{n_trials}{ablation_str}.pkl"
    )
    with open(score_file, 'rb') as f:
        result = pickle.load(f)
        
    name = names[a]
    for r in result:
        r['name'] = name
        results.append(r)

In [4]:
len(results)

12

## Test MSE

In [5]:
fig = go.Figure()

colors_list = colors.qualitative.Plotly * (
    3 // len(colors.qualitative.Plotly) + 1
)
colors_used = set()

In [6]:
scale = 'log2'
baseline_loss = 0.01 # = noise**2

for result in results:
    name = result['name']
    color = colors_list[names.index(name)]

    if scale == 'log2':
        y = np.log2(np.abs(result["scores"]))
        hline = np.log2(np.abs(baseline_loss))

    elif scale == 'linear':
        hline = np.abs(baseline_loss)
        y = np.abs(result["scores"])

    trace = go.Box(
        x=[result['logn']]*len(result["scores"]),
        y=y,
        name=name,
        # opacity=0.5,
        legendgroup=name,
        line_color=color,
        offsetgroup=name,
        showlegend=color not in colors_used,
        boxmean=True,
    )

    fig.add_trace(trace)
    colors_used.add(color)

# add line for baseline loss
fig.add_hline(
    y=hline,
    line_dash="dash",
)

fig.update_yaxes(title_text=f"{scale}(test MSE)")
fig.update_xaxes(title_text="log2(n)", type='linear')
fig.update_layout(
    width=800,
    height=600,
    title=f"Test MSE vs n (kernel={kernel}, ground_truth={ground_truth})",
    boxmode='group'
)

In [7]:
# save fig to file
fig_file = os.path.join(
    fig_dir,
    f"ablation-{method}-k={kernel}-gt={ground_truth}-logn={logn_lo}_{logn_hi}-t{n_trials}.png"
)
fig.write_image(fig_file)