# Kernel Nadaraya-Watson Estimator

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from datetime import datetime
from copy import deepcopy
# utils for plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors as colors
from joblib import Parallel, delayed
import pickle
import numpy as np

from functools import reduce
from operator import concat

from sklearn.metrics import mean_squared_error, accuracy_score

# utils for timing
from goodpoints.tictoc import tic, toc, TicToc
# utils for kernel ridge regression
from goodpoints.knw.util_estimators import get_estimator
# utils for evaluating kernels
# from goodpoints.krr.thin.util_k_mmd import kernel_eval, to_regression_kernel
# utils for generate samples from the data distribution
from goodpoints.krr.util_sample import get_Xy, ToyData , get_toy_dataset #, get_housing_dataset
# utils for dataset thinning
# from goodpoints.krr.thin.util_thin import sd_thin, kt_thin2
# # falkon baseline
# from goodpoints.krr.falkon.util_falkon_estimators import KernelRidgeFalkon

In [3]:
### Toy dataset parameters

n = 2**13

# X_name = 'unif-[0,1]' # 'unif-[0,1]'  # ['unif', 'unif-[0,1]', 'gauss', 'mog']
# f_name = 'dnc-paper' # 'dnc-paper' # ['sin', 'stair', 'quad', 'sum-gauss', 'sum-laplace', 'step', 'dnc-paper']

X_name = 'unif'
X_var = 1
f_name = 'sum-gauss'
noise = 0.1 # np.sqrt( 0.2 ) # in paper, variance is 1/5

d = 1
# M = 2
k = 8 # number of anchor points for sum-gauss and sum-laplace

### Regression parameters

# kernel = 'sobolev'  # ['sobolev', 'gauss', 'laplace']
# alpha_pow = -2/3  # gauss = -1, sobolev = -2/3

kernel = 'singular'
alpha_pow = -1

sigma =  0.25
alpha = np.power(float(n),alpha_pow)

### Experiment parameters

# varying_variable = 'kernel' # ['kernel', 'd', 'M']
n_repeats = 100
save = False
logn_lo = 8
logn_hi = 14
use_cross_validation = False
n_jobs = 1
# method = 'st'

### Thinning parameters

m = None # Thinned dataset will have size n/2**m
# m=9

In [4]:
# Determine auxiliary parameters
task = 'regression'
refit = 'neg_mean_squared_error'
postprocess = None

filename = '_'.join([X_name, f_name, f'k={k}', f'noise={noise}'])
print(filename)

baseline_loss = noise**2
print(f'baseline loss: {baseline_loss}')

unif_sum-gauss_k=8_noise=0.1
baseline loss: 0.010000000000000002


In [5]:
toy_data_noise = ToyData(X_name, f_name, 
                        #  X_var=X_var, 
                         noise=noise, 
                        d=d, 
                        # M=M, 
                        k=k)

X_train, y_train = toy_data_noise.sample(n)
X_test, y_test = toy_data_noise.sample(10000)
# validation set used for cross validation
X_val, y_val = toy_data_noise.sample(10000)

# no noise version for plotting purposes
toy_data_no_noise = ToyData(X_name, f_name, noise=0, 
                    d=d, k=k)
X_no_noise, y_no_noise = toy_data_no_noise.sample(n)

print(X_train.shape, y_train.shape)

(8192, 1) (8192,)


In [6]:
def trace_Xy(X, y, name=None, color=None, alpha=0.5):
    d = X.shape[-1]

    if d==1:
        return go.Scatter(
            x=X.squeeze(), 
            y=y, 
            mode='markers',
            name=name,
            opacity=alpha,
            marker=dict(
                color=color,
            )
        )

    elif d==2:
        x1,x2 = X[:,0], X[:,1]
        return go.Scatter3d(x=x1, y=x2, z=y, mode='markers', name=name, marker=dict(
            symbol='circle',
            opacity=alpha,
            color=color,
            size=2,
            # line=dict(width=1),
        ))

    else:
        print(f"cannot plot data with dimension {d}")

