In [1]:
!pip install -U jaxlib[cuda112]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -U jax[cuda112]==0.3.17 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install optax
!pip install optuna
!pip install dm-haiku
!pip install tensorflow-probability==0.17

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jaxlib[cuda112]==0.3.15
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15%2Bcuda11.cudnn82-cp39-none-manylinux2014_x86_64.whl (162.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.7/162.7 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.8+cuda11.cudnn82
    Uninstalling jaxlib-0.3.8+cuda11.cudnn82:
      Successfully uninstalled jaxlib-0.3.8+cuda11.cudnn82
Successfully installed jaxlib-0.3.15+cuda11.cudnn82
[0mLooking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda112]==0.3.17
  Downloading jax-0.3.17.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m66.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing

In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
SERVER = 1
import jax
import jax.random as random
import jax.numpy as jnp
import haiku as hk
import optax
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
from tqdm import tqdm
from nn_util import *
from optim_util import *

In [2]:
if SERVER:
    data_dir = "."
else:
    data_dir = "/home/xabush/code/snet/moses-incons-pen-xp/data"

In [3]:
N = 600
P = 500

def func(key, x):
    _, subkey = random.split(key, 2)
    noise = random.normal(subkey, shape=(N,))
    t1 = jnp.max(x[:,[0, 1]], axis=1)
    t2 = jnp.max(x[:,[2, 3, 4]], axis=1)
    y = (((10*jnp.sin(t1) + (t2**3)) / (1 + (x[:,0] + x[:,4])**2)) +
         jnp.sin(0.5*x[:,2])*(1 + jnp.exp(x[:,3] - 0.5*x[:,2])) + x[:,2]**2 + 2*jnp.sin(x[:,3]) + 2*x[:,4]) + noise

    return y

def generate_dataset_v1(seed):
    rng_key = random.PRNGKey(seed)
    key1, key2 = random.split(rng_key, 2)
    e = random.normal(key1, shape=(N,))
    z = random.normal(key2, shape=(N, P))
    dataset = jax.vmap(lambda a, b: 0.5*(a + b), in_axes=(None, 1), out_axes=1)(e, z)
    # dataset = 0.5*(e + z)
    output = func(key1, dataset)

    return dataset, output

In [4]:
#@title
class BayesNN():
    def __init__(self, sgd_optim, sgld_optim, temperature, sigma, data_size, hidden_sizes, act_fn=jax.nn.relu):
        self.hidden_sizes = hidden_sizes
        self.act_fn = act_fn
        self.sgd_optim = sgd_optim
        self.sgld_optim = sgld_optim
        self.optimiser = sgd_optim
        self._forward = hk.without_apply_rng(hk.transform(self._forward_fn))
        self.loss = jax.jit(self.loss)
        self.update = jax.jit(self.update)

        self.temperature = temperature
        self.sigma = sigma
        self.data_size = data_size
        self.add_noise = False

        # weight_decay = self.sigma*self.temperature
        # self.weight_prior = tfd.Normal(0, self.sigma)
        self.weight_prior = tfd.StudentT(df=2, loc=0, scale=self.sigma)
        # self.weight_prior = tfd.Laplace(0, self.sigma)

    def init(self, rng, x):
        params = self._forward.init(rng, x)
        opt_state = self.optimiser.init(params)
        return params, opt_state

    def apply(self, params, x):
        return self._forward.apply(params, x).ravel()


    def update(self, key, params, opt_state, x, y):
        if self.add_noise:
            self.optimiser = self.sgld_optim
        else:
            self.optimiser = self.sgd_optim
        grads = jax.grad(self.loss)(params, x, y)
        updates, opt_state = self.optimiser.update(key, grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    def _forward_fn(self, x):
        init_fn = hk.initializers.VarianceScaling()
        for hd in self.hidden_sizes:
            x = hk.Linear(hd, w_init=init_fn, b_init=init_fn)(x)
            x = self.act_fn(x)

        x = hk.Linear(1)(x)
        return x

    def log_prior(self, params):
        """Computes the Gaussian prior log-density."""
        logprob_tree = jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: jnp.sum(self.weight_prior.log_prob(x.reshape(-1))/self.temperature),
                                                                        params))

        return sum(logprob_tree)

    def log_likelihood(self, params, x, y):
        preds = self.apply(params, x).ravel()
        log_prob = jnp.sum(tfd.Normal(preds, self.temperature).log_prob(y))
        batch_size = x.shape[0]
        log_prob = (self.data_size / batch_size)*log_prob
        return log_prob

    def loss(self, params, x, y):
        logprob_prior = self.log_prior(params)
        logprob_likelihood = self.log_likelihood(params, x, y)
        return logprob_likelihood + logprob_prior

