In [None]:
import torch
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.distributions as TD
from zmq import device
import torch.optim as optim
from datetime import datetime
import functools
from tqdm import tqdm
import math

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

'''
This code reproduces the real data experiments using the CCLE data.
Preprocessing steps follow the code of W. Tansey at https://github.com/tansey/hrt.
And the following code is obtained from https://github.com/alexisbellot/GCIT.
'''


def load_ccle(drug_target='PLX4720', feature_type='both', normalize=False):
    '''
    :param drug target: specific drug we w ant to analyse
    :param normalize: normalize data
    :return: genetic features (mutations) as a 2d array for each cancer cell and corresponding drug response measured with Amax
    '''
    if feature_type in ['expression', 'both']:
        # Load gene expression
        expression = pd.read_csv('./ccle_data/expression.txt', delimiter='\t', header=2, index_col=1).iloc[:, 1:]
        expression.columns = [c.split(' (ACH')[0] for c in expression.columns]
        features = expression
    if feature_type in ['mutation', 'both']:
        # Load gene mutation
        mutations = pd.read_csv('./ccle_data/mutation.txt', delimiter='\t', header=2, index_col=1).iloc[:,1:]
        mutations = mutations.iloc[[c.endswith('_MUT') for c in mutations.index]]
        features = mutations
    if feature_type == 'both':
        # get cells having both expression and mutation data
        both_cells = set(expression.columns) & set(mutations.columns)
        z = {}
        for c in both_cells:
            exp = expression[c].values
            if len(exp.shape) > 1:
                exp = exp[:, 0]
            z[c] = np.concatenate([exp, mutations[c].values])
        both_df = pd.DataFrame(z, index=[c for c in expression.index] + [c for c in mutations.index])
        features = both_df

    print('Genetic features dimension = {} on {} cancer cells'.format(features.shape[0], features.shape[1]))

    # Get per-drug X and y regression targets
    response = pd.read_csv('./ccle_data/response.csv', header=0, index_col=[0, 2])

    # names of cell lines, there are 504
    cells = response.index.levels[0]
    # names of drugs, there are 24
    drugs = response.index.levels[1]

    X_drugs = [[] for _ in drugs]
    y_drugs = [[] for _ in drugs]

    # subset data to include only cells, mutations and response associated with chosen drug
    for j, drug in enumerate(drugs):
            if drug_target is not None and drug != drug_target:
                continue # return to beginning of the loop
            for i, cell in enumerate(cells):
                if cell not in features.columns or (cell, drug) not in response.index:
                    continue
                # all j empty except index that corresponds to target drug
                # for this j we iteratively append all the mutations on every cell
                X_drugs[j].append(features[cell].values) # store genetic features (mutations and expression) that appear in cells
                y_drugs[j].append(response.loc[(cell, drug), 'Amax']) # store response of the drug
            print('{}: Cell number = {}'.format(drug, len(y_drugs[j])))

    # convert to np array
    X_drugs = [np.array(x_i) for x_i in X_drugs]
    y_drugs = [np.array(y_i) for y_i in y_drugs]

    if normalize:
        X_drugs = [(x_i if (len(x_i) == 0) else (x_i - x_i.min(axis=0, keepdims=True)) /
                                                (x_i.max(axis=0, keepdims=True) - x_i.min(axis=0, keepdims=True)))
                   for x_i in X_drugs]
        y_drugs = [(y_i if (len(y_i) == 0 or y_i.std() == 0) else (y_i - y_i.min(axis=0, keepdims=True)) /
                                                (y_i.max(axis=0, keepdims=True) - y_i.min(axis=0, keepdims=True)))
                   for y_i in y_drugs]

    '''
    if normalize:
        X_drugs = [(x_i if (len(x_i) == 0) else (x_i - x_i.mean(axis=0, keepdims=True)) /
        x_i.std(axis=0).clip(1e-6)) for x_i in X_drugs]
        y_drugs = [(y_i if (len(y_i) == 0 or y_i.std() == 0) else (y_i - y_i.mean()) / y_i.std()) for y_i in y_drugs]
    '''
    drug_idx = drugs.get_loc(drug_target)
    # 2d array for features and 1d array for response
    X_drug, y_drug = X_drugs[drug_idx], y_drugs[drug_idx]

    return X_drug, y_drug, features


# X_drug, y_drug, features = load_ccle(feature_type='mutation')


def ccle_feature_filter(X, y, threshold=0.1):
    '''
    :param X: features
    :param y: response
    :param threshold: correlation threshold
    :return: logical array with False for all features that do not have at least pearson correlation at threshold with y
    and correlations for all variables
    '''
    corrs = np.array([np.abs(np.corrcoef(x, y)[0, 1]) if x.std() > 0 else 0 for x in X.T])
    selected = corrs >= threshold # True/False
    print(selected.sum(), selected.shape, corrs)
    return selected, corrs

# ccle_selected, corrs = ccle_feature_filter(X_drug, y_drug, threshold=0.1)

# features.index[ccle_selected]
# stats.describe(corrs[ccle_selected])


def fit_elastic_net_ccle(X, y, nfolds=3):
    '''
    :param X: features
    :param y: response
    :param nfolds: number of folds for hyperparameter tuning
    :return: fitted elastic net model
    '''
    from sklearn.linear_model import ElasticNetCV
    # The parameter l1_ratio corresponds to alpha in the glmnet R package
    # while alpha corresponds to the lambda parameter in glmnet
    # enet = ElasticNetCV(l1_ratio=np.linspace(0.2, 1.0, 10),
    #                     alphas=np.exp(np.linspace(-6, 5, 250)),
    #                     cv=nfolds)
    enet = ElasticNetCV(l1_ratio=0.2, # It always chooses l1_ratio=0.2
                        alphas=np.exp(np.linspace(-6, 5, 250)),
                        cv=nfolds)
    print('Fitting via CV')
    enet.fit(X, y)
    alpha, l1_ratio = enet.alpha_, enet.l1_ratio_
    print('Chose values: alpha={}, l1_ratio={}'.format(alpha, l1_ratio))
    return enet

# elastic_model = fit_elastic_net_ccle(X_drug[:,ccle_selected], y_drug)


def fit_random_forest_ccle(X, y):
    '''
    :param X: features
    :param y: response
    :param nfolds: number of folds for hyperparameter tuning
    :return: fitted elastic net model
    '''
    from sklearn.ensemble import RandomForestRegressor

    rf = RandomForestRegressor()

    rf.fit(X,y)

    return rf

# rf_model = fit_random_forest_ccle(X_drug[:,ccle_selected], y_drug)


def plot_ccle_predictions(model, X, y):
    from sklearn.metrics import r2_score
    plt.close()
    y_hat = model.predict(X)
    plt.scatter(y_hat, y, color='blue')
    plt.plot([min(y.min(), y_hat.min()), max(y.max(), y_hat.max())],
             [min(y.min(), y_hat.min()),max(y.max(), y_hat.max())], color='red', lw=3)
    plt.xlabel('Predicted')
    plt.ylabel('Truth')
    plt.title(' ($r^2$={:.4f})'.format( r2_score(y, y_hat)))
    plt.tight_layout()

# plot_ccle_predictions(elastic_model,X_drug[:,ccle_selected],y_drug)


def print_top_features(model):
    # model_weights = np.mean([m.coef_ for m in model.models], axis=0)
    if model == rf_model:
        model_weights = model.feature_importances_
    else:
        model_weights = model.coef_

    ccle_features = features[ccle_selected]

    print('Top by fit:')
    for idx, top in enumerate(np.argsort(np.abs(model_weights))[::-1]):
        print('{}. {}: {:.4f}'.format(idx+1, ccle_features.index[top], model_weights[top]))

# print_top_features(rf_model)
# print_top_features(elastic_model)


