In [34]:
import pandas as pd
import numpy as np

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import expit
import jax
import math

from fava.inference.fit import GaussianSKIMFA
from fava.basis.maps import LinearBasis, RepeatedFiniteBasis, TreeBasis
from fava.misc.scheduler import constantScheduler
from fava.misc.logger import GausLogger
from fava.decomposers.tensor_product import TensorProductKernelANOVA, LinearANOVA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import RidgeCV
from sklearn.preprocessing import PolynomialFeatures


In [35]:
key = random.PRNGKey(0)
N = 1000
p = 500
X = random.normal(key, shape=(N, p))
frac_train = .8
N_train = int(N * frac_train)

f_X = X[:, 0] + X[:, 1] + X[:, 2] * X[:, 3]
f_probs = expit(f_X)

Y = jnp.array(2 * random.bernoulli(key, f_probs) - 1, dtype=jnp.float32)

X_train = X[:N_train, :]
Y_train = Y[:N_train]

X_valid = X[N_train:, :]
Y_valid = Y[N_train:]

In [36]:
key = random.PRNGKey(0)
p = X.shape[1]

kernel_params = dict()
Q = 2
kernel_params['U_tilde'] = jnp.ones(p)
kernel_params['eta'] = jnp.ones(Q+1)

hyperparams = dict()
hyperparams['sigma_sq'] = .5 #
hyperparams['c'] = 0.

opt_params = dict()
opt_params['cg'] = True
opt_params['cg_tol'] = .01
opt_params['M'] = 100
opt_params['gamma'] = .1
opt_params['T'] = 250

featprocessor = LinearBasis(X_train)

scheduler = constantScheduler()
logger = GausLogger(100)

opt_params['scheduler'] = scheduler

skim = GaussianSKIMFA(X_train, Y_train, X_valid, Y_valid, featprocessor)

skim.fit(key, hyperparams, kernel_params, opt_params, 
            logger=GausLogger())

  0%|                                                                                                                 | 0/250 [00:00<?, ?it/s]

There are 500 covariates selected.


  0%|▍                                                                                                      | 1/250 [00:19<1:18:54, 19.01s/it]

MSE (Validation)=0.9813737273216248.
R2 (Validation)=0.015080571174621582.
eta=[1.0000821 1.0003042 0.9995982]


 40%|█████████████████████████████████████████▏                                                             | 100/250 [14:59<32:43, 13.09s/it]

There are 500 covariates selected.


 40%|█████████████████████████████████████████▌                                                             | 101/250 [15:15<34:39, 13.96s/it]

MSE (Validation)=0.9771352410316467.
R2 (Validation)=0.019334375858306885.
eta=[1.0086261  1.0425109  0.94489825]


 80%|██████████████████████████████████████████████████████████████████████████████████▍                    | 200/250 [30:56<07:03,  8.48s/it]

There are 500 covariates selected.


 80%|██████████████████████████████████████████████████████████████████████████████████▊                    | 201/250 [31:10<08:21, 10.24s/it]

MSE (Validation)=0.9676838517189026.
R2 (Validation)=0.028819918632507324.
eta=[1.0183513 1.1183168 0.840724 ]


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [37:37<00:00,  9.03s/it]


In [37]:
print(f'SKIM-FA AUROC: {roc_auc_score(Y_valid, skim.predict(X_valid))}')
print(f'SKIM-FA Brier Score: {((Y_valid - skim.predict(X_valid)) ** 2).mean()}')


SKIM-FA AUROC: 0.6206342834203131
SKIM-FA Brier Score: 0.9676838517189026


In [38]:
from sklearn.linear_model import LogisticRegression

In [39]:
platt = LogisticRegression(penalty=None, fit_intercept=True)
skim_train_scores = np.array(skim.predict(X_train)).reshape((X_train.shape[0],1))
skim_val_scores = np.array(skim.predict(X_valid)).reshape((X_valid.shape[0],1))

In [40]:
clf = LogisticRegression(random_state=0).fit(skim_train_scores, np.array(Y_train))

In [41]:
print(f'SKIM-FA Platt Brier Score: {(( (Y_valid + 1) / 2 - clf.predict_proba(skim_val_scores)[:, 1]) ** 2).mean()}')


SKIM-FA Platt Brier Score: 0.23847563564777374


In [42]:
from sklearn.calibration import calibration_curve

In [43]:
calibration_curve((Y_valid + 1) / 2, clf.predict_proba(skim_val_scores)[:, 1], strategy='uniform')

(array([0.3559322 , 0.47663551, 0.64705882]),
 array([0.3404627 , 0.50086778, 0.65969714]))

In [47]:
skim.logger.get_final_params()[1]['U_tilde']

DeviceArray([1.1348423 , 1.2061379 , 0.9805922 , 1.0668308 , 1.0190736 ,
             1.0444973 , 1.0325366 , 0.9950604 , 1.0058379 , 1.0206655 ,
             1.0078546 , 1.0415909 , 0.9039527 , 0.97189903, 0.9842333 ,
             0.9624395 , 1.070714  , 0.9720688 , 1.0464646 , 1.0151324 ,
             1.0098654 , 1.0186387 , 1.0152644 , 0.99518305, 0.9726318 ,
             1.0363299 , 1.0600337 , 1.0121847 , 1.0043366 , 0.9453262 ,
             1.0331128 , 0.9759302 , 0.98486143, 1.0138837 , 0.9974554 ,
             1.0442245 , 1.0071584 , 1.0208308 , 1.0265887 , 1.0149293 ,
             0.99740726, 0.96419924, 0.972902  , 0.96886194, 1.0540681 ,
             0.9687943 , 0.9960209 , 0.93547916, 0.93620515, 1.0231416 ,
             0.994825  , 1.0314397 , 0.98058456, 1.090109  , 0.9711593 ,
             1.0565969 , 1.0879236 , 1.018731  , 1.0110527 , 0.96555936,
             0.9676396 , 1.0244248 , 0.98148066, 0.94676876, 0.9207853 ,
             0.9833672 , 0.9723256 , 1.0373062 , 1.