In [5]:
#@title
from sklearn.metrics import r2_score
import torch
from torch.utils import data
import numpy as np
torch.backends.cudnn.deterministic = True

def init_bnn_model(seed, train_loader, epochs, lr_0, num_cycles, temp, sigma, hidden_sizes, act_fn):
    torch.manual_seed(seed)
    num_batches = len(train_loader)
    data_size = X.shape[0]
    total_steps = num_batches*epochs
    step_size_fn = make_cyclical_lr_fn(lr_0, total_steps, num_cycles)
    sgd_optim = sgd_gradient_update(step_size_fn, momentum_decay=0, preconditioner=get_rmsprop_preconditioner())
    sgld_optim = sgld_gradient_update(step_size_fn, momentum_decay=0, preconditioner=get_rmsprop_preconditioner())

    model = BayesNN(sgd_optim, sgld_optim,
                      temp, sigma, data_size, hidden_sizes, act_fn)

    return model


def train_bnn_model(seed, train_loader, epochs, num_cycles, beta, lr_0,
                       hidden_sizes, temp, sigma, act_fn=jax.nn.relu):

    rng_key = jax.random.PRNGKey(seed)
    model = init_bnn_model(seed, train_loader, epochs, lr_0, num_cycles, temp, sigma, hidden_sizes, act_fn)

    cycle_len = epochs // num_cycles
    num_batches = len(train_loader)
    M = (epochs*num_batches) // num_cycles
    init_params, init_opt_state = model.init(rng_key, next(iter(train_loader))[0])


    states = []
    params, opt_state = init_params, init_opt_state
    step = 0
    key = rng_key
    for epoch in tqdm(range(epochs)):
        for batch_x, batch_y in train_loader:
            _, key = jax.random.split(key, 2)
            rk = (step % M) / M
            params, opt_state = model.update(key, params, opt_state, batch_x, batch_y)
            if rk > beta:
                model.add_noise = True
                states.append(params)
            else:
                model.add_noise = False
            step += 1

    return model, states

def eval_bnn_model(model, X, y, params):

    if isinstance(params, list):
        y_preds = np.zeros((len(params), len(y)))
        for i, param in enumerate(params):
            preds = model.apply(param, X).ravel()
            y_preds[i] = preds

        y_preds = np.mean(y_preds, axis=0)
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))
    else:
        y_preds = model.apply(params, X).ravel()
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))

    return rmse

def score_bnn_model(model, X, y, params):
    if isinstance(params, list):
        y_preds = np.zeros((len(params), len(y)))
        for i, param in enumerate(params):
            preds = model.apply(param, X).ravel()
            # preds_mean = preds[::2]
            y_preds[i] = preds

        y_preds = np.mean(y_preds, axis=0)
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))
        if np.isfinite(y_preds).all():
            r2 = r2_score(y, y_preds)
        else:
            r2 = np.nan
    else:
        y_preds = model.apply(params, X)
        preds_mean = y_preds[::2]
        rmse = jnp.sqrt(jnp.mean((y - preds_mean)**2))
        if np.isfinite(y_preds).all():
            r2 = r2_score(y, y_preds)
        else:
            r2 = np.nan

    return rmse, r2

def eval_per_model_score(model, X, y, params):
    scores = []

    for param in params:
        preds = model.apply(param, X).ravel()
        # preds_mean = preds[::2]
        rmse = jnp.sqrt(jnp.mean((y - preds)**2))
        scores.append(rmse)



    return np.array(scores)

In [6]:
dataset, output = generate_dataset_v1(0)

In [7]:
from sklearn.model_selection import train_test_split
import pandas as pd
X_df, y_df = pd.DataFrame(dataset[0]), pd.Series(output[0])
X_train_outer_df, X_test_df, y_train_outer_df, y_test_df = train_test_split(X_df, y_df, shuffle=False, test_size=0.2)
X_train_df, X_val_df, y_train_df, y_val_df = train_test_split(X_train_outer_df, y_train_outer_df, test_size=0.2, shuffle=False)

ValueError: Found input variables with inconsistent numbers of samples: [500, 1]

In [None]:
train_indices, val_indices = X_train_df.index.to_list() ,X_val_df.index.to_list()