In [7]:
fig = go.Figure(data=[
    trace_Xy(X_train, y_train, name='train'),
    # trace_Xy(X_test, y_test, name='test'),
    trace_Xy(X_no_noise, y_no_noise, name='no noise', color='black'),
])
fig.update_layout(
    title=f'X={X_name}, f={f_name}, std[f]={np.std(y_train):.4f}', #, hnorm[f]={hnorm(k):.4f}',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_function.png")

## KNW Full

In [8]:
knw_full = get_estimator(
    task, 
    'full',
    kernel='box', 
    sigma=sigma,
)

In [9]:
knw_full

In [10]:
knw_full.fit(X_train, y_train)

In [11]:
pred_full = knw_full.predict(X_test)
train_pred_full = knw_full.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_full))
print('MSE:', mean_squared_error(y_test, pred_full))

train MSE: 0.0
MSE: 0.011173282255553078


In [12]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_full, name='pred', alpha=0.4)
])
fig.update_layout(
    title='full',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_full.png")

## KNW ST

In [13]:
knw_st = get_estimator(
    task, 
    'st',
    kernel=kernel, 
    sigma=sigma,
)

In [14]:
knw_st

In [15]:
knw_st.fit(X_train, y_train)

In [16]:
%%time
pred_st = knw_st.predict(X_test)
train_pred_st = knw_st.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_st))
print('MSE:', mean_squared_error(y_test, pred_st))

train MSE: 0.012338862617609563
MSE: 0.01295145938382656
CPU times: user 214 ms, sys: 353 ms, total: 566 ms
Wall time: 104 ms


In [17]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_st, name='pred', alpha=0.4)
])
fig.update_layout(
    title='st',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_st.png")

## KNW KT

In [18]:
# knw_kt = get_estimator(
#     task, 
#     'kt',
#     kernel=kernel, 
#     sigma=sigma,
# )

In [19]:
# knw_kt

In [20]:
# knw_kt.fit(X_train, y_train)

In [21]:
# knw_kt.y_fit_

In [22]:
# %%time
# pred_kt = knw_kt.predict(X_test)
# train_pred_kt = knw_kt.predict(X_train)


In [23]:

# print('train MSE:', mean_squared_error(y_train, train_pred_kt))
# # print('MSE:', mean_squared_error(y_test, pred_kt))

In [24]:
# fig = go.Figure(data=[
#     # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
#     trace_Xy(X_test, y_test, name='test', alpha=0.1),
#     trace_Xy(X_test, pred_kt, name='pred', alpha=0.4)
# ])
# fig.update_layout(
#     title='kt',
#     height=400,
#     width=800,
# )
# fig.show()
# # save fig
# if save: fig.write_image(f"figures/{filename}_kt.png")

## Experiment

In [25]:
from sklearn.model_selection import GridSearchCV, RepeatedKFold

In [26]:
varying_variable = 'kernel'
varying_variable_values = ['gauss', 'laplace', 'singular', 'box']

In [27]:
print('Running experiment with varying variable:', varying_variable)
print('taking values:', varying_variable_values)

Running experiment with varying variable: kernel
taking values: ['gauss', 'laplace', 'singular', 'box']


In [28]:
# # Default param grid to search for each model
default_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e1, 1e-0, 1e-1, 1e-2, 1e-3, 1e-4],

    # "sigma" :   1/np.sqrt(2*np.array([0.5, 1., 2, 5.])),
    # "alpha" :   [0.01, 0.02]
}
falkon_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e-4, 1e-5,1e-6, 1e-7, 0],
}


# Model constructors and data size for each model
# We allow for different data sizes to avoid running Full KR on large datasets
model_configs = {
    # 'full' : {
    #     'logn' : np.arange(logn_lo, logn_hi, 2),
    #     # 'logn' : np.arange(8, 13, 2),
    #     'kwargs': {
    #         'postprocess' : postprocess,
    #         # 'alpha' : alpha,
    #         # 'sigma' : sigma,
    #     },
    #     'param_grid' : default_param_grid,
    # },
    
}
for sigma in np.arange(0.05, 0.45, 0.05):
    model_configs[f'full_{sigma}'] = {
        'logn' : np.arange(logn_lo, logn_hi, 1),
        'kwargs' : {
            'postprocess' : postprocess,
            'sigma' : sigma,
        },
        # 'param_grid' : default_param_grid,
    }