def get_p_value_stat_1(boot_num, M, n, gen_x_all_torch, gen_y_all_torch, x_torch, y_torch, z_torch, sigma_w, sigma_u=1,
            sigma_v=1, boor_rv_type="gaussian", dy_g=1, dx_g=1):
  """
    Compute the p-value

    Input:
    - boot_num: Integer giving the number of bootstrap samples.
    - M: Integer giving the number of training samples per batch.
    - n: Integer giving the number of training samples.
    - gen_x_all_torch: PyTorch Tensor (batch_size, M) of generated data of X.
    - gen_y_all_torch: PyTorch Tensor (batch_size, M) of generated data of Y.
    - x_torch: PyTorch Tensor (batch_size) of training input X.
    - y_torch: PyTorch Tensor (batch_size) of training input Y.
    - z_torch: PyTorch Tensor (batch_size, dimension_Z) of training input Z
    - sigma_w: Float of the bandwith of the Laplace kernel.
    - sigma_u: Float of the bandwith of the Laplace kernel.
    - sigma_v: Float of the bandwith of the Laplace kernel.
    - boor_rv_type: "rademacher" or "gaussian", specifying the reference distribution.

    Output:
    - p_value: Float giving the p-value.
  """

    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dy_g) - y_torch.repeat(1, n).reshape(n, n, 1, dy_g), ord=1, dim=3) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dy_g)

    temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, 0, :], 0, 1)
    sum_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, dy_g), ord=1, dim=3) / sigma_u), dim=2)

    v_mx_1 = torch.exp(-torch.linalg.vector_norm(x_torch.repeat(n, 1, 1) - torch.swapaxes(x_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_v)
    v_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dx_g) - x_torch.repeat(1, n).reshape(n, n, 1, dx_g), ord=1, dim=3) / sigma_v), dim=2)
    v_mx_3 = v_mx_2.T

    gen_x_all_torch_rep = gen_x_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dx_g)

    temp2_mx = torch.swapaxes(gen_x_all_torch_rep[:, :, 0, :], 0, 1)
    sum2_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1, dx_g), ord=1, dim=3) / sigma_v), dim=2)

    for i in range(1, M):
        temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, i, :], 0, 1)
        temp_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, dy_g), ord=1, dim=3) / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

        temp2_mx = torch.swapaxes(gen_x_all_torch_rep[:, :, i, :], 0, 1)
        temp2_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_x_all_torch_rep - temp2_mx.reshape(n, n, 1, dx_g), ord=1, dim=3) / sigma_v), dim=2)
        sum2_mx = sum2_mx + temp2_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4
    v_mx_4 = 1 / M * sum2_mx
    v_mx = v_mx_1 - v_mx_2 - v_mx_3 + v_mx_4

    FF_mx = u_mx * v_mx * w_mx * (1 - torch.eye(n).to(device))

    stat = 1 / (n - 1) * torch.sum(FF_mx).item()

    # print("U_mx:", u_mx)
    # print("V_mx:", v_mx)
    # print("W_mx:", w_mx)
    boottemp = np.array([])
    if boor_rv_type == "rademacher":
        eboot = torch.sign(torch.randn(n, boot_num)).to(device)
    elif boor_rv_type == "gaussian":
        eboot = torch.randn(n, boot_num).to(device)
    for bb in range(boot_num):
        random_mx = torch.matmul(eboot[:, bb].reshape(-1, 1), eboot[:, bb].reshape(-1, 1).T)
        bootmatrix = FF_mx * random_mx
        stat_boot = 1 / (n - 1) * torch.sum(bootmatrix).item()
        boottemp = np.append(boottemp, stat_boot)
    return stat, boottemp


def find_loss(y_torch, gen_y_all_torch, z_torch, sigma_w, sigma_u, M, dim_y):
    """
    Compute the MMD loss via Laplace kernel.

    Inputs:
    - y_torch: PyTorch Tensor (batch_size) of training input. (X or Y)
    - gen_y_all_torch: PyTorch Tensor (batch_size, M) of generated data.
    - z_torch: PyTorch Tensor (batch_size, dimension_Z) of training input Z.
    - sigma_w: Float of the bandwith of the kernel.
    - sigma_u: Float of the bandwith of the kernel.
    - M: Number of training samples per batch.

    Outputs:
    - loss: PyTorch Tensor containing the MMD loss.
    """
    n = z_torch.shape[0]
    w_mx = torch.linalg.vector_norm(z_torch.repeat(n, 1, 1) - torch.swapaxes(z_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2)
    w_mx = torch.exp(-w_mx / sigma_w)

    u_mx_1 = torch.exp(-torch.linalg.vector_norm(y_torch.repeat(n, 1, 1) - torch.swapaxes(y_torch.repeat(n, 1, 1), 0, 1), ord=1, dim=2) / sigma_u)
    u_mx_2 = torch.mean(
        torch.exp(-torch.linalg.vector_norm(gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dim_y) - y_torch.repeat(1, n).reshape(n, n, 1, dim_y), ord=1, dim=3) / sigma_u), dim=2)
    u_mx_3 = u_mx_2.T

    gen_y_all_torch_rep = gen_y_all_torch.repeat(n, 1, 1).reshape(n, n, -1, dim_y)
    temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, 0, :], 0, 1)
    sum_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, dim_y), ord=1, dim=3) / sigma_u), dim=2)

    for i in range(1, M):
        temp_mx = torch.swapaxes(gen_y_all_torch_rep[:, :, i, :], 0, 1)
        temp_add_mx = torch.mean(torch.exp(-torch.linalg.vector_norm(gen_y_all_torch_rep - temp_mx.reshape(n, n, 1, dim_y), ord=1, dim=3) / sigma_u), dim=2)
        sum_mx = sum_mx + temp_add_mx

    u_mx_4 = 1 / M * sum_mx
    u_mx = u_mx_1 - u_mx_2 - u_mx_3 + u_mx_4

    FF_mx = u_mx * w_mx * (1 - torch.eye(n).to(device))

    loss = 1 / (n) * torch.sum(FF_mx)
    return loss


##### GAN architecture #####

class Generator(torch.nn.Module):
    """
    Specify the neural network architecture of the Generator.

    Here, we consider a FNN with a fully connected hidden layer with a width of 50,
    which is followed by a Leaky ReLU activation. The coefficient of Leaky ReLU needs to be
    specified. Batch normalization may be added prior to the activation function.
    The output layer a fully connected layer without activation.

    Inputs:
    - input_dimension: Integer giving the dimension of input Z.
    - output_dimension: Integer giving the dimension of output X or Y.
    - noise_dimension: Integer giving the dimension of random noise.
    - hidden_layer_size: Integer giving the size of the hidden layer of the generator.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Scalar giving the coefficient of the Leaky ReLU layer.
    - drop_out_p: Float giving the dropout probability.
    - drop_input: Boolean specifying whether to add dropout to the input layer.

    Returns:
    - x: PyTorch Tensor containing the (output_dimension,) output of the generator.
    """

    def __init__(self, input_dimension, output_dimension, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p,
                 drop_input = False):
      super(Generator, self).__init__()
      self.BN_type = BN_type
      self.ReLU_coef = ReLU_coef
      self.fc1 = torch.nn.Linear(input_dimension + noise_dimension, hidden_layer_size, bias=True)
      if BN_type:
        self.BN1 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN2 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
        self.BN3 = torch.nn.BatchNorm1d(hidden_layer_size, 0.8, affine=False)
      self.leakyReLU1 = torch.nn.LeakyReLU(ReLU_coef)
      self.fc2 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc3 = torch.nn.Linear(hidden_layer_size, hidden_layer_size, bias=True)
      self.fc_last = torch.nn.Linear(hidden_layer_size, output_dimension, bias=True)
      self.sigmoid = torch.nn.Sigmoid()
      self.drop_out0 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out1 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out2 = torch.nn.Dropout(p=drop_out_p)
      self.drop_out3 = torch.nn.Dropout(p=drop_out_p)
      self.drop_input = drop_input
      self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
      if self.BN_type:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.BN1(self.fc1(x))))
        x = self.drop_out2(self.leakyReLU1(self.BN2(self.fc2(x))))
        # x = self.drop_out3(self.leakyReLU1(self.BN3(self.fc3(x))))
        x = self.fc_last(x)
        x = self.sigmoid(x)
      else:
        if self.drop_input:
            x = self.drop_out0(x)
        x = self.drop_out1(self.leakyReLU1(self.fc1(x)))
        x = self.drop_out2(self.leakyReLU1(self.fc2(x)))
        # x = self.drop_out3(self.leakyReLU1(self.fc3(x)))
        x = self.fc_last(x)
        x = self.sigmoid(x)
      return x
##### Auxilliary functions #####

def sample_noise(sample_size, noise_dimension, noise_type, input_var):
    """
    Generate a PyTorch Tensor of random noise from the specified reference distribution.

    Input:
    - sample_size: the sample size of noise to generate.
    - noise_dimension: the dimension of noise to generate.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.

    Output:
    - A PyTorch Tensor of shape (sample_size, noise_dimension).
    """

    if (noise_type == "normal"):
      noise_generator = TD.MultivariateNormal(
        torch.zeros(noise_dimension).to(device), input_var * torch.eye(noise_dimension).to(device))

      Z = noise_generator.sample((sample_size,))
    if (noise_type == "unif"):
      Z = torch.rand(sample_size, noise_dimension)
    if (noise_type == "Cauchy"):
      Z = TD.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((sample_size, noise_dimension)).squeeze(2)

    return Z

class DatasetSelect_GAN_ver2(torch.utils.data.Dataset):
  """
    Create a DatasetSelect object to generate the DataLoader in the learning process.

    Input:
    - X: PyTorch Tensor of shape (N, input_dimension) giving the training data of X.
    - Y: PyTorch Tensor of shape (N, output_dimension) giving the training data of Y.
    - batch_size: Integer giving the batch size.
  """

  def __init__(self, Y, Z, batch_size):
    self.Y_real = Y
    self.Z_real = Z
    self.batch_size = batch_size
    self.sample_size = Z.shape[0]

  def __len__(self):
    return self.sample_size

  def __getitem__(self, index):
    return self.Y_real[index], self.Z_real[index]


