# Kernel Nadaraya-Watson Estimator

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from datetime import datetime
from copy import deepcopy
# utils for plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors as colors
from joblib import Parallel, delayed
import pickle
import numpy as np

from functools import reduce
from operator import concat

from sklearn.metrics import mean_squared_error, accuracy_score

# utils for timing
from goodpoints.tictoc import tic, toc, TicToc
# utils for kernel ridge regression
from goodpoints.knw.util_estimators import get_estimator
# utils for evaluating kernels
# from goodpoints.krr.thin.util_k_mmd import kernel_eval, to_regression_kernel
# utils for generate samples from the data distribution
from goodpoints.krr.util_sample import get_Xy, ToyData , get_toy_dataset #, get_housing_dataset
# utils for dataset thinning
# from goodpoints.krr.thin.util_thin import sd_thin, kt_thin2
# # falkon baseline
# from goodpoints.krr.falkon.util_falkon_estimators import KernelRidgeFalkon

In [3]:
### Toy dataset parameters

n = 2**13

# X_name = 'unif-[0,1]' # 'unif-[0,1]'  # ['unif', 'unif-[0,1]', 'gauss', 'mog']
# f_name = 'dnc-paper' # 'dnc-paper' # ['sin', 'stair', 'quad', 'sum-gauss', 'sum-laplace', 'step', 'dnc-paper']

X_name = 'unif'
X_var = 1
f_name = 'sum-gauss'
noise = 0.1 # np.sqrt( 0.2 ) # in paper, variance is 1/5

d = 1
# M = 2
k = 8 # number of anchor points for sum-gauss and sum-laplace

### Regression parameters

# kernel = 'sobolev'  # ['sobolev', 'gauss', 'laplace']
# alpha_pow = -2/3  # gauss = -1, sobolev = -2/3

kernel = 'epanechnikov'
alpha_pow = -1

sigma =  np.power(float(n),-1/3) # optimal is n^{-1/(2\beta+d)}, where \beta is the smoothness of the regression function
alpha = np.power(float(n),alpha_pow)

### Experiment parameters

# varying_variable = 'kernel' # ['kernel', 'd', 'M']
n_repeats = 100
save = False
logn_lo = 8
logn_hi = 15
use_cross_validation = False
n_jobs = 1
# method = 'st'

### Thinning parameters

m = None # Thinned dataset will have size n/2**m
# m=9

In [4]:
# Determine auxiliary parameters
task = 'regression'
refit = 'neg_mean_squared_error'
postprocess = None

filename = '_'.join([X_name, f_name, f'k={k}', f'noise={noise}'])
print(filename)

baseline_loss = noise**2
print(f'baseline loss: {baseline_loss}')

unif_sum-gauss_k=8_noise=0.1
baseline loss: 0.010000000000000002


In [5]:
toy_data_noise = ToyData(X_name, f_name, 
                        #  X_var=X_var, 
                         noise=noise, 
                        d=d, 
                        # M=M, 
                        k=k)

X_train, y_train = toy_data_noise.sample(n)
X_test, y_test = toy_data_noise.sample(10000)
# validation set used for cross validation
X_val, y_val = toy_data_noise.sample(10000)

# no noise version for plotting purposes
toy_data_no_noise = ToyData(X_name, f_name, noise=0, 
                    d=d, k=k)
X_no_noise, y_no_noise = toy_data_no_noise.sample(n)

print(X_train.shape, y_train.shape)

(8192, 1) (8192,)


In [6]:
def trace_Xy(X, y, name=None, color=None, alpha=0.5):
    d = X.shape[-1]

    if d==1:
        return go.Scatter(
            x=X.squeeze(), 
            y=y, 
            mode='markers',
            name=name,
            opacity=alpha,
            marker=dict(
                color=color,
            )
        )

    elif d==2:
        x1,x2 = X[:,0], X[:,1]
        return go.Scatter3d(x=x1, y=x2, z=y, mode='markers', name=name, marker=dict(
            symbol='circle',
            opacity=alpha,
            color=color,
            size=2,
            # line=dict(width=1),
        ))

    else:
        print(f"cannot plot data with dimension {d}")

