In [None]:
param = {
  "model": 'mgcit',
  "sample_size": 2000, # number of samples
  "batch_size": 64,
  "z_dim": 200, # [50,100,150,200,250,600,1300]
  "dx": 1,
  "dy": 1,
  "test": 'power', # ['type1error', 'power']
  "n_test": 500, ###original 500
  "n_iters": 3000, #original 1000
  "eps_std": 0.5,
  "dist_z": 'gaussian', #### choices=['gaussian', 'laplace']
  "alpha_x": 0.75, ##only used under alternative [0.15, 0.30, 0.45, 0.60, 0.75]
  "m_value": 100,
  "k_value": 2,
  "b_value": 30, #original 30
  "j_value": 1000, #original 1000##bootstrap number
  "alpha": 0.1,
  "alpha1": 0.05,
  "set_seed": 42,
}
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Utilites related to Sinkhorn computations and training for TensorFlow 2.0
import tensorflow as tf
import logging
import tensorflow_probability as tfp
from sklearn.metrics.pairwise import rbf_kernel
from scipy.stats import rankdata, ks_2samp, wilcoxon
from sklearn.model_selection import KFold
from datetime import datetime
import decimal
import torch
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
import gc # Garbage Collector
from tqdm import tqdm

logging.getLogger('tensorflow').disabled = True
tf.keras.backend.set_floatx('float32')


'''
This code reproduces the real data experiments using the CCLE data.
Preprocessing steps follow the code of W. Tansey at https://github.com/tansey/hrt.
And the following code is obtained from https://github.com/alexisbellot/GCIT.
'''


def load_ccle(drug_target='PLX4720', feature_type='both', normalize=False):
    '''
    :param drug target: specific drug we w ant to analyse
    :param normalize: normalize data
    :return: genetic features (mutations) as a 2d array for each cancer cell and corresponding drug response measured with Amax
    '''
    if feature_type in ['expression', 'both']:
        # Load gene expression
        expression = pd.read_csv('./ccle_data/expression.txt', delimiter='\t', header=2, index_col=1).iloc[:, 1:]
        expression.columns = [c.split(' (ACH')[0] for c in expression.columns]
        features = expression
    if feature_type in ['mutation', 'both']:
        # Load gene mutation
        mutations = pd.read_csv('./ccle_data/mutation.txt', delimiter='\t', header=2, index_col=1).iloc[:,1:]
        mutations = mutations.iloc[[c.endswith('_MUT') for c in mutations.index]]
        features = mutations
    if feature_type == 'both':
        # get cells having both expression and mutation data
        both_cells = set(expression.columns) & set(mutations.columns)
        z = {}
        for c in both_cells:
            exp = expression[c].values
            if len(exp.shape) > 1:
                exp = exp[:, 0]
            z[c] = np.concatenate([exp, mutations[c].values])
        both_df = pd.DataFrame(z, index=[c for c in expression.index] + [c for c in mutations.index])
        features = both_df

    print('Genetic features dimension = {} on {} cancer cells'.format(features.shape[0], features.shape[1]))

    # Get per-drug X and y regression targets
    response = pd.read_csv('./ccle_data/response.csv', header=0, index_col=[0, 2])

    # names of cell lines, there are 504
    cells = response.index.levels[0]
    # names of drugs, there are 24
    drugs = response.index.levels[1]

    X_drugs = [[] for _ in drugs]
    y_drugs = [[] for _ in drugs]

    # subset data to include only cells, mutations and response associated with chosen drug
    for j, drug in enumerate(drugs):
            if drug_target is not None and drug != drug_target:
                continue # return to beginning of the loop
            for i, cell in enumerate(cells):
                if cell not in features.columns or (cell, drug) not in response.index:
                    continue
                # all j empty except index that corresponds to target drug
                # for this j we iteratively append all the mutations on every cell
                X_drugs[j].append(features[cell].values) # store genetic features (mutations and expression) that appear in cells
                y_drugs[j].append(response.loc[(cell, drug), 'Amax']) # store response of the drug
            print('{}: Cell number = {}'.format(drug, len(y_drugs[j])))

    # convert to np array
    X_drugs = [np.array(x_i) for x_i in X_drugs]
    y_drugs = [np.array(y_i) for y_i in y_drugs]

    if normalize:
        X_drugs = [(x_i if (len(x_i) == 0) else (x_i - x_i.min(axis=0, keepdims=True)) /
                                                (x_i.max(axis=0, keepdims=True) - x_i.min(axis=0, keepdims=True)))
                   for x_i in X_drugs]
        y_drugs = [(y_i if (len(y_i) == 0 or y_i.std() == 0) else (y_i - y_i.min(axis=0, keepdims=True)) /
                                                (y_i.max(axis=0, keepdims=True) - y_i.min(axis=0, keepdims=True)))
                   for y_i in y_drugs]

    '''
    if normalize:
        X_drugs = [(x_i if (len(x_i) == 0) else (x_i - x_i.mean(axis=0, keepdims=True)) /
        x_i.std(axis=0).clip(1e-6)) for x_i in X_drugs]
        y_drugs = [(y_i if (len(y_i) == 0 or y_i.std() == 0) else (y_i - y_i.mean()) / y_i.std()) for y_i in y_drugs]
    '''
    drug_idx = drugs.get_loc(drug_target)
    # 2d array for features and 1d array for response
    X_drug, y_drug = X_drugs[drug_idx], y_drugs[drug_idx]

    return X_drug, y_drug, features


# X_drug, y_drug, features = load_ccle(feature_type='mutation')


def ccle_feature_filter(X, y, threshold=0.1):
    '''
    :param X: features
    :param y: response
    :param threshold: correlation threshold
    :return: logical array with False for all features that do not have at least pearson correlation at threshold with y
    and correlations for all variables
    '''
    corrs = np.array([np.abs(np.corrcoef(x, y)[0, 1]) if x.std() > 0 else 0 for x in X.T])
    selected = corrs >= threshold # True/False
    print(selected.sum(), selected.shape, corrs)
    return selected, corrs

