# Divide-and-Conquer Paper

In [1]:
# install using `conda install -c conda-forge line_profiler`
%load_ext line_profiler
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_is_fitted
import numpy as np
from numpy.linalg import LinAlgError
# set global seed
# np.random.seed(123)
from sklearn.model_selection import GridSearchCV, RepeatedKFold
from sklearn.metrics import r2_score, accuracy_score

import pandas as pd
from datetime import datetime
from copy import deepcopy
# utils for plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors as colors
from joblib import Parallel, delayed
import pickle

from functools import reduce
from operator import concat

# utils for timing
from goodpoints.tictoc import tic, toc, TicToc
# utils for kernel ridge regression
from goodpoints.krr.util_estimators import get_estimator, get_sigma_heuristic
# utils for evaluating kernels
from goodpoints.krr.thin.util_k_mmd import kernel_eval, to_regression_kernel
# utils for generate samples from the data distribution
from goodpoints.krr.util_sample import get_Xy, ToyData , get_toy_dataset #, get_housing_dataset
# utils for dataset thinning
from goodpoints.krr.thin.util_thin import sd_thin, kt_thin2
# # falkon baseline
# from goodpoints.krr.falkon.util_falkon_estimators import KernelRidgeFalkon

In [3]:
# add this to be able to render plotly plots in non-vscode notebooks
import plotly.io as pio
pio.renderers.default = "notebook_connected"

## Setup

In [4]:
### Toy dataset parameters

n = 2**13

# X_name = 'unif-[0,1]' # 'unif-[0,1]'  # ['unif', 'unif-[0,1]', 'gauss', 'mog']
# f_name = 'dnc-paper' # 'dnc-paper' # ['sin', 'stair', 'quad', 'sum-gauss', 'sum-laplace', 'step', 'dnc-paper']

X_name = 'unif'
f_name = 'sum-gauss'
noise = 0.1 # np.sqrt( 0.2 ) # in paper, variance is 1/5

d = 1
# M = 2
k = 8 # number of anchor points for sum-gauss and sum-laplace

### Regression parameters

# kernel = 'sobolev'  # ['sobolev', 'gauss', 'laplace']
# alpha_pow = -2/3  # gauss = -1, sobolev = -2/3

kernel = 'gauss'
alpha_pow = -1

sigma =  0.25
alpha = np.power(float(n),alpha_pow)

### Experiment parameters

# varying_variable = 'kernel' # ['kernel', 'd', 'M']
n_repeats = 20
save = False
logn_lo = 8
logn_hi = 14
use_cross_validation = False
n_jobs = 1
# method = 'st'

### Thinning parameters

m = None # Thinned dataset will have size n/2**m
# m=9

In [5]:
print('alpha', alpha)

alpha 0.0001220703125


In [6]:
# Determine auxiliary parameters
task = 'regression'
refit = 'neg_mean_squared_error'
postprocess = None

filename = '_'.join([X_name, f_name, f'k={k}', f'noise={noise}'])
print(filename)

baseline_loss = noise**2
print(f'baseline loss: {baseline_loss}')

unif_sum-gauss_k=8_noise=0.1
baseline loss: 0.010000000000000002


In [7]:
toy_data_noise = ToyData(X_name, f_name, 
                        #  X_var=X_var, 
                         noise=noise, 
                        d=d, 
                        # M=M, 
                        k=k)

X_train, y_train = toy_data_noise.sample(n)
X_test, y_test = toy_data_noise.sample(10000)
# validation set used for cross validation
X_val, y_val = toy_data_noise.sample(10000)

# no noise version for plotting purposes
toy_data_no_noise = ToyData(X_name, f_name, noise=0, 
                    d=d, k=k)
X_no_noise, y_no_noise = toy_data_no_noise.sample(n)

print(X_train.shape, y_train.shape)

(8192, 1) (8192,)


In [8]:
np.std(X_train), np.std(y_train)

(1.00012207776399, 0.8493696190573192)

