In [1]:
!pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install tensorflow-probability
!pip install numpyro
!pip install distrax
!pip install git+https://github.com/blackjax-devs/blackjax.git

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_cudnn82]
  Downloading jax-0.3.17.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m65.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting jaxlib==0.3.15+cuda11.cudnn82
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15%2Bcuda11.cudnn82-cp39-none-manylinux2014_x86_64.whl (162.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.7/162.7 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.3.17-py3-none-any.whl size=1217849 sha256=cf5d68c72c247ee1496af27707c457ba8b1b342321249ea0380d9fb1986b5dec
  Stored in directory: /root/.cache/pip/wheels/36/cd/88/2d90379f7549c27d5654e893f74210f30f

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from tensorflow_probability.substrates import jax as tfp
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
import scipy.stats as stats
import numpyro as npyro
import pandas as pd
from tqdm import tqdm

tfd = tfp.distributions
plt.style.use('ggplot')
%load_ext autoreload

In [3]:
jax.default_backend()

'gpu'

In [4]:
def run_fs_clf(clf, X_train, X_test, y_train, y_test, feats):
    results = {"cv_score": [], "test_score": []}
    for fts in feats:
        X_s_train, X_s_test = X_train[:,np.array(fts, dtype=np.int32)].astype(np.int64), X_test[:,np.array(fts, dtype=np.int32)].astype(np.int64)
        y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
        if fts.size == 1:
            X_s_train, X_s_test = X_s_train.reshape(-1, 1), X_s_test.reshape(-1, 1)
        cv_score = np.mean(cross_val_score(clf, X_s_train, y_train, scoring="roc_auc"))
        clf_est = clf.fit(X_s_train, y_train)
        test_score = roc_auc_score(y_test, clf_est.predict_proba(X_s_test)[:,1])
        # print({"moses_cv_score": cv_score, "moses_test_score": test_score, "log_cv_score": cv_score, "log_test_score": test_score})
        results["cv_score"].append(cv_score)
        results["test_score"].append(test_score)

    return pd.DataFrame(results)


def fisher_exact_test(X, y, thres=0.05):
    cols = X.columns
    p_values = np.zeros(len(cols))
    for i, col in enumerate(cols):
        table = pd.crosstab(y, X[col])
        _, p_val = stats.fisher_exact(table, alternative="two-sided")
        p_values[i] = p_val

    idx_sig = np.argwhere(p_values < thres)
    print(f"Total of {len(idx_sig)} variables are significant (p_val = {thres})")

    return idx_sig


def build_network(X):
    p = X.shape[1]
    J = np.zeros((p, p))
    cols = X.columns
    intrs = []
    intrs_rev = []
    for i, g1 in enumerate(cols):
        try:
            g_intrs = list(net_intr[g1])
            for g2 in g_intrs:
                if (g2, g1) not in intrs_rev: # check if we haven't encountered the reverse interaction
                    j = cols.get_loc(g2)
                    J[i, j] = 1.0
                    J[j, i] = 1.0
                    intrs.append((g1, g2))
        except KeyError:
            continue

        # Check the reverse direction
        try:
            g_intrs_rev = list(net_intr_rev[g1])
            for g2 in g_intrs_rev:
                if (g1, g2) not in intrs:
                    j = cols.get_loc(g2)
                    J[i, j] = 1.0
                    J[j, i] = 1.0
                    intrs_rev.append((g2, g1))

        except KeyError:
            continue


    return J

def get_ess(n_chain, samples):
    k = int(samples.shape[0] / n_chain)
    chains = samples.reshape(n_chain, k, samples.shape[-1])
    ess = npyro.diagnostics.effective_sample_size(jax.device_get(chains))
    ess[np.isnan(ess)] = 1.0
    return np.mean(ess)

def prepare_data(seed, X, y, p_val=0.01):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True, stratify=y, random_state=seed)

    try:
        idx_sig = np.load(f"{data_dir}/exp_data_2/npy/idx_sig_s_{seed}.npy")
    except FileNotFoundError:
        idx_sig = np.squeeze(fisher_exact_test(X_train, y_train, p_val))
        jnp.save(f"{data_dir}/exp_data_2/npy/idx_sig_s_{seed}.npy", idx_sig)

    X_train_sig, X_test_sig = X_train.iloc[:,idx_sig], X_test.iloc[:,idx_sig]



    return X_train_sig, X_test_sig, y_train, y_test