In [7]:
fig = go.Figure(data=[
    trace_Xy(X_train, y_train, name='train'),
    # trace_Xy(X_test, y_test, name='test'),
    trace_Xy(X_no_noise, y_no_noise, name='no noise', color='black'),
])
fig.update_layout(
    title=f'X={X_name}, f={f_name}, std[f]={np.std(y_train):.4f}', #, hnorm[f]={hnorm(k):.4f}',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_function.png")

## KNW Full

In [8]:
knw_full = get_estimator(
    task, 
    'full',
    kernel=kernel, 
    sigma=sigma,
)

In [9]:
knw_full

In [10]:
knw_full.fit(X_train, y_train)

In [11]:
pred_full = knw_full.predict(X_test)
train_pred_full = knw_full.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_full))
print('MSE:', mean_squared_error(y_test, pred_full))

(8192, 10000) (8192, 10000) (10000,) (8192,)
(8192, 8192) (8192, 8192) (8192,) (8192,)
train MSE: 0.009949943155792939
MSE: 0.009962294939537592


In [12]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_full, name='pred', alpha=0.4)
])
fig.update_layout(
    title='full',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_full.png")

## KNW ST

In [13]:
knw_st = get_estimator(
    task, 
    'st',
    kernel=kernel, 
    sigma=sigma,
)

In [14]:
knw_st

In [15]:
knw_st.fit(X_train, y_train)

In [16]:
%%time
pred_st = knw_st.predict(X_test)
train_pred_st = knw_st.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_st))
print('MSE:', mean_squared_error(y_test, pred_st))

(128, 10000) (128, 10000) (10000,) (128,)
(128, 8192) (128, 8192) (8192,) (128,)
train MSE: 0.17254992172426314
MSE: 0.17269715544214906
CPU times: user 295 ms, sys: 107 ms, total: 401 ms
Wall time: 84.5 ms



invalid value encountered in divide


invalid value encountered in divide



In [17]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_st, name='pred', alpha=0.4)
])
fig.update_layout(
    title='st',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_st.png")

## KNW KT

In [18]:
knw_kt = get_estimator(
    task, 
    'kt',
    kernel=kernel, 
    sigma=0.05,
)

In [19]:
knw_kt

In [20]:
knw_kt.fit(X_train, y_train)

: 

: 

In [55]:
y_train.max()

2.523477220662887

In [56]:
%%time
pred_kt = knw_kt.predict(X_test)
train_pred_kt = knw_kt.predict(X_train)


(128, 10000) (128, 10000) (10000,) (128,)
(128, 8192) (128, 8192) (8192,) (128,)
CPU times: user 110 ms, sys: 224 ms, total: 335 ms
Wall time: 69 ms


In [45]:
print('train MSE:', mean_squared_error(y_train, train_pred_kt))
print('MSE:', mean_squared_error(y_test, pred_kt))

train MSE: 0.2632339879878194
MSE: 0.26370414829270317


In [46]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_kt, name='pred', alpha=0.4)
])
fig.update_layout(
    title='kt',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_kt.png")

## Experiment

In [19]:
from sklearn.model_selection import GridSearchCV, RepeatedKFold

In [20]:
varying_variable = 'kernel'
varying_variable_values = ['gauss', 'laplace', 'singular', 'box']

In [21]:
# # Default param grid to search for each model
default_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e1, 1e-0, 1e-1, 1e-2, 1e-3, 1e-4],

    # "sigma" :   1/np.sqrt(2*np.array([0.5, 1., 2, 5.])),
    # "alpha" :   [0.01, 0.02]
}
falkon_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e-4, 1e-5,1e-6, 1e-7, 0],
}


