In [1]:
import tempfile
!pip uninstall jax[cuda11] -y
!pip install --no-cache-dir --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install tensorflow-probability
!pip install numpyro
# !pip install distrax
!pip install git+https://github.com/blackjax-devs/blackjax.git

Found existing installation: jax 0.3.17
Uninstalling jax-0.3.17:
  Successfully uninstalled jax-0.3.17
[0mLooking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Downloading jax-0.3.17.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.3.17-py3-none-any.whl size=1217849 sha256=aa59632f8c506485b13b2ca8c85b03abf6797acbb5fa7201c25c768ec65079f1
  Stored in directory: /tmp/pip-ephem-wheel-cache-v677ugag/wheels/36/cd/88/2d90379f7549c27d5654e893f74210f30f0c645c23a71e6f56
Successfully built jax
Installing collected packages: jax
Successfully installed jax-0.3.17
[0mCollecting numpyro
  Using cached numpyro-0.10.1-py3-none-any.whl (292 kB)
Collecting m

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
plt.style.use('ggplot')
%load_ext autoreload

In [3]:
jax.version.__version__

'0.3.17'

In [4]:
jax.default_backend()

'gpu'

In [5]:
# seed = 1234
# np.random.seed(seed)
# rng_key = jax.random.PRNGKey(seed)

In [6]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpyro as npyro
import scipy.stats as stats
import warnings
warnings.filterwarnings("ignore")

def run_fs_logreg(X_train, X_test, y_train, y_test, feats):
    results = {"log_cv_score": [], "log_test_score": []}
    for fts in feats:
        X_s_train, X_s_test = X_train[:,np.array(fts, dtype=np.int32)].astype(np.int64), X_test[:,np.array(fts, dtype=np.int32)].astype(np.int64)
        y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
        if fts.size == 1:
            X_s_train, X_s_test = X_s_train.reshape(-1, 1), X_s_test.reshape(-1, 1)
        cv_score = np.mean(cross_val_score(LogisticRegression(), X_s_train, y_train, scoring="roc_auc"))
        log_est = LogisticRegression(max_iter=1000).fit(X_s_train, y_train)
        test_score = roc_auc_score(y_test, log_est.predict_proba(X_s_test)[:,1])
        # print({"moses_cv_score": cv_score, "moses_test_score": test_score, "log_cv_score": cv_score, "log_test_score": test_score})
        results["log_cv_score"].append(cv_score)
        results["log_test_score"].append(test_score)

    return pd.DataFrame(results)


def fisher_exact_test(X, y, thres=0.05):
    cols = X.columns
    p_values = np.zeros(len(cols))
    for i, col in enumerate(cols):
        table = pd.crosstab(y, X[col])
        _, p_val = stats.fisher_exact(table, alternative="two-sided")
        p_values[i] = p_val

    idx_sig = np.argwhere(p_values < thres)
    print(f"Total of {len(idx_sig)} variables are significant (p_val = {thres})")

    return idx_sig


def build_network(X):
    p = X.shape[1]
    J = np.zeros((p, p))
    cols = X.columns
    intrs = []
    intrs_rev = []
    for i, g1 in enumerate(cols):
        try:
            g_intrs = list(net_intr[g1])
            for g2 in g_intrs:
                if (g2, g1) not in intrs_rev: # check if we haven't encountered the reverse interaction
                    j = cols.get_loc(g2)
                    J[i, j] = 1.0
                    J[j, i] = 1.0
                    intrs.append((g1, g2))
        except KeyError:
            continue

        # Check the reverse direction
        try:
            g_intrs_rev = list(net_intr_rev[g1])
            for g2 in g_intrs_rev:
                if (g1, g2) not in intrs:
                    j = cols.get_loc(g2)
                    J[i, j] = 1.0
                    J[j, i] = 1.0
                    intrs_rev.append((g2, g1))

        except KeyError:
            continue


    return J

def get_ess(n_chain, samples):
    k = int(samples.shape[0] / n_chain)
    chains = samples.reshape(n_chain, k, samples.shape[-1])
    ess = npyro.diagnostics.effective_sample_size(jax.device_get(chains))
    ess[np.isnan(ess)] = 1.0
    return ess

In [7]:
from typing import Callable, NamedTuple
from blackjax.types import PRNGKey, PyTree


class MixedMALAState(NamedTuple):
    """Holds info about the discrete and the continuous r.vs in the mixed support"""

    discrete_position: PyTree
    contin_position: PyTree

    disc_logprob: float
    contin_logprob: float

    discrete_logprob_grad: PyTree
    contin_logprob_grad: PyTree

    disc_step_size: float
    contin_step_size: float


from blackjax.mcmc.diffusion import generate_gaussian_noise
from blackjax.mcmc.mala import MALAState

EPS = 1e-10


def diff_fn(state, step_size):
    theta = jax.tree_util.tree_map(lambda x, g: -0.5 * (g) * (2. * x - 1) - (1. / (2. * step_size)),
                                   state.position, state.logprob_grad)

    return jax.nn.sigmoid(theta)


def take_discrete_step(rng_key: PRNGKey, disc_state: MALAState, contin_state: MALAState,
                       logprob_fn: Callable, disc_grad_fn: Callable,
                       step_size: float) -> MALAState:
    _, key_rmh, key_accept = jax.random.split(rng_key, 3)
    theta_cur = disc_state.position

    u = jax.random.uniform(key_rmh, shape=disc_state.position.shape)
    p_curr = diff_fn(disc_state, step_size)
    ind = jnp.array(u < p_curr)
    pos_new = (1. - theta_cur) * ind + theta_cur * (1. - ind)

    logprob_new = logprob_fn(pos_new, contin_state.position)
    logprob_grad_new = disc_grad_fn(pos_new, contin_state.position)
    new_state = MALAState(pos_new, logprob_new, logprob_grad_new)  # No metropolis update just accept the move

    return new_state


def take_contin_step(rng_key: PRNGKey, disc_state: MALAState, contin_state: MALAState,
                     logprob_fn: Callable, contin_grad_fn: Callable,
                     step_size: float) -> MALAState:
    key_integrator, key_rmh = jax.random.split(rng_key)
    noise = generate_gaussian_noise(key_integrator, contin_state.position)
    new_position = jax.tree_util.tree_map(
        lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
        contin_state.position,
        contin_state.logprob_grad,
        noise,
    )

    logprob_new = logprob_fn(disc_state.position, new_position)
    logprob_grad_new = contin_grad_fn(disc_state.position, new_position)
    new_state = MALAState(new_position, logprob_new, logprob_grad_new)

    return new_state


def one_step(
        rng_key: PRNGKey, state: MixedMALAState,
        discrete_logprob_fn: Callable, contin_logprob_fn: Callable,
        discrete_step_size: float, contin_step_size: float
) -> MixedMALAState:
    disc_grad_fn = jax.grad(discrete_logprob_fn)
    contin_grad_fn = jax.grad(contin_logprob_fn, argnums=1)
    # Evolve each variable in tandem and combine the results

    disc_state = MALAState(state.discrete_position, state.disc_logprob, state.discrete_logprob_grad)
    contin_state = MALAState(state.contin_position, state.contin_logprob, state.contin_logprob_grad)
    # print(f"disc pos: {disc_state.position}, contin pos: {contin_state.position}")
    # Take a step for the discrete variable - sample from p(discrete | contin)
    new_disc_state = take_discrete_step(rng_key, disc_state, contin_state,
                                        discrete_logprob_fn, disc_grad_fn, discrete_step_size)
    # Take a step for the contin variable - sample from p(contin | new_discrete)
    new_contin_state = take_contin_step(rng_key, new_disc_state, contin_state,
                                        contin_logprob_fn, contin_grad_fn, contin_step_size)

    new_state = MixedMALAState(new_disc_state.position, new_contin_state.position,
                               new_disc_state.logprob, new_contin_state.logprob,
                               new_disc_state.logprob_grad, new_contin_state.logprob_grad,
                               discrete_step_size, contin_step_size)

    return new_state

def init(disc_position: PyTree,contin_position: PyTree,
         disc_logprob_fn: Callable, contin_logprob_fn: Callable,
         init_disc_step: float, init_contin_step: float) -> MixedMALAState:

    disc_logprob, disc_grad_logprob = jax.value_and_grad(disc_logprob_fn)(disc_position, contin_position)
    contin_logprob, contin_grad_logprob = jax.value_and_grad(contin_logprob_fn, argnums=1)(disc_position, contin_position)

    return MixedMALAState(disc_position, contin_position,
                          disc_logprob, contin_logprob,
                          disc_grad_logprob, contin_grad_logprob,
                          init_disc_step, init_contin_step)

In [8]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

def inference_loop_multiple_chains(rng_key, kernel, initial_state, num_samples, num_chains):

    @jax.jit
    def one_step(state, step_key):
        subkeys = jax.random.split(step_key, num_chains)
        state = jax.vmap(kernel)(subkeys, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    jnp.save(f"{data_dir}/exp_data_2/rand_test/local/split_keys.npy", keys)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [9]:
def gamma_energy(theta, J, eta, mu):
    xg = theta.T @ J
    xgx = xg @ theta
    return eta*xgx - mu*jnp.sum(theta)

def generate_disc_logprob_fn(X, y, J, mu, eta):

    def discrete_logprob_fn(gamma, beta):
        # beta = pos["beta"]
        X_gamma = (X @ jnp.diag(gamma))
        ising_logp = gamma_energy(gamma, J, eta, mu)
        ll_dist = tfd.Bernoulli(logits=(X_gamma @ beta))
        log_ll = jnp.sum(ll_dist.log_prob(y), axis=0)

        # print(f"gamma logp: {ising_logp}, log_ll: {log_ll}")

        return ising_logp + log_ll

    return discrete_logprob_fn


def generate_contin_logprob_fn(X, y, tau, c):
    n, p = X.shape
    cov = X.T @ X
    R = np.identity(p)
    v, l = 1., 1.

    def contin_logprob_fn(gamma, beta):
        # beta = pos["beta"]

        D = (gamma*c*tau) + (1 - gamma)*(tau)
        # D_inv = jnp.linalg.inv(jnp.diag(D))

        # A = jnp.linalg.inv((1./sigma**2)*cov + (D_inv @ R @ D_inv))
        beta_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(p), scale_diag=D)
        # print(beta_dist.sample(seed=rng_key))
        beta_logp = beta_dist.log_prob(beta)
        X_gamma = (X @ jnp.diag(gamma))
        ll_dist = tfd.Bernoulli(logits=(X_gamma @ beta))
        log_ll = jnp.sum(ll_dist.log_prob(y), axis=0)

        # print(f"beta logp: {beta_logp}, log_ll: {log_ll}")

        return beta_logp + log_ll

    return contin_logprob_fn

In [10]:
# data_dir = "/home/xabush/code/snet/moses-incons-pen-xp/data"
data_dir = "."
tamox_df = pd.read_csv(f"{data_dir}/tamoxBinaryEntrez.csv")
tamox_df.head()

Unnamed: 0,posOutcome,4111,4110,10661,131,4438,330,1109,2637,2642,...,7634,55769,7637,7644,741,54993,79364,7791,23140,26009
0,0,0,0,0,1,0,1,1,1,1,...,0,0,0,1,1,0,0,1,0,0
1,1,1,0,0,0,0,0,1,0,1,...,0,0,0,1,1,0,0,1,0,1
2,0,0,0,0,1,0,0,1,1,1,...,0,0,0,1,1,0,0,1,0,0
3,0,0,0,0,0,0,0,1,1,1,...,0,0,0,1,0,0,0,1,0,0
4,1,0,0,0,0,0,1,1,1,1,...,0,0,0,1,1,0,0,1,0,0


In [11]:
regnet_df = pd.read_table(f"{data_dir}/human.source", sep="\t", header=None, names= ["REGULATOR SYMBOL", "REGULATOR ID", "TARGET SYMBOL", "TARGET ID"])
print(f"Total interactions: {regnet_df.shape[0]}")
regnet_df.head()

Total interactions: 372774


Unnamed: 0,REGULATOR SYMBOL,REGULATOR ID,TARGET SYMBOL,TARGET ID
0,USF1,7391,S100A6,6277
1,USF1,7391,DUSP1,1843
2,USF1,7391,C4A,720
3,USF1,7391,ABCA1,19
4,TP53,7157,TP73,7161


In [12]:
net_intr = pd.Series(regnet_df["REGULATOR ID"].values, index=regnet_df["TARGET ID"])
net_intr_rev = pd.Series(regnet_df["TARGET ID"].values, index=regnet_df["REGULATOR ID"])
X_df, y_df = tamox_df.iloc[:, 1:], tamox_df["posOutcome"]


In [13]:

def run_exp(X_train, X_test, y_train, y_test, J, vals):

    num_samples = 10000
    num_chains = 3
    res_dict = {"eta": [], "mu": [], "thres": [], "cv_score": [], "test_score": [], "len": [], "num_edges": [],
                "beta_cv_score": [], "beta_test_score": [] }
    for i, eta in enumerate(tqdm(vals)):
        for j, mu in enumerate(vals):
            # print(f"eta - {eta:.2f}, mu - {mu:.2f}")

            contin_init_pos = init_beta
            disc_init_pos = init_gamma

            disc_logprob = generate_disc_logprob_fn(X_train, y_train, J, mu, eta)
            contin_logprob = generate_contin_logprob_fn(X_train, y_train, tau, c)
            kernel = jax.jit(lambda key, state: one_step(key, state, disc_logprob, contin_logprob, disc_step_size, contin_step_size))

            init_state = jax.vmap(init, in_axes=(0, 0, None, None, None, None))(disc_init_pos, contin_init_pos, disc_logprob,
                                                                                contin_logprob,
                                                                                disc_step_size, contin_step_size)
            states = inference_loop_multiple_chains(rng_key, kernel, init_state, num_samples=num_samples, num_chains=num_chains)

            gamma_samples = states.discrete_position[int(burn_in*num_samples):]
            beta_samples = states.contin_position[int(burn_in*num_samples):]
            gamma_samples = gamma_samples.reshape(-1, p)
            beta_samples = beta_samples.reshape(-1, p)

            for t in thresholds:
                res_dict["eta"].append(eta)
                res_dict["mu"].append(mu)
                res_dict["thres"].append(t)
                idx = jnp.squeeze(jnp.argwhere((jnp.mean(gamma_samples, axis=0) > t)))

                num_edges = jnp.count_nonzero(J[idx,:][:,idx]) if idx.size > 1 else 0


                res_dict["len"].append(idx.size)
                res_dict["num_edges"].append(num_edges)

                if idx.size > 0:
                    res_idx_df = run_fs_logreg(X_train, X_test, y_train, y_test, [idx])

                    res_dict["cv_score"].append(res_idx_df["log_cv_score"][0])
                    res_dict["test_score"].append(res_idx_df["log_test_score"][0])

                    beta_sel = jnp.mean(beta_samples[:,idx], axis=0)

                    if idx.size == 1:
                        beta_sel = beta_sel.reshape(-1, 1)
                        X_train_idx_sel = X_train[:,idx].reshape(-1, 1)
                        X_test_idx_sel = X_test[:,idx].reshape(-1, 1)

                    else:
                        X_train_idx_sel = X_train[:,idx]
                        X_test_idx_sel = X_test[:,idx]

                    train_roc = roc_auc_score(y_train, jax.nn.sigmoid((X_train_idx_sel @ beta_sel)))
                    test_roc = roc_auc_score(y_test, jax.nn.sigmoid((X_test_idx_sel @ beta_sel)))
                    res_dict["beta_cv_score"].append(train_roc)
                    res_dict["beta_test_score"].append(test_roc)

                else:
                    res_dict["cv_score"].append(np.nan)
                    res_dict["test_score"].append(np.nan)
                    res_dict["beta_cv_score"].append(np.nan)
                    res_dict["beta_test_score"].append(np.nan)

    # res_df = pd.DataFrame(res_dict)

    print(f"========= Done for seed - {seed} =========")

    return res_dict

In [14]:
# res_dict = run_exp(X_train, X_test, y_train, y_test, J_1, param_vals)

In [49]:
res_df_1 = pd.read_csv(f"{data_dir}/gibbs_sampling_v2_tamox_param_sweep.csv")
res_df_1

Unnamed: 0,eta,mu,thres,cv_score,test_score,len,num_edges,beta_cv_score,beta_test_score
0,0.0,0.0,0.1,0.858077,0.716653,1490,218,0.926824,0.701740
1,0.0,0.0,0.2,0.858077,0.716653,1490,218,0.926824,0.701740
2,0.0,0.0,0.3,0.858077,0.716653,1490,218,0.926824,0.701740
3,0.0,0.0,0.4,0.857444,0.715962,1486,218,0.926368,0.701326
4,0.0,0.0,0.5,0.879019,0.699531,746,24,0.931636,0.707263
...,...,...,...,...,...,...,...,...,...
1084,10.0,10.0,0.5,0.732805,0.683651,116,208,0.846935,0.696355
1085,10.0,10.0,0.6,0.732805,0.683651,116,208,0.846935,0.696355
1086,10.0,10.0,0.7,0.720049,0.693041,101,188,0.828647,0.685998
1087,10.0,10.0,0.8,0.726779,0.698702,97,184,0.827254,0.683513


In [15]:
import os
import glob
def load_mcmc_exp_res(data_dir, moses_res=False):
    df_lst = []

    for file in glob.glob(f"{data_dir}/*.csv"):
        df = pd.read_csv(file)
        if moses_res:
            df = df.sort_values(by="moses_val_score", ascending=False).head(1) # take the top row of the sorted df
        else:
            df = df.sort_values(by="cv_score", ascending=False).head(1)
        df_lst.append(df)

    res_df = pd.concat(df_lst, axis=0)
    res_df = res_df.reset_index(drop=True)
    return res_df

def load_mcmc_exp_len(data_dir, size):
    df_lst = []

    for file in glob.glob(f"{data_dir}/*.csv"):
        df = pd.read_csv(file)
        df = df[df["len"] <= size]
        df = df.sort_values(by="cv_score", ascending=False).head(1)
        df_lst.append(df)

    res_df = pd.concat(df_lst, axis=0)
    res_df = res_df.reset_index(drop=True)
    return res_df

In [13]:
mcmc_params_df = load_mcmc_exp_res(f"{data_dir}/exp_data_2")
mcmc_params_df

Unnamed: 0,seed,eta,mu,thres,cv_score,test_score,len,num_edges,beta_cv_score,beta_test_score
0,644,0.0,0.16681,0.5,0.856467,0.69815,230,4,0.903749,0.708506
1,490,0.0,0.1,0.5,0.85013,0.748412,408,14,0.914032,0.732532
2,805,0.464159,0.0,0.5,0.848774,0.731566,891,246,0.913501,0.718586
3,256,1.29155,0.0,0.5,0.880548,0.698426,661,160,0.902229,0.718034
4,675,0.0,0.16681,0.5,0.895344,0.694283,218,2,0.911499,0.719138
5,781,1.29155,0.0,0.5,0.874963,0.732118,616,160,0.901114,0.714996
6,350,0.16681,0.1,0.5,0.906605,0.718724,401,164,0.918262,0.694421
7,925,0.1,0.0,0.5,0.881103,0.673019,730,140,0.892072,0.681994
8,947,10.0,0.1,0.5,0.860583,0.705468,551,242,0.863323,0.660315
9,549,5.994843,0.0,0.5,0.876115,0.688484,635,120,0.883536,0.710577


In [14]:
mcmc_params_size_100_df = load_mcmc_exp_len(f"{data_dir}/exp_data_2", 100)
mcmc_params_size_100_df

Unnamed: 0,seed,eta,mu,thres,cv_score,test_score,len,num_edges,beta_cv_score,beta_test_score
0,644,0.1,0.0,0.6,0.816948,0.654377,26,22,0.72652,0.669566
1,490,0.0,0.278256,0.5,0.822075,0.675366,61,0,0.860106,0.664457
2,805,0.0,0.278256,0.5,0.820818,0.652168,67,0,0.832751,0.635874
3,256,0.0,0.278256,0.5,0.811286,0.689036,68,0,0.867427,0.678956
4,675,0.1,0.278256,0.5,0.839227,0.706573,83,32,0.842224,0.7219
5,781,0.16681,0.774264,0.4,0.800173,0.644988,55,58,0.812766,0.668186
6,350,0.0,0.278256,0.5,0.857158,0.70533,59,0,0.802558,0.609362
7,925,0.16681,0.278256,0.5,0.872636,0.699669,99,74,0.791489,0.661281
8,947,0.1,0.0,0.6,0.804452,0.73129,32,28,0.691033,0.589202
9,549,1.29155,0.0,0.6,0.816686,0.697045,83,120,0.837259,0.696631


In [14]:
# %cd /home/xabush/code/snet/moses-incons-pen-xp
import time
import joblib
import datetime
import itertools
# from notebooks.variable_selection.MosesEstimator import *


def run_gibbs_sampling(seed, X_df, y_df, eta, mu, thres):
    start_time = time.time()
    num_chains = 2
    disc_step_size = 0.1
    contin_step_size = 1e-5
    n_steps = 10000
    tau, c = 0.01, 1000
    burn_in = 0.1

    data_path = f"{data_dir}/exp_data_2"

    key = jax.random.PRNGKey(seed)
    jnp.save(f"{data_path}/rand_test/local/jax_key.npy", key)
    np.random.random(seed)
    X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, shuffle=True, random_state=seed,
                                                        stratify=y_df, test_size=0.3)
    idx_sig = np.load(f"{data_dir}/exp_data_2/npy/idx_sig_s_{seed}.npy")
    X_train, X_test = X_train.iloc[:,idx_sig], X_test.iloc[:,idx_sig]
    np.save(f"{data_path}/rand_test/local/X_train.npy", X_train)
    np.save(f"{data_path}/rand_test/local/X_test.npy", X_test)
    np.save(f"{data_path}/rand_test/local/y_train.npy", y_train)
    np.save(f"{data_path}/rand_test/local/y_test.npy", y_test)
    J = build_network(X_train)
    p = J.shape[1]
    print(f"dim: {p}")
    np.save(f"{data_path}/rand_test/local/J_mat.npy", J)
    beta_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(p), scale_diag=10 * jnp.ones(p))
    gamma_dist = tfd.Bernoulli(probs=0.5 * jnp.ones(p))

    contin_init_pos = beta_dist.sample(seed=key, sample_shape=(num_chains,))
    disc_init_pos = gamma_dist.sample(seed=key, sample_shape=(num_chains,)) * 1.

    # contin_init_pos = beta_dist.sample(seed=key)
    # disc_init_pos = gamma_dist.sample(seed=key) * 1.

    jnp.save(f"{data_dir}/exp_data_2/rand_test/local/contin_init_pos.npy", contin_init_pos)
    jnp.save(f"{data_dir}/exp_data_2/rand_test/local/disc_init_pos.npy", disc_init_pos)

    X_train_dev, y_train_dev = jax.device_put(X_train.to_numpy()), jax.device_put(y_train.to_numpy())
    disc_logprob = generate_disc_logprob_fn(X_train_dev, y_train_dev, J, mu, eta)
    contin_logprob = generate_contin_logprob_fn(X_train_dev, y_train_dev, tau, c)

    kernel = jax.jit(lambda key, state: one_step(key, state, disc_logprob, contin_logprob, disc_step_size, contin_step_size))

    init_state = jax.vmap(init, in_axes=(0, 0, None, None, None, None))(disc_init_pos, contin_init_pos, disc_logprob, contin_logprob,
                                                                        disc_step_size, contin_step_size)

    # init_state = init(disc_init_pos, contin_init_pos, disc_logprob, contin_logprob,
    #                                  disc_step_size, contin_step_size)

    states = inference_loop_multiple_chains(key, kernel, init_state, num_samples=n_steps, num_chains=num_chains)
    # states = inference_loop(key, kernel, init_state, num_samples=n_steps)
    gamma_samples = states.discrete_position[int(burn_in*n_steps):]
    beta_samples = states.contin_position[int(burn_in*n_steps):]

    np.save(f"{data_path}/rand_test/local/gamma_samples.npy", gamma_samples)
    np.save(f"{data_path}/rand_test/local/beta_samples.npy", beta_samples)

    gamma_samples = gamma_samples.reshape(-1, p)

    # beta_samples = beta_samples.reshape(-1, p)

    gamma_means = jnp.mean(gamma_samples, axis=0)
    idx = jnp.squeeze(jnp.argwhere(gamma_means > thres))
    print(f"---- Inference took {(time.time() - start_time) : .2f} seconds -----")
    return idx