import logging
import sys
def setup_logger(log_path, seed):
    logging.getLogger().handlers = []
    logging.getLogger().setLevel(logging.NOTSET)

    formatter = logging.Formatter("%(asctime)s [%(levelname)s], %(message)s")

    console = logging.StreamHandler(sys.stdout)
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)
    logging.getLogger().addHandler(console)

    rotatingHandler = logging.handlers.RotatingFileHandler(filename=f"{log_path}/logs/log_s_{seed}.log", maxBytes=(1048576*5),
                                                           backupCount=7)
    rotatingHandler.setLevel(logging.INFO)
    rotatingHandler.setFormatter(formatter)
    logging.getLogger().addHandler(rotatingHandler)
    log = logging.getLogger()
    return log

In [5]:
from typing import Callable, NamedTuple
from blackjax.types import PRNGKey, PyTree


class MixedMALAState(NamedTuple):
    """Holds info about the discrete and the continuous r.vs in the mixed support"""

    discrete_position: PyTree
    contin_position: PyTree

    disc_logprob: float
    contin_logprob: float

    discrete_logprob_grad: PyTree
    contin_logprob_grad: PyTree

    disc_step_size: float
    contin_step_size: float


from blackjax.mcmc.diffusion import generate_gaussian_noise
from blackjax.mcmc.mala import MALAState

EPS = 1e-10


def diff_fn(state, step_size):
    theta = jax.tree_util.tree_map(lambda x, g: -0.5 * (g) * (2. * x - 1) - (1. / (2. * step_size)),
                                   state.position, state.logprob_grad)

    return jax.nn.sigmoid(theta)


def take_discrete_step(rng_key: PRNGKey, disc_state: MALAState, contin_state: MALAState,
                       logprob_fn: Callable, disc_grad_fn: Callable,
                       step_size: float) -> MALAState:
    _, key_rmh, key_accept = jax.random.split(rng_key, 3)
    theta_cur = disc_state.position

    u = jax.random.uniform(key_rmh, shape=disc_state.position.shape)
    p_curr = diff_fn(disc_state, step_size)
    ind = jnp.array(u < p_curr)
    pos_new = (1. - theta_cur) * ind + theta_cur * (1. - ind)

    logprob_new = logprob_fn(pos_new, contin_state.position)
    logprob_grad_new = disc_grad_fn(pos_new, contin_state.position)
    new_state = MALAState(pos_new, logprob_new, logprob_grad_new)  # No metropolis update just accept the move

    return new_state


def take_contin_step(rng_key: PRNGKey, disc_state: MALAState, contin_state: MALAState,
                     logprob_fn: Callable, contin_grad_fn: Callable,
                     step_size: float) -> MALAState:
    key_integrator, key_rmh = jax.random.split(rng_key)
    noise = generate_gaussian_noise(key_integrator, contin_state.position)
    new_position = jax.tree_util.tree_map(
        lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
        contin_state.position,
        contin_state.logprob_grad,
        noise,
    )

    logprob_new = logprob_fn(disc_state.position, new_position)
    logprob_grad_new = contin_grad_fn(disc_state.position, new_position)
    new_state = MALAState(new_position, logprob_new, logprob_grad_new)

    return new_state