def train_label(Z, Y, rounded, noise_dimension, noise_type, G_lr, hidden_layer_size,
      DataLoader, BN_type, ReLU_coef,
      epochs_num=10,  sigma_z = 1, sigma_y = 1,
      lambda_1 = 1, wgt_decay = 0,
      lambda_3 = 0, drop_out_p = 0.2, M_train = 3,
      verbose = False):
    """
    Train loop for GAN.

    Inputs:
    - X: PyTorch Tensor (sample_size, dimension_X) of training input.
    - Y: PyTorch Tensor (sample_size, dimension_Y) of training input.
    - Z: PyTorch Tensor (sample_size, dimension_Z) of training input.
    - X_test: PyTorch Tensor (sample_size, dimension_X) of test input.
    - Y_test: PyTorch Tensor (sample_size, dimension_Y) of test input.
    - Z_test: PyTorch Tensor (sample_size, dimension_Z) of test input.
    - noise_dimension: Integer giving the dimension of random noise Z.
    - noise_type: "normal", "unif" or "Cauchy", giving the reference distribution.
    - G_lr: Float giving the learning rate of the generator.
    - hidden_layer_size: Integer giving the size of the hidden layer of the generator.
    - DataLoader: DataLoader object used to generate training batches.
    - BN_type: 'True' or 'False' specifying whether batch normalization is included.
    - ReLU_coef: Float giving the coefficient of the Leaky ReLU layer.
    - epochs_num: Number of epochs over the training dataset to use for training.
    - sigma_z: Float of the bandwith of the kernel.
    - sigma_x: Float of the bandwith of the kernel.
    - sigma_y: Float of the bandwith of the kernel.
    - normal_ini: Boolean specifying whether to initialize the generator with normal initialization.
    - lambda_1: Float giving the coefficient of the MMD loss using Laplace kernel.
    - lambda_2: Float giving the coefficient of the MMD loss using Gaussian kernel. (not using)
    - using_Gen: '1' or '2' specifying whether to use the first or second generator.(not using)
    - wgt_decay: Float giving the weight decay. (L2 regularization)
    - lambda_3: Scalar giving the coefficient of the L1 regularization.
    - drop_out_p: Float giving the dropout probability.
    - M_train: Number of training samples per batch used in the Laplace or Gaussian kernel.

    Outputs:
    - G_zy: PyTorch Net giving the trained generator.
    """

    input_dimension = Z.shape[1]
    output_dimension_y = Y.shape[1]
    M_eval = 50

    train_yz = DatasetSelect_GAN_ver2(Y, Z, 1)
    DataLoader_eval = torch.utils.data.DataLoader(train_yz, batch_size=1, shuffle=True)

    G_zy = Generator(input_dimension, output_dimension_y, noise_dimension, hidden_layer_size, BN_type, ReLU_coef, drop_out_p).to(device)
    G_zy_solver = optim.Adam(G_zy.parameters(), lr=G_lr, betas=(0.5, 0.999), weight_decay=wgt_decay) # betas=(0.5, 0.999),
    # G_zy_solver = optim.SGD(G_zy.parameters(), lr=G_lr, weight_decay=wgt_decay)

    G_zy = G_zy.eval()


    test_size = Z.shape[0]
    gen_y_all = torch.zeros(test_size, M_eval)
    z_all = torch.zeros(test_size, input_dimension)
    y_all = torch.zeros(test_size, output_dimension_y)

    cur_itr = 0

    for i, (y_test, z_test) in enumerate(DataLoader_eval):

      z_test_temp = z_test.repeat(M_eval,1).to(device)
      Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
      fake_y = G_zy(torch.cat((z_test_temp, Noise_fake),dim=1)).reshape(1, -1)
      gen_y_all[cur_itr,:] = fake_y.detach().reshape(-1)
      y_all[cur_itr,:] = y_test
      z_all[cur_itr,:] = z_test
      cur_itr = cur_itr + 1
    gen_y_mean = torch.mean(gen_y_all, dim = 1).reshape(-1,1)
    if rounded:
      gen_y_mean = (gen_y_mean> 0.9).float()

    mse_y = torch.mean((gen_y_mean - y_all)**2).item()

    # print(f'Epoch [0], MSE y [{mse_y}]')

    Y_real = Y.to(device)
    Z_real = Z.to(device)

    batch_size = Z_real.shape[0]
    Z_real_repeat = Z_real.repeat(M_eval,1)
    # Generate fake data
    Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
    Y_fake = G_zy(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

    Y_fake = Y_fake.reshape(batch_size, M_eval, output_dimension_y)
    if rounded:
      Y_fake = (Y_fake> 0.9).float()
    MMD_loss = find_loss(Y_real, Y_fake, Z_real, sigma_z, sigma_y, M_eval, output_dimension_y)

    if verbose:
        print(f'Epoch [0],MSE [{mse_y}] MMD_loss [{MMD_loss.item()}]')

    iter_count = 0
    G_zy = G_zy.train()
    for epoch in tqdm(range(epochs_num)):
        # print('EPOCH: ', (epoch+1))
        batch_count = 0
        G_zy = G_zy.train()
        for Y_real, Z_real in DataLoader:
            Y_real = Y_real.to(device)
            Z_real = Z_real.to(device)

            batch_size = Z_real.shape[0]
            Z_real_repeat = Z_real.repeat(M_train,1)
            # Generate fake data
            Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
            Y_fake = G_zy(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

            Y_fake = Y_fake.reshape(batch_size, M_train, output_dimension_y)

            # standardise = True

            # if standardise:
            #     Y_fake = (Y_fake - torch.mean(Y_fake, dim=0, keepdim=True)) / torch.std(Y_fake, dim=0, keepdim=True)
            #     Y_real = (Y_real - torch.mean(Y_real, dim=0, keepdim=True)) / torch.std(Y_real, dim=0, keepdim=True)
            #     Z_real = (Z_real - torch.mean(Z_real, dim=0, keepdim=True)) / torch.std(Z_real, dim=0, keepdim=True)

            # Generator step
            g_zy_error = None
            G_zy_solver.zero_grad()

            l1_regularization = 0

            for param in G_zy.parameters():
                l1_regularization += torch.linalg.vector_norm(param, ord = 1)

            g_zy_error = lambda_1 * find_loss(Y_real, Y_fake, Z_real, sigma_z, sigma_y, M_train, output_dimension_y) + lambda_3 * l1_regularization

            g_zy_error.backward()
            torch.nn.utils.clip_grad_norm_(G_zy.parameters(), max_norm=0.5)
            G_zy_solver.step()

            iter_count += 1
            batch_count += 1

        G_zy = G_zy.eval()
        if((epoch+1) % 1000 == 0) and verbose:

          test_size = Z.shape[0]
          gen_y_all = torch.zeros(test_size, M_eval)
          z_all = torch.zeros(test_size, input_dimension)
          y_all = torch.zeros(test_size, output_dimension_y)

          cur_itr = 0

          for i, (y_test, z_test) in enumerate(DataLoader_eval):

            z_test_temp = z_test.repeat(M_eval,1).to(device)
            Noise_fake = sample_noise(z_test_temp.size()[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
            fake_y = G_zy(torch.cat((z_test_temp, Noise_fake),dim=1)).reshape(1, -1)
            gen_y_all[cur_itr,:] = fake_y.detach().reshape(-1)
            y_all[cur_itr,:] = y_test
            z_all[cur_itr,:] = z_test
            cur_itr = cur_itr + 1
          gen_y_mean = torch.mean(gen_y_all, dim = 1).reshape(-1,1)
          if rounded:
            gen_y_mean = (gen_y_mean> 0.9).float()
          mse_y = torch.mean((gen_y_mean - y_all)**2).item()

          # print(f'Epoch [{epoch+1}/{epochs_num}], MSE y [{mse_y}]')

          Y_real = Y.to(device)
          Z_real = Z.to(device)

          batch_size = Z_real.shape[0]
          Z_real_repeat = Z_real.repeat(M_eval,1)
          # Generate fake data
          Noise_fake = sample_noise(Z_real_repeat.shape[0], noise_dimension, noise_type, input_var = 1.0/3.0).to(device)
          Y_fake = G_zy(torch.cat((Z_real_repeat,Noise_fake),dim=1)).to(device)

          Y_fake = Y_fake.reshape(batch_size, M_eval, output_dimension_y)
          if rounded:
            Y_fake = (Y_fake> 0.9).float()
          MMD_loss = find_loss(Y_real, Y_fake, Z_real, sigma_z, sigma_y, M_eval, output_dimension_y)

          if verbose:
              print(f'Epoch [{epoch+1}/{epochs_num}], MSE [{mse_y}] MMD_loss [{MMD_loss.item()}]')



    return G_zy



In [None]:
def mgan_ccle(ccle_index_input, ccle_threshold_input, batch_size,
        noise_dimension, G_zy_lr, G_zx_lr, k, M, boot_num,
        boor_rv_type, set_seed = 42, verbose = False, epcoh_y = 4000,
        epcoh_x = 4000):
    x_drug, y, features = load_ccle(feature_type='mutation', normalize=False)

    ccle_selected, corrs = ccle_feature_filter(x_drug, y, threshold=ccle_threshold_input)

    features.index[ccle_selected]

    var_names = ['BRAF.MC_MUT', 'BRAF.V600E_MUT', 'HIP1_MUT', 'CDC42BPA_MUT', 'THBS3_MUT', 'DNMT1_MUT', 'PRKD1_MUT',
                  'FLT3_MUT', 'PIP5K1A_MUT', 'MAP3K5_MUT']
    idx = []

    for var in var_names:
        idx.append(features.T.columns.get_loc(var))

    x = x_drug[:, idx[ccle_index_input]]
    ccle_selected[idx] = True
    ccle_selected[idx[ccle_index_input]] = False
    z = x_drug[:, ccle_selected]

    z_dim = z.shape[1]
    x = np.expand_dims(x, axis=1).astype(np.float64)
    y = np.expand_dims(y, axis=1).astype(np.float64)
    n = y.shape[0]

    z = (z - z.min()) / (z.max() - z.min())
    x = (x - x.min()) / (x.max() - x.min())
    y = (y - y.min()) / (y.max() - y.min())

    x, y, z = torch.from_numpy(x).float(), torch.from_numpy(y).float(), torch.from_numpy(z).float()

    w_mx = torch.linalg.vector_norm(z.repeat(n,1,1) - torch.swapaxes(z.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_w_train = torch.median(w_mx).item()

    u_mx = torch.linalg.vector_norm(y.repeat(n,1,1) - torch.swapaxes(y.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_u_train = torch.median(u_mx).item()

    v_mx = torch.linalg.vector_norm(x.repeat(n,1,1) - torch.swapaxes(x.repeat(n,1,1), 0, 1), ord = 1, dim = 2)
    sigma_v_train = 1.0 # torch.median(v_mx).item()

    test_size = math.floor(n/k)
    stat_all = torch.zeros(k, 1)
    boot_temp_all = torch.zeros(k, boot_num)
    cur_k = 0


    for k_fold in range(k):
        torch.manual_seed(set_seed)
        k_fold_start = int(n/k * k_fold)
        k_fold_end = int(n/k * (k_fold+1))

        X_test, Y_test, Z_test = x[k_fold_start:k_fold_end], y[k_fold_start:k_fold_end], z[k_fold_start:k_fold_end]
        X_train, Y_train, Z_train = torch.cat((x[0:k_fold_start], x[k_fold_end:])), torch.cat((y[0:k_fold_start], y[k_fold_end:])), torch.cat((z[0:k_fold_start], z[k_fold_end:]))


        train_yz = DatasetSelect_GAN_ver2(Y_train, Z_train, batch_size)
        DataLoader_yz = torch.utils.data.DataLoader(train_yz, batch_size=batch_size, shuffle=True)

        train_xz = DatasetSelect_GAN_ver2(X_train, Z_train, batch_size)
        DataLoader_xz = torch.utils.data.DataLoader(train_xz, batch_size=batch_size, shuffle=True)

        if verbose:
            print("Training Y")

        G_zy = train_label(Z = Z_train, Y = Y_train, rounded = False,
                  noise_dimension = noise_dimension, noise_type = "normal", G_lr = G_zy_lr, hidden_layer_size = 512,
                  DataLoader = DataLoader_yz, BN_type = True, ReLU_coef = 0.5,
                  epochs_num=epcoh_y,  sigma_z = sigma_w_train, sigma_y = sigma_u_train,
                  lambda_1 = 1, wgt_decay = 1e-5,
                  lambda_3 = 1e-5, drop_out_p = 0.2, M_train = 20, verbose = verbose)

        if verbose:
            print("Training X")

        G_zx = train_label(Z = Z_train, Y = X_train, rounded = False,
                  noise_dimension = noise_dimension, noise_type = "normal", G_lr = G_zx_lr, hidden_layer_size = 512,
                  DataLoader = DataLoader_xz, BN_type = True, ReLU_coef = 0.5,
                  epochs_num=epcoh_x,  sigma_z = sigma_w_train, sigma_y = sigma_v_train,
                  lambda_1 = 1, wgt_decay = 1e-5,
                  lambda_3 = 1e-5, drop_out_p = 0.2, M_train = 20, verbose = verbose)

        G_zx = G_zx.eval()
        G_zy = G_zy.eval()

        Z_test, X_test, Y_test = Z_test[0:(test_size),:], X_test[0:(test_size),:], Y_test[0:(test_size),:]
        Z_test_repeat = Z_test.repeat(M,1).to(device)


        # Generate fake data
        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension, "normal", input_var = 1.0/3.0).to(device)
        gen_y_all = G_zy(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device)

        Noise_fake = sample_noise(Z_test_repeat.shape[0], noise_dimension, "normal", input_var = 1.0/3.0).to(device)
        gen_x_all = G_zx(torch.cat((Z_test_repeat,Noise_fake),dim=1)).to(device)

        gen_x_all = gen_x_all.reshape(test_size, M).detach().to(device)
        gen_y_all = gen_y_all.reshape(test_size, M).detach().to(device)

        gen_x_all = (gen_x_all> 0.9).float()

        z_all = Z_test.to(device)
        x_all = X_test.to(device)
        y_all = Y_test.to(device)


        w_mx = torch.linalg.vector_norm(z_all.repeat(test_size,1,1) - torch.swapaxes(z_all.repeat(test_size,1,1), 0, 1), ord = 1, dim = 2)
        sigma_w = torch.median(w_mx).item()

        u_mx = torch.abs(y_all.repeat(1, test_size) - y_all.repeat(1, test_size).T)
        sigma_u = torch.median(u_mx).item()

        v_mx = torch.abs(x_all.repeat(1, test_size) - x_all.repeat(1, test_size).T)
        sigma_v = 1.0 # torch.median(v_mx).item()

        cur_stat, cur_boot_temp = get_p_value_stat_1(boot_num, M, test_size, gen_x_all.to(device), gen_y_all.to(device),
                                x_all.to(device), y_all.to(device), z_all.to(device), sigma_w, sigma_u, sigma_v,
                                boor_rv_type)
        stat_all[cur_k,:] = cur_stat
        boot_temp_all[cur_k,:] = torch.from_numpy(cur_boot_temp)
        cur_k = cur_k + 1

    p_val = np.mean(torch.mean(boot_temp_all, dim = 0).numpy() > torch.mean(stat_all).item() )
    return p_val


In [None]:
# ccle_index_input = 9
ccle_threshold_input = 0.05
batch_size = 256
noise_dimension = 1
G_zy_lr = 2e-5
G_zx_lr = 2e-3
k = 2
M = 100
boot_num = 1000
boor_rv_type = "gaussian"

for ccle_index_input in range(10):

    p_val = mgan_ccle(ccle_index_input = ccle_index_input, ccle_threshold_input = ccle_threshold_input, batch_size = batch_size,
                noise_dimension = noise_dimension, G_zy_lr = G_zy_lr, G_zx_lr = G_zx_lr, k = k, M = M, boot_num = boot_num,
                boor_rv_type = boor_rv_type, verbose = True, epcoh_y = 5000, epcoh_x = 5000, set_seed = 42)

    print(f'ccle_index [{ccle_index_input}], p_value [{p_val}]')

Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.0348055399954319] MMD_loss [54.37702560424805]


 20%|██        | 1005/5000 [00:29<03:16, 20.38it/s]

Epoch [1000/5000], MSE [0.057580478489398956] MMD_loss [-0.11196421831846237]


 40%|████      | 2006/5000 [01:00<02:39, 18.74it/s]

Epoch [2000/5000], MSE [0.05623731017112732] MMD_loss [-0.11982018500566483]


 60%|██████    | 3006/5000 [01:28<02:03, 16.13it/s]

Epoch [3000/5000], MSE [0.05635655298829079] MMD_loss [-0.11369892954826355]


 80%|████████  | 4006/5000 [01:57<00:49, 20.25it/s]

Epoch [4000/5000], MSE [0.05596894025802612] MMD_loss [-0.10913387686014175]


100%|██████████| 5000/5000 [02:26<00:00, 34.04it/s]

Epoch [5000/5000], MSE [0.05595039948821068] MMD_loss [-0.11129146814346313]
Training X





Epoch [0],MSE [0.25608760118484497] MMD_loss [61.05925750732422]


 20%|██        | 1004/5000 [00:28<03:13, 20.64it/s]

Epoch [1000/5000], MSE [0.11056926846504211] MMD_loss [0.02696552313864231]


 40%|████      | 2004/5000 [00:57<02:51, 17.48it/s]

Epoch [2000/5000], MSE [0.11390918493270874] MMD_loss [0.025274302810430527]


 60%|██████    | 3004/5000 [01:26<01:49, 18.29it/s]

Epoch [3000/5000], MSE [0.11392353475093842] MMD_loss [0.025170927867293358]


 80%|████████  | 4004/5000 [01:55<00:49, 19.96it/s]

Epoch [4000/5000], MSE [0.11813188344240189] MMD_loss [0.02819659374654293]


100%|██████████| 5000/5000 [02:24<00:00, 34.65it/s]

Epoch [5000/5000], MSE [0.11392389237880707] MMD_loss [0.025673365220427513]





Training Y
Epoch [0],MSE [0.03841855749487877] MMD_loss [55.8660888671875]


 20%|██        | 1004/5000 [00:28<03:27, 19.26it/s]

Epoch [1000/5000], MSE [0.07050642371177673] MMD_loss [-0.10503526031970978]


 40%|████      | 2004/5000 [00:57<02:52, 17.39it/s]

Epoch [2000/5000], MSE [0.06999997049570084] MMD_loss [-0.10716208070516586]


 60%|██████    | 3004/5000 [01:26<01:44, 19.18it/s]

Epoch [3000/5000], MSE [0.0698835477232933] MMD_loss [-0.11110168695449829]


 80%|████████  | 4004/5000 [01:55<00:57, 17.42it/s]

Epoch [4000/5000], MSE [0.06958652287721634] MMD_loss [-0.10536497086286545]


100%|██████████| 5000/5000 [02:24<00:00, 34.60it/s]

Epoch [5000/5000], MSE [0.06927219033241272] MMD_loss [-0.11159779131412506]
Training X





Epoch [0],MSE [0.2563707232475281] MMD_loss [61.088294982910156]


 20%|██        | 1004/5000 [00:28<03:13, 20.65it/s]

Epoch [1000/5000], MSE [0.12658236920833588] MMD_loss [0.02967488020658493]


 40%|████      | 2004/5000 [00:57<02:36, 19.18it/s]

Epoch [2000/5000], MSE [0.12658272683620453] MMD_loss [0.02944585122168064]


 60%|██████    | 3004/5000 [01:26<01:52, 17.68it/s]

Epoch [3000/5000], MSE [0.13079404830932617] MMD_loss [0.03094000555574894]


 80%|████████  | 4004/5000 [01:55<00:48, 20.51it/s]

Epoch [4000/5000], MSE [0.13080157339572906] MMD_loss [0.030878227204084396]


100%|██████████| 5000/5000 [02:23<00:00, 34.87it/s]

Epoch [5000/5000], MSE [0.13080157339572906] MMD_loss [0.030841954052448273]





ccle_index [0], p_value [0.0]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03480512648820877] MMD_loss [54.371891021728516]


 20%|██        | 1004/5000 [00:29<03:27, 19.29it/s]

Epoch [1000/5000], MSE [0.056835103780031204] MMD_loss [-0.10850682854652405]


 40%|████      | 2004/5000 [00:57<02:28, 20.20it/s]

Epoch [2000/5000], MSE [0.05573630705475807] MMD_loss [-0.11916789412498474]


 60%|██████    | 3004/5000 [01:26<01:36, 20.71it/s]

Epoch [3000/5000], MSE [0.055870626121759415] MMD_loss [-0.1129566878080368]


 80%|████████  | 4004/5000 [01:54<00:55, 17.94it/s]

Epoch [4000/5000], MSE [0.055722493678331375] MMD_loss [-0.10728370398283005]


100%|██████████| 5000/5000 [02:23<00:00, 34.88it/s]

Epoch [5000/5000], MSE [0.05611429736018181] MMD_loss [-0.11060234904289246]
Training X





Epoch [0],MSE [0.2562321126461029] MMD_loss [61.270233154296875]


 20%|██        | 1004/5000 [00:29<03:13, 20.63it/s]

Epoch [1000/5000], MSE [0.10548434406518936] MMD_loss [0.02823631465435028]


 40%|████      | 2004/5000 [00:57<02:26, 20.41it/s]

Epoch [2000/5000], MSE [0.10548478364944458] MMD_loss [0.027812613174319267]


 60%|██████    | 3004/5000 [01:26<01:55, 17.35it/s]

Epoch [3000/5000], MSE [0.11392363905906677] MMD_loss [0.028277738019824028]


 80%|████████  | 4004/5000 [01:54<00:51, 19.35it/s]

Epoch [4000/5000], MSE [0.11392363905906677] MMD_loss [0.028290795162320137]


100%|██████████| 5000/5000 [02:23<00:00, 34.77it/s]

Epoch [5000/5000], MSE [0.1223626434803009] MMD_loss [0.035406336188316345]





Training Y
Epoch [0],MSE [0.03841973841190338] MMD_loss [55.91128921508789]


 20%|██        | 1004/5000 [00:28<03:11, 20.83it/s]

Epoch [1000/5000], MSE [0.07065459340810776] MMD_loss [-0.11277399957180023]


 40%|████      | 2004/5000 [00:57<02:58, 16.81it/s]

Epoch [2000/5000], MSE [0.06985653936862946] MMD_loss [-0.11672922968864441]


 60%|██████    | 3004/5000 [01:25<01:42, 19.38it/s]

Epoch [3000/5000], MSE [0.06923563778400421] MMD_loss [-0.12071378529071808]


 80%|████████  | 4004/5000 [01:54<00:49, 20.11it/s]

Epoch [4000/5000], MSE [0.06837701797485352] MMD_loss [-0.11205480247735977]


100%|██████████| 5000/5000 [02:23<00:00, 34.83it/s]

Epoch [5000/5000], MSE [0.06855543702840805] MMD_loss [-0.11809282749891281]
Training X





Epoch [0],MSE [0.2561459243297577] MMD_loss [59.77839660644531]


 20%|██        | 1004/5000 [00:28<03:38, 18.28it/s]

Epoch [1000/5000], MSE [0.16033349931240082] MMD_loss [0.07928232848644257]


 40%|████      | 2004/5000 [00:56<02:23, 20.86it/s]

Epoch [2000/5000], MSE [0.16033490002155304] MMD_loss [0.07866454869508743]


 60%|██████    | 3004/5000 [01:25<01:38, 20.19it/s]

Epoch [3000/5000], MSE [0.16033709049224854] MMD_loss [0.07848178595304489]


 80%|████████  | 4004/5000 [01:54<00:48, 20.57it/s]

Epoch [4000/5000], MSE [0.16033723950386047] MMD_loss [0.07844037562608719]


100%|██████████| 5000/5000 [02:22<00:00, 34.97it/s]

Epoch [5000/5000], MSE [0.1603371948003769] MMD_loss [0.0784270167350769]





ccle_index [1], p_value [0.0]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03483430668711662] MMD_loss [56.54957580566406]


 20%|██        | 1004/5000 [00:28<03:13, 20.65it/s]

Epoch [1000/5000], MSE [0.05527324229478836] MMD_loss [-0.09965790808200836]


 40%|████      | 2004/5000 [00:57<02:59, 16.70it/s]

Epoch [2000/5000], MSE [0.05393935739994049] MMD_loss [-0.11445501446723938]


 60%|██████    | 3004/5000 [01:26<01:38, 20.34it/s]

Epoch [3000/5000], MSE [0.05380670726299286] MMD_loss [-0.10910539329051971]


 80%|████████  | 4004/5000 [01:54<00:47, 20.82it/s]

Epoch [4000/5000], MSE [0.0538635179400444] MMD_loss [-0.101631760597229]


100%|██████████| 5000/5000 [02:22<00:00, 35.04it/s]

Epoch [5000/5000], MSE [0.05402238294482231] MMD_loss [-0.10843472182750702]
Training X





Epoch [0],MSE [0.2573137581348419] MMD_loss [67.51190185546875]


 20%|██        | 1004/5000 [00:29<04:06, 16.19it/s]

Epoch [1000/5000], MSE [0.04210715740919113] MMD_loss [-0.005066727753728628]


 40%|████      | 2004/5000 [00:57<02:24, 20.68it/s]

Epoch [2000/5000], MSE [0.04214689880609512] MMD_loss [-0.006222615949809551]


 60%|██████    | 3004/5000 [01:26<01:36, 20.64it/s]

Epoch [3000/5000], MSE [0.042066995054483414] MMD_loss [-0.006397203542292118]


 80%|████████  | 4004/5000 [01:54<00:52, 19.11it/s]

Epoch [4000/5000], MSE [0.042174432426691055] MMD_loss [-0.006443167105317116]


100%|██████████| 5000/5000 [02:23<00:00, 34.74it/s]

Epoch [5000/5000], MSE [0.04218631610274315] MMD_loss [-0.006474025547504425]





Training Y
Epoch [0],MSE [0.038239266723394394] MMD_loss [58.3486213684082]


 20%|██        | 1004/5000 [00:28<03:14, 20.57it/s]

Epoch [1000/5000], MSE [0.06302927434444427] MMD_loss [-0.09375850856304169]


 40%|████      | 2004/5000 [00:56<02:24, 20.74it/s]

Epoch [2000/5000], MSE [0.06288537383079529] MMD_loss [-0.09739542752504349]


 60%|██████    | 3004/5000 [01:25<01:55, 17.30it/s]

Epoch [3000/5000], MSE [0.06266417354345322] MMD_loss [-0.09702224284410477]


 80%|████████  | 4007/5000 [01:54<00:48, 20.39it/s]

Epoch [4000/5000], MSE [0.061928343027830124] MMD_loss [-0.09281174093484879]


100%|██████████| 5000/5000 [02:22<00:00, 35.09it/s]

Epoch [5000/5000], MSE [0.06150650605559349] MMD_loss [-0.10507982224225998]
Training X





Epoch [0],MSE [0.2569647431373596] MMD_loss [66.93236541748047]


 20%|██        | 1004/5000 [00:28<03:52, 17.19it/s]

Epoch [1000/5000], MSE [0.06651630997657776] MMD_loss [-0.015687771141529083]


 40%|████      | 2004/5000 [00:57<03:00, 16.57it/s]

Epoch [2000/5000], MSE [0.06749975681304932] MMD_loss [-0.01708054170012474]


 60%|██████    | 3004/5000 [01:26<01:39, 20.09it/s]

Epoch [3000/5000], MSE [0.06748786568641663] MMD_loss [-0.017243480309844017]


 80%|████████  | 4004/5000 [01:54<00:48, 20.68it/s]

Epoch [4000/5000], MSE [0.06747454404830933] MMD_loss [-0.017267530784010887]


100%|██████████| 5000/5000 [02:22<00:00, 35.01it/s]

Epoch [5000/5000], MSE [0.0674208253622055] MMD_loss [-0.017275145277380943]





ccle_index [2], p_value [0.032]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03477438539266586] MMD_loss [54.24443817138672]


 20%|██        | 1004/5000 [00:29<03:16, 20.31it/s]

Epoch [1000/5000], MSE [0.05820157378911972] MMD_loss [-0.09684641659259796]


 40%|████      | 2004/5000 [00:57<02:45, 18.07it/s]

Epoch [2000/5000], MSE [0.05652119591832161] MMD_loss [-0.10703373700380325]


 60%|██████    | 3004/5000 [01:26<01:36, 20.72it/s]

Epoch [3000/5000], MSE [0.05656930431723595] MMD_loss [-0.10338924080133438]


 80%|████████  | 4004/5000 [01:54<00:48, 20.40it/s]

Epoch [4000/5000], MSE [0.056432854384183884] MMD_loss [-0.10418227314949036]


100%|██████████| 5000/5000 [02:23<00:00, 34.96it/s]

Epoch [5000/5000], MSE [0.05683358758687973] MMD_loss [-0.09937529265880585]
Training X





Epoch [0],MSE [0.2563624083995819] MMD_loss [63.06306457519531]


 20%|██        | 1004/5000 [00:29<03:44, 17.80it/s]

Epoch [1000/5000], MSE [0.05907093361020088] MMD_loss [-0.00835834164172411]


 40%|████      | 2004/5000 [00:57<02:24, 20.76it/s]

Epoch [2000/5000], MSE [0.05907130986452103] MMD_loss [-0.008896993473172188]


 60%|██████    | 3004/5000 [01:25<01:36, 20.63it/s]

Epoch [3000/5000], MSE [0.05907144024968147] MMD_loss [-0.009052261710166931]


 80%|████████  | 4004/5000 [01:54<00:52, 19.00it/s]

Epoch [4000/5000], MSE [0.059071481227874756] MMD_loss [-0.009108096361160278]


100%|██████████| 5000/5000 [02:23<00:00, 34.94it/s]

Epoch [5000/5000], MSE [0.059071436524391174] MMD_loss [-0.009103567339479923]





Training Y
Epoch [0],MSE [0.03828943520784378] MMD_loss [55.825321197509766]


 20%|██        | 1004/5000 [00:28<03:17, 20.19it/s]

Epoch [1000/5000], MSE [0.06679098308086395] MMD_loss [-0.08807038515806198]


 40%|████      | 2004/5000 [00:57<02:26, 20.45it/s]

Epoch [2000/5000], MSE [0.0657791942358017] MMD_loss [-0.09127803146839142]


 60%|██████    | 3004/5000 [01:25<01:37, 20.56it/s]

Epoch [3000/5000], MSE [0.06511480361223221] MMD_loss [-0.0941038578748703]


 80%|████████  | 4004/5000 [01:54<00:54, 18.32it/s]

Epoch [4000/5000], MSE [0.06426629424095154] MMD_loss [-0.08861224353313446]


100%|██████████| 5000/5000 [02:23<00:00, 34.93it/s]

Epoch [5000/5000], MSE [0.0640016496181488] MMD_loss [-0.09719129651784897]
Training X





Epoch [0],MSE [0.2567974627017975] MMD_loss [63.74665069580078]


 20%|██        | 1004/5000 [00:28<03:17, 20.18it/s]

Epoch [1000/5000], MSE [0.31645023822784424] MMD_loss [5.3491435050964355]


 40%|████      | 2004/5000 [00:57<02:26, 20.49it/s]

Epoch [2000/5000], MSE [0.07172981649637222] MMD_loss [0.0729629248380661]


 60%|██████    | 3004/5000 [01:25<01:38, 20.28it/s]

Epoch [3000/5000], MSE [0.0717298835515976] MMD_loss [0.07279478013515472]


 80%|████████  | 4004/5000 [01:54<00:49, 19.96it/s]

Epoch [4000/5000], MSE [0.0717298835515976] MMD_loss [0.07275433838367462]


100%|██████████| 5000/5000 [02:22<00:00, 34.99it/s]

Epoch [5000/5000], MSE [0.07172992825508118] MMD_loss [0.07270742952823639]





ccle_index [3], p_value [0.001]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03475587069988251] MMD_loss [56.594566345214844]


 20%|██        | 1004/5000 [00:28<03:13, 20.70it/s]

Epoch [1000/5000], MSE [0.056769490242004395] MMD_loss [-0.10341040045022964]


 40%|████      | 2006/5000 [00:57<02:25, 20.52it/s]

Epoch [2000/5000], MSE [0.0555301234126091] MMD_loss [-0.11130604892969131]


 60%|██████    | 3006/5000 [01:26<01:39, 20.13it/s]

Epoch [3000/5000], MSE [0.05546475574374199] MMD_loss [-0.11270763725042343]


 80%|████████  | 4006/5000 [01:55<00:54, 18.12it/s]

Epoch [4000/5000], MSE [0.05517425760626793] MMD_loss [-0.10854162275791168]


100%|██████████| 5000/5000 [02:23<00:00, 34.92it/s]

Epoch [5000/5000], MSE [0.05565483123064041] MMD_loss [-0.1111067607998848]
Training X





Epoch [0],MSE [0.2569239139556885] MMD_loss [68.60952758789062]


 20%|██        | 1004/5000 [00:28<03:55, 16.97it/s]

Epoch [1000/5000], MSE [0.016880350187420845] MMD_loss [2.435390342725441e-05]


 40%|████      | 2004/5000 [00:57<02:48, 17.82it/s]

Epoch [2000/5000], MSE [0.01687755435705185] MMD_loss [-0.0007734830141998827]


 60%|██████    | 3004/5000 [01:26<01:38, 20.27it/s]

Epoch [3000/5000], MSE [0.016877608373761177] MMD_loss [-0.0009465775219723582]


 80%|████████  | 4004/5000 [01:54<00:47, 20.87it/s]

Epoch [4000/5000], MSE [0.016877619549632072] MMD_loss [-0.0009786597220227122]


100%|██████████| 5000/5000 [02:22<00:00, 35.00it/s]

Epoch [5000/5000], MSE [0.016877610236406326] MMD_loss [-0.0009940849849954247]





Training Y
Epoch [0],MSE [0.038256824016571045] MMD_loss [58.3293571472168]


 20%|██        | 1004/5000 [00:29<03:56, 16.92it/s]

Epoch [1000/5000], MSE [0.06305590271949768] MMD_loss [-0.09144295752048492]


 40%|████      | 2004/5000 [00:58<02:25, 20.54it/s]

Epoch [2000/5000], MSE [0.06271130591630936] MMD_loss [-0.09134626388549805]


 60%|██████    | 3004/5000 [01:26<01:44, 19.02it/s]

Epoch [3000/5000], MSE [0.061960287392139435] MMD_loss [-0.10005171597003937]


 80%|████████  | 4004/5000 [01:54<00:49, 20.02it/s]

Epoch [4000/5000], MSE [0.061115484684705734] MMD_loss [-0.09275500476360321]


100%|██████████| 5000/5000 [02:24<00:00, 34.65it/s]

Epoch [5000/5000], MSE [0.060925111174583435] MMD_loss [-0.09828713536262512]
Training X





Epoch [0],MSE [0.2569269835948944] MMD_loss [68.02643585205078]


 20%|██        | 1004/5000 [00:28<03:20, 19.91it/s]

Epoch [1000/5000], MSE [0.08016836643218994] MMD_loss [0.46068382263183594]


 40%|████      | 2004/5000 [00:57<02:25, 20.59it/s]

Epoch [2000/5000], MSE [0.05063243210315704] MMD_loss [-0.006014969199895859]


 60%|██████    | 3004/5000 [01:25<01:36, 20.67it/s]

Epoch [3000/5000], MSE [0.050632696598768234] MMD_loss [-0.006313106510788202]


 80%|████████  | 4004/5000 [01:54<00:58, 16.94it/s]

Epoch [4000/5000], MSE [0.050632741302251816] MMD_loss [-0.006404436659067869]


100%|██████████| 5000/5000 [02:23<00:00, 34.94it/s]

Epoch [5000/5000], MSE [0.050632789731025696] MMD_loss [-0.006487613543868065]





ccle_index [4], p_value [0.048]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03468881919980049] MMD_loss [56.6180419921875]


 20%|██        | 1004/5000 [00:28<03:42, 17.99it/s]

Epoch [1000/5000], MSE [0.05797193571925163] MMD_loss [-0.1008518859744072]


 40%|████      | 2004/5000 [00:57<02:26, 20.48it/s]

Epoch [2000/5000], MSE [0.05593005567789078] MMD_loss [-0.11248696595430374]


 60%|██████    | 3004/5000 [01:26<01:36, 20.62it/s]

Epoch [3000/5000], MSE [0.055587880313396454] MMD_loss [-0.1124226525425911]


 80%|████████  | 4004/5000 [01:54<00:48, 20.53it/s]

Epoch [4000/5000], MSE [0.05543160066008568] MMD_loss [-0.1055239662528038]


100%|██████████| 5000/5000 [02:22<00:00, 34.98it/s]

Epoch [5000/5000], MSE [0.055540792644023895] MMD_loss [-0.10855050384998322]
Training X





Epoch [0],MSE [0.2573345899581909] MMD_loss [66.94583129882812]


 20%|██        | 1004/5000 [00:29<03:16, 20.31it/s]

Epoch [1000/5000], MSE [0.16877619922161102] MMD_loss [1.9216294288635254]


 40%|████      | 2004/5000 [00:57<02:24, 20.69it/s]

Epoch [2000/5000], MSE [0.042193811386823654] MMD_loss [-0.007159647066146135]


 60%|██████    | 3004/5000 [01:25<01:36, 20.58it/s]

Epoch [3000/5000], MSE [0.04219396412372589] MMD_loss [-0.007334080990403891]


 80%|████████  | 4004/5000 [01:54<01:01, 16.26it/s]

Epoch [4000/5000], MSE [0.04219401627779007] MMD_loss [-0.0074211712926626205]


100%|██████████| 5000/5000 [02:23<00:00, 34.82it/s]

Epoch [5000/5000], MSE [0.04219389706850052] MMD_loss [-0.00744617311283946]





Training Y
Epoch [0],MSE [0.03817135840654373] MMD_loss [58.36129379272461]


 20%|██        | 1004/5000 [00:28<03:12, 20.78it/s]

Epoch [1000/5000], MSE [0.06297954171895981] MMD_loss [-0.09328000992536545]


 40%|████      | 2004/5000 [00:56<02:35, 19.22it/s]

Epoch [2000/5000], MSE [0.06192744895815849] MMD_loss [-0.09053927659988403]


 60%|██████    | 3004/5000 [01:26<01:42, 19.44it/s]

Epoch [3000/5000], MSE [0.06172812357544899] MMD_loss [-0.10241606086492538]


 80%|████████  | 4004/5000 [01:54<00:48, 20.40it/s]

Epoch [4000/5000], MSE [0.06072806194424629] MMD_loss [-0.09330160915851593]


100%|██████████| 5000/5000 [02:22<00:00, 35.01it/s]

Epoch [5000/5000], MSE [0.06033165007829666] MMD_loss [-0.09893419593572617]
Training X





Epoch [0],MSE [0.2573607861995697] MMD_loss [68.16181945800781]


 20%|██        | 1004/5000 [00:28<03:12, 20.76it/s]

Epoch [1000/5000], MSE [0.2658226788043976] MMD_loss [5.954656600952148]


 40%|████      | 2004/5000 [00:57<02:57, 16.87it/s]

Epoch [2000/5000], MSE [0.04641024395823479] MMD_loss [-0.009244777262210846]


 60%|██████    | 3007/5000 [01:26<01:40, 19.86it/s]

Epoch [3000/5000], MSE [0.04640018567442894] MMD_loss [-0.009409982711076736]


 80%|████████  | 4007/5000 [01:55<00:48, 20.53it/s]

Epoch [4000/5000], MSE [0.04641182720661163] MMD_loss [-0.00948723778128624]


100%|██████████| 5000/5000 [02:23<00:00, 34.84it/s]

Epoch [5000/5000], MSE [0.04640815407037735] MMD_loss [-0.009523613378405571]





ccle_index [5], p_value [0.001]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03480531647801399] MMD_loss [56.656551361083984]


 20%|██        | 1004/5000 [00:28<03:14, 20.52it/s]