def run_moses_on_fs(X_train, X_test, y_train, y_test, seed, complx_ratios=None, tmps=None, div_pres=None,
                    feats=None, init_exemplar=None):

    s_time = time.time()

    if complx_ratios is None:
        complx_ratios = [5, 10, 100, 1000]
    if tmps is None:
        tmps = [10, 100, 1000, 2000]
    if div_pres is None:
        div_pres = [0.0, 0.3, 0.6, 0.9]

    if feats is None:
        X_s_train, X_s_test = X_train.astype(np.int64), X_test.astype(np.int64)
    else:
        X_s_train, X_s_test = X_train[:,feats].astype(np.int64), X_test[:,feats].astype(np.int64)
    y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)

    x_train_cv, x_test_cv, y_train_cv, y_test_cv = train_test_split(X_s_train, y_train, shuffle=True, random_state=seed, stratify=y_train, test_size=0.3)

    def run_moses(cr, temp, div):
        if init_exemplar is None:
            moses_opts = ["--complexity-temperature", f"{temp}", "--hc-crossover-min-neighbors", "500", "--hc-crossover-pop-size",
                          "100", "--hc-fraction-of-nn", "0.01", "--diversity-autoscale", "1",
                          "--diversity-pressure", f"{div}"]
        else:
            moses_opts = ["--complexity-temperature", f"{temp}", "--hc-crossover-min-neighbors", "500", "--hc-crossover-pop-size",
                          "100", "--hc-fraction-of-nn", "0.01", "-e", init_exemplar, "--diversity-autoscale", "1",
                          "--diversity-pressure", f"{div}"]
        start_time = time.time()

        moses_est_r = MosesEstimator(fs_algo=None, complexity_ratio=cr, num_models=100,
                                     random_state=seed, ensemble=False, num_evals=1000).fit(x_train_cv, y_train_cv, moses_params=moses_opts)

        cv_score = MosesEstimator.score(moses_est_r, x_test_cv, y_test_cv)

        # moses_est_r = MosesEstimator(fs_algo=None, complexity_ratio=cr, num_models=100,
        #                              random_state=seed, ensemble=False, num_evals=30000).fit(X_s_train, y_train, moses_params=moses_opts)

        # train_score = MosesEstimator.score(moses_est_r, X_s_train, y_train)
        test_score = MosesEstimator.score(moses_est_r, X_s_test, y_test)
        res = {"seed": seed, "complexity_ratio": cr, "complexity_tmp": temp, "div_pres": div,
               "moses_val_score": cv_score, "moses_test_score": test_score}
        end_time = time.time()

        print(f"cr: {cr:.2f}, tmp: {temp: .2f}, div: {div: .2f} ,moses_tr: {cv_score: .4f}, test_score: {test_score: .4f}")
        print(f"============== Took {datetime.timedelta(seconds=(end_time - start_time))} ===============")
        print(len(moses_est_r.models_))
        moses_est_r.cleanup()
        return res

    results = joblib.Parallel(n_jobs=joblib.cpu_count(), require="sharedmem")(
        joblib.delayed(run_moses)(cr, temp, div)
        for cr, temp, div in list(itertools.product(complx_ratios, tmps, div_pres))
    )

    df_dict = {"seed": [] ,"complexity_ratio": [], "complexity_tmp": [], "div_pres": [],
               "moses_val_score": [], "moses_test_score": []}
    for res in results:
        df_dict["seed"].append(res["seed"])
        df_dict["complexity_ratio"].append(res["complexity_ratio"])
        df_dict["complexity_tmp"].append(res["complexity_tmp"])
        df_dict["div_pres"].append(res["div_pres"])
        df_dict["moses_val_score"].append(res["moses_val_score"])
        df_dict["moses_test_score"].append(res["moses_test_score"])

    e_time = time.time()
    print(f"Total elapsed time: {datetime.timedelta(seconds=(e_time - s_time))}")

    return pd.DataFrame(df_dict)