# ccle_selected, corrs = ccle_feature_filter(X_drug, y_drug, threshold=0.1)

# features.index[ccle_selected]
# stats.describe(corrs[ccle_selected])


def fit_elastic_net_ccle(X, y, nfolds=3):
    '''
    :param X: features
    :param y: response
    :param nfolds: number of folds for hyperparameter tuning
    :return: fitted elastic net model
    '''
    from sklearn.linear_model import ElasticNetCV
    # The parameter l1_ratio corresponds to alpha in the glmnet R package
    # while alpha corresponds to the lambda parameter in glmnet
    # enet = ElasticNetCV(l1_ratio=np.linspace(0.2, 1.0, 10),
    #                     alphas=np.exp(np.linspace(-6, 5, 250)),
    #                     cv=nfolds)
    enet = ElasticNetCV(l1_ratio=0.2, # It always chooses l1_ratio=0.2
                        alphas=np.exp(np.linspace(-6, 5, 250)),
                        cv=nfolds)
    print('Fitting via CV')
    enet.fit(X, y)
    alpha, l1_ratio = enet.alpha_, enet.l1_ratio_
    print('Chose values: alpha={}, l1_ratio={}'.format(alpha, l1_ratio))
    return enet

# elastic_model = fit_elastic_net_ccle(X_drug[:,ccle_selected], y_drug)


def fit_random_forest_ccle(X, y):
    '''
    :param X: features
    :param y: response
    :param nfolds: number of folds for hyperparameter tuning
    :return: fitted elastic net model
    '''
    from sklearn.ensemble import RandomForestRegressor

    rf = RandomForestRegressor()

    rf.fit(X,y)

    return rf

# rf_model = fit_random_forest_ccle(X_drug[:,ccle_selected], y_drug)


def plot_ccle_predictions(model, X, y):
    from sklearn.metrics import r2_score
    plt.close()
    y_hat = model.predict(X)
    plt.scatter(y_hat, y, color='blue')
    plt.plot([min(y.min(), y_hat.min()), max(y.max(), y_hat.max())],
             [min(y.min(), y_hat.min()),max(y.max(), y_hat.max())], color='red', lw=3)
    plt.xlabel('Predicted')
    plt.ylabel('Truth')
    plt.title(' ($r^2$={:.4f})'.format( r2_score(y, y_hat)))
    plt.tight_layout()

# plot_ccle_predictions(elastic_model,X_drug[:,ccle_selected],y_drug)


def print_top_features(model):
    # model_weights = np.mean([m.coef_ for m in model.models], axis=0)
    if model == rf_model:
        model_weights = model.feature_importances_
    else:
        model_weights = model.coef_

    ccle_features = features[ccle_selected]

    print('Top by fit:')
    for idx, top in enumerate(np.argsort(np.abs(model_weights))[::-1]):
        print('{}. {}: {:.4f}'.format(idx+1, ccle_features.index[top], model_weights[top]))

# print_top_features(rf_model)
# print_top_features(elastic_model)


def run_test_ccle(X, Y):
    pval = []
    for x_index in range(X.shape[1]):
        z = np.delete(X, x_index, axis=1)
        x = X[:, x_index]
        x = x.reshape((len(x), 1))
        Y = Y.reshape((len(Y), 1))
        # now run test
        pval.append(GCIT(x, Y, z))

    ccle_features = features[ccle_selected]

    print('Top by fit:')
    for idx, top in enumerate(np.argsort(np.abs(pval))):
        print('{}. {}: {:.4f}'.format(idx+1, ccle_features.index[top], pval[top]))

# run_test_ccle(X_drug[:,ccle_selected],y_drug)


def cost_xy(x, y, scaling_coef):
    '''
    L2 distance between vectors, using expanding and hence is more memory intensive
    :param x: x is tensor of shape [batch_size, x dims]
    :param y: y is tensor of shape [batch_size, y dims]
    :param scaling_coef: a scaling coefficient for distance between x and y
    :return: cost matrix: a matrix of size [batch_size, batch_size] where
    '''
    x = tf.expand_dims(x, 1)
    y = tf.expand_dims(y, 0)
    return tf.reduce_sum((x - y)**2, -1) * scaling_coef


def benchmark_sinkhorn(x, y, scaling_coef, epsilon=1.0, L=10):
    '''
    :param x: a tensor of shape [batch_size, sequence length]
    :param y: a tensor of shape [batch_size, sequence length]
    :param scaling_coef: a scaling coefficient for squared distance between x and y
    :param epsilon: (float) entropic regularity constant
    :param L: (int) number of iterations
    :return: V: (float) value of regularized optimal transport
    '''
    n_data = x.shape[0]
    # Note that batch size of x can be different from batch size of y
    m = 1.0 / tf.cast(n_data, tf.float64) * tf.ones(n_data, dtype=tf.float64)
    n = 1.0 / tf.cast(n_data, tf.float64) * tf.ones(n_data, dtype=tf.float64)
    m = tf.expand_dims(m, axis=1)
    n = tf.expand_dims(n, axis=1)

    c_xy = cost_xy(x, y, scaling_coef)  # shape: [batch_size, batch_size]

    k = tf.exp(-c_xy / epsilon) + 1e-09  # add 1e-09 to prevent numerical issues
    k_t = tf.transpose(k)

    a = tf.expand_dims(tf.ones(n_data, dtype=tf.float64), axis=1)
    b = tf.expand_dims(tf.ones(n_data, dtype=tf.float64), axis=1)

    for i in range(L):
        b = m / tf.matmul(k_t, a)  # shape: [m,]
        a = n / tf.matmul(k, b)  # shape: [m,]

    return tf.reduce_sum(a * k * tf.reshape(b, (1, -1)) * c_xy)


