# Comparison

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_is_fitted
import numpy as np
from numpy.linalg import LinAlgError
# set global seed
# np.random.seed(123)
from sklearn.model_selection import GridSearchCV, RepeatedKFold
from sklearn.metrics import r2_score, accuracy_score

import pandas as pd
from datetime import datetime
from copy import deepcopy
# utils for plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors as colors
from joblib import Parallel, delayed
import pickle


In [7]:

from functools import reduce
from operator import concat

# # utils for timing
# from goodpoints.tictoc import tic, toc, TicToc
# # utils for kernel ridge regression
# from goodpoints.krr.util_estimators import get_estimator, get_sigma_heuristic
# # utils for evaluating kernels
# from goodpoints.krr.thin.util_k_mmd import kernel_eval, to_regression_kernel
# # utils for generate samples from the data distribution
# from goodpoints.krr.util_sample import get_Xy, ToyData , get_toy_dataset #, get_housing_dataset
# # utils for dataset thinning
# from goodpoints.krr.thin.util_thin import sd_thin, kt_thin2
# # # falkon baseline
# # from goodpoints.krr.falkon.util_falkon_estimators import KernelRidgeFalkon

In [41]:
# with open('results/unif_sum-gauss_k=8_alpha=0.0001220703125_sigma=0.25_method=st.pkl', 'rb') as f:
#     st_results = pickle.load(f)
# with open('results/unif_sum-gauss_k=8_alpha=0.0001220703125_sigma=0.25_method=kt.pkl', 'rb') as f:
#     kt_results = pickle.load(f)
with open('results/unif_sum-gauss_k=8_noise=0.1.pkl', 'rb') as f:
    results = pickle.load(f)
    st_results = [result for result in results if 'st' in result['model']]
    kt_results = [result for result in results if 'kt' in result['model']]

In [42]:
st_results