def run_moses(seed, X, y, complexity_ratio=10,
              div_pres=0.6, temp=2000, init_exemplar=None, hnn=0.1, size=-1):

    start_time = time.time()


    X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=seed,
                                                        stratify=y, test_size=0.3)

    idx_sig = np.load(f"{data_dir}/exp_data_2/npy/idx_sig_s_{seed}.npy")

    X_train, X_test = X_train.iloc[:,idx_sig].to_numpy(), X_test.iloc[:,idx_sig].to_numpy()
    y_train, y_test = y_train.to_numpy(), y_test.to_numpy()

    if size == -1:
        feats = np.load(f"{data_dir}/exp_data_2/idx_sel/idx_sel_s_{seed}.npy")
    else:
        feats = np.load(f"{data_dir}/exp_data_2/idx_sel_{size}/idx_sel_s_{seed}.npy")
    print(f"Selected feats len: {feats.size}")

    if feats is None:
        X_s_train, X_s_test = X_train.astype(np.int64), X_test.astype(np.int64)
    else:
        X_s_train, X_s_test = X_train[:,feats].astype(np.int64), X_test[:,feats].astype(np.int64)
    y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)

    x_train_cv, x_test_cv, y_train_cv, y_test_cv = train_test_split(X_s_train, y_train, shuffle=True, random_state=seed, stratify=y_train,
                                                                    test_size=0.3)

    if init_exemplar is None:
        moses_opts = ["--complexity-temperature", f"{temp}", "--hc-crossover-min-neighbors", "500", "--hc-crossover-pop-size",
                      "100", "--hc-fraction-of-nn", f"{hnn}", "--diversity-autoscale", "1",
                      "--diversity-pressure", f"{div_pres}", "-l", "DEBUG"]
    else:
        moses_opts = ["--complexity-temperature", f"{temp}", "--hc-crossover-min-neighbors", "500", "--hc-crossover-pop-size",
                      "100", "--hc-fraction-of-nn", f"{hnn}", "-e", init_exemplar, "--diversity-autoscale", "1",
                      "--diversity-pressure", f"{div_pres}"]


    moses_est_r = MosesEstimator(fs_algo=None, complexity_ratio=complexity_ratio, num_models=100,
                                 random_state=seed, ensemble=False, num_evals=1000).fit(x_train_cv, y_train_cv, moses_params=moses_opts)

    cv_score = MosesEstimator.score(moses_est_r, x_test_cv, y_test_cv)

    # moses_est_r = MosesEstimator(fs_algo=None, complexity_ratio=cr, num_models=100,
    #                              random_state=seed, ensemble=False, num_evals=30000).fit(X_s_train, y_train, moses_params=moses_opts)

    # train_score = MosesEstimator.score(moses_est_r, X_s_train, y_train)
    test_score = MosesEstimator.score(moses_est_r, X_s_test, y_test)

    end_time = time.time()

    print(f"cr: {complexity_ratio:.2f}, tmp: {temp: .2f}, div: {div_pres: .2f} ,moses_tr: {cv_score: .4f}, test_score: {test_score: .4f}")
    print(f"============== Took {datetime.timedelta(seconds=(end_time - start_time))} ===============")

    return moses_est_r