Epoch [1000/5000], MSE [0.05536044389009476] MMD_loss [-0.10133311152458191]


 40%|████      | 2004/5000 [00:56<02:41, 18.61it/s]

Epoch [2000/5000], MSE [0.05388732999563217] MMD_loss [-0.11167823523283005]


 60%|██████    | 3004/5000 [01:26<01:40, 19.96it/s]

Epoch [3000/5000], MSE [0.05375286191701889] MMD_loss [-0.10856033861637115]


 80%|████████  | 4004/5000 [01:54<00:48, 20.55it/s]

Epoch [4000/5000], MSE [0.05369408056139946] MMD_loss [-0.0991986095905304]


100%|██████████| 5000/5000 [02:22<00:00, 35.02it/s]

Epoch [5000/5000], MSE [0.054169222712516785] MMD_loss [-0.1033748909831047]
Training X





Epoch [0],MSE [0.2569020092487335] MMD_loss [67.70861053466797]


 20%|██        | 1004/5000 [00:29<03:41, 18.01it/s]

Epoch [1000/5000], MSE [0.04641296714544296] MMD_loss [0.010963695123791695]


 40%|████      | 2004/5000 [00:57<02:25, 20.55it/s]

Epoch [2000/5000], MSE [0.0464133694767952] MMD_loss [0.010374296456575394]


 60%|██████    | 3007/5000 [01:26<01:36, 20.63it/s]

