## Overview

We are trying to compare the performance of NTK based Kernel Ridge Regression and linear Ridge Regression using [Boston Housing dataset](https://www.cs.toronto.edu/~delve/data/boston/bostonDetail.html)

## Imports and Constants Definition

In [67]:
import numpy as np
import pandas as pd
from collections import OrderedDict
from operator import itemgetter
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.linear_model import Ridge
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import ShuffleSplit
from NTK
import NtkIterator

# Regularization coefficient for (Kernel) Ridge Regression
C_LIST = [10.0 ** i for i in range(-2, 3)]

# Number of splits in cross validation
N_SPLITS = 5
# Fraction of data save for test in cross validation
TEST_FRAC = 0.25

## Load Boston Housing Dataset
This dataset has been studied extensively. For basic analysis on the features of the dataset, please refer to: https://towardsdatascience.com/linear-regression-on-boston-housing-dataset-f409b7e4a155

In [81]:
boston = datasets.load_boston()
print(boston.data.shape, boston.target.shape)

X_train_val, X_test, y_train_val, y_test = train_test_split(
    boston.data, boston.target, test_size=0.4, random_state=0)
print(X_train_val.shape, X_test.shape)

pd.DataFrame(boston.data, columns=boston.feature_names).head()

(506, 13) (506,)
(303, 13) (203, 13)


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33


## Use cross validation to select the best parameter for Neural Tangent Kernel

In [116]:
D_MAX = 6
param_to_metric = OrderedDict()

def kernel_ridge_regression_precomputed(K1, K2, y1, y2, alpha):
    model = KernelRidge(kernel = "precomputed", alpha = alpha)
    model.fit(K1, y1)
    y_hat = model.predict(K2)
    return y_hat

rs = ShuffleSplit(n_splits=N_SPLITS, test_size=TEST_FRAC, random_state=0)
for train_index, val_index in rs.split(X_train_val):
    X_train = X_train_val[train_index]; y_train = y_train_val[train_index]
    X_val = X_train_val[val_index]; y_val  = y_train_val[val_index]
    ntk_train = NtkIterator(X_train, X_train, D_MAX)
    ntk_val = NtkIterator(X_val, X_train, D_MAX)
    for fix_dep in range(D_MAX - 1):
        ntk_train.set_fix_dep(fix_dep)
        ntk_val.set_fix_dep(fix_dep)
        while ntk_train.has_next():
            ntk_train.next(); ntk_val.next()
            for a_index, alpha in enumerate(C_LIST):
                param = (ntk_train.dep - 1, fix_dep, a_index)
                y_hat = kernel_ridge_regression_precomputed(
                            ntk_train.H, ntk_val.H, y_train, y_val, alpha)
                metric = np.sqrt(mean_squared_error(y_val, y_hat))
                if param in param_to_metric:
                    param_to_metric[param] = param_to_metric[param] + metric   
                else:
                    param_to_metric[param] = metric
                print(param, metric, r2_score(y_val, y_hat))

(1, 0, 0) 4.130569189452867 0.7648167043473916
(1, 0, 1) 4.130581698056566 0.7648152799339698
(1, 0, 2) 4.130688104592367 0.7648031627478213
(1, 0, 3) 4.1317661425309655 0.7646803821391691
(1, 0, 4) 4.143220648422896 0.7633738191899152
(2, 0, 0) 4.1643003409410495 0.7609599018839122
(2, 0, 1) 4.164305364686365 0.7609593251352633
(2, 0, 2) 4.164339951078414 0.7609553544423842
(2, 0, 3) 4.1646972049286095 0.7609143379623831
(2, 0, 4) 4.168343766766312 0.7604954732549992
(3, 0, 0) 4.171135987382906 0.7601744951878575
(3, 0, 1) 4.171138725688256 0.7601741803020899
(3, 0, 2) 4.171155495801645 0.7601722518530312
(3, 0, 3) 4.171331261619076 0.7601520395121683
(3, 0, 4) 4.17310765094956 0.7599477142969373
(4, 0, 0) 4.171743240259191 0.7601046603375309
(4, 0, 1) 4.1717449062912335 0.7601044687277508
(4, 0, 2) 4.171754747399061 0.7601033369037165
(4, 0, 3) 4.171858382801558 0.7600914176526801
(4, 0, 4) 4.172901766595511 0.7599714001419944
(5, 0, 0) 4.1700680055228245 0.7602972897933072
(5, 0, 1)

(2, 1, 0) 5.053530361215741 0.6758448623439413
(2, 1, 1) 5.053624782145068 0.6758327491030626
(2, 1, 2) 5.054515191661275 0.675718507525703
(2, 1, 3) 5.062044717163419 0.6747516474730271
(2, 1, 4) 5.089478811728113 0.6712166833726202
(3, 1, 0) 5.107578967866605 0.6688739635504636
(3, 1, 1) 5.107610363539756 0.668869892754171
(3, 1, 2) 5.1078288635910765 0.6688415611126164
(3, 1, 3) 5.1099525773849885 0.6685661282552647
(3, 1, 4) 5.123355229087492 0.6668252439104508
(4, 1, 0) 5.132989861232261 0.6655709742610227
(4, 1, 1) 5.1330047856201935 0.6655690295247553
(4, 1, 2) 5.1330983073352074 0.6655568429618575
(4, 1, 3) 5.134046612666674 0.6654332593069362
(4, 1, 4) 5.141402617367445 0.6644738454943552
(5, 1, 0) 5.1497796050963585 0.663379596142051
(5, 1, 1) 5.149788091081243 0.663378486751657
(5, 1, 2) 5.1498393543471455 0.6633717849597676
(5, 1, 3) 5.150367937890948 0.6633026778461828
(5, 1, 4) 5.154880365093018 0.662712433521601
(3, 2, 0) 5.061958475199884 0.6747627298791226
(3, 2, 1) 5.

(4, 2, 0) 4.780895228597294 0.6553197617276949
(4, 2, 1) 4.780931890239397 0.6553144754407626
(4, 2, 2) 4.781187273312663 0.6552776503189416
(4, 2, 3) 4.7837095287135 0.6549138464770863
(4, 2, 4) 4.802267478601492 0.652231194639578
(5, 2, 0) 4.818219455575849 0.6499169487562082
(5, 2, 1) 4.818236193230727 0.6499145164969382
(5, 2, 2) 4.818350743395499 0.6498978702285056
(5, 2, 3) 4.819508309955025 0.6497296320821799
(5, 2, 4) 4.829241456785744 0.6483134397124024
(4, 3, 0) 4.708408514860334 0.6656924348221553
(4, 3, 1) 4.708508569512714 0.6656782264617229
(4, 3, 2) 4.709448491502388 0.6655447371641303
(4, 3, 3) 4.7177633569160635 0.6643626855990292
(4, 3, 4) 4.762084774175667 0.6580267179530255
(5, 3, 0) 4.793600586384436 0.6534853334147428
(5, 3, 1) 4.793636083760466 0.6534802014031197
(5, 3, 2) 4.793883377424036 0.6534444480170316
(5, 3, 3) 4.796326882480909 0.6530910701844064
(5, 3, 4) 4.814373097397473 0.6504756644187071
(5, 4, 0) 4.725224749623835 0.6633001896748675
(5, 4, 1) 4.725

## List Performance of Top 5 NTK models

In [117]:
res = OrderedDict()
for i in range(5):
    best_dep, best_fix_dep, best_a_index = sorted(param_to_metric.items(), key=itemgetter(1))[i][0]

    ntk_train = NtkIterator(X_train_val, X_train_val, D_MAX)
    ntk_val = NtkIterator(X_test, X_train_val, D_MAX)
    ntk_train.set_fix_dep(best_fix_dep)
    ntk_val.set_fix_dep(best_fix_dep)
    while ntk_train.dep <= best_dep:
        ntk_train.next(); ntk_val.next()
    y_hat = kernel_ridge_regression_precomputed(
        ntk_train.H, ntk_val.H, y_train_val, y_test, C_LIST[best_a_index])
    res[i] = {"num_layers": best_dep, "num_layers_fixed": best_fix_dep,
              "alpha": C_LIST[best_a_index],
              "RMSE": np.sqrt(mean_squared_error(y_test, y_hat)),
              "R2": r2_score(y_test, y_hat)}
pd.DataFrame(res).T

Unnamed: 0,num_layers,num_layers_fixed,alpha,RMSE,R2
0,1.0,0.0,0.01,4.84038,0.7168
1,1.0,0.0,0.1,4.840423,0.716795
2,1.0,0.0,1.0,4.840839,0.716746
3,1.0,0.0,10.0,4.844538,0.716313
4,2.0,1.0,0.01,4.852645,0.715363


## Compare with Kernel Ridge

In [115]:
res = OrderedDict()
param_to_metric = OrderedDict()

rs = ShuffleSplit(n_splits=N_SPLITS, test_size=TEST_FRAC, random_state=0)
for train_index, val_index in rs.split(X_train_val):
    X_train = X_train_val[train_index]; y_train = y_train_val[train_index]
    X_val = X_train_val[val_index]; y_val  = y_train_val[val_index]
    for a_index, alpha in enumerate(C_LIST):
        metric = mean_squared_error(Ridge(alpha = alpha).fit(X_train, y_train).predict(X_val), y_val)
        if param in param_to_metric:
            param_to_metric[a_index] = param_to_metric[a_index] + metric   
        else:
            param_to_metric[a_index] = metric

for i in range(5):
    best_a_index = sorted(param_to_metric.items(), key=itemgetter(1))[i][0]
    y_hat = Ridge(alpha = C_LIST[best_a_index]).fit(X_train_val, y_train_val).predict(X_test)
    res[i] = {"alpha": C_LIST[best_a_index],
              "RMSE": np.sqrt(mean_squared_error(y_test, y_hat)),
              "R2": r2_score(y_test, y_hat)}
    
pd.DataFrame(res).T

Unnamed: 0,alpha,RMSE,R2
0,0.1,5.076836,0.688455
1,0.01,5.07815,0.688294
2,1.0,5.093924,0.686354
3,10.0,5.16419,0.677642
4,100.0,5.346861,0.654433


## Discussions

In this example, NTK based Kernel Ridge Regression showed better performance than linear Ridge Regression. Although we can not conclude that NTK is better based on this simple example, the author think that some properties that the NTK model has demonstrated is quite appealing:
* NTK can approximate behavior of neural networks with infinite width learned with L2 loss, but provides exact solutions which does not depend on hyperparameters like weight initialization methods, learning rate, batch size, etc.
* There are two main hyperparamers: number of layers and number of fixed layers. Many different kernels can be generated using different combinations of these two parameters. This is essentially an efficient way to explore the kernel function space, potentially more efficent than other popular kernels like polynomial kernel and RBF kernels.
* This dataset does not seem to have complicated non-linear interactions. Cross validation consistently show that smaller number of layers, more fixed layers, and less regularization leads to better results in NTK kernel ridge regression. It indicates that if there are enough data for cross validation, it is possible that cross validation can correctly choose the right level of model complexity for NTK kernel ridge regression.