def benchmark_loss(x, y, scaling_coef, sinkhorn_eps, sinkhorn_l, xp=None, yp=None):
    '''
    :param x: real data of shape [batch size, sequence length]
    :param y: fake data of shape [batch size, sequence length]
    :param scaling_coef: a scaling coefficient
    :param sinkhorn_eps: Sinkhorn parameter - epsilon
    :param sinkhorn_l: Sinkhorn parameter - the number of iterations
    :return: final Sinkhorn loss(and several values for monitoring the training process)
    '''
    if yp is None:
        yp = y
    if xp is None:
        xp = x
    x = tf.reshape(x, [x.shape[0], -1])
    y = tf.reshape(y, [y.shape[0], -1])
    xp = tf.reshape(xp, [xp.shape[0], -1])
    yp = tf.reshape(yp, [yp.shape[0], -1])
    loss_xy = benchmark_sinkhorn(x, y, scaling_coef, sinkhorn_eps, sinkhorn_l)
    loss_xx = benchmark_sinkhorn(x, xp, scaling_coef, sinkhorn_eps, sinkhorn_l)
    loss_yy = benchmark_sinkhorn(y, yp, scaling_coef, sinkhorn_eps, sinkhorn_l)

    loss = loss_xy - 0.5 * loss_xx - 0.5 * loss_yy

    return loss

class WGanGenerator(tf.keras.Model):
    '''
    class for WGAN generator
    Args:
        inputs, noise and confounding factor [v, z], of shape [batch size, z_dims + v_dims]
    return:
       fake samples of shape [batch size, x_dims]
    '''
    def __init__(self, n_samples, z_dims, h_dims, v_dims, x_dims, batch_size):
        super(WGanGenerator, self).__init__()
        self.n_samples = n_samples
        self.hidden_dims = h_dims
        self.batch_size = batch_size
        self.dz = z_dims
        self.dx = x_dims
        self.dv = v_dims

        self.input_dim = self.dz + self.dv
        self.input_shape1 = [self.input_dim, self.hidden_dims]
        self.input_shape2 = [self.hidden_dims, self.hidden_dims]
        self.input_shape3 = [self.hidden_dims, self.dx]

        self.w1 = self.xavier_var_creator(self.input_shape1)
        self.b1 = tf.Variable(tf.zeros(self.input_shape1[1], tf.float64))

        self.w2 = self.xavier_var_creator(self.input_shape2)
        self.b2 = tf.Variable(tf.zeros(self.input_shape2[1], tf.float64))

        self.w3 = self.xavier_var_creator(self.input_shape3)
        self.b3 = tf.Variable(tf.zeros(self.input_shape3[1], tf.float64))

    def xavier_var_creator(self, input_shape):
        xavier_stddev = 1.0 / tf.sqrt(input_shape[0] / 2.0)
        init = tf.random.normal(shape=input_shape, mean=0.0, stddev=xavier_stddev)
        init = tf.cast(init, tf.float64)
        var = tf.Variable(init, shape=tf.TensorShape(input_shape), trainable=True)
        return var

    def call(self, inputs, training=None, mask=None):
        # inputs are concatenations of z and v
        z = tf.reshape(tensor=inputs, shape=[-1, self.input_dim])
        h1 = tf.nn.relu(tf.matmul(z, self.w1) + self.b1)
        # h2 = tf.nn.relu(tf.matmul(h1, self.w2) + self.b2)
        out = tf.math.sigmoid(tf.matmul(h1, self.w3) + self.b3)
        return out


class WGanDiscriminator(tf.keras.Model):
    '''
    class for WGAN discriminator
    Args:
        inputss: real and fake samples of shape [batch size, x_dims]
    return:
       features f_x of shape [batch size, features]
    '''
    def __init__(self, n_samples, z_dims, h_dims, v_dims, batch_size):
        super(WGanDiscriminator, self).__init__()
        self.n_samples = n_samples
        self.hidden_dims = h_dims
        self.batch_size = batch_size

        self.input_dim = z_dims + v_dims
        self.input_shape1 = [self.input_dim, self.hidden_dims]
        self.input_shape2 = [self.hidden_dims, self.hidden_dims]
        self.input_shape3 = [self.hidden_dims, 1]

        self.w1 = self.xavier_var_creator(self.input_shape1)
        self.b1 = tf.Variable(tf.zeros(self.input_shape1[1], tf.float64))

        self.w2 = self.xavier_var_creator(self.input_shape2)
        self.b2 = tf.Variable(tf.zeros(self.input_shape2[1], tf.float64))

        self.w3 = self.xavier_var_creator(self.input_shape3)
        self.b3 = tf.Variable(tf.zeros(self.input_shape3[1], tf.float64))

    def xavier_var_creator(self, input_shape):
        xavier_stddev = 1.0 / tf.sqrt(input_shape[0] / 2.0)
        init = tf.random.normal(shape=input_shape, mean=0.0, stddev=xavier_stddev)
        init = tf.cast(init, tf.float64)
        var = tf.Variable(init, shape=tf.TensorShape(input_shape), trainable=True)
        return var

    def call(self, inputs, training=None, mask=None):
        # inputs are concatenations of z and v
        z = tf.reshape(tensor=inputs, shape=[self.batch_size, -1])
        z = tf.cast(z, tf.float64)
        h1 = tf.nn.relu(tf.matmul(z, self.w1) + self.b1)
        # h2 = tf.nn.sigmoid(tf.matmul(h1, self.w2) + self.b2)
        # out = tf.nn.sigmoid(tf.matmul(h1, self.w3) + self.b3)
        out = tf.matmul(h1, self.w3) + self.b3
        return out