In [None]:
X_train, X_val, X_test = X_train_df.to_numpy(), X_val_df.to_numpy(), X_test_df.to_numpy()
y_train, y_val, y_test = y_train_df.to_numpy(), y_val_df.to_numpy(), y_test_df.to_numpy()

In [None]:
J = np.zeros((P, P))
J.shape

In [None]:
#@title
class BgBayesNN():
    def __init__(self, sgd_optim, sgld_optim, disc_sgd_optim, disc_sgld_optim, 
                        temperature, sigma, data_size, hidden_sizes, 
                        J, eta, mu,
                        act_fn=jax.nn.relu):
        self.hidden_sizes = hidden_sizes
        self.act_fn = act_fn
        self.sgd_optim = sgd_optim
        self.sgld_optim = sgld_optim
        self.optimiser = sgd_optim

        self.disc_optimiser = disc_sgd_optim
        self.disc_sgd_optim = disc_sgd_optim
        self.disc_sgld_optim = disc_sgld_optim

        self._forward = hk.transform(self._forward_fn)
        self.loss = jax.jit(self.loss)
        self.update = jax.jit(self.update)

        self.temperature = temperature
        self.sigma = sigma
        self.data_size = data_size
        self.add_noise = False
        self.J = J
        self.eta = eta
        self.mu = mu

        # weight_decay = self.sigma*self.temperature
        # self.weight_prior = tfd.Normal(0, self.sigma)
        self.weight_prior = tfd.StudentT(df=2, loc=0, scale=self.sigma)
        # self.weight_prior = tfd.Horseshoe(scale=self.sigma)
        # self.weight_prior = tfd.Laplace(0, self.sigma)

    def init(self, rng, x):
        gamma = tfd.Bernoulli(0.5*jnp.ones(x.shape[-1])).sample(seed=rng)*1.
        params = self._forward.init(rng, x, gamma)
        opt_state = self.optimiser.init(params)
        disc_opt_state = self.disc_optimiser.init(gamma)
        return params, gamma, opt_state, disc_opt_state

    def apply(self, params, gamma, x):
        return self._forward.apply(params, None, x, gamma).ravel()

    
    def loss(self, params, gamma, x, y):
        logprob_prior = self.log_prior(params)
        logprob_likelihood = self.log_likelihood(params, gamma, x, y)
        return logprob_likelihood + logprob_prior

    def update(self, key, params, gamma, opt_state, disc_opt_state, x, y):
        if self.add_noise:
            self.optimiser = self.sgld_optim
            self.disc_optimiser = self.disc_sgld_optim
        else:
            self.optimiser = self.sgd_optim
            self.disc_optimiser = self.disc_sgd_optim

        contin_loss = lambda p: self.log_prior(params) + self.log_likelihood(p, gamma, x, y)

        grads = jax.grad(contin_loss)(params)
        updates, opt_state = self.optimiser.update(key, grads, opt_state)
        params = optax.apply_updates(params, updates)

        disc_loss = lambda g: self.ising_prior(g) + self.log_likelihood(params, g, x, y)
        disc_logprob, disc_grads = jax.value_and_grad(disc_loss)(gamma)
        gamma, disc_opt_state = self.disc_optimiser.update(key, gamma, disc_grads, disc_opt_state)
        return params, gamma, opt_state, disc_opt_state

    def _forward_fn(self, x, gamma):
        x = x @ jnp.diag(gamma)
        init_fn = hk.initializers.VarianceScaling()
        for hd in self.hidden_sizes:
            x = hk.Linear(hd, w_init=init_fn)(x)
            x = self.act_fn(x)

        x = hk.Linear(1)(x)
        return x

    def log_prior(self, params):
        """Computes the Gaussian prior log-density."""
        logprob_tree = jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: jnp.sum(self.weight_prior.log_prob(x.reshape(-1))/self.temperature), 
                                                                            params))
        
        return sum(logprob_tree)

    def log_likelihood(self, params, gamma, x, y):
        preds = self.apply(params, gamma, x).ravel()
        # preds_mean, preds_std = preds[::2], preds[1::2]
        # print(preds.shape)
        # print(preds_mean.shape)
        # print(preds_std.shape)
        # preds_std = jax.nn.softplus(preds_std.squeeze())
        # preds_mean = preds_mean.squeeze()
        # preds_std = (preds_std**2)*self.temperature
        log_prob = jnp.sum(tfd.Normal(preds, self.temperature).log_prob(y))
        # log_prob = jnp.sum(tfd.MultivariateNormalDiag(preds_mean, preds_std).log_prob(y))
        batch_size = x.shape[0]
        log_prob = (self.data_size / batch_size)*log_prob
        return log_prob

    def ising_prior(self, gamma):
        """Log probability of the Ising model - prior over the discrete variables"""
        return  -self.mu*jnp.sum(gamma) / self.temperature
        # x = (2 * gamma) - 1
        # xg = x @ self.J
        # xgx = (xg * x).sum(-1)
        # return (0.5*self.eta*xgx + self.mu*jnp.sum(x)) / self.temperature

