In [23]:
import pandas as pd
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import expit
import jax
import math

from fava.inference.fit import GaussianSKIMFA
from fava.basis.maps import LinearBasis, RepeatedFiniteBasis, TreeBasis
from fava.misc.scheduler import constantScheduler
from fava.misc.logger import GausLogger
from fava.decomposers.tensor_product import TensorProductKernelANOVA, LinearANOVA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import RidgeCV
from sklearn.preprocessing import PolynomialFeatures


In [24]:
key = random.PRNGKey(0)
N = 1000
p = 10
X = random.normal(key, shape=(N, p))
frac_train = .8
N_train = int(N * frac_train)

f_X = X[:, 0] + X[:, 1] + X[:, 2] * X[:, 3]
f_probs = expit(f_X)

Y = jnp.array(2 * random.bernoulli(key, f_probs) - 1, dtype=jnp.float32)

X_train = X[:N_train, :]
Y_train = Y[:N_train]

X_valid = X[N_train:, :]
Y_valid = Y[N_train:]

In [25]:
key = random.PRNGKey(0)
p = X.shape[1]

kernel_params = dict()
Q = 2
kernel_params['U_tilde'] = jnp.ones(p)
kernel_params['eta'] = jnp.ones(Q+1)

hyperparams = dict()
hyperparams['sigma_sq'] = .5 #
hyperparams['c'] = 0.

opt_params = dict()
opt_params['cg'] = True
opt_params['cg_tol'] = .01
opt_params['M'] = 100
opt_params['gamma'] = .1
opt_params['T'] = 250

featprocessor = LinearBasis(X_train)

scheduler = constantScheduler()
logger = GausLogger(100)

opt_params['scheduler'] = scheduler

skim = GaussianSKIMFA(X_train, Y_train, X_valid, Y_valid, featprocessor)

skim.fit(key, hyperparams, kernel_params, opt_params, 
            logger=GausLogger())

  0%|▍                                                                                                        | 1/250 [00:00<01:20,  3.10it/s]

There are 10 covariates selected.
MSE (Validation)=0.750066339969635.
R2 (Validation)=0.24380850791931152.
eta=[1.0000012  0.99953175 0.99983126]
kappa=[0.4998755  0.49995202 0.5000333  0.49998212 0.49995038 0.4998691
 0.5000507  0.4999922  0.49990714 0.49998403]


 40%|█████████████████████████████████████████▌                                                             | 101/250 [00:21<00:35,  4.24it/s]

There are 10 covariates selected.
MSE (Validation)=0.7519117593765259.
R2 (Validation)=0.24194806814193726.
eta=[0.9999719  0.96365047 0.9760169 ]
kappa=[0.4896515  0.49566108 0.50273377 0.4973725  0.4946751  0.4861879
 0.5037017  0.49847153 0.49094585 0.49770862]


 80%|██████████████████████████████████████████████████████████████████████████████████▊                    | 201/250 [00:42<00:11,  4.30it/s]

There are 10 covariates selected.
MSE (Validation)=0.7513584494590759.
R2 (Validation)=0.24250584840774536.
eta=[0.99995965 0.92856145 0.9409224 ]
kappa=[0.4800974  0.4889873  0.5049225  0.49380374 0.4887226  0.4685794
 0.50669765 0.49626207 0.47936398 0.49455035]


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:52<00:00,  4.79it/s]


In [26]:
print(f'SKIM-FA AUROC: {roc_auc_score(Y_valid, skim.predict(X_valid))}')
print(f'SKIM-FA Brier Score: {((Y_valid - skim.predict(X_valid)) ** 2).mean()}')


SKIM-FA AUROC: 0.8157072285512652
SKIM-FA Brier Score: 0.7513584494590759


In [27]:
from sklearn.linear_model import LogisticRegression

In [28]:
platt = LogisticRegression(penalty=None, fit_intercept=True)
skim_train_scores = np.array(skim.predict(X_train)).reshape((X_train.shape[0],1))
skim_val_scores = np.array(skim.predict(X_valid)).reshape((X_valid.shape[0],1))

In [29]:
clf = LogisticRegression(random_state=0).fit(skim_train_scores, np.array(Y_train))

In [30]:
print(f'SKIM-FA Platt Brier Score: {(( (Y_valid + 1) / 2 - clf.predict_proba(skim_val_scores)[:, 1]) ** 2).mean()}')


SKIM-FA Platt Brier Score: 0.1769619882106781


In [31]:
from sklearn.calibration import calibration_curve

In [33]:
calibration_curve((Y_valid + 1) / 2, clf.predict_proba(skim_val_scores)[:, 1], strategy='uniform')

(array([0.18181818, 0.34375   , 0.53125   , 0.70212766, 0.88888889]),
 array([0.08680025, 0.2946468 , 0.49819171, 0.7106324 , 0.89725811]))