class MINEDiscriminator(tf.keras.layers.Layer):
    '''
    class for MINE discriminator for benchmark GCIT
    '''

    def __init__(self, in_dims, output_activation='linear'):
        super(MINEDiscriminator, self).__init__()
        self.output_activation = output_activation
        self.input_dim = in_dims

        self.w1a = self.xavier_var_creator()
        self.w1b = self.xavier_var_creator()
        self.b1 = tf.Variable(tf.zeros([self.input_dim, ], tf.float64))

        self.w2a = self.xavier_var_creator()
        self.w2b = self.xavier_var_creator()
        self.b2 = tf.Variable(tf.zeros([self.input_dim, ], tf.float64))

        self.w3 = self.xavier_var_creator()
        self.b3 = tf.Variable(tf.zeros([self.input_dim, ], tf.float64))

    def xavier_var_creator(self):
        xavier_stddev = 1.0 / tf.sqrt(self.input_dim / 2.0)
        init = tf.random.normal(shape=[self.input_dim, ], mean=0.0, stddev=xavier_stddev)
        init = tf.cast(init, tf.float64)
        var = tf.Variable(init, shape=tf.TensorShape(self.input_dim, ), trainable=True)
        return var

    def mine_layer(self, x, x_hat, wa, wb, b):
        return tf.math.tanh(wa * x + wb * x_hat + b)

    def call(self, x, x_hat):
        h1 = self.mine_layer(x, x_hat, self.w1a, self.w1b, self.b1)
        h2 = self.mine_layer(x, x_hat, self.w2a, self.w2b, self.b2)
        out = self.w3 * (h1 + h2) + self.b3
        return out, tf.exp(out)


class CharacteristicFunction:
    '''
    class to construct a function that represents the characteristic function
    '''

    def __init__(self, size, x_dims, z_dims, test_size):
        self.n_samples = size
        self.hidden_dims = 20
        self.test_size = test_size

        self.input_dim = z_dims + x_dims
        self.z_dims = z_dims
        self.x_dims = x_dims
        self.input_shape1x = [self.x_dims, self.hidden_dims]
        self.input_shape1z = [self.z_dims, self.hidden_dims]
        self.input_shape1 = [self.input_dim, self.hidden_dims]
        self.input_shape2 = [self.hidden_dims, 1]

        self.w1x = self.xavier_var_creator(self.input_shape1x)
        self.b1 = tf.squeeze(self.xavier_var_creator([self.hidden_dims, 1]))

        self.w2 = self.xavier_var_creator(self.input_shape2)
        self.b2 = tf.Variable(tf.zeros(self.input_shape2[1], tf.float64))

    def xavier_var_creator(self, input_shape):
        xavier_stddev = tf.sqrt(2.0 / (input_shape[0]))
        init = tf.random.normal(shape=input_shape, mean=0.0, stddev=xavier_stddev)
        init = tf.cast(init, tf.float64)
        var = tf.Variable(init, shape=tf.TensorShape(input_shape), trainable=True)
        return var

    def update(self):
        self.w1x = self.xavier_var_creator(self.input_shape1x)
        self.b1 = tf.squeeze(self.xavier_var_creator([self.hidden_dims, 1]))
        self.w2 = self.xavier_var_creator(self.input_shape2)

    def call(self, x, z):
        # inputs are concatenations of z and v
        x = tf.reshape(tensor=x, shape=[self.test_size, -1, self.x_dims])
        z = tf.reshape(tensor=z, shape=[self.test_size, -1, self.z_dims])
        # we asssume parameter b for z to be 0
        h1 = tf.nn.sigmoid(tf.matmul(x, self.w1x) + self.b1)
        out = tf.nn.sigmoid(tf.matmul(h1, self.w2))
        return out


#
# The generate_samples_random function and rdc function were inspired by
# GCIT Github repository by Alexis Bellot1,2 Mihaela van der Schaar
# source: https://github.com/alexisbellot/GCIT
#

def generate_samples_random(size=1000, sType='CI', dx=1, dy=1, dz=20, nstd=0.05, alpha_x=0.05,
               normalize=True, seed=None, dist_z='gaussian', scaling_z = False):
    '''
    Generate CI,I or NI post-nonlinear samples
    1. Z is independent Gaussian or Laplace
    2. X = f1(<a,Z> + b + noise) and Y = f2(<c,Z> + d + noise) in case of CI
    Arguments:
        size : number of samples
        sType: CI, I, or NI
        dx: Dimension of X
        dy: Dimension of Y
        dz: Dimension of Z
        nstd: noise standard deviation
        we set f1 to be sin function and f2 to be cos function.
    Output:
        Samples X, Y, Z
    '''
    num = size

    if dist_z == 'gaussian':
        cov = np.eye(dz)
        mu = np.zeros(dz)
        Z = np.random.multivariate_normal(mu, cov, num)

    elif dist_z == 'laplace':
        Z = np.random.laplace(loc=0.0, scale=1.0, size=num*dz)
        Z = np.reshape(Z, (num, dz))

    Ax = np.random.rand(dz, dx)
    for i in range(dx):
        Ax[:, i] = Ax[:, i] / np.linalg.norm(Ax[:, i], ord=1)
    Ay = np.random.rand(dz, dy)
    for i in range(dy):
        Ay[:, i] = Ay[:, i] / np.linalg.norm(Ay[:, i], ord=1)

    Axy = np.ones((dx, dy)) * alpha_x

    if sType == 'CI':
        X = np.sin(np.matmul(Z, Ax) + nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))
        Y = np.cos(np.matmul(Z, Ay) + nstd * np.random.multivariate_normal(np.zeros(dy), np.eye(dy), num))
    elif sType == 'I':
        X = np.sin(nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))
        Y = np.cos(nstd * np.random.multivariate_normal(np.zeros(dy), np.eye(dy), num))
    else:
        X = np.sin(np.matmul(Z, Ax) + nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))
        Y = np.cos(np.matmul(X, Axy) + np.matmul(Z, Ay) + nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))

    if normalize:
        Z = (Z - Z.min()) / (Z.max() - Z.min())
        X = (X - X.min()) / (X.max() - X.min())
        Y = (Y - Y.min()) / (Y.max() - Y.min())
    if scaling_z:
        Z = Z / Z.max()
    return np.array(X), np.array(Y), np.array(Z)