In [None]:
#@title
def init_bg_bnn_model(seed, train_loader, epochs, lr_0, disc_lr_0, num_cycles, temp, sigma, hidden_sizes, J, eta, mu, act_fn):
    torch.manual_seed(seed)
    num_batches = len(train_loader)
    data_size = X.shape[0]
    total_steps = num_batches*epochs
    step_size_fn = make_cyclical_lr_fn(lr_0, total_steps, num_cycles)
    # disc_step_size_fn = make_cyclical_lr_fn(disc_lr_0, total_steps, num_cycles)
    disc_step_size_fn = lambda count: disc_lr_0
    sgd_optim = sgd_gradient_update(step_size_fn, momentum_decay=0, preconditioner=get_rmsprop_preconditioner())
    sgld_optim = sgld_gradient_update(step_size_fn, momentum_decay=0, preconditioner=get_rmsprop_preconditioner())
    disc_sgd_optim = disc_sgld_gradient_update(disc_step_size_fn, momentum_decay=0, preconditioner=get_identity_preconditioner())
    disc_sgld_optim = disc_sgld_gradient_update(disc_step_size_fn, momentum_decay=0, preconditioner=get_identity_preconditioner())

    model = BgBayesNN(sgd_optim, sgld_optim, disc_sgd_optim, disc_sgld_optim,
                      temp, sigma, data_size, hidden_sizes,
                      J, eta, mu, act_fn)

    return model


def train_bg_bnn_model(seed, train_loader, epochs, num_cycles, beta, lr_0, disc_lr_0,
                    hidden_sizes, temp, sigma, eta, mu, J, act_fn=jax.nn.relu):

    rng_key = jax.random.PRNGKey(seed)
    model = init_bg_bnn_model(seed, train_loader, epochs, lr_0, disc_lr_0, num_cycles, temp, sigma, hidden_sizes, J, eta, mu, act_fn)

    cycle_len = epochs // num_cycles
    n_models_per_cycle = 3
    num_batches = len(train_loader)
    M = (epochs*num_batches) // num_cycles
    init_params, init_gamma, init_opt_state, init_disc_opt_state = model.init(rng_key, next(iter(train_loader))[0])


    states = []
    disc_states = []
    val_losses = []
    params, gamma, opt_state, disc_opt_state = init_params, init_gamma, init_opt_state, init_disc_opt_state
    step = 0
    key = rng_key
    for epoch in range(epochs):
        for batch_x, batch_y in train_loader:
            _, key = jax.random.split(key, 2)
            rk = (step % M) / M
            params, gamma, opt_state, disc_opt_state = model.update(key, params, gamma, opt_state, disc_opt_state, batch_x, batch_y)
            if rk > beta:
                states.append(params)
                disc_states.append(gamma)
                model.add_noise = True
            else:
                model.add_noise = False
            step += 1
        # if (epoch % cycle_len) + 1 > (cycle_len - n_models_per_cycle):
        #     # print(epoch)
        #     states.append(params)
        #     disc_states.append(gamma)
        # val_loss = eval_bg_bnn_model(model, X_val, y_val, params, gamma)
        # val_losses.append(val_loss)

    return model, states, disc_states

def eval_bg_bnn_model(model, X, y, params, gammas):

    if isinstance(params, list):
        y_preds = np.zeros((len(params), len(y)))
        for i, (param, gamma) in enumerate(zip(params, gammas)):
            preds = model.apply(param, gamma, X).ravel()
            y_preds[i] = preds

        y_preds = np.mean(y_preds, axis=0)
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))
    else:
        y_preds = model.apply(params, gammas, X).ravel()
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))

    return rmse