In [29]:
model_configs

{'full_0.05': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.05}},
 'full_0.1': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.1}},
 'full_0.15000000000000002': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.15000000000000002}},
 'full_0.2': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.2}},
 'full_0.25': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.25}},
 'full_0.30000000000000004': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.30000000000000004}},
 'full_0.35000000000000003': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.35000000000000003}},
 'full_0.4': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'postprocess': None, 'sigma': 0.4}}}

In [30]:
# Run experiment (depending on experiment_type)

results = []

i = 0
for name, config in model_configs.items():
    for logn in config['logn']:

        for v in varying_variable_values:
            kwargs = deepcopy(config['kwargs'])
            kwargs[varying_variable] = v

            model_name = f"{name}_{v}"
            trials = 1

            X, y = get_toy_dataset(
                X_name=X_name,
                f_name=f_name,
                n=2**logn,
                # X_var=X_var,
                d=kwargs['d'] if 'd' in kwargs else d,
                noise=noise,
                # M=kwargs['M'] if 'M' in kwargs else M,
                k=k,
            )
            
            # Set kernel, alpha, sigma params
            if 'kernel' not in kwargs:
                kwargs['kernel'] = kernel
            
            model = get_estimator(task, name=name, **kwargs)
            if model is None: continue
            # if kwargs['m'] > logn: continue

            print(f'i={i+1}: logn={logn}, model={model}')

            # STEP 2: Get optimal parameters through grid search
            # NOTE: we do something slightly better than k-fold cross validation.
            # Namely, we are trying to get rid of randomness in the Kernel Thinning (or Standard Thinning) routine,
            # but if we did 100-fold CV, then the validation set would be 1% of the data
            # (which is too small to get a good estimate of the validation score).
            # Instead we use the same train-val split for each parameter setting and repeat `trials` times
            if use_cross_validation:
                X_concat, y_concat = np.concatenate([X, X_val]), np.concatenate([y, y_val])
                split = [(np.arange(len(X)), np.arange(len(X), len(X)+len(X_val))) for _ in range(trials)]
                grid_search = GridSearchCV(
                    estimator=model,
                    param_grid=config['param_grid'],
                    return_train_score=True,
                    cv=split,
                    scoring=refit,
                    refit=False,
                    n_jobs=n_jobs,
                ).fit(X_concat, y_concat)
                # get validation scores
                cv_results = pd.DataFrame(grid_search.cv_results_)
                val_scores = []
                for i in range(trials):
                    val_scores.append( cv_results.iloc[grid_search.best_index_][f'split{i}_test_score'] )

                # get optimal parameters
                best_params = grid_search.best_params_
            else:
                # Dummy values
                val_scores = [1,] * trials
                
                best_params = {
                    # 'sigma' : sigma,
                    # 'alpha' : alpha, # * (len(X_train)**(1/4) if name in ['st', 'kt'] else 1),
                }    

            def run():
                # print(f"best params: {best_params}")            
                best_model = get_estimator(task, name=name, 
                                        # sigma=best_params['sigma'],
                                        #    alpha=best_params['alpha'],
                                        **kwargs)
                best_model.fit(X, y)
                return best_model.predict(X_test).squeeze()

            parallel = Parallel(n_jobs=n_jobs)
            output_generator = parallel(delayed(run)() for _ in range(trials))

            mean_scores = []
            for test_pred in output_generator:
                if refit == 'neg_mean_squared_error':
                    test_scores = mean_squared_error(y_test, test_pred)
                elif refit == 'accuracy':
                    test_scores = accuracy_score(y_test, test_pred)
                else:
                    raise ValueError(f"invalid refit metric: {refit}")

                mean_scores.append( np.mean(test_scores) )

            results.append({
                "logn": logn, 
                "model": model_name, 
                "cv_results": pd.DataFrame(grid_search.cv_results_) if use_cross_validation else None,
                "best_index_" : grid_search.best_index_ if use_cross_validation else 0,
                "mean_scores" : mean_scores,
                # "std_score" : np.std(mean_scores),
            })

            if save:
                with open(f"results/{filename}.pkl", 'wb') as f:
                    pickle.dump(results, f)

            i += 1

sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=1: logn=8, model=KernelNadarayaWatsonRegressor(kernel='gauss', sigma=0.05)
i=2: logn=8, model=KernelNadarayaWatsonRegressor(sigma=0.05)
i=3: logn=8, model=KernelNadarayaWatsonRegressor(kernel='singular', sigma=0.05)
i=4: logn=8, model=KernelNadarayaWatsonRegressor(kernel='box', sigma=0.05)
sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=5: logn=9, model=KernelNadarayaWatsonRegressor(kernel='gauss', sigma=0.05)
i=6: logn=9, model=KernelNadarayaWatsonRegressor(sigma=0.05)
i=7: logn=9, model=KernelNadarayaWatsonRegressor(kernel='singular', sigma=0.05)
i=8: logn=9, model=KernelNadarayaWatsonRegressor(kernel='box', sigma=0.05)
sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=9: logn=10, model=KernelNadarayaWatsonRegressor(kernel='gauss', sigma=0.05)
i=10: logn=10, model=Kernel

In [31]:
results

[{'logn': 8,
  'model': 'full_0.05_gauss',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010865750242829498]},
 {'logn': 8,
  'model': 'full_0.05_laplace',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.01096672669810978]},
 {'logn': 8,
  'model': 'full_0.05_singular',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.014231075771837362]},
 {'logn': 8,
  'model': 'full_0.05_box',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.012020346775352143]},
 {'logn': 9,
  'model': 'full_0.05_gauss',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.010582393592570858]},
 {'logn': 9,
  'model': 'full_0.05_laplace',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.01069707629131954]},
 {'logn': 9,
  'model': 'full_0.05_singular',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.012497833569027514]},
 {'logn': 9,
  'model': 'full_0.05_box',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.01

## Plot results

In [32]:
row_subplot_titles = [
    # "Test Score vs m", 
    # "Log10 Test Score vs m", 
    'Log10 Excess Risk vs sigma'
] #, "Train time vs n", "Predict time vs n"]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=len(varying_variable_values),
    shared_yaxes=True,
    # subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    subplot_titles=row_subplot_titles,
    vertical_spacing=0.1,
)

# model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * 4
colors_used = set()