# Model constructors and data size for each model
# We allow for different data sizes to avoid running Full KR on large datasets
model_configs = {
    'full' : {
        'logn' : np.arange(logn_lo, logn_hi, 2),
        # 'logn' : np.arange(8, 13, 2),
        'kwargs': {
            'postprocess' : postprocess,
            # 'alpha' : alpha,
            # 'sigma' : sigma,
        },
        'param_grid' : default_param_grid,
    },
    'st' : {
        'logn' : np.arange(logn_lo, logn_hi, 2),
        # 'logn' : np.arange(8, 13, 2),
        'kwargs' : {
            'm' : m,
            'postprocess' : postprocess,
            # 'alpha' : alpha,
            # 'sigma' : sigma,
        },
        'param_grid' : default_param_grid,
    },
    'kt' : {
        'logn' : np.arange(logn_lo, logn_hi, 2),
        # 'logn' : np.arange(8, 13, 2),
        'kwargs' : {
            'm' : m,
            'postprocess' : postprocess,
            # 'alpha' : alpha,
            # 'sigma' : sigma,
        },
        'param_grid' : default_param_grid,
    },
    # 'falkon' : {
    #     'logn' : np.arange(logn_lo, logn_hi, 2),
    #     # 'logn' : np.arange(8, 13, 2),
    #     'kwargs' : {
    #         'm' : m,
    #         'postprocess' : postprocess,
    #         # 'alpha' : alpha /100, # https://falkonml.github.io/falkon/examples/falkon_regression_tutorial.html
    #         # 'sigma' : sigma,
    #     },
    #     'param_grid' : falkon_param_grid,
    # },
}

# cv = RepeatedKFold(n_repeats=n_repeats, n_splits=k_fold)

In [22]:
model_configs

{'full': {'logn': array([ 8, 10, 12, 14]),
  'kwargs': {'postprocess': None},
  'param_grid': {'sigma': [0.25]}},
 'st': {'logn': array([ 8, 10, 12, 14]),
  'kwargs': {'m': None, 'postprocess': None},
  'param_grid': {'sigma': [0.25]}},
 'kt': {'logn': array([ 8, 10, 12, 14]),
  'kwargs': {'m': None, 'postprocess': None},
  'param_grid': {'sigma': [0.25]}}}

In [23]:
# Run experiment (depending on experiment_type)

results = []

i = 0
for name, config in model_configs.items():
    for logn in config['logn']:

        for v in varying_variable_values:
            # note: haven't figured out KT for singular kernels yet
            if name in ['kt',] and v in ['singular', 'box']:
                continue

            kwargs = deepcopy(config['kwargs'])
            kwargs[varying_variable] = v
            model_name = f"{name}_{v}"
            trials = (1 if name in ['full'] else n_repeats)

            X, y = get_toy_dataset(
                X_name=X_name,
                f_name=f_name,
                n=2**logn,
                # X_var=X_var,
                d=kwargs['d'] if 'd' in kwargs else d,
                noise=noise,
                # M=kwargs['M'] if 'M' in kwargs else M,
                k=k,
            )
            
            # Set kernel, alpha, sigma params
            if 'kernel' not in kwargs:
                kwargs['kernel'] = kernel

            # if name in ['st', 'kt']:
            #     # NOTE: I think you need to set alpha to be proportional to sqrt(n)
            #     kwargs['alpha'] /= np.power(2**logn, 1/4)
                
            
            model = get_estimator(task, name=name, **kwargs)
            if model is None: continue
            print(f'i={i+1}: logn={logn}, model={model}')

            # STEP 2: Get optimal parameters through grid search
            # NOTE: we do something slightly better than k-fold cross validation.
            # Namely, we are trying to get rid of randomness in the Kernel Thinning (or Standard Thinning) routine,
            # but if we did 100-fold CV, then the validation set would be 1% of the data
            # (which is too small to get a good estimate of the validation score).
            # Instead we use the same train-val split for each parameter setting and repeat `trials` times
            if use_cross_validation:
                X_concat, y_concat = np.concatenate([X, X_val]), np.concatenate([y, y_val])
                split = [(np.arange(len(X)), np.arange(len(X), len(X)+len(X_val))) for _ in range(trials)]
                grid_search = GridSearchCV(
                    estimator=model,
                    param_grid=config['param_grid'],
                    return_train_score=True,
                    cv=split,
                    scoring=refit,
                    refit=False,
                    n_jobs=n_jobs,
                ).fit(X_concat, y_concat)
                # get validation scores
                cv_results = pd.DataFrame(grid_search.cv_results_)
                val_scores = []
                for i in range(trials):
                    val_scores.append( cv_results.iloc[grid_search.best_index_][f'split{i}_test_score'] )

                # get optimal parameters
                best_params = grid_search.best_params_
            else:
                # Dummy values
                val_scores = [1,] * trials
                
                best_params = {
                    'sigma' : sigma,
                    # 'alpha' : alpha, # * (len(X_train)**(1/4) if name in ['st', 'kt'] else 1),
                }    

            print(f"best params: {best_params}")            
            best_model = get_estimator(task, name=name, 
                                       sigma=best_params['sigma'],
                                    #    alpha=best_params['alpha'],
                                       **kwargs)

            mean_scores = []
            for _ in range(trials):
                best_model.fit(X, y)

                # compute test score
                test_pred = best_model.predict(X_test).squeeze()

                if refit == 'neg_mean_squared_error':
                    test_scores = mean_squared_error(y_test, test_pred)
                elif refit == 'accuracy':
                    test_scores = accuracy_score(y_test, test_pred)
                    # test_scores = [accuracy_score([y], [pred]) for y, pred in zip(y_test, test_pred)]
                else:
                    raise ValueError(f"invalid refit metric: {refit}")

                mean_scores.append( np.mean(test_scores) )
                # std_scores.append( np.std(test_scores) / np.sqrt(len(test_scores)-1) ) # biased estimator of std

            results.append({
                "logn": logn, 
                "model": model_name, 
                "cv_results": pd.DataFrame(grid_search.cv_results_) if use_cross_validation else None,
                "best_index_" : grid_search.best_index_ if use_cross_validation else 0,
                "mean_scores" : mean_scores,
                # "std_score" : np.std(mean_scores),
            })

            i += 1

sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=1: logn=8, model=KernelNadarayaWatsonRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=2: logn=8, model=KernelNadarayaWatsonRegressor()
best params: {'sigma': 0.25}
i=3: logn=8, model=KernelNadarayaWatsonRegressor(kernel='singular')
best params: {'sigma': 0.25}
i=4: logn=8, model=KernelNadarayaWatsonRegressor(kernel='box')
best params: {'sigma': 0.25}
sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=5: logn=10, model=KernelNadarayaWatsonRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=6: logn=10, model=KernelNadarayaWatsonRegressor()
best params: {'sigma': 0.25}
i=7: logn=10, model=KernelNadarayaWatsonRegressor(kernel='singular')
best params: {'sigma': 0.25}
i=8: logn=10, model=KernelNadarayaWatsonRegressor(kernel='box')
best params: {'sigma': 0.25}
sampling dataset with params ToyData(X_name=unif, f_name=sum-gau


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid val

i=20: logn=8, model=KernelNadarayaWatsonSTRegressor(kernel='box')
best params: {'sigma': 0.25}



invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid val

i=21: logn=10, model=KernelNadarayaWatsonSTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=22: logn=10, model=KernelNadarayaWatsonSTRegressor()
best params: {'sigma': 0.25}
i=23: logn=10, model=KernelNadarayaWatsonSTRegressor(kernel='singular')
best params: {'sigma': 0.25}



invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid val

i=24: logn=10, model=KernelNadarayaWatsonSTRegressor(kernel='box')
best params: {'sigma': 0.25}



invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide



i=25: logn=12, model=KernelNadarayaWatsonSTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=26: logn=12, model=KernelNadarayaWatsonSTRegressor()
best params: {'sigma': 0.25}
i=27: logn=12, model=KernelNadarayaWatsonSTRegressor(kernel='singular')
best params: {'sigma': 0.25}



invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide



i=28: logn=12, model=KernelNadarayaWatsonSTRegressor(kernel='box')
best params: {'sigma': 0.25}



invalid value encountered in divide



i=29: logn=14, model=KernelNadarayaWatsonSTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=30: logn=14, model=KernelNadarayaWatsonSTRegressor()
best params: {'sigma': 0.25}
i=31: logn=14, model=KernelNadarayaWatsonSTRegressor(kernel='singular')
best params: {'sigma': 0.25}
i=32: logn=14, model=KernelNadarayaWatsonSTRegressor(kernel='box')
best params: {'sigma': 0.25}
i=33: logn=8, model=KernelNadarayaWatsonKTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=34: logn=8, model=KernelNadarayaWatsonKTRegressor()
best params: {'sigma': 0.25}
i=35: logn=10, model=KernelNadarayaWatsonKTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=36: logn=10, model=KernelNadarayaWatsonKTRegressor()
best params: {'sigma': 0.25}
i=37: logn=12, model=KernelNadarayaWatsonKTRegressor(kernel='gauss')
best params: {'sigma': 0.25}
i=38: logn=12, model=KernelNadarayaWatsonKTRegressor()
best params: {'sigma': 0.25}
i=39: logn=14, model=KernelNadarayaWatsonKTRegressor(kernel='gauss')
best par

In [24]:
results

[{'logn': 8,
  'model': 'full_gauss',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.030337317646805553]},
 {'logn': 8,
  'model': 'full_laplace',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.047020330298266536]},
 {'logn': 8,
  'model': 'full_singular',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.01138289734971634]},
 {'logn': 8,
  'model': 'full_box',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.011438852380664473]},
 {'logn': 10,
  'model': 'full_gauss',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.030970756159751714]},
 {'logn': 10,
  'model': 'full_laplace',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.04818884096863406]},
 {'logn': 10,
  'model': 'full_singular',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010533471802844573]},
 {'logn': 10,
  'model': 'full_box',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.011136505463488815]},
 {'logn': 12,
  

In [25]:
# Save results with pickle
if save:
    import pickle

    pickle_file = filename + '.p'
    print(pickle_file)

    with open(pickle_file, 'wb') as f:
        pickle.dump(results, f)

## Plot results

In [26]:
from functools import reduce
from operator import concat

### Test scores

In [27]:
row_subplot_titles = ["Test Score vs n", "Log2 Test Score vs n"] #, "Train time vs n", "Predict time vs n"]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=len(varying_variable_values),
    shared_yaxes=True,
    subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    vertical_spacing=0.1,
)