def score_bg_bnn_model(model, X, y, params, gammas):
    if isinstance(params, list):
        y_preds = np.zeros((len(params), len(y)))
        for i, (param, gamma) in enumerate(zip(params, gammas)):
            preds = model.apply(param, gamma, X).ravel()
            # preds_mean = preds[::2]
            y_preds[i] = preds

        y_preds = np.mean(y_preds, axis=0)
        rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))
        if np.isfinite(y_preds).all():
            r2 = r2_score(y, y_preds)
        else:
            r2 = np.nan
    else:
        y_preds = model.apply(params, gammas, X)
        preds_mean = y_preds[::2]
        rmse = jnp.sqrt(jnp.mean((y - preds_mean)**2))
        if np.isfinite(y_preds).all():
            r2 = r2_score(y, y_preds)
        else:
            r2 = np.nan

    return rmse, r2

def eval_per_model_score_bg(model, X, y, params, gammas):
    scores = []

    for param, gamma in zip(params, gammas):
        preds = model.apply(param, gamma, X).ravel()
        # preds_mean = preds[::2]
        rmse = jnp.sqrt(jnp.mean((y - preds)**2))
        scores.append(rmse)



    return np.array(scores)

In [8]:
import optuna

seed = 0
epochs = 100
num_cycles = 10
batch_size = 80
lr_0 = 1e-3
hidden_sizes = [6]
act_fn = jax.nn.swish

def objective(trial, seed, x_train, x_val, y_train, y_val, J, epochs, num_cycles,
              batch_size, hidden_sizes, lr_0, act_fn):

    disc_lr = trial.suggest_float("disc_lr", 0.1, 0.9)
    temp = trial.suggest_categorical("temp", [1e-3, 1e-2, 1e-1, 0.5, 1.])
    beta = trial.suggest_float("beta", 0.7, 0.9)
    eta = 1.0
    mu = trial.suggest_float("mu", 1.0, 1e2, log=True)
    sigma = 1.0
    torch.manual_seed(seed)
    data_loader = NumpyLoader(NumpyData(x_train, y_train), batch_size=batch_size, shuffle=False)

    bg_bnn_model, states, disc_states = train_bg_bnn_model(seed, data_loader, epochs, num_cycles, beta, lr_0, disc_lr,
                                                           hidden_sizes, temp, sigma, eta, mu, J, act_fn)

    rmse = eval_bg_bnn_model(bg_bnn_model, x_val, y_val, states, disc_states)
    return rmse

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(sampler=sampler)
study.optimize(lambda trial: objective(trial, seed, X_train, X_val, y_train, y_val, J, epochs, num_cycles, batch_size, hidden_sizes, lr_0, act_fn), timeout=60)

In [111]:
bnn_config = study.best_params
print(bnn_config)

{'disc_lr': 0.7695758245603153, 'temp': 0.1, 'beta': 0.8435518997739255, 'mu': 4.555786039977584}


In [9]:
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.metrics import r2_score
from tree_utils import tree_stack

def false_selection_rate(ft_idxs):
    causal_fts = {0, 1, 2, 3, 4}
    sel_fts = set(ft_idxs)
    diff = sel_fts.difference(causal_fts)

    return (len(diff)/len(ft_idxs))

def negative_selection_rate(ft_idxs):
    causal_fts = {0, 1, 2, 3, 4}
    sel_fts = set(ft_idxs)
    diff = causal_fts.difference(sel_fts)

    return (len(diff)/len(causal_fts))


def evaluate_bnn_bg_models(model, X, y, params, gammas):
    eval_fn = lambda p, g: model.apply(p, g, X).ravel()
    preds = jax.vmap(eval_fn)(params, gammas)
    preds = preds.reshape(-1, preds.shape[-1])
    losses = jax.vmap(lambda x, z: jnp.sqrt(jnp.mean((x - z)**2)), in_axes=(0, None))(preds, y)
    # mean_loss = jnp.sqrt(jnp.mean(losses, axis=-1))
    return jnp.mean(losses)