Epoch [3000/5000], MSE [0.04641340672969818] MMD_loss [0.010262394323945045]


 80%|████████  | 4007/5000 [01:55<00:49, 20.08it/s]

Epoch [4000/5000], MSE [0.04641343280673027] MMD_loss [0.010216055437922478]


100%|██████████| 5000/5000 [02:23<00:00, 34.75it/s]

Epoch [5000/5000], MSE [0.046413443982601166] MMD_loss [0.010206058621406555]





Training Y
Epoch [0],MSE [0.03827346861362457] MMD_loss [58.16093444824219]


 20%|██        | 1004/5000 [00:28<03:12, 20.75it/s]

Epoch [1000/5000], MSE [0.061898842453956604] MMD_loss [-0.09737042337656021]


 40%|████      | 2004/5000 [00:56<02:25, 20.53it/s]

Epoch [2000/5000], MSE [0.06104040518403053] MMD_loss [-0.09643231332302094]


 60%|██████    | 3004/5000 [01:25<01:38, 20.37it/s]

Epoch [3000/5000], MSE [0.06044083461165428] MMD_loss [-0.10052194446325302]


 80%|████████  | 4004/5000 [01:54<00:55, 17.99it/s]

Epoch [4000/5000], MSE [0.05984504148364067] MMD_loss [-0.09422612935304642]