[{'logn': 8,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 8,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 9,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 9,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 10,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 10,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 11,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'logn': 11,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'l

In [43]:
st_0_results = [result for result in st_results if '_0_' in result['model']]
kt_0_results = [result for result in kt_results if '_0_' in result['model']]

In [44]:
st_0_results

[{'logn': 8,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 8,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 9,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 9,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 10,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 10,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 11,
  'model': 'st_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'logn': 11,
  'model': 'st_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'l

In [45]:
kt_0_results

[{'logn': 8,
  'model': 'kt_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 8,
  'model': 'kt_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010633963918066066]},
 {'logn': 9,
  'model': 'kt_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 9,
  'model': 'kt_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010515464164740772]},
 {'logn': 10,
  'model': 'kt_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 10,
  'model': 'kt_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010431599535523589]},
 {'logn': 11,
  'model': 'kt_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'logn': 11,
  'model': 'kt_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010299006020484204]},
 {'l

In [46]:
logn_lo = 8
logn_hi = 14
baseline_loss = 0.1**2

varying_variable = 'alpha'
varying_variable_values = ['under', 'not-under']

## Figure 1

In [50]:
row_subplot_titles = [
    # "Test Score vs n", 
    # "Log10 Test Score vs n", 
    'Log10 Excess Risk vs m'
]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=len(varying_variable_values),
    shared_yaxes=True,
    subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    vertical_spacing=0.1,
)

# model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * 4
colors_used = set()

In [51]:
def plot_vs_n(fig, results, print_name, vvv, r, c, is_better='higher', scale='log2', dash='solid'):
    """
    Plot Figure 1 from DnC paper
    - x-axis: log2(n)
    - y-axis: test score (or excess risk)
    - different lines correspond to different values of m

    Args:
    - vvv: varying variable value
    """
    result_m = {m: [] for m in range(0, logn_hi, 2)}
    
    for result in results:
        model_name, m, vv_name = result["model"].split('_')

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue

        # color = colors_list[model_names.index(f'{model_name}_{m}')]
        color = colors_list[int(m)//2]

        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'log10':
            y = np.log10(np.abs(result[f"mean_scores"]))
            hline = np.log10(np.abs(baseline_loss))

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        elif scale == 'excess-log10':
            y = np.log10(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0

        trace = go.Box(
            x=[result['logn']]*len(result[f"mean_scores"]),
            y=y,
            name=f'm={m}',
            # opacity=0.5,
            legendgroup=str(m),
            line_color=color,
            offsetgroup=str(m),
            showlegend=color not in colors_used,
            boxmean=True,
            # line=dict(
            #     width=1,
            #     color=color,
            #     dash=dash,
            # ),
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

        result_m[int(m)].append((result['logn'], np.mean(y), np.std(y)))

    if 'excess' not in scale:
        # add line for baseline loss
        fig.add_hline(
            y=hline,
            row=r, col=c, line_dash="dash",
        )

    if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="log2(n)", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

    return result_m

In [53]:
st_result_m = {}
kt_result_m = {}
for c, vvv in enumerate(varying_variable_values):
    # plot_test_score_vs_n(str(vvv), 1, c+1, scale='linear')
    # plot_test_score_vs_n(str(vvv), 2, c+1, scale='log10')
    # result_m[vvv] = plot_test_score_vs_n(str(vvv), 3, c+1, scale='excess-log10')
    st_result_m[vvv] = plot_vs_n(fig, st_results, 'st', str(vvv), 1, c+1, scale='excess-log10', dash='solid')
    kt_result_m[vvv] = plot_vs_n(fig, kt_results, 'kt', str(vvv), 1, c+1, scale='excess-log10', dash='dot')

fig = make_subplots(
    rows=1,
    cols=len(varying_variable_values),
    shared_yaxes=True,
)

for c, vvv in enumerate(varying_variable_values):
    
    for m, result in st_result_m[vvv].items():
        x,y, y_std = zip(*st_result_m[vvv][m])
        fig.add_trace(go.Scatter(
            x=x, y=y, 
            mode='lines+markers', 
            name=f'm={m}', 
            error_y=dict(type='data', array=y_std, visible=True),
            legendgroup=str(m),
            # line_color = colors_list[m//2]
            line=dict(
                width=1,
                color=colors_list[m//2],
                dash='solid',
            ),
        ), row=1, col=c+1)

    for m, result in kt_result_m[vvv].items():
        x,y, y_std = zip(*kt_result_m[vvv][m])
        fig.add_trace(go.Scatter(
            x=x, y=y, 
            mode='lines+markers', 
            name=f'm={m}', 
            error_y=dict(type='data', array=y_std, visible=True),
            legendgroup=str(m),
            # line_color = colors_list[m//2],
            line=dict(
                width=1,
                color=colors_list[m//2],
                dash='dot',
            ),
        ), row=1, col=c+1)

# update x axis title
fig.update_xaxes(title_text="logn")
fig.update_yaxes(title_text="log10(Excess Risk)", row=1, col=1)
# fig.update_layout(title=f'Figure 1: f={f_name}(sigma={noise})')
fig.show()

## Figure 2

In [61]:
row_subplot_titles = [
    # "Test Score vs m", 
    # "Log10 Test Score vs m", 
    'Log10 Excess Risk vs m'
] #, "Train time vs n", "Predict time vs n"]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=1, #len(varying_variable_values),
    shared_yaxes=True,
    # subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    subplot_titles=row_subplot_titles,
    vertical_spacing=0.1,
)

# model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * (
    (logn_hi-logn_lo) // len(colors.qualitative.Plotly) + 1
)
colors_used = set()

In [62]:
def plot_vs_m(fig, results, vvv, r, c, is_better='higher', scale='log2'):
    """
    Plot Figure 2 from DnC paper
    - x-axis: m
    - y-axis: test score (or excess risk)
    - different lines correspond to different values of logn

    Args:
    - vvv: varying variable value
    """
    result_logn = {logn: [] for logn in range(logn_lo, logn_hi)}
    
    for result in results:
        model_name, m, vv_name = result["model"].split('_')

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue
        
        logn = result['logn']
        color = colors_list[logn - logn_lo]
        # print(model_name, m, f'logn: {logn}')   
        
        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'log10':
            y = np.log10(np.abs(result[f"mean_scores"]))
            hline = np.log10(np.abs(baseline_loss))

        elif scale == 'excess-log10':
            y = np.log10(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        trace = go.Box(
            x=[m,]*len(result[f"mean_scores"]),
            y=y,
            name=f'logn={logn}',
            # opacity=0.5,
            legendgroup=str(logn),
            line_color=color,
            offsetgroup=str(logn),
            showlegend=color not in colors_used,
            boxmean=True,
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

        result_logn[logn].append((int(m)/logn, np.mean(y), np.std(y)))

    # add line for baseline loss
    if 'excess' not in scale:
        fig.add_hline(
            y=hline,
            row=r, col=c, line_dash="dash",
        )

    # if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="m", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

    return result_logn

In [63]:
for c, vvv in enumerate(['under',]):
    # plot_test_score_vs_m(str(vvv), 1, c+1, scale='linear')
    # plot_test_score_vs_m(str(vvv), 2, c+1, scale='log10')
    st_result_logn = plot_vs_m(fig, st_results, str(vvv), 1, c+1, scale='excess-log10')
    kt_result_logn = plot_vs_m(fig, kt_results, str(vvv), 1, c+1, scale='excess-log10')

fig = go.Figure()
for logn, result in st_result_logn.items():
    x,y, y_std = zip(*st_result_logn[logn])
    fig.add_trace(go.Scatter(
        x=x, 
        y=y, 
        mode='lines+markers', 
        name=f'N={2**logn}', 
        error_y=dict(
            type='data', 
            array=y_std, 
            visible=True
        ),
        line=dict(
            width=1,
            color=colors_list[logn - logn_lo],
            dash='solid',
        ),
        legendgroup=str(logn),
    ))
for logn, result in kt_result_logn.items():
    x,y, y_std = zip(*kt_result_logn[logn])
    fig.add_trace(go.Scatter(
        x=x, 
        y=y, 
        mode='lines+markers', 
        name=f'N={2**logn}', 
        error_y=dict(
            type='data', 
            array=y_std, 
            visible=True
        ),
        line=dict(
            width=1,
            color=colors_list[logn - logn_lo],
            dash='dot',
        ),
        legendgroup=str(logn),
    ))

# update x axis title
fig.update_xaxes(title_text="m/logn")
fig.update_yaxes(title_text="log10(Excess Risk)")
# fig.update_layout(title=f'Figure 2: f={f_name}(sigma={noise})')
fig.show()