def run_moses_on_mcmc_res(res_param_df, X, y):

    data_path = f"{data_dir}/exp_data_2"

    # rem_seeds = [23, 99, 763, 1234, 464]

    for i, row  in res_param_df.iterrows():
        seed = int(row["seed"])
        if seed != 1234: # For now, skip seed 1234 as there seems to be reproducibility issue
            np.random.seed(seed)
            idx_sig = np.load(f"{data_path}/npy/idx_sig_s_{seed}.npy")
            X_train, X_test, y_train, y_test =  train_test_split(X, y, test_size=0.3, random_state=seed, stratify=y, shuffle=True)
            X_train, X_test = X_train.iloc[:,idx_sig], X_test.iloc[:,idx_sig]
            eta, mu = row["eta"], row["mu"]
            thres, feat_size = row["thres"], int(row["len"])
            log_cv_score, log_test_score = row["cv_score"], row["test_score"]
            idx = np.load(f"{data_path}/idx_sel_100/idx_sel_s_{seed}.npy")
            print(f"seed; {seed}, eta: {eta : .3f}, mu: {mu: .3f}, thres: {thres}, len: {idx.size}, cv_score: {log_cv_score: .3f}, test_score: {log_test_score: .3f}")
            X_train, X_test = X_train.iloc[:,idx].to_numpy(), X_test.iloc[:,idx].to_numpy()
            y_train, y_test = y_train.to_numpy(), y_test.to_numpy()

            res_moses_df = run_moses_on_fs(X_train, X_test, y_train, y_test, seed)
            res_moses_df.to_csv(f"{data_path}/moses_hnn_001/res_moses_s_{seed}.csv", index=False)

            print(f" ====== Done for seed - {seed} ============")


    # for i, row  in res_param_df.iterrows():
    #     seed = int(row["seed"])
    #     if seed in rem_seeds: # For now, skip seed 1234 as there seems to be reproducibility issue
    #         np.random.seed(seed)
    #         key = jax.random.PRNGKey(seed)
    #         idx_sig = np.load(f"{data_path}/npy/idx_sig_s_{seed}.npy")
    #         X_train, X_test, y_train, y_test =  train_test_split(X, y, test_size=0.3, random_state=seed, stratify=y, shuffle=True)
    #         X_train, X_test = X_train.iloc[:,idx_sig], X_test.iloc[:,idx_sig]
    #         J = build_network(X_train)
    #         np.fill_diagonal(J, 0.0)
    #         eta, mu = row["eta"], row["mu"]
    #         thres, feat_size = row["thres"], int(row["len"])
    #         log_cv_score, log_test_score = row["cv_score"], row["test_score"]
    #         print(f"seed; {seed}, eta: {eta : .3f}, mu: {mu: .3f}, thres: {thres}, len: {feat_size}, cv_score: {log_cv_score: .3f}, test_score: {log_test_score: .3f}")
    #         idx = run_gibbs_sampling(key, X_train, y_train, J, eta, mu, thres)
    #         np.save(f"{data_path}/idx_sel_100/idx_sel_s_{seed}.npy", idx)
    #         # X_train, X_test = X_train.iloc[:,idx].to_numpy(), X_test.iloc[:,idx].to_numpy()
    #         # y_train, y_test = y_train.to_numpy(), y_test.to_numpy()
    #         #
    #         # res_moses_df = run_moses_on_fs(X_train, X_test, y_train, y_test, seed)
    #         # res_moses_df.to_csv(f"{data_path}/moses/res_moses_s_{seed}.csv")
    #
    #         print(f" ====== Done for seed - {seed} ============")