#
# test statistics
#



def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u = 1, sigma_v = 1,
                       boor_rv_type = "gaussian"):

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n,1,1) - torch.swapaxes(z_torch.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    w_mx = torch.exp(-w_mx/sigma_w)

    u_mx_1 = torch.exp(-torch.abs(y_torch.repeat(1,n) - y_torch.repeat(1,n).T)/sigma_u)
    u_mx_2 = torch.mean(torch.exp(-torch.abs(gen_y_all_torch.repeat(n,1,1) - y_torch.repeat(1, n).reshape(n,n,1))/sigma_u), dim = 2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n,1,1)

    temp_mx = gen_y_all_torch_rep[:,:,0].T
    sum_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n,n,1))/sigma_u), dim = 2)

    v_mx_1 = torch.exp(-torch.abs(x_torch.repeat(1,n) - x_torch.repeat(1,n).T)/sigma_v)
    v_mx_2 = torch.mean(torch.exp(-torch.abs(gen_x_all_torch.repeat(n,1,1) - x_torch.repeat(1, n).reshape(n,n,1))/sigma_v), dim = 2)
    v_mx_3 = v_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n,1,1)

    temp2_mx = gen_x_all_torch_rep[:,:,0].T
    sum2_mx = torch.mean(torch.exp(-torch.abs(gen_x_all_torch_rep - temp2_mx.reshape(n,n,1))/sigma_v), dim = 2)

    for i in range(1, M):
      temp_mx = gen_y_all_torch_rep[:,:,i].T
      temp_add_mx = torch.mean(torch.exp(-torch.abs(gen_y_all_torch_rep - temp_mx.reshape(n,n,1))/sigma_u), dim = 2)
      sum_mx = sum_mx + temp_add_mx

      temp2_mx = gen_x_all_torch_rep[:,:,i].T
      temp2_add_mx = torch.mean(torch.exp(-torch.abs(gen_x_all_torch_rep - temp2_mx.reshape(n,n,1))/sigma_v), dim = 2)
      sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1/M*sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1/M*sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx *w_mx * (1-torch.eye(n).to(device))

    stat = 1/(n-1) * torch.sum(FF_mx).item()

    boottemp = np.array([])
    if boor_rv_type == "rademacher":
      eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
      eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
      random_mx = torch.matmul(eboot[:,bb].reshape(-1,1), eboot[:,bb].reshape(-1,1).T)
      bootmatrix = FF_mx * random_mx
      stat_boot = 1/(n-1) * torch.sum(bootmatrix).item()
      boottemp = np.append(boottemp, stat_boot)

    return stat, boottemp