In [9]:
def trace_Xy(X, y, name=None, color=None, alpha=0.5):
    d = X.shape[-1]

    if d==1:
        return go.Scatter(
            x=X.squeeze(), 
            y=y, 
            mode='markers',
            name=name,
            opacity=alpha,
            marker=dict(
                color=color,
            )
        )

    elif d==2:
        x1,x2 = X[:,0], X[:,1]
        return go.Scatter3d(x=x1, y=x2, z=y, mode='markers', name=name, marker=dict(
            symbol='circle',
            opacity=alpha,
            color=color,
            size=2,
            # line=dict(width=1),
        ))

    else:
        print(f"cannot plot data with dimension {d}")

In [10]:
# def hnorm(k):
#     from goodpoints.krr.thin.util_k_mmd import laplacian, gauss
#     anchor_points = np.linspace(-1, 1, k)[:, np.newaxis]
#     # print(anchor_points)

#     if f_name == 'sum-gauss':
#         K = gauss(anchor_points, anchor_points, 0.25)
#     elif f_name == 'sum-laplace':
#         K = laplacian(anchor_points, anchor_points, 0.25)
#     else:
#         raise ValueError(f'f_name {f_name} not supported')
#     # print(K)
    
#     return np.sum(K)
#     # return 1/(k**2) * np.sum(K)

In [11]:
# hnorm(2)

In [12]:
fig = go.Figure(data=[
    trace_Xy(X_train, y_train, name='train'),
    # trace_Xy(X_test, y_test, name='test'),
    trace_Xy(X_no_noise, y_no_noise, name='no noise', color='black'),
])
fig.update_layout(
    title=f'X={X_name}, f={f_name}, std[f]={np.std(y_train):.4f}', #, hnorm[f]={hnorm(k):.4f}',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_function.png")

In [13]:
np.minimum(np.array([0,1])[:,np.newaxis],np.array([1,2,3])[:,np.newaxis].T)

array([[0, 0, 0],
       [1, 1, 1]])

## KRR Full

In [14]:
krr_full = get_estimator(
    task,
    'full',
    kernel=kernel, 
    sigma=sigma, 
    alpha=alpha
)

In [15]:
krr_full.fit(X_train, y_train)

array([[1.00000000e+00, 1.14269057e-08, 4.00215990e-01, ...,
        4.74388046e-06, 3.96558660e-01, 4.43241608e-04],
       [1.14269057e-08, 1.00000000e+00, 1.27564434e-12, ...,
        5.35869921e-27, 1.69239613e-05, 2.41662207e-22],
       [4.00215990e-01, 1.27564434e-12, 1.00000000e+00, ...,
        1.54396921e-03, 2.51891604e-02, 3.61917863e-02],
       ...,
       [4.74388046e-06, 5.35869921e-27, 1.54396921e-03, ...,
        1.00000000e+00, 2.23707045e-09, 5.93332156e-01],
       [3.96558660e-01, 1.69239613e-05, 2.51891604e-02, ...,
        2.23707045e-09, 1.00000000e+00, 8.38927680e-07],
       [4.43241608e-04, 2.41662207e-22, 3.61917863e-02, ...,
        5.93332156e-01, 8.38927680e-07, 1.00000000e+00]])

In [16]:
pred_full = krr_full.predict(X_test)
train_pred_full = krr_full.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_full))
print('MSE:', mean_squared_error(y_test, pred_full))

train MSE: 0.009873844565867549
MSE: 0.010084646333420766


### sanity check

In [17]:
# make sure krr full is same as krr dnc (st) with 1 partition
krr_sd_thin_0 = get_estimator(
    task,
    'st',
    kernel=kernel, 
    sigma=sigma, 
    alpha=alpha,
    m=0,
    use_dnc=True
)

In [18]:
krr_sd_thin_0.fit(X_train, y_train)

In [19]:
print('train MSE:', mean_squared_error(y_train, krr_sd_thin_0.predict(X_train)))
print('MSE:', mean_squared_error(y_test, krr_sd_thin_0.predict(X_test)))

train MSE: 0.009873844565867549
MSE: 0.010084646333420766