In [15]:
idx = run_gibbs_sampling(644, X_df, y_df, 0.0,0.16681005372000587, 0.5)
len(idx)

dim: 1512
---- Inference took  43.85 seconds -----


280

In [15]:
# run_moses_on_mcmc_res(mcmc_params_size_100_df, X_df, y_df)

seed; 99, eta:  0.000, mu:  0.278, thres: 0.5, len: 68, cv_score:  0.808, test_score:  0.641
---- Inference took  39.87 seconds -----
seed; 1234, eta:  0.000, mu:  0.278, thres: 0.5, len: 63, cv_score:  0.845, test_score:  0.670
---- Inference took  136.57 seconds -----
seed; 23, eta:  0.000, mu:  0.278, thres: 0.5, len: 55, cv_score:  0.815, test_score:  0.634
---- Inference took  170.85 seconds -----
seed; 464, eta:  0.167, mu:  0.000, thres: 0.6, len: 43, cv_score:  0.791, test_score:  0.693
---- Inference took  89.76 seconds -----
seed; 763, eta:  0.100, mu:  0.278, thres: 0.5, len: 79, cv_score:  0.843, test_score:  0.717
---- Inference took  84.07 seconds -----


In [18]:
# run_moses_on_mcmc_res(mcmc_params_df, X_df, y_df)