def mgcit(n=500, z_dim=100, simulation='type1error', batch_size=64, n_iter=1000, train_writer=None,
          current_iters=0, nstd=1.0, z_dist='gaussian', x_dims=1, y_dims=1, a_x=0.05, M=500, k=2,
          var_idx=1, b=30, j=1000):
    # generate samples x, y, z
    # arguments: size, sType='CI', dx=1, dy=1, dz=20, nstd=1, fixed_function='linear',
    # debug=False, normalize=True, seed=None, dist_z='gaussian'
    if simulation == 'type1error':
        # generate samples x, y, z under null hypothesis - x and y are conditional independent
        x, y, z = generate_samples_random(size=n, sType='CI', dx=x_dims, dy=y_dims, dz=z_dim, nstd=nstd, alpha_x=a_x,
                            dist_z=z_dist, scaling_z=False, normalize=True)

    elif simulation == 'power':
        # generate samples x, y, z under alternative hypothesis - x and y are dependent
        x, y, z = generate_samples_random(size=n, sType='dependent', dx=x_dims, dy=y_dims, dz=z_dim, nstd=nstd,
                            alpha_x=a_x, dist_z=z_dist, scaling_z=False, normalize=True)

    elif simulation == 'ccle':
        x_drug, y, features = load_ccle(feature_type='mutation', normalize=False)

        ccle_selected, corrs = ccle_feature_filter(x_drug, y, threshold=0.05)

        features.index[ccle_selected]

        var_names = ['BRAF.MC_MUT', 'BRAF.V600E_MUT', 'HIP1_MUT', 'CDC42BPA_MUT', 'THBS3_MUT', 'DNMT1_MUT', 'PRKD1_MUT',
                     'FLT3_MUT', 'PIP5K1A_MUT', 'MAP3K5_MUT']
        idx = []

        for var in var_names:
            idx.append(features.T.columns.get_loc(var))

        x = x_drug[:, idx[5]]
        z = np.delete(x_drug, (idx[5]), axis=1).astype(np.float64)
        z_dim = z.shape[1]

        x = np.expand_dims(x, axis=1).astype(np.float64)
        y = np.expand_dims(y, axis=1)
        n = y.shape[0]
    elif simulation == 'brain':
        path = './data/ADNI-Mediation-new.csv'
        df = pd.read_csv(path, header=None)
        y = df.loc[:, 7].values
        age = df.loc[:, 5].values
        tr_measures = df.loc[:, 12:79].values
        ct_measures = df.loc[:, 80:].values
        all_data = np.concatenate((np.expand_dims(age, axis=1), tr_measures), axis=1)
        all_data = np.concatenate((all_data, ct_measures), axis=1)
        x_idx = np.argwhere(np.isnan(all_data))[:, 0]
        y_idx = np.argwhere(np.isnan(y))[:, 0]
        idx = np.concatenate([x_idx, y_idx])
        idx = np.unique(idx)
        idx_diff = np.arange(0, idx.shape[0])
        remove_idx = idx - idx_diff
        for i in remove_idx:
            all_data = np.delete(all_data, i, axis=0)
            y = np.delete(y, i, axis=0)

        all_data = np.delete(all_data, i, axis=0)
        y = np.delete(y, i, axis=0)
        x = all_data[:, var_idx]
        z = np.delete(all_data, var_idx, axis=1).astype(np.float64)
        z_dim = z.shape[1]

        z = (z - z.min()) / (z.max() - z.min())
        x = (x - x.min()) / (x.max() - x.min())
        y = (y - y.min()) / (y.max() - y.min())

        x = np.expand_dims(x, axis=1).astype(np.float64)
        y = np.expand_dims(y, axis=1)
        n = y.shape[0]

    else:
        raise ValueError('Test does not exist.')

    # w_mx = np.linalg.norm(np.tile(z,(n,1,1)) - np.swapaxes(np.tile(z,(n,1,1)), 0, 1), ord = 1, axis = 2)
    # sigma_w = np.median(w_mx).item()

    # u_mx = np.abs(np.tile(y,(1, n)) - np.tile(y,(1, n)).T)
    # sigma_u = np.median(u_mx).item()

    # v_mx = np.abs(np.tile(x,(1, n)) - np.tile(x,(1, n)).T) )
    # sigma_v = np.median(v_mx).item()

    test_size = int(n/k)
    stat_all = torch.zeros(k, 1)
    boot_temp_all = torch.zeros(k, j)

    cur_k = 0

    # split the train-test sets to k folds
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    epochs = int(n_iter)

    for train_idx, test_idx in kf.split(x):
        x_train, y_train, z_train = x[train_idx], y[train_idx], z[train_idx]

        dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, z_train))
        # Repeat n epochs
        training = dataset.repeat(epochs)
        training_dataset = training.shuffle(100).batch(batch_size * 2)
        # test-set is the one left
        testing_dataset = tf.data.Dataset.from_tensor_slices((x[test_idx], y[test_idx], z[test_idx]))

        if z_dim <= 20:
            v_dims = int(3)
            h_dims = int(3)

        else:
            v_dims = int(50)
            h_dims = int(512)

        v_dist = tfp.distributions.Normal(0, scale=tf.sqrt(1.0 / 3.0))
        # create instance of G & D
        lr = 0.0005
        generator_x = WGanGenerator(n, z_dim, h_dims, v_dims, x_dims, batch_size)
        generator_y = WGanGenerator(n, z_dim, h_dims, v_dims, y_dims, batch_size)
        discriminator_x = WGanDiscriminator(n, z_dim, h_dims, x_dims, batch_size)
        discriminator_y = WGanDiscriminator(n, z_dim, h_dims, y_dims, batch_size)

        gen_clipping_val = 0.5
        gen_clipping_norm = 1.0
        w_clipping_val = 0.5
        w_clipping_norm = 1.0
        scaling_coef = 1.0
        sinkhorn_eps = 0.8
        sinkhorn_l = 30

        gx_optimiser = tf.keras.optimizers.Adam(lr, beta_1=0.5, clipnorm=gen_clipping_norm, clipvalue=gen_clipping_val)
        dx_optimiser = tf.keras.optimizers.Adam(lr, beta_1=0.5, clipnorm=w_clipping_norm, clipvalue=w_clipping_val)
        gy_optimiser = tf.keras.optimizers.Adam(lr, beta_1=0.5, clipnorm=gen_clipping_norm, clipvalue=gen_clipping_val)
        dy_optimiser = tf.keras.optimizers.Adam(lr, beta_1=0.5, clipnorm=w_clipping_norm, clipvalue=w_clipping_val)

        @tf.function
        def x_update_d(real_x, real_x_p, real_z, real_z_p, v, v_p):
            gen_inputs = tf.concat([real_z, v], axis=1)
            gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
            # concatenate real inputs for WGAN discriminator (x, z)
            d_real = tf.concat([real_x, real_z], axis=1)
            d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
            fake_x = generator_x.call(gen_inputs)
            fake_x_p = generator_x.call(gen_inputs_p)
            d_fake = tf.concat([fake_x, real_z], axis=1)
            d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)

            with tf.GradientTape() as disc_tape:
                f_real = discriminator_x.call(d_real)
                f_fake = discriminator_x.call(d_fake)
                f_real_p = discriminator_x.call(d_real_p)
                f_fake_p = discriminator_x.call(d_fake_p)
                # call compute loss using @tf.function + autograph

                loss1 = benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l,
                                                f_real_p, f_fake_p)
                # disc_loss = - tf.math.minimum(loss1, 1)
                disc_loss = - loss1
            # update discriminator parameters
            d_grads = disc_tape.gradient(disc_loss, discriminator_x.trainable_variables)
            dx_optimiser.apply_gradients(zip(d_grads, discriminator_x.trainable_variables))

        @tf.function
        def x_update_g(real_x, real_x_p, real_z, real_z_p, v, v_p):
            gen_inputs = tf.concat([real_z, v], axis=1)
            gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
            # concatenate real inputs for WGAN discriminator (x, z)
            d_real = tf.concat([real_x, real_z], axis=1)
            d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
            with tf.GradientTape() as gen_tape:
                fake_x = generator_x.call(gen_inputs)
                fake_x_p = generator_x.call(gen_inputs_p)
                d_fake = tf.concat([fake_x, real_z], axis=1)
                d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)
                f_real = discriminator_x.call(d_real)
                f_fake = discriminator_x.call(d_fake)
                f_real_p = discriminator_x.call(d_real_p)
                f_fake_p = discriminator_x.call(d_fake_p)
                # call compute loss using @tf.function + autograph
                gen_loss = benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l, f_real_p, f_fake_p)
            # update generator parameters
            generator_grads = gen_tape.gradient(gen_loss, generator_x.trainable_variables)
            gx_optimiser.apply_gradients(zip(generator_grads, generator_x.trainable_variables))
            return gen_loss

        @tf.function
        def y_update_d(real_x, real_x_p, real_z, real_z_p, v, v_p):
            gen_inputs = tf.concat([real_z, v], axis=1)
            gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
            # concatenate real inputs for WGAN discriminator (x, z)
            d_real = tf.concat([real_x, real_z], axis=1)
            d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
            fake_x = generator_y.call(gen_inputs)
            fake_x_p = generator_y.call(gen_inputs_p)
            d_fake = tf.concat([fake_x, real_z], axis=1)
            d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)

            with tf.GradientTape() as disc_tape:
                f_real = discriminator_y.call(d_real)
                f_fake = discriminator_y.call(d_fake)
                f_real_p = discriminator_y.call(d_real_p)
                f_fake_p = discriminator_y.call(d_fake_p)
                # call compute loss using @tf.function + autograph

                loss1 = benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l,
                                                f_real_p, f_fake_p)
                disc_loss = - loss1
            # update discriminator parameters
            d_grads = disc_tape.gradient(disc_loss, discriminator_y.trainable_variables)
            dy_optimiser.apply_gradients(zip(d_grads, discriminator_y.trainable_variables))

        @tf.function
        def y_update_g(real_x, real_x_p, real_z, real_z_p, v, v_p):
            gen_inputs = tf.concat([real_z, v], axis=1)
            gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
            # concatenate real inputs for WGAN discriminator (x, z)
            d_real = tf.concat([real_x, real_z], axis=1)
            d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
            with tf.GradientTape() as gen_tape:
                fake_x = generator_y.call(gen_inputs)
                fake_x_p = generator_y.call(gen_inputs_p)
                d_fake = tf.concat([fake_x, real_z], axis=1)
                d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)
                f_real = discriminator_y.call(d_real)
                f_fake = discriminator_y.call(d_fake)
                f_real_p = discriminator_y.call(d_real_p)
                f_fake_p = discriminator_y.call(d_fake_p)
                # call compute loss using @tf.function + autograph
                gen_loss = benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l, f_real_p, f_fake_p)
            # update generator parameters
            generator_grads = gen_tape.gradient(gen_loss, generator_y.trainable_variables)
            gy_optimiser.apply_gradients(zip(generator_grads, generator_y.trainable_variables))
            return gen_loss

        for x_batch, y_batch, z_batch in training_dataset.take(n_iter):

            if x_batch.shape[0] != batch_size * 2:
                continue

            # seperate the batch into two parts to train two gans
            x_batch1 = tf.convert_to_tensor(x_batch[:batch_size, ...])
            x_batch2 = tf.convert_to_tensor(x_batch[batch_size:, ...])
            y_batch1 = tf.convert_to_tensor(y_batch[:batch_size, ...])
            y_batch2 = tf.convert_to_tensor(y_batch[batch_size:, ...])
            z_batch1 = tf.convert_to_tensor(z_batch[:batch_size, ...])
            z_batch2 = tf.convert_to_tensor(z_batch[batch_size:, ...])

            noise_v = v_dist.sample([batch_size, v_dims])
            noise_v = tf.cast(noise_v, tf.float64)
            noise_v_p = v_dist.sample([batch_size, v_dims])
            noise_v_p = tf.cast(noise_v_p, tf.float64)
            x_update_d(x_batch1, x_batch2, z_batch1, z_batch2, noise_v, noise_v_p)
            loss_x = x_update_g(x_batch1, x_batch2, z_batch1, z_batch2, noise_v, noise_v_p)
            y_update_d(y_batch1, y_batch2, z_batch1, z_batch2, noise_v, noise_v_p,)
            loss_y = y_update_g(y_batch1, y_batch2, z_batch1, z_batch2, noise_v, noise_v_p)

            current_iters += 1
        tf.keras.backend.clear_session()
        gen_x_all = torch.zeros(test_size, M)
        gen_y_all = torch.zeros(test_size, M)
        z_all = torch.zeros(test_size, z_dim)
        x_all = torch.zeros(test_size, x_dims)
        y_all = torch.zeros(test_size, y_dims)

        cur_itr = 0
        # the following code generate x_1, ..., x_400 for all B and it takes 61 secs for one test
        for test_x, test_y, test_z in testing_dataset:
            test_z = tf.reshape(test_z, (1, z_dim))
            test_y = tf.reshape(test_y, (1, y_dims))
            test_x = tf.reshape(test_x, (1, x_dims))

            tiled_z = tf.tile(test_z, [M, 1])
            noise_v = v_dist.sample([M, v_dims])
            noise_v = tf.cast(noise_v, tf.float64)
            g_inputs = tf.concat([tiled_z, noise_v], axis=1)
            # generator samples from G and evaluate from D
            fake_x = generator_x.call(g_inputs, training=False)
            fake_y = generator_y.call(g_inputs, training=False)
            gen_x_all[cur_itr,:] = torch.from_numpy(fake_x.numpy()).reshape(-1)
            gen_y_all[cur_itr,:] = torch.from_numpy(fake_y.numpy()).reshape(-1)
            x_all[cur_itr,:] = torch.from_numpy(test_x.numpy())
            y_all[cur_itr,:] = torch.from_numpy(test_y.numpy())
            z_all[cur_itr,:] = torch.from_numpy(test_z.numpy())
            cur_itr = cur_itr + 1
        standardise = True

        if standardise:
            gen_x_all = (gen_x_all - torch.mean(gen_x_all, dim=0, keepdim=True)) / torch.std(gen_x_all, dim=0, keepdim=True)
            gen_y_all = (gen_y_all - torch.mean(gen_y_all, dim=0, keepdim=True)) / torch.std(gen_y_all, dim=0, keepdim=True)
            x_all = (x_all - torch.mean(x_all, dim=0, keepdim=True)) / torch.std(x_all, dim=0, keepdim=True)
            y_all = (y_all - torch.mean(y_all, dim=0, keepdim=True)) / torch.std(y_all, dim=0, keepdim=True)
            z_all = (z_all - torch.mean(z_all, dim=0, keepdim=True)) / torch.std(z_all, dim=0, keepdim=True)

        w_mx = torch.linalg.vector_norm(z_all.repeat(test_size,1,1) - torch.swapaxes(z_all.repeat(test_size,1,1), 0, 1), ord = 1, dim = 2)
        sigma_w = torch.median(w_mx).item()

        u_mx = torch.abs(y_all.repeat(1, test_size) - y_all.repeat(1, test_size).T)
        sigma_u = torch.median(u_mx).item()

        v_mx = torch.abs(x_all.repeat(1, test_size) - x_all.repeat(1, test_size).T)
        sigma_v = torch.median(v_mx).item()

        cur_stat, cur_boot_temp = get_p_value_stat_1(boot_num = j, M = M, n = test_size, gen_x_all_torch = gen_x_all.to(device), gen_y_all_torch = gen_y_all.to(device), x_torch = x_all.to(device), y_torch = y_all.to(device), z_torch = z_all.to(device), sigma_w = sigma_w, sigma_u = sigma_u, sigma_v = sigma_v,boor_rv_type = "gaussian")
        stat_all[cur_k,:] = cur_stat
        boot_temp_all[cur_k,:] = torch.from_numpy(cur_boot_temp)
        cur_k = cur_k + 1
    torch.cuda.empty_cache()
    return np.mean(torch.mean(boot_temp_all, dim = 0).numpy() > torch.mean(stat_all).item() )