def get_feats_dropout_loss(model, params, gammas, X, y):
    var_loss_dict = {"feats_idx": [], "num_models": [] , "loss_on": [], "loss_off": [], "loss_diff": []}

    disc_states = tree_stack(gammas)
    contin_states = tree_stack(params)


    # print(disc_states.shape)
    eval_fn = jax.jit(lambda X, y, params, gammas: evaluate_bnn_bg_models(model, X, y, params, gammas))
    p = X.shape[1]

    for idx in range(p):
        idx_on = np.argwhere(disc_states[:,idx] == 1.).ravel()
        loss_on, loss_off = 0., 0.
        if idx_on.size == 0: ## irrelevant feature
            loss_diff = 1e9
        else:
            disc_states_on = disc_states[idx_on]
            params_on = jax.tree_util.tree_map(lambda x: x[idx_on], contin_states)
            loss_on = eval_fn(X, y, params_on, disc_states_on)

            # Turn-off the variable, and see how the loss changes
            disc_states_off = disc_states_on.at[:,idx].set(0)
            loss_off = eval_fn(X, y, params_on, disc_states_off)

            # loss_diff = (loss_on - loss_off) * (len(idx_on) / num_models)
            loss_diff = (loss_on - loss_off)


        var_loss_dict["feats_idx"].append(idx)
        var_loss_dict["num_models"].append(idx_on.size)
        var_loss_dict["loss_on"].append(loss_on)
        var_loss_dict["loss_off"].append(loss_off)
        var_loss_dict["loss_diff"].append(loss_diff)


    var_loss_df = pd.DataFrame(var_loss_dict).sort_values(by="loss_diff")

    return var_loss_df

def train_rf_model(seed, X, y, train_idxs, val_idxs):

    # cv = KFold(n_splits=5)
    cv = [(train_idxs, val_idxs) for _ in range(5)]
    param_grid = {
        'max_depth': [80, 100, 120],
        'max_features': [2, 3],
        'min_samples_leaf': [3, 4, 5],
        'min_samples_split': [8, 10, 12],
        'n_estimators': [100, 500, 1000]
    }

    rf_reg = RandomForestRegressor(random_state=seed, max_samples=1.0)
    grid_cv = GridSearchCV(estimator = rf_reg, param_grid = param_grid,
                           cv = cv, n_jobs = -1, verbose = 0, scoring="neg_root_mean_squared_error").fit(X, y)

    rf_reg = RandomForestRegressor(random_state=seed, max_samples=1.0, **grid_cv.best_params_)
    rf_reg.fit(X, y)

    return rf_reg

def eval_rf_model(model, X, y):
    y_preds = model.predict(X)
    rmse = jnp.sqrt(jnp.mean((y - y_preds)**2))
    r2 = r2_score(y, y_preds)
    return rmse, r2

In [113]:
epochs = 200
num_cycles = 10
batch_size = 80
lr_0 = 1e-3
sigma = 1.0
eta = 1.0
disc_lr_0 = bnn_config["disc_lr"]
mu = bnn_config["mu"]
temp = bnn_config["temp"]
beta = bnn_config["beta"]

data_loader = NumpyLoader(NumpyData(X_train, y_train), batch_size=batch_size, shuffle=False)
bg_bnn_model, bnn_bg_states, bg_disc_states = train_bg_bnn_model(seed, data_loader, epochs, num_cycles, beta, lr_0, disc_lr_0,
                                                                     hidden_sizes, temp, sigma, eta, mu, J, act_fn)

dropout_loss = get_feats_dropout_loss(bg_bnn_model, bnn_bg_states, bg_disc_states, X_val, y_val)

print(false_selection_rate(dropout_loss["feats_idx"][:5].to_list()))

0.0


In [114]:
len(bg_disc_states)

150

In [115]:
dropout_loss

Unnamed: 0,feats_idx,num_models,loss_on,loss_off,loss_diff
1,1,126,3.1637871,3.233195,-0.06940794
3,3,139,3.3348029,3.3934684,-0.058665514
2,2,116,3.0309684,3.070101,-0.039132595
4,4,98,2.907516,2.9180067,-0.010490656
0,0,99,2.9239378,2.9320858,-0.008147955
...,...,...,...,...,...
339,339,94,2.8699884,2.8593206,0.010667801
469,469,100,2.8946939,2.8834436,0.0112502575
321,321,97,2.9019895,2.8899856,0.012003899
397,397,93,2.83549,2.8224888,0.013001204


In [116]:
bnn_bg_rmse_test, bnn_bg_r2_test = score_bg_bnn_model(bg_bnn_model, X_test, y_test, bnn_bg_states, bg_disc_states)
print(f"Test RMSE: {bnn_bg_rmse_test}, r2_score: {bnn_bg_r2_test}")

Test RMSE: 3.1551499366760254, r2_score: 0.5481285440288763


In [10]:
save_dir = f"{data_dir}/exp_data_5/synthetic/exp1_mao"