seed; 464, eta:  5.995, mu:  0.000, thres: 0.5, len: 703, cv_score:  0.869, test_score:  0.694
cr: 5.00, tmp:  10.00, div:  0.90 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  10.00, div:  0.30 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  10.00, div:  0.60 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  10.00, div:  0.00 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 10.00, tmp:  10.00, div:  0.00 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  1000.00, div:  0.30 ,moses_tr:  0.8161, test_score:  0.7145
100
cr: 10.00, tmp:  10.00, div:  0.30 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 10.00, tmp:  10.00, div:  0.90 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  2000.00, div:  0.60 ,moses_tr:  0.7765, test_score:  0.6912
100
cr: 10.00, tmp:  10.00, div:  0.60 ,moses_tr:  0.6105, test_score:  0.6038
2
cr: 5.00, tmp:  2000.00, div:  0.00 ,moses_tr:  0.7890, test_score:  0.6729
100
cr: 5.00, tmp:  2000.00, div:  0.90 ,moses_tr:  0.761

In [22]:
mcmc_params_df

Unnamed: 0,seed,eta,mu,thres,cv_score,test_score,len,num_edges,beta_cv_score,beta_test_score
0,644,0.0,0.16681,0.5,0.856467,0.69815,230,4,0.903749,0.708506
1,490,0.0,0.1,0.5,0.85013,0.748412,408,14,0.914032,0.732532
2,805,0.464159,0.0,0.5,0.848774,0.731566,891,246,0.913501,0.718586
3,256,1.29155,0.0,0.5,0.880548,0.698426,661,160,0.902229,0.718034
4,675,0.0,0.16681,0.5,0.895344,0.694283,218,2,0.911499,0.719138
5,781,1.29155,0.0,0.5,0.874963,0.732118,616,160,0.901114,0.714996
6,350,0.16681,0.1,0.5,0.906605,0.718724,401,164,0.918262,0.694421
7,925,0.1,0.0,0.5,0.881103,0.673019,730,140,0.892072,0.681994
8,947,10.0,0.1,0.5,0.860583,0.705468,551,242,0.863323,0.660315
9,549,5.994843,0.0,0.5,0.876115,0.688484,635,120,0.883536,0.710577