In [20]:
krr_kt_thin_0 = get_estimator(
    task,
    'kt',
    kernel=kernel, 
    sigma=sigma, 
    alpha=alpha,
    m=0,
    use_dnc=True,
    use_compresspp=False,
)

In [21]:
krr_kt_thin_0.fit(X_train, y_train)

In [22]:
print('train MSE:', mean_squared_error(y_train, krr_kt_thin_0.predict(X_train)))
print('MSE:', mean_squared_error(y_test, krr_kt_thin_0.predict(X_test)))

train MSE: 0.009873844565867549
MSE: 0.010084646333420766


In [23]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_full, name='pred', alpha=0.4)
])
fig.update_layout(
    title='full',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_full.png")

## DnC (Standard Thin)

In [24]:
print(m,n)

None 8192


In [25]:
krr_sd_thin = get_estimator(
    task, 
    'st', 
    alpha=alpha, # / np.power(n, 1/4), 
    kernel=kernel, 
    sigma=sigma, 
    m=m, 
    use_dnc=True,
    verbose=0,
)

In [26]:
%%time
krr_sd_thin.fit(X_train, y_train)

CPU times: user 18.3 ms, sys: 7.6 ms, total: 25.9 ms
Wall time: 31.8 ms


In [27]:
# krr_sd_thin.estimators_[0].sol_
# krr_sd_thin.estimators_[0].y_fit_
# krr_sd_thin.estimators_[0].X_fit_
# a = (k+alpha)^-1 y
# 0.63134153/(1+0.09131417 + alpha)

In [28]:
%%time
pred_sd, pred_sd_lst = krr_sd_thin.predict(X_test, return_all=True)
train_pred_sd = krr_sd_thin.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_sd))
print('MSE:', mean_squared_error(y_test, pred_sd))

train MSE: 0.009891746170507707
MSE: 0.01007595092276503
CPU times: user 8.29 s, sys: 5.47 s, total: 13.8 s
Wall time: 3.12 s


In [29]:
# get test mse for all partitions
test_mse_sd = [mean_squared_error(y_test, p) for p in pred_sd_lst]

In [31]:
px.histogram(test_mse_sd)

In [33]:
np.min(test_mse_sd), np.max(test_mse_sd)

(0.011169806726692258, 0.029162829247778962)

In [34]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_sd, name='pred', alpha=0.4)
])
fig.update_layout(
    title='st',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_st.png")

## DnC (Kernel Thin)

In [35]:
n

8192

In [36]:
krr_kt_thin = get_estimator(
    task, 
    'kt', 
    alpha=alpha, # / np.power(n, 1/4), 
    kernel=kernel, 
    sigma=sigma, 
    m=m, 
    use_dnc=True,
    use_compresspp=False,
)

In [37]:
%%time
krr_kt_thin.fit(X_train, y_train)

CPU times: user 803 ms, sys: 451 ms, total: 1.25 s
Wall time: 2.02 s


In [38]:
%%time
pred_kt, pred_kt_lst = krr_kt_thin.predict(X_test, return_all=True)
train_pred_kt = krr_kt_thin.predict(X_train)

print('train MSE:', mean_squared_error(y_train, train_pred_kt))
print('MSE:', mean_squared_error(y_test, pred_kt))

train MSE: 0.009886723712250623
MSE: 0.010081666223716177
CPU times: user 7.1 s, sys: 7.73 s, total: 14.8 s
Wall time: 2.42 s


In [39]:
test_mse_kt = [mean_squared_error(y_test, p) for p in pred_kt_lst]

In [40]:
px.histogram(test_mse_kt)

In [42]:
np.min(test_mse_kt), np.max(test_mse_kt)

(0.010820330958050831, 0.014170609347377358)

In [41]:
fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=test_mse_sd,
        name='st',
        opacity=0.75,
    )
)
fig.add_trace(
    go.Histogram(
        x=test_mse_kt,
        name='kt',
        opacity=0.75,
    )
)
fig.show()