100%|██████████| 5000/5000 [02:22<00:00, 34.99it/s]

Epoch [5000/5000], MSE [0.05994515120983124] MMD_loss [-0.10076786577701569]
Training X





Epoch [0],MSE [0.25683557987213135] MMD_loss [68.10478210449219]


 20%|██        | 1004/5000 [00:28<03:13, 20.60it/s]

Epoch [1000/5000], MSE [0.2911382019519806] MMD_loss [6.405304431915283]


 40%|████      | 2005/5000 [00:57<02:24, 20.70it/s]

Epoch [2000/5000], MSE [0.03797449916601181] MMD_loss [-0.007650856859982014]


 60%|██████    | 3005/5000 [01:25<01:52, 17.78it/s]

Epoch [3000/5000], MSE [0.0379745177924633] MMD_loss [-0.007972919382154942]


 80%|████████  | 4005/5000 [01:54<00:47, 20.77it/s]

Epoch [4000/5000], MSE [0.03797464072704315] MMD_loss [-0.008047353476285934]


100%|██████████| 5000/5000 [02:22<00:00, 34.99it/s]

Epoch [5000/5000], MSE [0.03797461465001106] MMD_loss [-0.008114240132272243]





ccle_index [6], p_value [0.0]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.034750521183013916] MMD_loss [56.69744873046875]


 20%|██        | 1004/5000 [00:28<03:30, 18.98it/s]