In [16]:
mcmc_moses_res_df = load_mcmc_exp_res(f"{data_dir}/exp_data_2/moses", moses_res=True)
mcmc_moses_res_df

Unnamed: 0,seed,complexity_ratio,complexity_tmp,div_pres,moses_val_score,moses_test_score
0,425,1000,2000,0.9,0.739057,0.667633
1,805,100,100,0.6,0.746633,0.599489
2,440,10,100,0.3,0.781566,0.643469
3,350,10,1000,0.3,0.763047,0.736054
4,221,100,2000,0.9,0.718996,0.625794
5,359,10,2000,0.6,0.728114,0.571665
6,806,1000,1000,0.6,0.744388,0.632974
7,490,10,2000,0.6,0.752946,0.636081
8,919,5,100,0.3,0.703704,0.719345
9,886,1000,100,0.6,0.751684,0.643538


In [20]:
mcmc_params_df[mcmc_params_df["seed"] != 1234][["cv_score", "test_score"]].mean()

cv_score      0.865948
test_score    0.702281
dtype: float64

In [36]:
mcmc_moses_res_df[["moses_val_score", "moses_test_score"]].mean()

moses_val_score     0.729970
moses_test_score    0.654115
dtype: float64

In [18]:
from notebooks.variable_selection.parse_log import parse_log
import tempfile
moses_est = run_moses(425, X_df, y_df, complexity_ratio=5, temp=2000, div_pres=0.0, hnn=0.1)
log_file  = tempfile.NamedTemporaryFile().name

