# QM9 KRR Benchmark

Code to perform KRR on the QM9 dataset using different Nystrom methods, using 100k randomly selected molecules as training points. The l1 Laplace kernel is used with a bandwidth 5120, and the regularization parameter is 1e-8; both were chosen using cross-validation. This code was used to produce Figure 3 in the manuscript together with `matlab_plotting/make_krr_plots.m`

In [1]:
# %load_ext line_profiler
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

import qml, os
from scipy.io import savemat, loadmat
import numpy as np
from sklearn.preprocessing import StandardScaler
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# import plotly.express as px
import plotly.colors as colors

from KRR_Nystrom import KRR_Nystrom
import rpcholesky
import leverage_score
import unif_sample
import matplotlib.pyplot as plt
import time
from functools import partial
import pickle
# util for parallelizing run trials
from joblib import Parallel, delayed

# kernel thinning
# utils for kernel ridge regression
from goodpoints.krr.util_estimators import get_estimator

`eigenpro2` is not installed...
Using `torch.linalg.solve` for training the kernel model

          and may cause an `Out-of-Memory` error
`eigenpro2` is a more scalable solver. To use, pass `method="eigenpro"` to `model.fit()`
To install `eigenpro2` visit https://github.com/EigenPro/EigenPro-pytorch/tree/pytorch/


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import joblib
print(joblib.__version__)
print(joblib.cpu_count())

1.2.0
8


In [4]:
# add this to be able to render plotly plots in non-vscode notebooks
import plotly.io as pio
pio.renderers.default = "notebook_connected"

In [5]:

def get_molecules(directory = "molecules/", max_atoms = 29, max_mols = np.Inf, output_index = 7):
    compounds = []
    energies = []
    for f in sorted(os.listdir("molecules/")):
        if len(compounds) >= max_mols:
            break

        try:
            mol = qml.Compound(xyz="molecules/"+f)
            mol.generate_coulomb_matrix(size=max_atoms, sorting="row-norm")
            with open("molecules/"+f) as myfile:
                line = list(myfile.readlines())[1]
                energies.append(float(line.split()[output_index]) * 27.2114) # Hartrees to eV
            compounds.append(mol)
        except ValueError:
            pass
    
    c = list(zip(compounds, energies))
    np.random.shuffle(c)
    compounds, energies = zip(*c)

    X = np.array([mol.representation for mol in compounds])
    Y = np.array(energies).reshape((X.shape[0],1))

    return X, Y 
    

In [6]:
if not os.path.isfile("data/homo.mat"):
    X, Y = get_molecules()
    data = { "X" : X, "Y" : Y }
    savemat("data/homo.mat", data)
else:
    data = loadmat("data/homo.mat")
    
feature = data['X']
target = data['Y'].flatten()
scaler = StandardScaler()
feature = scaler.fit_transform(feature)
n,d = np.shape(feature)


In [7]:
# From KT compress
def log4(n):
    return np.log2(n) / 2
def get_g(n):
    return int( np.ceil( log4(log4(n)) ) ) # Use default value