In [34]:
fig = go.Figure(data=[
    # trace_Xy(krr_sd_thin.X_fit_, krr_sd_thin.y_fit_, name='train coreset', alpha=1),
    trace_Xy(X_test, y_test, name='test', alpha=0.1),
    trace_Xy(X_test, pred_kt, name='pred', alpha=0.4)
])
fig.update_layout(
    title='kt',
    height=400,
    width=800,
)
fig.show()
# save fig
if save: fig.write_image(f"figures/{filename}_kt.png")

## KRR-KT (no DnC)

## Experiment

In [35]:
"""
Varying variables (during grid search, these are NOT parallelized)
"""

varying_variable = 'alpha'
varying_variable_values = ['under', 'not-under'] #[2.0**-logn for logn in np.arange(logn_lo, logn_hi, 2)]

In [36]:
print('Running experiment with varying variable:', varying_variable)
print('taking values:', varying_variable_values)

Running experiment with varying variable: alpha
taking values: ['under', 'not-under']


In [37]:
# # Default param grid to search for each model
default_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e1, 1e-0, 1e-1, 1e-2, 1e-3, 1e-4],

    # "sigma" :   1/np.sqrt(2*np.array([0.5, 1., 2, 5.])),
    # "alpha" :   [0.01, 0.02]
}
falkon_param_grid = {
    "sigma" :   [sigma,],
    # "alpha" :   [1e-4, 1e-5,1e-6, 1e-7, 0],
}


# Model constructors and data size for each model
# We allow for different data sizes to avoid running Full KR on large datasets
model_configs = {
    # 'full' : {
    #     'logn' : np.arange(logn_lo, logn_hi, 2),
    #     # 'logn' : np.arange(8, 13, 2),
    #     'kwargs': {
    #         'postprocess' : postprocess,
    #         # 'alpha' : alpha,
    #         # 'sigma' : sigma,
    #     },
    #     'param_grid' : default_param_grid,
    # },
    
}
for m in range(0, logn_hi):
    # if method == 'st':
    model_configs[f'st_{m}'] = {
        'logn' : np.arange(logn_lo, logn_hi, 1),
        # 'logn' : np.arange(8, 13, 2),
        'kwargs' : {
            'm' : m,
            'postprocess' : postprocess,
            # 'alpha' : alpha,
            # 'sigma' : sigma,
            'use_dnc' : True,
        },
        'param_grid' : default_param_grid,
    }
    # elif method == 'kt':
    model_configs[f'kt_{m}'] = {
        'logn' : np.arange(logn_lo, logn_hi, 1),
        # 'logn' : np.arange(8, 13, 2),
        'kwargs' : {
            'm' : m,
            'postprocess' : postprocess,
            # 'alpha' : alpha,
            # 'sigma' : sigma,
            'use_dnc' : True,
            'use_compresspp' : False,
        },
        'param_grid' : default_param_grid,
    }
    # else:
    #     raise ValueError(f'unknown method {method}')

In [38]:
model_configs