In [33]:
def plot_vs_sigma(fig, results, vvv, r, c, is_better='higher', scale='log2'):
    """
    Plot Figure 2 from DnC paper
    - x-axis: sigma
    - y-axis: test score (or excess risk)
    - different lines correspond to different values of logn

    Args:
    - vvv: varying variable value
    """
    result_logn = {logn: [] for logn in range(logn_lo, logn_hi)}
    
    for result in results:
        model_name, sigma, vv_name = result["model"].split('_')

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue
        
        logn = result['logn']
        color = colors_list[logn - logn_lo]
        # print(model_name, m, f'logn: {logn}')   
        
        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'log10':
            y = np.log10(np.abs(result[f"mean_scores"]))
            hline = np.log10(np.abs(baseline_loss))

        elif scale == 'excess-log10':
            y = np.log10(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0
        
        elif scale == 'excess-log2':
            y = np.log2(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        trace = go.Box(
            x=[float(sigma),]*len(result[f"mean_scores"]),
            y=y,
            name=f'logn={logn}',
            # opacity=0.5,
            legendgroup=str(logn),
            line_color=color,
            offsetgroup=str(logn),
            showlegend=color not in colors_used,
            boxmean=True,
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

        result_logn[logn].append((float(sigma), np.mean(y), np.std(y)))

    # add line for baseline loss
    if 'excess' not in scale:
        fig.add_hline(
            y=hline,
            row=r, col=c, line_dash="dash",
        )

    # if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="m", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

    return result_logn

In [34]:
result_logn = {}
for c, vvv in enumerate(varying_variable_values):
    # plot_test_score_vs_m(str(vvv), 1, c+1, scale='linear')
    # plot_test_score_vs_m(str(vvv), 2, c+1, scale='log10')
    result_logn[vvv] = plot_vs_sigma(fig, results, str(vvv), 1, c+1, scale='excess-log2')


In [35]:
fig.show()

In [36]:
result_logn

{'gauss': {8: [(0.05, -10.173761492473337, 0.0),
   (0.1, -9.596440148437656, 0.0),
   (0.15000000000000002, -8.041990009929998, 0.0),
   (0.2, -6.6307733948687675, 0.0),
   (0.25, -5.527008093531952, 0.0),
   (0.30000000000000004, -4.677241662738355, 0.0),
   (0.35000000000000003, -4.016384702152628, 0.0),
   (0.4, -3.493512448050078, 0.0)],
  9: [(0.05, -10.745717896294952, 0.0),
   (0.1, -9.755342019407195, 0.0),
   (0.15000000000000002, -7.962514672811414, 0.0),
   (0.2, -6.53918690924588, 0.0),
   (0.25, -5.453772809040967, 0.0),
   (0.30000000000000004, -4.6203344272821925, 0.0),
   (0.35000000000000003, -3.9712957925665, 0.0),
   (0.4, -3.4567606610976975, 0.0)],
  10: [(0.05, -12.038445086581762, 0.0),
   (0.1, -10.40873795072227, 0.0),
   (0.15000000000000002, -8.243971531975179, 0.0),
   (0.2, -6.684433878855189, 0.0),
   (0.25, -5.537021333567144, 0.0),
   (0.30000000000000004, -4.671566941954006, 0.0),
   (0.35000000000000003, -4.004500620997543, 0.0),
   (0.4, -3.479180178

In [38]:

fig = make_subplots(
    rows=1, 
    cols=len(varying_variable_values), 
    shared_yaxes=True,
    subplot_titles=[f'{varying_variable}={vvv}' for vvv in varying_variable_values],
)
colors_used = set()

for c, vvv in enumerate(varying_variable_values):
    for logn, result in result_logn[vvv].items():
        x,y, y_std = zip(*result_logn[vvv][logn])
        color = colors_list[logn - logn_lo]

        fig.add_trace(go.Scatter(
            # x=np.log2(x) / logn, 
            x=x,
            y=y, 
            mode='lines+markers', 
            name=f'N={2**logn}', 
            error_y=dict(
                type='data', 
                array=y_std, 
                visible=True
            ),
            line=dict(
                width=1,
                color=colors_list[logn - logn_lo],
                dash='solid',
            ),
            legendgroup=str(logn),
            showlegend=color not in colors_used,
        ), row=1, col=c+1)
        colors_used.add(color)

    if c==0: fig.update_yaxes(title_text="log2(Excess Risk)", row=1, col=c+1)
# for logn, result in result_logn.items():
#     x,y, y_std = zip(*kt_result_logn[logn])
#     fig.add_trace(go.Scatter(
#         x=x, 
#         y=y, 
#         mode='lines+markers', 
#         name=f'N={2**logn}', 
#         error_y=dict(
#             type='data', 
#             array=y_std, 
#             visible=True
#         ),
#         line=dict(
#             width=1,
#             color=colors_list[logn - logn_lo],
#             dash='dot',
#         ),
#         legendgroup=str(logn),
#     ))

# update x axis title
fig.update_xaxes(title_text="h")
# fig.update_layout(title=f'Figure 2: f={f_name}(sigma={noise})')
fig.update_layout(title=f'Figure 2: f={f_name}(k={k}, noise={noise}) - Full')
fig.show()