In [153]:
epochs = 200
hpo_epochs = 100
num_cycles = 10
batch_size = 80
lr_0 = 1e-3
hidden_sizes = [6]
sigma = 1.0
eta = 1.0
act_fn = jax.nn.swish
k = dataset.shape[0]

dropout_loss_lst = []
bnn_fdr = []
rf_fdr = []
optuna.logging.set_verbosity(optuna.logging.WARNING)

res_ft_dict = {"model": [], "rmse": [], "fdr": []}

seeds = [422,261,968,282,739,573,220,413,745,775]
s = 5 # change this? See Wojats et.al 2020

for seed in tqdm(seeds):
    dataset, output = generate_dataset_v1(seed)
    X_df, y_df = pd.DataFrame(dataset), pd.Series(output)
    X_train_outer_df, X_test_df, y_train_outer_df, y_test_df = train_test_split(X_df, y_df, shuffle=False, test_size=0.2)
    X_train_df, X_val_df, y_train_df, y_val_df = train_test_split(X_train_outer_df, y_train_outer_df, test_size=0.2, shuffle=False)
    train_indices, val_indices = X_train_df.index.to_list(), X_val_df.index.to_list()
    X_train, X_val, X_test = X_train_df.to_numpy(), X_val_df.to_numpy(), X_test_df.to_numpy()
    y_train, y_val, y_test = y_train_df.to_numpy(), y_val_df.to_numpy(), y_test_df.to_numpy()
    data_loader = NumpyLoader(NumpyData(X_train, y_train), batch_size=batch_size, shuffle=False)
    study = optuna.create_study(sampler=sampler)
    study.optimize(lambda trial: objective(trial, seed, X_train, X_val, y_train, y_val, J, hpo_epochs, num_cycles, batch_size, hidden_sizes, lr_0, act_fn), timeout=100)
    bnn_config = study.best_params
    disc_lr_0 = bnn_config["disc_lr"]
    mu = bnn_config["mu"]
    temp = bnn_config["temp"]
    beta = bnn_config["beta"]
    bg_bnn_model, bnn_bg_states, bg_disc_states = train_bg_bnn_model(seed, data_loader, epochs, num_cycles, beta, lr_0, disc_lr_0,
                                                                 hidden_sizes, temp, sigma, eta, mu, J, act_fn)
    bnn_bg_rmse_test, _ = score_bg_bnn_model(bg_bnn_model, X_test, y_test, bnn_bg_states, bg_disc_states)
    dropout_loss = get_feats_dropout_loss(bg_bnn_model, bnn_bg_states, bg_disc_states, X_val, y_val)
    dropout_loss_lst.append(dropout_loss)
    bnn_fdr = false_selection_rate(dropout_loss["feats_idx"][:s].to_list())

    rf_model = train_rf_model(seed, X_train_outer_df, y_train_outer_df, train_indices, val_indices)
    rf_rmse_test, _ = eval_rf_model(rf_model, X_test, y_test)
    rf_fdr = false_selection_rate(np.argsort(rf_model.feature_importances_)[::-1][:s])

    res_ft_dict["model"].append("BNN")
    res_ft_dict["rmse"].append(bnn_bg_rmse_test)
    res_ft_dict["fdr"].append(bnn_fdr)

    res_ft_dict["model"].append("RF")
    res_ft_dict["rmse"].append(rf_rmse_test)
    res_ft_dict["fdr"].append(rf_fdr)

100%|██████████| 10/10 [38:35<00:00, 231.59s/it]


In [13]:
res_df = pd.read_csv(f"{save_dir}/res_df_hn_6.csv")
res_df

Unnamed: 0,model,rmse,fdr
0,BNN,3.22449,0.2
1,RF,2.693497,0.8
2,BNN,3.288866,0.4
3,RF,2.935442,0.8
4,BNN,3.332286,0.0
5,RF,3.190772,0.8
6,BNN,3.068228,0.2
7,RF,2.737681,1.0
8,BNN,3.095022,0.2
9,RF,2.853587,0.4


In [14]:
res_df.groupby(['model'])["rmse"].mean()

model
BNN    3.107310
RF     2.818817
Name: rmse, dtype: float64

In [15]:
res_df.groupby(['model'])["fdr"].mean()

model
BNN    0.30
RF     0.76
Name: fdr, dtype: float64

In [174]:
import pickle

epochs = 200
hpo_epochs = 100
num_cycles = 10
batch_size = 80
lr_0 = 1e-3
hidden_sizes = [100]
sigma = 1.0
eta = 1.0
act_fn = jax.nn.swish
k = dataset.shape[0]