def largest_power_of_four(n):
    """Returns largest power of four less than or equal to n
    """
    return 4**( (n.bit_length() - 1 )//2)

def get_coreset_size(n, m=1):
    if get_g(n) <= m:
        # with TicToc('compresspp', print_toc=PRINT_TOC):
        # Compress with g'=g+inflation (compressing returns set of size 2^(g+inflation) sqrt(n) )
        # Thin with g'=g (thinning returns set of size 2^inflation sqrt(n) )
        largest_pow_four = largest_power_of_four(n)
        log2n = n.bit_length() - 1
        scale = n // largest_pow_four
        return 2**( 2*(log2n//2) - m ) * scale
    else:
        return int(n / 2**m)

In [8]:

# num_train = 20000
# num_test = 20000 

num_train = 100000
num_test = n - num_train


# ks = range(200, 1200, 200)
sqrt_n = get_coreset_size(num_train, int(log4(num_train)))
ks = [ sqrt_n, sqrt_n* 2, sqrt_n*4 ]
print('ks', ks)

train_sample = feature[:num_train]
train_sample_target = target[:num_train]
test_sample = feature[num_train:num_train+num_test]
test_sample_target = target[num_train:num_train+num_test]


ks [256, 512, 1024]


In [9]:

def mean_squared_error(true, pred):
    return np.mean((true - pred)**2)
def mean_average_error(true, pred):
    return np.mean(np.abs(true - pred))
def SMAPE(true,pred):
    return np.mean(abs(true - pred)/((abs(true)+abs(pred))/2))


In [10]:

methods = { 
    # deterministic methods
    'Greedy' : rpcholesky.greedy, 
    
    # random methods
    'Uniform' : unif_sample.uniform_sample,
    'RPCholesky' : rpcholesky.rpcholesky,
    'RLS' : leverage_score.recursive_rls_acc,
    'block50RPCholesky' : partial(rpcholesky.block_rpcholesky,b=50),

    'kt': None,
    'st' : None
}

num_trials = 100
lamb = 1.0e-8
sigma = 5120.0
result = dict()
# n_jobs = 2 
n_jobs = 8
savepath = f"data/molecule{num_train // 1000}k-trials={num_trials}.pkl"
print(savepath)

data/molecule100k-trials=100.pkl


In [11]:
def train_predict(name, method, train_sample, test_sample, k, idx_k):
    start = time.time()
    if name == 'Greedy':
        model = KRR_Nystrom(kernel = "laplace", 
                    bandwidth = sigma)
        model.fit_Nystrom(train_sample, train_sample_target, lamb = lamb, sample_num = k, sample_method = method, solve_method = solve_method)
        preds = model.predict_Nystrom(test_sample)
    elif name in ['kt', 'st']: # our methods
        while True:
            try:
                # print(f"Trial {i}")
                model = get_estimator(
                    'regression', 
                    name.lower(), 
                    kernel='laplace',
                    alpha=lamb,
                    sigma=sigma,
                    m=int(log4(num_train))-idx_k,
                )

                model.fit(train_sample, train_sample_target)
                assert len(model.sol_) == k, f"len(model.sol_)={len(model.sol_)} should be same as k={k}"
                preds = model.predict(test_sample)

                break
            except np.linalg.LinAlgError:
                continue
    else:
        while True:
            try:
                # print(f"Trial {i}")
                # # Original
                # model = KRR_Nystrom(kernel = "gaussian", bandwidth = sigma)
                # # Bug fix
                model = KRR_Nystrom(kernel = "laplace", bandwidth = sigma)
                
                model.fit_Nystrom(train_sample, train_sample_target, lamb = lamb, sample_num = k, sample_method = method, solve_method = solve_method)
                preds = model.predict_Nystrom(test_sample)
                break
            except np.linalg.LinAlgError:
                continue
    end = time.time()
    return preds, end - start
    # return preds

In [12]:

solve_method = 'Direct'

for name, method in methods.items():
    result[name] = dict()
    print(f'------------- Method: {name} -------------')
    result[name]["trace_errors"] = {} #np.zeros((len(ks),2))
    result[name]["KRRMSE"] = {} #np.zeros((len(ks),2))
    result[name]["KRRMAE"] = {} #np.zeros((len(ks),2))
    result[name]["KRRSMAPE"] = {}   #np.zeros((len(ks),2))
    result[name]["queries"] = {}    #np.zeros((len(ks),2))

    for idx_k in range(len(ks)):
        k = ks[idx_k]
        print(f'k = {k}')
        trace_err = []
        runtime = []
        queries = []
        KRRmse = []
        KRRmae = []
        KRRsmape = []

        if name == 'Greedy':
            trials = 1 # deterministic
        else:
            trials = num_trials # stochastic

        parallel = Parallel(n_jobs= n_jobs) #, return_as="generator") need joblib>=1.3 for return_as functionality
        output_generator = parallel(delayed(train_predict)(
            name, 
            method, 
            train_sample, 
            test_sample,
            k,
            idx_k,
        ) for _ in range(trials))

        for preds, elapsed_time in output_generator:
            KRRmse.append(mean_squared_error(test_sample_target, preds))
            KRRmae.append(mean_average_error(test_sample_target, preds))
            KRRsmape.append(SMAPE(test_sample_target, preds))
            # queries.append(model.queries)
            # trace_err.append(model.reltrace_err) 
            
            # TODO: placeholder for now
            queries.append(np.nan)
            trace_err.append(np.nan)

            print(f'KRR acc: mse {KRRmse[-1]}, mae {KRRmae[-1]}, smape {KRRsmape[-1]}')
            # print(f'time: sample {model.sample_time}, linsolve {model.linsolve_time}, pred {model.pred_time}')
            print(f'time: {elapsed_time}')
            
        result[name]["trace_errors"][k] = trace_err   # [np.mean(trace_err),np.std(trace_err)]
        result[name]["KRRMSE"][k] = KRRmse    # [np.mean(KRRmse),np.std(KRRmse)]
        result[name]["KRRMAE"][k] = KRRmae    #[np.mean(KRRmae),np.std(KRRmae)]
        result[name]["KRRSMAPE"][k] = KRRsmape    #[np.mean(KRRsmape),np.std(KRRsmape)]
        result[name]["queries"][k] = queries  #[np.mean(queries)/float(num_train**2),np.std(queries)/float(num_train**2)]

        # savemat("data/{}_molecule100k.mat".format(name), result[name])
        # use pickle to save the result periodically
        with open(savepath, 'wb') as f:
            pickle.dump(result, f)


------------- Method: Greedy -------------
k = 256


KRR acc: mse 0.2953552168387065, mae 0.4035185379305591, smape 0.06261773763928957
time: 39.453399658203125
k = 512


KRR acc: mse 0.24421373487415501, mae 0.36844320535362246, smape 0.05722567879137312
time: 81.02892112731934
k = 1024


KRR acc: mse 0.19397560568887487, mae 0.329450039528845, smape 0.05118195187590928
time: 173.9728467464447
------------- Method: Uniform -------------
k = 256


KRR acc: mse 0.22343088147622003, mae 0.350080467365929, smape 0.05442110121672776
time: 10.7155921459198
KRR acc: mse 0.2230129292819063, mae 0.3494556090531451, smape 0.05429602304705962
time: 10.601705074310303
KRR acc: mse 0.22540637721269666, mae 0.3521572927883694, smape 0.05473928119766837
time: 10.754739999771118
KRR acc: mse 0.22136375513269888, mae 0.34923754360501186, smape 0.05425512593196117
time: 10.733570337295532
KRR acc: mse 0.21442888385545614, mae 0.3432807032223982, smape 0.053369637680714364
time: 10.873241186141968
KRR acc: mse 0.2157875924269255, mae 0.3444035218294266, smape 0.05352094965161737
time: 10.817147493362427
KRR acc: mse 0.20868490939111403, mae 0.340180317477612, smape 0.05285615206592838
time: 10.745110750198364
KRR acc: mse 0.21376915877021563, mae 0.34514093847111643, smape 0.0536067613840602
time: 10.839035034179688
KRR acc: mse 0.2081954532904643, mae 0.34003822374478593, smape 0.05284500983836238
time: 10.69798755645752
KRR acc: mse 0.216116320

KRR acc: mse 0.17607963746263325, mae 0.3147422102439733, smape 0.048887696063317276
time: 21.55115818977356
KRR acc: mse 0.1837533110752012, mae 0.31866436937986703, smape 0.04947441027879041
time: 21.65778875350952
KRR acc: mse 0.17351995745262255, mae 0.3120524389691739, smape 0.04846634739443829
time: 21.74245548248291
KRR acc: mse 0.18027539702586404, mae 0.3181327455005253, smape 0.04940416361928604
time: 21.821916341781616
KRR acc: mse 0.18177678911608655, mae 0.317736795852945, smape 0.049351048961995034
time: 21.518473625183105
KRR acc: mse 0.17813392127612515, mae 0.31699411708297653, smape 0.04923473625839034
time: 21.670578002929688
KRR acc: mse 0.17108153991695926, mae 0.31070779139883037, smape 0.04826648814072252
time: 21.74284315109253
KRR acc: mse 0.18455414246805374, mae 0.320979899055722, smape 0.049871749623979035
time: 21.48123598098755
KRR acc: mse 0.18133592857769917, mae 0.31730007534741406, smape 0.0492805286328102
time: 21.523186683654785
KRR acc: mse 0.179292

  self.sol = scipy.linalg.solve(KMn @ KnM + KnM.shape[0]*lamb*KMM + 100*KMM.max()*np.finfo(float).eps*np.identity(sample_num), KMn @ Ytr, assume_a='pos')


  self.sol = scipy.linalg.solve(KMn @ KnM + KnM.shape[0]*lamb*KMM + 100*KMM.max()*np.finfo(float).eps*np.identity(sample_num), KMn @ Ytr, assume_a='pos')


  self.sol = scipy.linalg.solve(KMn @ KnM + KnM.shape[0]*lamb*KMM + 100*KMM.max()*np.finfo(float).eps*np.identity(sample_num), KMn @ Ytr, assume_a='pos')


  self.sol = scipy.linalg.solve(KMn @ KnM + KnM.shape[0]*lamb*KMM + 100*KMM.max()*np.finfo(float).eps*np.identity(sample_num), KMn @ Ytr, assume_a='pos')


KRR acc: mse 0.14187375427107907, mae 0.2812318988810599, smape 0.04364980938517731
time: 45.04656100273132
KRR acc: mse 0.135937006491588, mae 0.27835627325245765, smape 0.043174576419473656
time: 45.2932550907135
KRR acc: mse 0.14053390422220263, mae 0.2792479128342109, smape 0.04336065581075389
time: 45.194960594177246
KRR acc: mse 0.14148661089665684, mae 0.28129033933540476, smape 0.043652105052614736
time: 45.26844596862793
KRR acc: mse 0.13824756441262262, mae 0.2810475043194215, smape 0.0435956649738083
time: 44.323896646499634
KRR acc: mse 0.13908160776329048, mae 0.27917848569161124, smape 0.04335795056560413
time: 44.47974753379822
KRR acc: mse 0.14060459994235555, mae 0.2802054115130715, smape 0.043474132294341564
time: 44.71091175079346
KRR acc: mse 0.1399757078012694, mae 0.27954658302199625, smape 0.04339316425560432
time: 45.0292181968689
KRR acc: mse 0.13955229101964622, mae 0.27903285211775375, smape 0.04332044607278284
time: 44.34794759750366
KRR acc: mse 0.136889031

KRR acc: mse 0.21138728876126114, mae 0.3443415578070268, smape 0.05348353981373336
time: 45.698691606521606
KRR acc: mse 0.21926092454855434, mae 0.3501507573675779, smape 0.054394337823932555
time: 45.908162117004395
KRR acc: mse 0.21264392128354445, mae 0.34533561477441443, smape 0.05365611796539146
time: 45.38870906829834
KRR acc: mse 0.21392422620341817, mae 0.34761631825396877, smape 0.05401186831517695
time: 45.73367691040039
KRR acc: mse 0.2125171363438798, mae 0.344207605846156, smape 0.05349172964443368
time: 45.36164212226868
KRR acc: mse 0.22274133181036396, mae 0.3540403736462429, smape 0.05500496680430834
time: 45.53053617477417
KRR acc: mse 0.21228487597036916, mae 0.3466432568654447, smape 0.05385765853030277
time: 45.53023815155029
KRR acc: mse 0.21305736541726497, mae 0.34674093570894426, smape 0.05387298581410101
time: 45.55008125305176
KRR acc: mse 0.21117013780819874, mae 0.3449558818654263, smape 0.053602959672412784
time: 45.68190908432007
KRR acc: mse 0.21087285

KRR acc: mse 0.1752340413669396, mae 0.31688163191148766, smape 0.04920615760375579
time: 93.04933071136475
KRR acc: mse 0.1798132921643946, mae 0.32187671946744506, smape 0.04997231981973246
time: 93.85233545303345
KRR acc: mse 0.17662947030281462, mae 0.3164492712110187, smape 0.049140069269246
time: 93.64322924613953
KRR acc: mse 0.1817333946257958, mae 0.3220350090994003, smape 0.05000445984836154
time: 94.27462792396545
KRR acc: mse 0.17890934757000423, mae 0.3198027558010472, smape 0.04963419560099992
time: 92.83134007453918
KRR acc: mse 0.1787619065735982, mae 0.31968997523444564, smape 0.049655844405715514
time: 92.8592643737793
KRR acc: mse 0.17519612226249287, mae 0.3172048317374194, smape 0.04923026613874764
time: 93.23465061187744
KRR acc: mse 0.1780408644028838, mae 0.3197397802355828, smape 0.049655364930273976
time: 93.67913150787354
KRR acc: mse 0.17803012180231162, mae 0.3185238618636491, smape 0.0494432753096491
time: 93.0277783870697
KRR acc: mse 0.1788779712140007, 

KRR acc: mse 0.14063957266915855, mae 0.2837133796727442, smape 0.0440313688691581
time: 202.41067671775818
KRR acc: mse 0.14141785553895747, mae 0.2844944674785357, smape 0.04416439581936814
time: 202.07690024375916
KRR acc: mse 0.14377326497948695, mae 0.28743571858985884, smape 0.04461295822193428
time: 202.65485620498657
KRR acc: mse 0.14143373004555845, mae 0.28489823696433464, smape 0.04419661578812029
time: 202.66084933280945
KRR acc: mse 0.1413016134428986, mae 0.2844566677744404, smape 0.04413722514345081
time: 198.7896511554718
KRR acc: mse 0.14337173463271008, mae 0.28615907777154315, smape 0.04437873095360086
time: 199.47917795181274
KRR acc: mse 0.14043095116244628, mae 0.2843979917467744, smape 0.04412725333206386
time: 199.7887818813324
KRR acc: mse 0.14233192520665439, mae 0.2856178385976356, smape 0.04430761748205083
time: 201.60865378379822
KRR acc: mse 0.14448318215483458, mae 0.2874942152749719, smape 0.04464368906725232
time: 198.22815418243408
KRR acc: mse 0.14202

KRR acc: mse 0.21115337060794695, mae 0.34388127679431335, smape 0.05343257545383293
time: 16.997455835342407
KRR acc: mse 0.21035701910867272, mae 0.3440721967421283, smape 0.053481639464797934
time: 17.68487524986267
KRR acc: mse 0.21433478231576145, mae 0.3475483099386972, smape 0.05402752904102812
time: 17.288947343826294
KRR acc: mse 0.2116952369465441, mae 0.3442806119005283, smape 0.05347939249426432
time: 17.390276193618774
KRR acc: mse 0.20717502809351007, mae 0.34187553182778563, smape 0.053133540680026896
time: 17.426438331604004
KRR acc: mse 0.21234643498795377, mae 0.34554354057817144, smape 0.05367334699652959
time: 17.02574872970581
KRR acc: mse 0.216819933449447, mae 0.3465635176629244, smape 0.05386792692017623
time: 17.0366849899292
KRR acc: mse 0.2100460859472401, mae 0.34333432866303104, smape 0.05336774230346173
time: 17.584407806396484
KRR acc: mse 0.21653695483422591, mae 0.3463962246133931, smape 0.053827787890116226
time: 17.39273953437805
KRR acc: mse 0.215080

  self.sol = scipy.linalg.solve(KMn @ KnM + KnM.shape[0]*lamb*KMM + 100*KMM.max()*np.finfo(float).eps*np.identity(sample_num), KMn @ Ytr, assume_a='pos')


KRR acc: mse 0.17764098842172468, mae 0.31705010244273996, smape 0.04923548849801486
time: 34.3268768787384
KRR acc: mse 0.17710003543228794, mae 0.318074583285096, smape 0.04941634803887721
time: 33.443931341171265
KRR acc: mse 0.1780478819067894, mae 0.31747716769481327, smape 0.049355071854011175
time: 33.913209438323975
KRR acc: mse 0.18009698627979512, mae 0.3202466640335214, smape 0.049709769553951184
time: 34.85738205909729
KRR acc: mse 0.18049332859856881, mae 0.3201400475353116, smape 0.049739752105060604
time: 32.97246789932251
KRR acc: mse 0.17050007122108965, mae 0.3130249667005912, smape 0.04861048206341081
time: 33.42473077774048
KRR acc: mse 0.17769287422987234, mae 0.31812030866812235, smape 0.04941000664523382
time: 33.54075360298157
KRR acc: mse 0.17458928541370455, mae 0.3174762118379543, smape 0.04928854309652557
time: 33.81065487861633
KRR acc: mse 0.17789381094428242, mae 0.3181554606089792, smape 0.04942781901814601
time: 32.77408719062805
KRR acc: mse 0.17644204

KRR acc: mse 0.13894643872944454, mae 0.2816785451298462, smape 0.04373612314867245
time: 70.92500734329224
KRR acc: mse 0.13836039083099916, mae 0.2816444778428111, smape 0.043708332344353916
time: 70.97019815444946
KRR acc: mse 0.14180597472495024, mae 0.2848332332265937, smape 0.04417841351424615
time: 72.38344287872314
KRR acc: mse 0.13980790530945733, mae 0.28357484942584316, smape 0.04404057422823912
time: 71.71982526779175
KRR acc: mse 0.1394158526553277, mae 0.2832941631538113, smape 0.04397581088703215
time: 70.29313731193542
KRR acc: mse 0.13753725418031684, mae 0.2814652140684072, smape 0.04367218947762679
time: 71.68643283843994
KRR acc: mse 0.1407783035252812, mae 0.2841682178624577, smape 0.04412373244892965
time: 69.92167329788208
KRR acc: mse 0.13716789127833476, mae 0.2807220939425477, smape 0.0435647226108897
time: 70.77108383178711
KRR acc: mse 0.14169770722560446, mae 0.284644175187317, smape 0.0441690426856466
time: 69.4543809890747
KRR acc: mse 0.13906645433303422

KRR acc: mse 0.21563774693792384, mae 0.346902203339851, smape 0.05391011448035463
time: 12.264813899993896
KRR acc: mse 0.21440604760208762, mae 0.3469397257416902, smape 0.05390168147008618
time: 12.23325228691101
KRR acc: mse 0.21958111959329807, mae 0.3503696990603966, smape 0.05443103074430905
time: 12.241312026977539
KRR acc: mse 0.2133328210880379, mae 0.3467697851882933, smape 0.053891206950143884
time: 12.235602617263794
KRR acc: mse 0.21517984744303503, mae 0.3465862496592209, smape 0.05384090724013947
time: 11.969645261764526
KRR acc: mse 0.20902688854939475, mae 0.34365229246266843, smape 0.053386143160881715
time: 12.1319739818573
KRR acc: mse 0.21249845195136186, mae 0.34668296269474863, smape 0.053869115876281064
time: 12.106831789016724
KRR acc: mse 0.21867528960384008, mae 0.3500957256636071, smape 0.054401756080999925
time: 12.193360567092896
KRR acc: mse 0.21315919490023794, mae 0.3473008372471172, smape 0.05393466091985036
time: 11.973363399505615
KRR acc: mse 0.211

KRR acc: mse 0.17448650231163002, mae 0.3160398558321291, smape 0.04906051555682613
time: 23.783677101135254
KRR acc: mse 0.17995556791015757, mae 0.3202012559379759, smape 0.0497074065744198
time: 23.820955753326416
KRR acc: mse 0.18001394929374692, mae 0.32088904623324027, smape 0.04985830904511795
time: 23.810956716537476
KRR acc: mse 0.17877154404276868, mae 0.3196760281991747, smape 0.049635966037939176
time: 23.80834698677063
KRR acc: mse 0.17571159084428206, mae 0.3167290524510405, smape 0.049204168866251785
time: 23.267433166503906
KRR acc: mse 0.18212582225752558, mae 0.321539299044864, smape 0.04991847564588605
time: 23.537238121032715
KRR acc: mse 0.17407639580059545, mae 0.3151340369674215, smape 0.04892914003910123
time: 23.766276121139526
KRR acc: mse 0.1797517731399827, mae 0.32051756660528447, smape 0.04978729592283
time: 23.716209173202515
KRR acc: mse 0.1790947605442601, mae 0.3194017365895351, smape 0.04960892888155733
time: 23.22209143638611
KRR acc: mse 0.177501640

KRR acc: mse 0.14227551813786177, mae 0.28636717419503305, smape 0.04442425075362669
time: 48.795745611190796
KRR acc: mse 0.13826066135612183, mae 0.2825072280830979, smape 0.043843905942073345
time: 48.801443099975586
KRR acc: mse 0.14209918696057514, mae 0.28622195016435986, smape 0.04441026688853434
time: 48.783567667007446
KRR acc: mse 0.14256715532494343, mae 0.2858956450084181, smape 0.04438675455437578
time: 48.53784966468811
KRR acc: mse 0.14145001501821353, mae 0.2846282196878135, smape 0.044166089390028605
time: 47.587356090545654
KRR acc: mse 0.1411216539805289, mae 0.2846189008261651, smape 0.04416969390470749
time: 48.13260841369629
KRR acc: mse 0.13743840772122207, mae 0.28102162066855074, smape 0.043606094591035705
time: 48.217973947525024
KRR acc: mse 0.1419303356283296, mae 0.28564636950937394, smape 0.044351528661588785
time: 48.51805639266968
KRR acc: mse 0.1403002093023481, mae 0.28409879044993763, smape 0.04411434705263213
time: 47.13744020462036
KRR acc: mse 0.14

KRR acc: mse 0.3021067573952497, mae 0.408934255944167, smape 0.06344246302458745
time: 35.12693524360657
KRR acc: mse 0.3206561816524344, mae 0.42330399259636475, smape 0.06568918240854951
time: 35.33640122413635
KRR acc: mse 0.3180672942700211, mae 0.4177547807491926, smape 0.0647646599280881
time: 37.817166566848755
KRR acc: mse 0.3164322832517267, mae 0.42156845734782394, smape 0.06540492165508026
time: 38.26119685173035
KRR acc: mse 0.32763649220186347, mae 0.4214027219443896, smape 0.0653259461463449
time: 36.027647972106934
KRR acc: mse 0.3052386992336667, mae 0.4101140572833136, smape 0.06362622596443798
time: 37.57387709617615
KRR acc: mse 0.31796313396867637, mae 0.41939703805705586, smape 0.06500478301403435
time: 38.474459171295166
KRR acc: mse 0.30319533021799144, mae 0.41105785176554704, smape 0.06379130131298365
time: 36.67221474647522
KRR acc: mse 0.3114340901470796, mae 0.4105462351349536, smape 0.06362818745087048
time: 38.04887390136719
KRR acc: mse 0.304617334918673

KRR acc: mse 0.2931328780388128, mae 0.4044937613264973, smape 0.06280767692663468
time: 152.82203340530396
KRR acc: mse 0.27788851259688785, mae 0.39086762548061393, smape 0.06069525746961008
time: 149.69806933403015
KRR acc: mse 0.2702656933547271, mae 0.38131272639400415, smape 0.05926418574083677
time: 147.3258306980133
KRR acc: mse 0.2799927567441056, mae 0.39261938617758874, smape 0.06096828271574772
time: 147.32163906097412
KRR acc: mse 0.2698300991164483, mae 0.3843434359121515, smape 0.05973426884574436
time: 131.80755138397217
KRR acc: mse 0.27452965857577477, mae 0.3894057017923185, smape 0.060528178341551865
time: 153.0134997367859
KRR acc: mse 0.2731665920864147, mae 0.3854251266026881, smape 0.059836967080101784
time: 131.83930087089539
KRR acc: mse 0.2721436569835743, mae 0.38302985855363925, smape 0.05954839181274681
time: 138.48391246795654
KRR acc: mse 0.2814363560874731, mae 0.3907826992549248, smape 0.06072490075089531
time: 129.69393682479858
KRR acc: mse 0.2745506

`eigenpro2` is not installed...
Using `torch.linalg.solve` for training the kernel model

          and may cause an `Out-of-Memory` error
`eigenpro2` is a more scalable solver. To use, pass `method="eigenpro"` to `model.fit()`
To install `eigenpro2` visit https://github.com/EigenPro/EigenPro-pytorch/tree/pytorch/


`eigenpro2` is not installed...
Using `torch.linalg.solve` for training the kernel model

          and may cause an `Out-of-Memory` error
`eigenpro2` is a more scalable solver. To use, pass `method="eigenpro"` to `model.fit()`
To install `eigenpro2` visit https://github.com/EigenPro/EigenPro-pytorch/tree/pytorch/


KRR acc: mse 0.2449796885564612, mae 0.3617801040578157, smape 0.056257150213205355
time: 612.8018391132355
KRR acc: mse 0.24026014825971392, mae 0.3603353907105764, smape 0.0560300970416014
time: 625.9755816459656
KRR acc: mse 0.24521589820027467, mae 0.36548426159429537, smape 0.05682489377650409
time: 540.1509802341461
KRR acc: mse 0.2386337863855084, mae 0.36116438493964526, smape 0.056172354771911834
time: 633.4110040664673
KRR acc: mse 0.2351859574404182, mae 0.359838633021613, smape 0.05595506769003765
time: 605.8466169834137
KRR acc: mse 0.24689038130666238, mae 0.3695654437097423, smape 0.05745247380886873
time: 638.1866703033447
KRR acc: mse 0.24426902152993338, mae 0.36623310995865227, smape 0.056937634595536454
time: 637.9279382228851
KRR acc: mse 0.24190903763590738, mae 0.36361813762684486, smape 0.05656339873237848
time: 580.8703162670135
KRR acc: mse 0.23623999638891935, mae 0.3594968492246981, smape 0.05591797667236665
time: 525.363608121872
KRR acc: mse 0.249523628138

KRR acc: mse 0.30352164930348413, mae 0.40585583854316853, smape 0.0630255233254316
time: 0.30269742012023926
KRR acc: mse 0.28902236200212605, mae 0.3955031774611388, smape 0.06142254157100841
time: 0.3037123680114746
KRR acc: mse 0.3229577597908488, mae 0.42151224494418515, smape 0.06539375047288618
time: 0.30364227294921875
KRR acc: mse 0.31954895375791337, mae 0.4167844505079748, smape 0.06466513899991212
time: 0.3022754192352295
KRR acc: mse 0.2977470504593079, mae 0.40414681994955404, smape 0.06274862628185432
time: 0.29883766174316406
KRR acc: mse 0.3073525070027638, mae 0.4109355422444997, smape 0.06381606606728867
time: 0.3041951656341553
KRR acc: mse 0.31175023905110005, mae 0.40799204649041176, smape 0.06329319884041494
time: 0.3370335102081299
KRR acc: mse 0.32053351684629894, mae 0.4272132091645062, smape 0.0663115464223345
time: 0.33778929710388184
KRR acc: mse 0.3079533940324442, mae 0.4095152215673294, smape 0.06356080273486374
time: 0.30503201484680176
KRR acc: mse 0.3

KRR acc: mse 0.27771200376080707, mae 0.3892323120573496, smape 0.060444604998348854
time: 0.5588119029998779
KRR acc: mse 0.2737291112776941, mae 0.38774724152601286, smape 0.0602787324185591
time: 0.5609698295593262
KRR acc: mse 0.2739404012217588, mae 0.3849252291122066, smape 0.05984080906719013
time: 0.5598971843719482
KRR acc: mse 0.27279373350940145, mae 0.384948113023668, smape 0.05980506551035177
time: 0.5611460208892822
KRR acc: mse 0.27954243917098914, mae 0.39577243907050924, smape 0.06151396253876678
time: 0.5633182525634766
KRR acc: mse 0.2785614912406888, mae 0.3952108953497294, smape 0.06136333749813395
time: 0.5636563301086426
KRR acc: mse 0.28215695524088913, mae 0.3922358094090068, smape 0.06095865190162618
time: 0.5533661842346191
KRR acc: mse 0.28929327032042235, mae 0.39794823594654216, smape 0.06183504976653982
time: 0.5563220977783203
KRR acc: mse 0.28745646914953965, mae 0.40168843502742385, smape 0.0623824531339162
time: 0.5580735206604004
KRR acc: mse 0.28186

KRR acc: mse 0.23741303038469186, mae 0.3611944614381893, smape 0.056186569310182684
time: 1.0845158100128174
KRR acc: mse 0.24129646293527932, mae 0.3666698159902612, smape 0.05699510617712132
time: 1.0794346332550049
KRR acc: mse 0.2404881100487247, mae 0.3648230244371633, smape 0.05671080408582734
time: 1.077453374862671
KRR acc: mse 0.24408409853356805, mae 0.362526046127279, smape 0.05637625641218269
time: 1.1043992042541504
KRR acc: mse 0.2585642972136433, mae 0.3759216350766435, smape 0.05845126013499966
time: 1.0705580711364746
KRR acc: mse 0.24902814858747513, mae 0.3688377657859851, smape 0.05733951655883235
time: 1.0691072940826416
KRR acc: mse 0.23936462982366583, mae 0.3607624688054441, smape 0.056081668506378725
time: 1.0879130363464355
KRR acc: mse 0.24665438178655116, mae 0.36862501719926627, smape 0.05732299940278137
time: 1.0959758758544922
KRR acc: mse 0.24525463266585099, mae 0.3644088070870553, smape 0.0567158777888392
time: 1.079817295074463
KRR acc: mse 0.2461094

In [13]:
result

{'Greedy': {'trace_errors': {256: [nan], 512: [nan], 1024: [nan]},
  'KRRMSE': {256: [0.2953552168387065],
   512: [0.24421373487415501],
   1024: [0.19397560568887487]},
  'KRRMAE': {256: [0.4035185379305591],
   512: [0.36844320535362246],
   1024: [0.329450039528845]},
  'KRRSMAPE': {256: [0.06261773763928957],
   512: [0.05722567879137312],
   1024: [0.05118195187590928]},
  'queries': {256: [nan], 512: [nan], 1024: [nan]}},
 'Uniform': {'trace_errors': {256: [nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    nan,
    

## Plot results

In [14]:
with open(savepath, "rb") as f:
    result = pickle.load(f)
# with open("data/molecule20k-kt.pkl", "rb") as f:
#     result2 = pickle.load(f)
# concatenate the results
# result.update(result2)

In [15]:
# plot line graph with error bars
metrics = ['KRRSMAPE', 'KRRMSE']
fig = make_subplots(rows=len(metrics), cols=1, subplot_titles=metrics, shared_xaxes=True, vertical_spacing=0.1)
colors_list = colors.qualitative.Plotly 
# * (
#     len(model_names) // len(colors.qualitative.Plotly) + 1
# )
model_names = list(result.keys())

for name in result.keys():
    color = colors_list[model_names.index(name)]
    for r, metric in enumerate(metrics):
        if name in ['kt', 'st']:
            for k, vals in result[name][metric].items():
                # print(k, krrmse)
                fig.add_trace(go.Box(
                    x=[k] *len(vals),
                    y=vals,
                    name=name,
                    # opacity=0.5,
                    legendgroup=name,
                    # line_color=color,
                    # offsetgroup=model_name_prefix,
                    # showlegend=color not in colors_used,
                    boxmean=True,
                    line_color = color
                ), row=r+1, col=1)
        else:
            means = [np.mean(krrmse) for _, krrmse in result[name][metric].items() ]
            stds = [np.std(krrmse) for _, krrmse in result[name][metric].items() ]
            # print(ks, means, stds)
            fig.add_trace(go.Scatter(
                x=ks, 
                y=means, 
                mode='lines+markers', 
                name=name, 
                error_y=dict(
                    type='data', # value of error bar given in data coordinates
                    array=stds,
                    visible=True
                ),
                legendgroup=name,
                line_color = color
            ), row=r+1, col=1)
    fig.update_xaxes(title_text="k", row=r+1, col=1)
        
fig.update_layout(title='MSE vs. k (columns)',
                    #  xaxis_title='Coreset Size',
                    #  yaxis_title='MSE',
                     height=800, width=800)
fig.show()