Epoch [1000/5000], MSE [0.05691926181316376] MMD_loss [-0.10314657539129257]


 40%|████      | 2004/5000 [00:57<02:29, 19.98it/s]

Epoch [2000/5000], MSE [0.05520586296916008] MMD_loss [-0.108623206615448]


 60%|██████    | 3004/5000 [01:26<01:35, 20.79it/s]

Epoch [3000/5000], MSE [0.05501437559723854] MMD_loss [-0.10901818424463272]


 80%|████████  | 4004/5000 [01:54<00:53, 18.49it/s]

Epoch [4000/5000], MSE [0.055062513798475266] MMD_loss [-0.10610201209783554]


100%|██████████| 5000/5000 [02:23<00:00, 34.80it/s]

Epoch [5000/5000], MSE [0.05544496327638626] MMD_loss [-0.10651513934135437]
Training X





Epoch [0],MSE [0.25727561116218567] MMD_loss [66.30259704589844]


 20%|██        | 1004/5000 [00:28<03:14, 20.59it/s]

Epoch [1000/5000], MSE [0.16033734381198883] MMD_loss [1.4103009700775146]


 40%|████      | 2004/5000 [00:56<02:27, 20.30it/s]

Epoch [2000/5000], MSE [0.06328967958688736] MMD_loss [-0.015928860753774643]


 60%|██████    | 3004/5000 [01:25<02:10, 15.30it/s]

Epoch [3000/5000], MSE [0.06328132748603821] MMD_loss [-0.01627163030207157]


 80%|████████  | 4005/5000 [01:55<00:49, 20.19it/s]

Epoch [4000/5000], MSE [0.06329088658094406] MMD_loss [-0.016462991014122963]


100%|██████████| 5000/5000 [02:23<00:00, 34.86it/s]

Epoch [5000/5000], MSE [0.0632908046245575] MMD_loss [-0.016515880823135376]





Training Y
Epoch [0],MSE [0.03821859881281853] MMD_loss [58.300045013427734]


 20%|██        | 1004/5000 [00:28<03:12, 20.75it/s]

Epoch [1000/5000], MSE [0.06389980763196945] MMD_loss [-0.09269346296787262]


 40%|████      | 2004/5000 [00:57<02:53, 17.25it/s]

Epoch [2000/5000], MSE [0.06314024329185486] MMD_loss [-0.09312210232019424]


 60%|██████    | 3004/5000 [01:26<01:38, 20.25it/s]

Epoch [3000/5000], MSE [0.06322168558835983] MMD_loss [-0.09912716597318649]


 80%|████████  | 4004/5000 [01:54<00:48, 20.68it/s]

Epoch [4000/5000], MSE [0.06272998452186584] MMD_loss [-0.09160436689853668]


100%|██████████| 5000/5000 [02:23<00:00, 34.84it/s]

Epoch [5000/5000], MSE [0.06189712509512901] MMD_loss [-0.1021689921617508]
Training X