{'st_0': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 0, 'postprocess': None, 'use_dnc': True},
  'param_grid': {'sigma': [0.25]}},
 'kt_0': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 0,
   'postprocess': None,
   'use_dnc': True,
   'use_compresspp': False},
  'param_grid': {'sigma': [0.25]}},
 'st_2': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 2, 'postprocess': None, 'use_dnc': True},
  'param_grid': {'sigma': [0.25]}},
 'kt_2': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 2,
   'postprocess': None,
   'use_dnc': True,
   'use_compresspp': False},
  'param_grid': {'sigma': [0.25]}},
 'st_4': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 4, 'postprocess': None, 'use_dnc': True},
  'param_grid': {'sigma': [0.25]}},
 'kt_4': {'logn': array([ 8,  9, 10, 11, 12, 13]),
  'kwargs': {'m': 4,
   'postprocess': None,
   'use_dnc': True,
   'use_compresspp': False},
  'param_grid': {'sigma': [0.25]}},
 'st_6': {'logn'

In [37]:
# Run experiment (depending on experiment_type)

results = []

i = 0
for name, config in model_configs.items():
    for logn in config['logn']:

        for v in varying_variable_values:
            kwargs = deepcopy(config['kwargs'])
            # kwargs[varying_variable] = v
            if v == 'under':
                kwargs['alpha'] = np.power(float(2**logn), alpha_pow)
            else:
                kwargs['alpha'] = np.power(float(2**logn / 2**kwargs['m']), alpha_pow)

            model_name = f"{name}_{v}"
            # trials = (1 if name in ['full'] else n_repeats)
            # st_0 is equivalent to KRR-Full
            trials = (1 if kwargs['m'] == 0 else n_repeats)

            X, y = get_toy_dataset(
                X_name=X_name,
                f_name=f_name,
                n=2**logn,
                # X_var=X_var,
                d=kwargs['d'] if 'd' in kwargs else d,
                noise=noise,
                # M=kwargs['M'] if 'M' in kwargs else M,
                k=k,
            )
            
            # Set kernel, alpha, sigma params
            if 'kernel' not in kwargs:
                kwargs['kernel'] = kernel
            
            model = get_estimator(task, name=name, **kwargs)
            if model is None: continue
            if kwargs['m'] > logn: continue

            print(f'i={i+1}: logn={logn}, model={model}')

            # STEP 2: Get optimal parameters through grid search
            # NOTE: we do something slightly better than k-fold cross validation.
            # Namely, we are trying to get rid of randomness in the Kernel Thinning (or Standard Thinning) routine,
            # but if we did 100-fold CV, then the validation set would be 1% of the data
            # (which is too small to get a good estimate of the validation score).
            # Instead we use the same train-val split for each parameter setting and repeat `trials` times
            if use_cross_validation:
                X_concat, y_concat = np.concatenate([X, X_val]), np.concatenate([y, y_val])
                split = [(np.arange(len(X)), np.arange(len(X), len(X)+len(X_val))) for _ in range(trials)]
                grid_search = GridSearchCV(
                    estimator=model,
                    param_grid=config['param_grid'],
                    return_train_score=True,
                    cv=split,
                    scoring=refit,
                    refit=False,
                    n_jobs=n_jobs,
                ).fit(X_concat, y_concat)
                # get validation scores
                cv_results = pd.DataFrame(grid_search.cv_results_)
                val_scores = []
                for i in range(trials):
                    val_scores.append( cv_results.iloc[grid_search.best_index_][f'split{i}_test_score'] )

                # get optimal parameters
                best_params = grid_search.best_params_
            else:
                # Dummy values
                val_scores = [1,] * trials
                
                best_params = {
                    'sigma' : sigma,
                    # 'alpha' : alpha, # * (len(X_train)**(1/4) if name in ['st', 'kt'] else 1),
                }    

            def run():
                # print(f"best params: {best_params}")            
                best_model = get_estimator(task, name=name, 
                                        sigma=best_params['sigma'],
                                        #    alpha=best_params['alpha'],
                                        **kwargs)
                best_model.fit(X, y)
                return best_model.predict(X_test).squeeze()

            parallel = Parallel(n_jobs=n_jobs)
            output_generator = parallel(delayed(run)() for _ in range(trials))

            mean_scores = []
            for test_pred in output_generator:
                if refit == 'neg_mean_squared_error':
                    test_scores = mean_squared_error(y_test, test_pred)
                elif refit == 'accuracy':
                    test_scores = accuracy_score(y_test, test_pred)
                else:
                    raise ValueError(f"invalid refit metric: {refit}")

                mean_scores.append( np.mean(test_scores) )

            results.append({
                "logn": logn, 
                "model": model_name, 
                "cv_results": pd.DataFrame(grid_search.cv_results_) if use_cross_validation else None,
                "best_index_" : grid_search.best_index_ if use_cross_validation else 0,
                "mean_scores" : mean_scores,
                # "std_score" : np.std(mean_scores),
            })

            if save:
                with open(f"results/{filename}.pkl", 'wb') as f:
                    pickle.dump(results, f)

            i += 1

i=1: logn=8, model=KernelRidgeSTRegressor(alpha=0.00390625, kernel='gauss', m=0, use_dnc=True)
fitting estimator 0 on coreset of size 256...
i=2: logn=8, model=KernelRidgeSTRegressor(alpha=0.00390625, kernel='gauss', m=0, use_dnc=True)
fitting estimator 0 on coreset of size 256...
sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=3: logn=9, model=KernelRidgeSTRegressor(alpha=0.001953125, kernel='gauss', m=0, use_dnc=True)
fitting estimator 0 on coreset of size 512...
i=4: logn=9, model=KernelRidgeSTRegressor(alpha=0.001953125, kernel='gauss', m=0, use_dnc=True)
fitting estimator 0 on coreset of size 512...
sampling dataset with params ToyData(X_name=unif, f_name=sum-gauss, X_var=1, d=1, noise=0.1, M=4, k=8)
i=5: logn=10, model=KernelRidgeSTRegressor(alpha=0.0009765625, kernel='gauss', m=0, use_dnc=True)
fitting estimator 0 on coreset of size 1024...
i=6: logn=10, model=KernelRidgeSTRegressor(alpha=0.0009765625, kernel='gauss', m=0,

KeyboardInterrupt: 

In [None]:
results

[{'logn': 8,
  'model': 'kt_0_under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859]},
 {'logn': 8,
  'model': 'kt_0_not-under',
  'cv_results': None,
  'best_index_': 0,
  'mean_scores': [0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0.011321615398064859,
   0

In [None]:
st_results = [result for result in results if 'st' in result['model']]
kt_results = [result for result in results if 'kt' in result['model']]

## Figure 1

In [None]:
row_subplot_titles = [
    # "Test Score vs n", 
    # "Log10 Test Score vs n", 
    'Log10 Excess Risk vs m'
]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=len(varying_variable_values),
    shared_yaxes=True,
    subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    vertical_spacing=0.1,
)

# model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * 4
colors_used = set()

In [None]:
def plot_vs_n(fig, results, print_name, vvv, r, c, is_better='higher', scale='log2', dash='solid'):
    """
    Plot Figure 1 from DnC paper
    - x-axis: log2(n)
    - y-axis: test score (or excess risk)
    - different lines correspond to different values of m

    Args:
    - vvv: varying variable value
    """
    result_m = {m: [] for m in range(0, logn_hi)}
    
    for result in results:
        model_name, m, vv_name = result["model"].split('_')

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue

        # color = colors_list[model_names.index(f'{model_name}_{m}')]
        color = colors_list[int(m)]

        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'log10':
            y = np.log10(np.abs(result[f"mean_scores"]))
            hline = np.log10(np.abs(baseline_loss))

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        elif scale == 'excess-log10':
            y = np.log10(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0

        trace = go.Box(
            x=[result['logn']]*len(result[f"mean_scores"]),
            y=y,
            name=f'm={m}',
            # opacity=0.5,
            legendgroup=str(m),
            line_color=color,
            offsetgroup=str(m),
            showlegend=color not in colors_used,
            boxmean=True,
            # line=dict(
            #     width=1,
            #     color=color,
            #     dash=dash,
            # ),
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

        result_m[int(m)].append((result['logn'], np.mean(y), np.std(y)))

    if 'excess' not in scale:
        # add line for baseline loss
        fig.add_hline(
            y=hline,
            row=r, col=c, line_dash="dash",
        )

    if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="log2(n)", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

    return result_m

In [None]:
st_result_m = {}
kt_result_m = {}
for c, vvv in enumerate(varying_variable_values):
    # plot_test_score_vs_n(str(vvv), 1, c+1, scale='linear')
    # plot_test_score_vs_n(str(vvv), 2, c+1, scale='log10')
    # result_m[vvv] = plot_test_score_vs_n(str(vvv), 3, c+1, scale='excess-log10')
    st_result_m[vvv] = plot_vs_n(fig, st_results, 'st', str(vvv), 1, c+1, scale='excess-log10', dash='solid')
    kt_result_m[vvv] = plot_vs_n(fig, kt_results, 'kt', str(vvv), 1, c+1, scale='excess-log10', dash='dot')

fig = make_subplots(
    rows=1,
    cols=len(varying_variable_values),
    shared_yaxes=True,
)

for c, vvv in enumerate(varying_variable_values):
    
    for m, result in st_result_m[vvv].items():
        x,y, y_std = zip(*st_result_m[vvv][m])
        fig.add_trace(go.Scatter(
            x=x, y=y, 
            mode='lines+markers', 
            name=f'm={m}', 
            error_y=dict(type='data', array=y_std, visible=True),
            legendgroup=str(m),
            line=dict(
                width=1,
                color=colors_list[m],
                dash='solid',
            ),
        ), row=1, col=c+1)

    for m, result in kt_result_m[vvv].items():
        x,y, y_std = zip(*kt_result_m[vvv][m])
        fig.add_trace(go.Scatter(
            x=x, y=y, 
            mode='lines+markers', 
            name=f'm={m}', 
            error_y=dict(type='data', array=y_std, visible=True),
            legendgroup=str(m),
            line=dict(
                width=1,
                color=colors_list[m],
                dash='dot',
            ),
        ), row=1, col=c+1)

# update x axis title
fig.update_xaxes(title_text="logn")
fig.update_yaxes(title_text="log10(Excess Risk)", row=1, col=1)
# fig.update_layout(title=f'Figure 1: f={f_name}(sigma={noise})')
fig.show()


invalid value encountered in log10



## Figure 2

In [None]:
row_subplot_titles = [
    # "Test Score vs m", 
    # "Log10 Test Score vs m", 
    'Log10 Excess Risk vs m'
] #, "Train time vs n", "Predict time vs n"]

fig = make_subplots(
    rows=len(row_subplot_titles),
    cols=1, #len(varying_variable_values),
    shared_yaxes=True,
    # subplot_titles=reduce(concat, [[f'{varying_variable}={v}' for v in varying_variable_values] for _ in row_subplot_titles]),
    subplot_titles=row_subplot_titles,
    vertical_spacing=0.1,
)

# model_names = list(model_configs.keys())
colors_list = colors.qualitative.Plotly * (
    (logn_hi-logn_lo) // len(colors.qualitative.Plotly) + 1
)
colors_used = set()

In [None]:
def plot_vs_m(fig, results, vvv, r, c, is_better='higher', scale='log2'):
    """
    Plot Figure 2 from DnC paper
    - x-axis: m
    - y-axis: test score (or excess risk)
    - different lines correspond to different values of logn

    Args:
    - vvv: varying variable value
    """
    result_logn = {logn: [] for logn in range(logn_lo, logn_hi)}
    
    for result in results:
        model_name, m, vv_name = result["model"].split('_')

        # only select results with the correct varying variable value
        if vv_name != vvv:
            continue
        
        logn = result['logn']
        color = colors_list[logn - logn_lo]
        # print(model_name, m, f'logn: {logn}')   
        
        if scale == 'log2':
            y = np.log2(np.abs(result[f"mean_scores"]))
            hline = np.log2(np.abs(baseline_loss))

        elif scale == 'log10':
            y = np.log10(np.abs(result[f"mean_scores"]))
            hline = np.log10(np.abs(baseline_loss))

        elif scale == 'excess-log10':
            y = np.log10(np.abs(result[f"mean_scores"]) - np.abs(baseline_loss))
            hline = 0

        elif scale == 'linear':
            hline = np.abs(baseline_loss)
            y = np.abs(result[f"mean_scores"])

        trace = go.Box(
            x=[m,]*len(result[f"mean_scores"]),
            y=y,
            name=f'logn={logn}',
            # opacity=0.5,
            legendgroup=str(logn),
            line_color=color,
            offsetgroup=str(logn),
            showlegend=color not in colors_used,
            boxmean=True,
        )

        fig.add_trace(trace, row=r, col=c)
        colors_used.add(color)

        result_logn[logn].append((int(m)/logn, np.mean(y), np.std(y)))

    # add line for baseline loss
    if 'excess' not in scale:
        fig.add_hline(
            y=hline,
            row=r, col=c, line_dash="dash",
        )

    # if c == 1: fig.update_yaxes(title_text=f"{scale}({print_name}) - {is_better} is better", row=r, col=c)
    fig.update_xaxes(title_text="m", type='linear', row=r, col=c)
    fig.update_yaxes(type='linear', row=r, col=c)
    fig.update_layout(boxmode='group')

    return result_logn

In [None]:
for c, vvv in enumerate(['under',]):
    # plot_test_score_vs_m(str(vvv), 1, c+1, scale='linear')
    # plot_test_score_vs_m(str(vvv), 2, c+1, scale='log10')
    st_result_logn = plot_vs_m(fig, st_results, str(vvv), 1, c+1, scale='excess-log10')
    kt_result_logn = plot_vs_m(fig, kt_results, str(vvv), 1, c+1, scale='excess-log10')

fig = go.Figure()
for logn, result in st_result_logn.items():
    x,y, y_std = zip(*st_result_logn[logn])
    fig.add_trace(go.Scatter(
        x=x, 
        y=y, 
        mode='lines+markers', 
        name=f'N={2**logn}', 
        error_y=dict(
            type='data', 
            array=y_std, 
            visible=True
        ),
        line=dict(
            width=1,
            color=colors_list[logn - logn_lo],
            dash='solid',
        ),
        legendgroup=str(logn),
    ))
for logn, result in kt_result_logn.items():
    x,y, y_std = zip(*kt_result_logn[logn])
    fig.add_trace(go.Scatter(
        x=x, 
        y=y, 
        mode='lines+markers', 
        name=f'N={2**logn}', 
        error_y=dict(
            type='data', 
            array=y_std, 
            visible=True
        ),
        line=dict(
            width=1,
            color=colors_list[logn - logn_lo],
            dash='dot',
        ),
        legendgroup=str(logn),
    ))

# update x axis title
fig.update_xaxes(title_text="m/logn")
fig.update_yaxes(title_text="log10(Excess Risk)")
# fig.update_layout(title=f'Figure 2: f={f_name}(sigma={noise})')
fig.show()

kt 0 logn: 8
kt 0 logn: 9
kt 0 logn: 10
kt 0 logn: 11
kt 0 logn: 12
kt 0 logn: 13
kt 2 logn: 8
kt 2 logn: 9
kt 2 logn: 10
kt 2 logn: 11
kt 2 logn: 12
kt 2 logn: 13
kt 4 logn: 8
kt 4 logn: 9
kt 4 logn: 10
kt 4 logn: 11
kt 4 logn: 12
kt 4 logn: 13
kt 6 logn: 8
kt 6 logn: 9
kt 6 logn: 10
kt 6 logn: 11
kt 6 logn: 12
kt 6 logn: 13
kt 8 logn: 8
kt 8 logn: 9
kt 8 logn: 10
kt 8 logn: 11
kt 8 logn: 12
kt 8 logn: 13
kt 10 logn: 10
kt 10 logn: 11
kt 10 logn: 12
kt 10 logn: 13
kt 12 logn: 12
kt 12 logn: 13
kt 0 logn: 8
kt 0 logn: 9
kt 0 logn: 10
kt 0 logn: 11
kt 0 logn: 12
kt 0 logn: 13
kt 2 logn: 8
kt 2 logn: 9
kt 2 logn: 10
kt 2 logn: 11
kt 2 logn: 12
kt 2 logn: 13
kt 4 logn: 8
kt 4 logn: 9
kt 4 logn: 10
kt 4 logn: 11
kt 4 logn: 12
kt 4 logn: 13
kt 6 logn: 8
kt 6 logn: 9
kt 6 logn: 10
kt 6 logn: 11
kt 6 logn: 12
kt 6 logn: 13
kt 8 logn: 8
kt 8 logn: 9
kt 8 logn: 10
kt 8 logn: 11
kt 8 logn: 12
kt 8 logn: 13
kt 10 logn: 10
kt 10 logn: 11
kt 10 logn: 12
kt 10 logn: 13
kt 12 logn: 12
kt 12 logn: 13



invalid value encountered in log10