def one_step(
        rng_key: PRNGKey, state: MixedMALAState,
        discrete_logprob_fn: Callable, contin_logprob_fn: Callable,
        discrete_step_size: float, contin_step_size: float
) -> MixedMALAState:
    disc_grad_fn = jax.grad(discrete_logprob_fn)
    contin_grad_fn = jax.grad(contin_logprob_fn, argnums=1)
    # Evolve each variable in tandem and combine the results

    disc_state = MALAState(state.discrete_position, state.disc_logprob, state.discrete_logprob_grad)
    contin_state = MALAState(state.contin_position, state.contin_logprob, state.contin_logprob_grad)
    # print(f"disc pos: {disc_state.position}, contin pos: {contin_state.position}")
    # Take a step for the discrete variable - sample from p(discrete | contin)
    new_disc_state = take_discrete_step(rng_key, disc_state, contin_state,
                                        discrete_logprob_fn, disc_grad_fn, discrete_step_size)
    # Take a step for the contin variable - sample from p(contin | new_discrete)
    new_contin_state = take_contin_step(rng_key, new_disc_state, contin_state,
                                        contin_logprob_fn, contin_grad_fn, contin_step_size)

    new_state = MixedMALAState(new_disc_state.position, new_contin_state.position,
                               new_disc_state.logprob, new_contin_state.logprob,
                               new_disc_state.logprob_grad, new_contin_state.logprob_grad,
                               discrete_step_size, contin_step_size)

    return new_state

def init(disc_position: PyTree,contin_position: PyTree,
         disc_logprob_fn: Callable, contin_logprob_fn: Callable,
         init_disc_step: float, init_contin_step: float) -> MixedMALAState:

    disc_logprob, disc_grad_logprob = jax.value_and_grad(disc_logprob_fn)(disc_position, contin_position)
    contin_logprob, contin_grad_logprob = jax.value_and_grad(contin_logprob_fn, argnums=1)(disc_position, contin_position)

    return MixedMALAState(disc_position, contin_position,
                          disc_logprob, contin_logprob,
                          disc_grad_logprob, contin_grad_logprob,
                          init_disc_step, init_contin_step)

In [6]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

def inference_loop_multiple_chains(rng_key, kernel, initial_state, num_samples, num_chains):

    def one_step(states, rng_key):
        keys = jax.random.split(rng_key, num_chains)
        states = jax.vmap(kernel)(keys, states)
        return states, states

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [7]:
def gamma_energy(theta, J, eta, mu):
    xg = theta.T @ J
    xgx = xg @ theta
    return eta*xgx - mu*jnp.sum(theta)

def generate_disc_logprob_fn(X, y, J, mu, eta):

    def discrete_logprob_fn(gamma, beta):
        # beta = pos["beta"]
        X_gamma = (X @ jnp.diag(gamma))
        ising_logp = gamma_energy(gamma, J, eta, mu)
        ll_dist = tfd.Bernoulli(logits=(X_gamma @ beta))
        log_ll = jnp.sum(ll_dist.log_prob(y), axis=0)

        # print(f"gamma logp: {ising_logp}, log_ll: {log_ll}")

        return ising_logp + log_ll

    return discrete_logprob_fn


def generate_contin_logprob_fn(X, y, tau, c):
    n, p = X.shape
    cov = X.T @ X
    R = np.identity(p)
    v, l = 1., 1.

    def contin_logprob_fn(gamma, beta):
        # beta = pos["beta"]

        D = (gamma*c*tau) + (1 - gamma)*(tau)
        # D_inv = jnp.linalg.inv(jnp.diag(D))

        # A = jnp.linalg.inv((1./sigma**2)*cov + (D_inv @ R @ D_inv))
        beta_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(p), scale_diag=D)
        # print(beta_dist.sample(seed=rng_key))
        beta_logp = beta_dist.log_prob(beta)
        X_gamma = (X @ jnp.diag(gamma))
        ll_dist = tfd.Bernoulli(logits=(X_gamma @ beta))
        log_ll = jnp.sum(ll_dist.log_prob(y), axis=0)

        # print(f"beta logp: {beta_logp}, log_ll: {log_ll}")

        return beta_logp + log_ll

    return contin_logprob_fn

In [8]:
# data_dir = "/home/xabush/code/snet/moses-incons-pen-xp/data"
data_dir = "."
tamox_df = pd.read_csv(f"{data_dir}/tamoxBinaryEntrez.csv")
tamox_df.describe()