parse_log(moses_est.log_file_, {"prefix": "DemesHill", "output_file": log_file})
pd.read_csv(log_file)

Selected feats len: 544
cr: 5.00, tmp:  2000.00, div:  0.00 ,moses_tr:  0.7138, test_score:  0.6196


Unnamed: 0,time,demeID,iteration,total_steps,total_evals,microseconds,new_instances,num_instances,inst_RAM,num_evals,has_improved,best_weighted_score,delta_weighted,best_raw,delta_raw,complexity
0,0.0,1,1,1,1,742,1,1,0.00106,1,1,0.0,3.40282e+38,0.0,3.40282e+38,0
1,0.56,1,2,2,1000,560189,999,1000,1.06049,1000,1,64.217,64.217,64.617,64.617,2


In [20]:
# run_moses_on_mcmc_res(mcmc_params_size_100_df, X_df, y_df)

seed; 644, eta:  0.100, mu:  0.000, thres: 0.6, len: 26, cv_score:  0.817, test_score:  0.654
cr: 5.00, tmp:  10.00, div:  0.90 ,moses_tr:  0.7029, test_score:  0.6233
99
cr: 5.00, tmp:  2000.00, div:  0.30 ,moses_tr:  0.7093, test_score:  0.6617
100
cr: 5.00, tmp:  2000.00, div:  0.90 ,moses_tr:  0.7118, test_score:  0.6700
99
cr: 5.00, tmp:  1000.00, div:  0.00 ,moses_tr:  0.6813, test_score:  0.6368
100
cr: 5.00, tmp:  1000.00, div:  0.30 ,moses_tr:  0.7079, test_score:  0.6417
100
cr: 5.00, tmp:  10.00, div:  0.00 ,moses_tr:  0.6720, test_score:  0.5972
100
cr: 5.00, tmp:  2000.00, div:  0.60 ,moses_tr:  0.6582, test_score:  0.5933
100
cr: 5.00, tmp:  10.00, div:  0.30 ,moses_tr:  0.7149, test_score:  0.6123
100
cr: 5.00, tmp:  10.00, div:  0.60 ,moses_tr:  0.7041, test_score:  0.6252
100
cr: 5.00, tmp:  100.00, div:  0.60 ,moses_tr:  0.7041, test_score:  0.6252
100
cr: 5.00, tmp:  1000.00, div:  0.60 ,moses_tr:  0.6893, test_score:  0.6194
100
cr: 5.00, tmp:  100.00, div:  0.30 ,m

In [45]:
mcmc_moses_res_100_df = load_mcmc_exp_res(f"{data_dir}/exp_data_2/moses_hnn_001", moses_res=True)
mcmc_moses_res_100_df

Unnamed: 0,seed,complexity_ratio,complexity_tmp,div_pres,moses_val_score,moses_test_score
0,425,10,1000,0.0,0.778479,0.613366
1,805,100,1000,0.3,0.708614,0.575255
2,440,1000,2000,0.3,0.743266,0.621306
3,350,5,2000,0.9,0.695988,0.664388
4,221,5,2000,0.9,0.743126,0.629039
5,359,10,1000,0.0,0.737374,0.640983
6,806,100,2000,0.3,0.743126,0.604667
7,490,5,2000,0.3,0.75,0.629453
8,919,10,2000,0.0,0.763187,0.699531
9,886,5,10,0.9,0.744388,0.568696


In [35]:
mcmc_moses_res_100_df[["moses_val_score", "moses_test_score"]].mean()

moses_val_score     0.740927
moses_test_score    0.617083
dtype: float64

In [47]:
moses_est = run_moses(425, X_df, y_df, complexity_ratio=10, temp=1000, div_pres=0.0, size=100, hnn=0.01)
from notebooks.variable_selection.parse_log import parse_log
import tempfile

log_file = tempfile.NamedTemporaryFile().name

parse_log(moses_est.log_file_, {"prefix": "DemesHill", "output_file": log_file})
pd.read_csv(log_file)

Selected feats len: 86
cr: 10.00, tmp:  1000.00, div:  0.00 ,moses_tr:  0.7785, test_score:  0.6134


Unnamed: 0,time,demeID,iteration,total_steps,total_evals,microseconds,new_instances,num_instances,inst_RAM,num_evals,has_improved,best_weighted_score,delta_weighted,best_raw,delta_raw,complexity
0,0.0,1,1,1,1,179,1,1,0.000191,1,1,0.0,3.40282e+38,0.0,3.40282e+38,0
1,0.006,1,2,2,101,6289,100,101,0.019264,101,1,64.4222,64.4222,64.6222,64.6222,2
2,0.015,1,3,3,201,8368,100,201,0.038338,201,1,67.181,2.7588,67.381,2.7588,2
3,0.019,1,4,4,301,4719,100,301,0.057411,301,0,67.181,0.0,67.381,0.0,2
4,0.169,2,1,5,302,264,1,1,0.000267,1,1,54.2462,3.40282e+38,54.4462,3.40282e+38,2
5,0.18,2,2,6,402,11765,100,101,0.02697,101,1,61.8464,7.6002,62.2464,7.80021,4
6,0.193,2,3,7,502,12837,100,201,0.053673,201,1,66.2391,4.39275,66.7391,4.49275,5
7,0.202,2,4,8,602,8634,100,301,0.080376,301,1,67.8352,1.59607,68.5352,1.79607,7
8,0.21,2,5,9,702,8518,100,401,0.107079,401,0,67.8352,0.0,68.5352,0.0,7
9,0.348,3,1,10,703,238,1,1,0.000229,1,1,53.6992,3.40282e+38,53.7992,3.40282e+38,1