Epoch [0],MSE [0.2572990655899048] MMD_loss [68.99493408203125]


 20%|██        | 1004/5000 [00:29<03:54, 17.05it/s]

Epoch [1000/5000], MSE [0.046412739902734756] MMD_loss [0.008062307722866535]


 40%|████      | 2005/5000 [00:57<02:25, 20.54it/s]

Epoch [2000/5000], MSE [0.04641326144337654] MMD_loss [0.006931172218173742]


 60%|██████    | 3005/5000 [01:25<01:36, 20.78it/s]

Epoch [3000/5000], MSE [0.0464133620262146] MMD_loss [0.006776230875402689]


 80%|████████  | 4005/5000 [01:54<00:52, 18.91it/s]

Epoch [4000/5000], MSE [0.04641338065266609] MMD_loss [0.00668730866163969]


100%|██████████| 5000/5000 [02:23<00:00, 34.87it/s]

Epoch [5000/5000], MSE [0.04641338065266609] MMD_loss [0.006667293608188629]





ccle_index [7], p_value [0.0]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03477036580443382] MMD_loss [56.616600036621094]


 20%|██        | 1004/5000 [00:28<03:13, 20.65it/s]

Epoch [1000/5000], MSE [0.05699048191308975] MMD_loss [-0.1050431877374649]


 40%|████      | 2004/5000 [00:57<02:48, 17.75it/s]

Epoch [2000/5000], MSE [0.05570747330784798] MMD_loss [-0.10806725174188614]


 60%|██████    | 3004/5000 [01:26<01:41, 19.74it/s]

Epoch [3000/5000], MSE [0.0555347204208374] MMD_loss [-0.1088612973690033]


 80%|████████  | 4004/5000 [01:54<00:49, 20.29it/s]

Epoch [4000/5000], MSE [0.05547724664211273] MMD_loss [-0.10566779971122742]


100%|██████████| 5000/5000 [02:23<00:00, 34.91it/s]

Epoch [5000/5000], MSE [0.05580579489469528] MMD_loss [-0.10482333600521088]
Training X





Epoch [0],MSE [0.256813108921051] MMD_loss [68.27696990966797]


 20%|██        | 1004/5000 [00:28<03:52, 17.16it/s]

Epoch [1000/5000], MSE [0.033755093812942505] MMD_loss [0.0458625890314579]


 40%|████      | 2004/5000 [00:57<02:29, 19.97it/s]

Epoch [2000/5000], MSE [0.03375523164868355] MMD_loss [0.045276496559381485]


 60%|██████    | 3004/5000 [01:26<01:38, 20.37it/s]

Epoch [3000/5000], MSE [0.03375522047281265] MMD_loss [0.045193955302238464]


 80%|████████  | 4004/5000 [01:55<00:48, 20.66it/s]

Epoch [4000/5000], MSE [0.03375523164868355] MMD_loss [0.04518427327275276]


100%|██████████| 5000/5000 [02:23<00:00, 34.85it/s]

Epoch [5000/5000], MSE [0.03375524282455444] MMD_loss [0.04517602548003197]





Training Y
Epoch [0],MSE [0.03829937055706978] MMD_loss [58.30783462524414]


 20%|██        | 1004/5000 [00:28<03:18, 20.10it/s]

Epoch [1000/5000], MSE [0.06767496466636658] MMD_loss [-0.09242306649684906]


 40%|████      | 2004/5000 [00:57<02:27, 20.32it/s]

Epoch [2000/5000], MSE [0.06711553782224655] MMD_loss [-0.09270232915878296]


 60%|██████    | 3004/5000 [01:26<01:36, 20.65it/s]

Epoch [3000/5000], MSE [0.06652805209159851] MMD_loss [-0.10074973106384277]


 80%|████████  | 4004/5000 [01:54<00:58, 16.97it/s]

Epoch [4000/5000], MSE [0.06551089137792587] MMD_loss [-0.09128040820360184]


100%|██████████| 5000/5000 [02:23<00:00, 34.85it/s]

Epoch [5000/5000], MSE [0.06580416858196259] MMD_loss [-0.10166516900062561]
Training X





Epoch [0],MSE [0.2569393813610077] MMD_loss [68.01393127441406]


 20%|██        | 1004/5000 [00:28<03:15, 20.40it/s]

Epoch [1000/5000], MSE [0.2742614150047302] MMD_loss [6.121667861938477]


 40%|████      | 2004/5000 [00:57<02:24, 20.67it/s]

Epoch [2000/5000], MSE [0.04218969866633415] MMD_loss [-0.007485846523195505]


 60%|██████    | 3004/5000 [01:25<01:56, 17.09it/s]

Epoch [3000/5000], MSE [0.04219381883740425] MMD_loss [-0.00799330323934555]


 80%|████████  | 4004/5000 [01:54<00:48, 20.64it/s]

Epoch [4000/5000], MSE [0.042193807661533356] MMD_loss [-0.008152108639478683]


100%|██████████| 5000/5000 [02:23<00:00, 34.93it/s]

Epoch [5000/5000], MSE [0.04219392314553261] MMD_loss [-0.008175318129360676]





ccle_index [8], p_value [0.013]
Genetic features dimension = 1638 on 1030 cancer cells
PLX4720: Cell number = 474
466 (1638,) [0.00793665 0.03243895 0.0438247  ... 0.04631409 0.51928232 0.07857439]
Training Y
Epoch [0],MSE [0.03484245762228966] MMD_loss [56.49945831298828]


 20%|██        | 1004/5000 [00:28<03:15, 20.45it/s]

Epoch [1000/5000], MSE [0.05636601895093918] MMD_loss [-0.0946066677570343]


 40%|████      | 2004/5000 [00:57<02:26, 20.47it/s]

Epoch [2000/5000], MSE [0.05450903996825218] MMD_loss [-0.1111038401722908]


 60%|██████    | 3004/5000 [01:25<01:38, 20.25it/s]

Epoch [3000/5000], MSE [0.05439252778887749] MMD_loss [-0.10680484026670456]


 80%|████████  | 4004/5000 [01:54<00:50, 19.66it/s]

Epoch [4000/5000], MSE [0.05421801283955574] MMD_loss [-0.10286134481430054]


100%|██████████| 5000/5000 [02:23<00:00, 34.91it/s]

Epoch [5000/5000], MSE [0.05446042865514755] MMD_loss [-0.1052364930510521]
Training X





Epoch [0],MSE [0.25718599557876587] MMD_loss [67.8177261352539]


 20%|██        | 1004/5000 [00:28<03:16, 20.33it/s]

Epoch [1000/5000], MSE [0.14345954358577728] MMD_loss [1.2057331800460815]


 40%|████      | 2004/5000 [00:57<02:30, 19.85it/s]

Epoch [2000/5000], MSE [0.0379745289683342] MMD_loss [0.0010199362877756357]


 60%|██████    | 3004/5000 [01:25<01:39, 20.16it/s]

Epoch [3000/5000], MSE [0.037974677979946136] MMD_loss [0.0007909052073955536]


 80%|████████  | 4004/5000 [01:54<00:47, 20.80it/s]

Epoch [4000/5000], MSE [0.03797463700175285] MMD_loss [0.0007138365181162953]


100%|██████████| 5000/5000 [02:22<00:00, 35.10it/s]

Epoch [5000/5000], MSE [0.03797494247555733] MMD_loss [0.0006946197827346623]





Training Y
Epoch [0],MSE [0.038276609033346176] MMD_loss [58.35866928100586]


 20%|██        | 1004/5000 [00:29<03:39, 18.19it/s]

Epoch [1000/5000], MSE [0.0623449832201004] MMD_loss [-0.09160654246807098]


 40%|████      | 2004/5000 [00:57<02:27, 20.32it/s]

Epoch [2000/5000], MSE [0.06216253712773323] MMD_loss [-0.08945527672767639]


 60%|██████    | 3004/5000 [01:26<01:46, 18.79it/s]

Epoch [3000/5000], MSE [0.062168266624212265] MMD_loss [-0.09759453684091568]


 80%|████████  | 4004/5000 [01:54<00:48, 20.42it/s]

Epoch [4000/5000], MSE [0.061456095427274704] MMD_loss [-0.0929020345211029]


100%|██████████| 5000/5000 [02:23<00:00, 34.83it/s]

Epoch [5000/5000], MSE [0.06099855154752731] MMD_loss [-0.09899315237998962]
Training X





Epoch [0],MSE [0.25683143734931946] MMD_loss [66.2386703491211]


 20%|██        | 1004/5000 [00:28<03:13, 20.70it/s]

Epoch [1000/5000], MSE [0.10126423090696335] MMD_loss [0.03367064148187637]


 40%|████      | 2004/5000 [00:57<02:40, 18.64it/s]

Epoch [2000/5000], MSE [0.10126522928476334] MMD_loss [0.03310330584645271]


 60%|██████    | 3004/5000 [01:26<01:37, 20.50it/s]

Epoch [3000/5000], MSE [0.10126543045043945] MMD_loss [0.03292929381132126]


 80%|████████  | 4004/5000 [01:55<00:59, 16.65it/s]

Epoch [4000/5000], MSE [0.10126549750566483] MMD_loss [0.03290821611881256]


100%|██████████| 5000/5000 [02:23<00:00, 34.84it/s]

Epoch [5000/5000], MSE [0.10126557946205139] MMD_loss [0.03287840634584427]





ccle_index [9], p_value [0.013]