Unnamed: 0,posOutcome,4111,4110,10661,131,4438,330,1109,2637,2642,...,7634,55769,7637,7644,741,54993,79364,7791,23140,26009
count,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0,...,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0,642.0
mean,0.733645,0.225857,0.450156,0.255452,0.459502,0.367601,0.490654,0.311526,0.478193,0.481308,...,0.404984,0.5,0.5,0.5,0.356698,0.5,0.5,0.5,0.470405,0.5
std,0.442397,0.418471,0.497897,0.436455,0.498746,0.482528,0.500302,0.463479,0.499914,0.50004,...,0.491272,0.50039,0.50039,0.50039,0.479398,0.50039,0.50039,0.50039,0.499513,0.50039
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.5,0.5,0.5,0.0,0.5,0.5,0.5,0.0,0.5
75%,1.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [9]:
regnet_df = pd.read_table(f"{data_dir}/human.source", sep="\t", header=None, names= ["REGULATOR SYMBOL", "REGULATOR ID", "TARGET SYMBOL", "TARGET ID"])
print(f"Total interactions: {regnet_df.shape[0]}")
regnet_df.head()

Total interactions: 372774


  regnet_df = pd.read_table(f"{data_dir}/human.source", sep="\t", header=None, names= ["REGULATOR SYMBOL", "REGULATOR ID", "TARGET SYMBOL", "TARGET ID"])


Unnamed: 0,REGULATOR SYMBOL,REGULATOR ID,TARGET SYMBOL,TARGET ID
0,USF1,7391,S100A6,6277
1,USF1,7391,DUSP1,1843
2,USF1,7391,C4A,720
3,USF1,7391,ABCA1,19
4,TP53,7157,TP73,7161


In [10]:
net_intr = pd.Series(regnet_df["REGULATOR ID"].values, index=regnet_df["TARGET ID"])
net_intr_rev = pd.Series(regnet_df["TARGET ID"].values, index=regnet_df["REGULATOR ID"])

In [11]:
exp_seeds = []
with open(f"{data_dir}/seeds.txt", "r") as fp:
    for line in fp.readlines():
        exp_seeds.append(int(line.strip()))

print(len(exp_seeds))
jax_exp_seeds = []
for seed in exp_seeds:
    jax_exp_seeds.append(jax.random.PRNGKey(seed))

50


In [12]:
X_df, y_df = tamox_df.iloc[:,1:], tamox_df["posOutcome"]

In [13]:
thresholds = np.round(np.linspace(0.1, 0.9, 9), decimals=1)
param_vals = np.logspace(-1, 1, 10)
param_vals = np.insert(param_vals, 0, 0.)
param_vals

array([ 0.        ,  0.1       ,  0.16681005,  0.27825594,  0.46415888,
        0.77426368,  1.29154967,  2.15443469,  3.59381366,  5.9948425 ,
       10.        ])

In [14]:
import time
import datetime
from sklearn.svm import SVC
import warnings
warnings.filterwarnings("ignore")