model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * (
    len(model_names) // len(colors.qualitative.Plotly) + 1
)
colors_used = set()

In [28]:
def plot_vs_n(print_name, attr_name, vvv, r, c, is_better='higher', scale='log2'):
    """
    Args:
    - vvv: varying variable value
    """
    
    for result in results:
        model_name = result["model"]
        model_name_prefix, vv_name = model_name.split('_') # E.g., kt_rbf -> (kt, rbf)

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue

        color = colors_list[model_names.index(model_name_prefix)]

        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        trace = go.Box(
            x=[result['logn']]*len(result[f"mean_scores"]),
            y=y,
            name=model_name_prefix,
            # opacity=0.5,
            legendgroup=model_name_prefix,
            line_color=color,
            offsetgroup=model_name_prefix,
            showlegend=color not in colors_used,
            boxmean=True,
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

    # add line for baseline loss
    fig.add_hline(
        y=hline,
        row=r, col=c, line_dash="dash",
    )

    if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="log2(n)", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

def plot_test_score_vs_n(vvv, r, c, scale):
    plot_vs_n(f"Test MSE", "score", vvv, r, c, is_better='lower', scale=scale)

# def plot_val_score_vs_n(vvv, r, c):
#     plot_vs_n("Val MSE score", "test_score", vvv, r, c, is_better='lower')

# def plot_train_time_vs_n(vvv, r, c):
#     plot_vs_n("Train time", "fit_time", vvv, r, c, is_better='lower')

# def plot_test_time_vs_n(vvv, r, c):
#     plot_vs_n("Test time", "score_time", vvv, r, c, is_better='lower')

In [29]:
for c, vvv in enumerate(varying_variable_values):
    plot_test_score_vs_n(str(vvv), 1, c+1, scale='linear')
    plot_test_score_vs_n(str(vvv), 2, c+1, scale='log2')
    # plot_val_score_vs_n(str(vvv), 2, c+1)
    
    # plot_train_time_vs_n(str(vvv), 3, c+1)
    # plot_test_time_vs_n(str(vvv), 4, c+1)


In [30]:
fig.update_layout(
    # legend=dict(traceorder="normal", borderwidth=1),
    width=1200,
    height=800,
    showlegend=True,
    title=f'f={f_name}(k={k}, noise={noise})',
)

In [None]:
if save:
    fig_file = 'figures/' + filename + '_results.png'
    print(fig_file)
    fig.write_image(fig_file)