dropout_loss_lst = []
bnn_fdr = []
rf_fdr = []
optuna.logging.set_verbosity(optuna.logging.WARNING)

res_ft_dict_2 = {"model": [], "rmse": [], "fdr": []}

seeds = [422,261,968,282,739,573,220,413,745,775]
s = 5 # change this? See Wojats et.al 2020

for seed in tqdm(seeds):
    dataset, output = generate_dataset_v1(seed)
    X_df, y_df = pd.DataFrame(dataset), pd.Series(output)
    X_train_outer_df, X_test_df, y_train_outer_df, y_test_df = train_test_split(X_df, y_df, shuffle=False, test_size=0.2)
    X_train_df, X_val_df, y_train_df, y_val_df = train_test_split(X_train_outer_df, y_train_outer_df, test_size=0.2, shuffle=False)
    train_indices, val_indices = X_train_df.index.to_list(), X_val_df.index.to_list()
    X_train, X_val, X_test = X_train_df.to_numpy(), X_val_df.to_numpy(), X_test_df.to_numpy()
    y_train, y_val, y_test = y_train_df.to_numpy(), y_val_df.to_numpy(), y_test_df.to_numpy()
    data_loader = NumpyLoader(NumpyData(X_train, y_train), batch_size=batch_size, shuffle=False)
    study = optuna.create_study(sampler=sampler)
    study.optimize(lambda trial: objective(trial, seed, X_train, X_val, y_train, y_val, J, hpo_epochs, num_cycles, batch_size, hidden_sizes, lr_0, act_fn), timeout=100)
    bnn_config = study.best_params
    disc_lr_0 = bnn_config["disc_lr"]
    mu = bnn_config["mu"]
    temp = bnn_config["temp"]
    beta = bnn_config["beta"]

    pickle.dump(bnn_config, open(f"{save_dir}/bnn_config_s_{seed}.pkl", "wb"))

    bg_bnn_model, bnn_bg_states, bg_disc_states = train_bg_bnn_model(seed, data_loader, epochs, num_cycles, beta, lr_0, disc_lr_0,
                                                                     hidden_sizes, temp, sigma, eta, mu, J, act_fn)
    bnn_bg_rmse_test, _ = score_bg_bnn_model(bg_bnn_model, X_test, y_test, bnn_bg_states, bg_disc_states)
    dropout_loss = get_feats_dropout_loss(bg_bnn_model, bnn_bg_states, bg_disc_states, X_val, y_val)
    dropout_loss_lst.append(dropout_loss)
    bnn_fdr = false_selection_rate(dropout_loss["feats_idx"][:s].to_list())

    rf_model = train_rf_model(seed, X_train_outer_df, y_train_outer_df, train_indices, val_indices)
    rf_rmse_test, _ = eval_rf_model(rf_model, X_test, y_test)
    rf_fdr = false_selection_rate(np.argsort(rf_model.feature_importances_)[::-1][:s])

    pickle.dump(rf_model, open(f"{save_dir}/rf_model_{seed}.pkl", "wb"))

    res_ft_dict_2["model"].append("BNN")
    res_ft_dict_2["rmse"].append(bnn_bg_rmse_test)
    res_ft_dict_2["fdr"].append(bnn_fdr)

    res_ft_dict_2["model"].append("RF")
    res_ft_dict_2["rmse"].append(rf_rmse_test)
    res_ft_dict_2["fdr"].append(rf_fdr)

100%|██████████| 10/10 [39:38<00:00, 237.81s/it]


In [16]:
res_df_2 = pd.read_csv(f"{save_dir}/res_df_hn_100.csv")
res_df_2

Unnamed: 0,model,rmse,fdr
0,BNN,2.70159,0.2
1,RF,2.693497,0.8
2,BNN,2.766211,0.2
3,RF,2.935442,0.8
4,BNN,3.048509,0.0
5,RF,3.190772,0.8
6,BNN,2.827833,0.0
7,RF,2.737681,1.0
8,BNN,2.594652,0.2
9,RF,2.853587,0.4


In [17]:
res_df_2.groupby(['model'])["rmse"].mean()

model
BNN    2.775331
RF     2.818817
Name: rmse, dtype: float64

In [18]:
res_df_2.groupby(['model'])["fdr"].mean()

model
BNN    0.20
RF     0.76
Name: fdr, dtype: float64

### Wojtas et.al 2020