def run_exp(X, y, seeds, jax_seeds):

    num_samples = 10000
    num_chains = 3
    burn_in = 0.1
    tau, c = 0.01, 1000
    disc_step_size = 0.1
    contin_step_size = 1e-5

    for s, seed in enumerate(seeds):
        np.random.seed(seed)
        rng_key = jax_seeds[s]
        log = setup_logger(f"{data_dir}/exp_data_3" ,seed)
        start_time = time.time()
        log.info(f"========= Running seed - {seed} =========")
        X_train_sig, X_test_sig, y_train, y_test = prepare_data(seed, X, y, p_val=0.05)
        J = build_network(X_train_sig)
        p = J.shape[1]
        np.fill_diagonal(J, 0.)
        log.info(f"Num of sig feats - {p}")
        beta_dist = tfd.MultivariateNormalDiag(loc=jnp.zeros(p), scale_diag=10 * jnp.ones(p))
        gamma_dist = tfd.Bernoulli(probs=0.5 * jnp.ones(p))

        res_dict = {"seed": [], "eta": [], "mu": [], "thres": [], "cv_score": [], "test_score": [], "len": [], "num_edges": [],
                    "beta_cv_score": [], "beta_test_score": [] }

        init_beta = beta_dist.sample(seed=rng_key, sample_shape=(num_chains,))
        init_gamma = gamma_dist.sample(seed=rng_key, sample_shape=(num_chains,)) * 1.

        X_train, X_test = jax.device_put(X_train_sig.to_numpy()), jax.device_put(X_test_sig.to_numpy())
        y_train, y_test = jax.device_put(y_train.to_numpy()), jax.device_put(y_test.to_numpy())

        for i, eta in enumerate(param_vals):
            for j, mu in enumerate(param_vals):
                # print(f"eta - {eta:.2f}, mu - {mu:.2f}")

                contin_init_pos = init_beta
                disc_init_pos = init_gamma

                disc_logprob = generate_disc_logprob_fn(X_train, y_train, J, mu, eta)
                contin_logprob = generate_contin_logprob_fn(X_train, y_train, tau, c)
                kernel = jax.jit(lambda key, state: one_step(key, state, disc_logprob, contin_logprob, disc_step_size, contin_step_size))

                init_state = jax.vmap(init, in_axes=(0, 0, None, None, None, None))(disc_init_pos, contin_init_pos, disc_logprob,
                                                                                    contin_logprob,
                                                                                    disc_step_size, contin_step_size)
                states = inference_loop_multiple_chains(rng_key, kernel, init_state, num_samples=num_samples, num_chains=num_chains)

                gamma_samples = states.discrete_position[int(burn_in*num_samples):]
                beta_samples = states.contin_position[int(burn_in*num_samples):]
                gamma_samples = gamma_samples.reshape(-1, p)
                beta_samples = beta_samples.reshape(-1, p)

                for t in thresholds:
                    res_dict["seed"].append(seed)
                    res_dict["eta"].append(eta)
                    res_dict["mu"].append(mu)
                    res_dict["thres"].append(t)
                    idx = jnp.squeeze(jnp.argwhere((jnp.mean(gamma_samples, axis=0) > t)))

                    num_edges = jnp.count_nonzero(J[idx,:][:,idx]) if idx.size > 1 else 0


                    res_dict["len"].append(idx.size)
                    res_dict["num_edges"].append(num_edges)
                    clf = SVC(kernel="rbf", probability=True)
                    if idx.size > 0:
                        res_idx_df = run_fs_clf(clf, X_train, X_test, y_train, y_test, [idx])

                        res_dict["cv_score"].append(res_idx_df["cv_score"][0])
                        res_dict["test_score"].append(res_idx_df["test_score"][0])

                        beta_sel = jnp.mean(beta_samples[:,idx], axis=0)

                        if idx.size == 1:
                            beta_sel = beta_sel.reshape(-1, 1)
                            X_train_idx_sel = X_train[:,idx].reshape(-1, 1)
                            X_test_idx_sel = X_test[:,idx].reshape(-1, 1)

                        else:
                            X_train_idx_sel = X_train[:,idx]
                            X_test_idx_sel = X_test[:,idx]

                        train_roc = roc_auc_score(y_train, jax.nn.sigmoid((X_train_idx_sel @ beta_sel)))
                        test_roc = roc_auc_score(y_test, jax.nn.sigmoid((X_test_idx_sel @ beta_sel)))
                        res_dict["beta_cv_score"].append(train_roc)
                        res_dict["beta_test_score"].append(test_roc)

                    else:
                        res_dict["cv_score"].append(np.nan)
                        res_dict["test_score"].append(np.nan)
                        res_dict["beta_cv_score"].append(np.nan)
                        res_dict["beta_test_score"].append(np.nan)

        res_df = pd.DataFrame(res_dict)

        res_df.to_csv(f"{data_dir}/exp_data_3/res_param_seed_{seed}.csv", index=False)
        end_time = time.time()
        elapsed_time = end_time - start_time
        log.info(f"========= Done for seed - {seed} , Elapsed time - {datetime.timedelta(seconds=elapsed_time)} =========")


In [16]:
k = 18
print(exp_seeds[k:])

[919, 256, 206, 507, 98, 97, 675, 539, 425, 233, 781, 886, 837, 350, 482, 503, 713, 950, 947, 337, 5, 356, 38, 210, 359, 67, 61, 499, 925, 506, 805, 490]


In [17]:
curr_seeds = exp_seeds[k:]
curr_jax_seeds = jax_exp_seeds[k:]
len(curr_seeds)

32

In [None]:
run_exp(X_df, y_df, curr_seeds, curr_jax_seeds)

2022-09-21 11:38:46,104 [INFO], Num of sig feats - 1076