def run_experiment(params):

    model = params['model']
    sample_size = params['sample_size']
    batch_size = params['batch_size']
    z_dim = params['z_dim']
    dx = params['dx']
    dy = params['dy']
    test = params['test']
    n_test = params['n_test']
    n_iters = params['n_iters']
    eps_std = params['eps_std']
    dist_z = params['dist_z']
    alpha_x = params['alpha_x']
    m_value = params['m_value']
    k_value = params['k_value']
    b_value = params['b_value']
    j_value = params['j_value']
    alpha = params['alpha']
    alpha1 = params['alpha1']
    set_seed = params['set_seed']


    saved_file = "{}-{}{}-{}-{}".format(model, datetime.now().strftime("%h"), datetime.now().strftime("%d"),
                      datetime.now().strftime("%H"), datetime.now().strftime("%M"))
    log_dir = "./trained/{}/log".format(saved_file)
    base_path = './trained/{}/'.format(saved_file)
    train_writer = tf.summary.create_file_writer(logdir=log_dir)



    tf.random.set_seed(set_seed)
    np.random.seed(set_seed)
    torch.manual_seed(set_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(set_seed)

    if test == 'type1error':
        p_values = np.array([])
        p_values1 = []
        p_values5 = []
        test_count = 0
        for n in range(n_test):
            start_time = datetime.now()
            p_value = 0.0
            p_value1 = 0.0
            p_value5 = 0.0

            if model == 'mgcit':
                p_value = mgcit(n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                            n_iter=n_iters, train_writer=train_writer, current_iters=test_count * n_test,
                            nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                            k=k_value, b=b_value, j=j_value)
            else:
                raise ValueError('Test does not exist.')

            test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]##no need to devide by 2
            final_result = np.mean(fp)
            fp1 = [pval < alpha1 for pval in p_values]##no need to devide by 2
            final_result1 = np.mean(fp1)

            # print('The p-value is {}'.format(p_value))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result, z_dim, alpha))
            # print('Type 1 error: {} for z dimension {} with significance level {}'.format(final_result1, z_dim, alpha1))
            gc.collect()

    if test == 'power':
        p_values = np.array([])
        p_values1 = []
        p_values5 = []
        test_count = 0
        for n in range(n_test):
            start_time = datetime.now()
            p_value = 0.0
            p_value1 = 0.0
            p_value5 = 0.0

            if model == 'mgcit':
                p_value = mgcit(n=sample_size, z_dim=z_dim, simulation=test, batch_size=batch_size,
                            n_iter=n_iters, train_writer=train_writer, current_iters=test_count * n_test,
                            nstd=eps_std, z_dist=dist_z, x_dims=dx, y_dims=dy, a_x=alpha_x, M=m_value,
                            k=k_value, b=b_value, j=j_value)
            else:
                raise ValueError('Test does not exist.')

            test_count += 1
            # print("--- The %d'th iteration take %s seconds ---" % (test_count, (datetime.now() - start_time)))

            p_values = np.append(p_values, p_value)
            fp = [pval < alpha  for pval in p_values]
            final_result = np.mean(fp)
            fp1 = [pval < alpha1  for pval in p_values]
            final_result1 = np.mean(fp1)

            # print('The p-value is {}'.format(p_value))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result,z_dim, alpha))
            # print('Power: {} for z dimension {} with significance level {}'.format(final_result1,z_dim, alpha1))
            gc.collect()

    print('Emp Rej Rate: {} for z dimension {} with significance level {}'.format(final_result,z_dim, alpha))
    print('Emp Rej Rate: {} for z dimension {} with significance level {}'.format(final_result1,z_dim, alpha1))

    return p_values

In [None]:
# @title code to get Emperical Rejection Rate for Figure 8 (a) (b) GANS

param["test"] = "type1error"

for z_dim in [50, 100, 150, 200, 250]:
    param["z_dim"] = z_dim
    p_val_list = run_experiment(param)

In [None]:
# @title code to get Size Adjusted Power for Figure 8 (c) (d) GANS

param["test"] = "power"
param["z_dim"] = 200

param["alpha"] = 0.043 # 5% quantile of the p_val_list from previous block when param["test"] = "type1error" param["z_dim"] = 200
param["alpha1"] = 0.084 # 10% quantile of the p_val_list from previous block when param["test"] = "type1error" param["z_dim"] = 200


for alpha_x in [0.15, 0.30, 0.45, 0.60, 0.75]:
    param["alpha_x"] = alpha_x
    p_val_list = run_experiment(param)