<a href="https://colab.research.google.com/github/Vlasovets/MB-GAN/blob/master/MB_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! git clone https://github.com/zhanxw/MB-GAN.git

Cloning into 'MB-GAN'...
remote: Enumerating objects: 164, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 164 (delta 15), reused 8 (delta 0), pack-reused 130[K
Receiving objects: 100% (164/164), 250.10 MiB | 29.33 MiB/s, done.
Resolving deltas: 100% (54/54), done.
Checking out files: 100% (64/64), done.


In [2]:
%cd MB-GAN/

/content/MB-GAN


In [20]:
!python utils.py

In [21]:
import numpy as np
import pandas as pd
from keras import backend as K
from keras.models import load_model
from keras.layers import Layer
import pickle
from scipy.stats import describe
from utils import shannon_entropy, get_sparsity

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

SEED = 256
TOL = 1e-4

In [22]:
class PhyloTransform(Layer):
    def __init__(self, transform, **kwargs):
        self.output_dim = transform.shape[1:]
        self.kernel = K.constant(transform, dtype='float32')
        super(PhyloTransform, self).__init__(**kwargs)

    def call(self, x):
        return K.dot(x, self.kernel)
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], ) + self.output_dim


def predict(generator, n_samples=1000, transform=None, seed=None):
    np.random.seed(seed)
    latent_dim = generator.inputs[0].shape[-1]
    z = np.random.normal(0, 1, (n_samples, latent_dim))
    res = generator.predict(z)
    if transform is not None:
        res = transform(res)
    
    return res

In [25]:
## Load raw dataset
raw_data = pickle.load(open("data/raw_data.pkl", 'rb'))
dataset = raw_data.iloc[:,1:].values/100.
labels = raw_data["group"].values
taxa_list = raw_data.columns[1:]
data_o_case = dataset[labels == 'case']
data_o_ctrl = dataset[labels == 'ctrl']

In [28]:
## Generate data
GENERATOR_CASE_PATH = os.path.join('models', 'stool_2_case_generator.h5')
GENERATOR_CTRL_PATH = os.path.join('models', 'stool_2_ctrl_generator.h5')
generator_case = load_model(GENERATOR_CASE_PATH)  #, custom_objects={'PhyloTransform': PhyloTransform})
generator_ctrl = load_model(GENERATOR_CTRL_PATH)  #, custom_objects={'PhyloTransform': PhyloTransform})
data_g_case = predict(generator_case, n_samples=1000, seed=SEED)
data_g_ctrl = predict(generator_ctrl, n_samples=1000, seed=SEED)



In [29]:
## Show data statistics
print("Sparsity")
display(pd.DataFrame(
    [describe(get_sparsity(data_o_ctrl, TOL)),
     describe(get_sparsity(data_g_ctrl, TOL)),
     describe(get_sparsity(data_o_case, TOL)),
     describe(get_sparsity(data_g_case, TOL)),], 
    index=['Original ctrl', 'GAN ctrl', 'Original case', 'GAN case']))
print("Shannon Entropy")
display(pd.DataFrame(
    [describe(shannon_entropy(data_o_ctrl)),
     describe(shannon_entropy(data_g_ctrl)),
     describe(shannon_entropy(data_o_case)),
     describe(shannon_entropy(data_g_case)),], 
    index=['Original ctrl', 'GAN ctrl', 'Original case', 'GAN case']))

Sparsity


Unnamed: 0,nobs,minmax,mean,variance,skewness,kurtosis
Original ctrl,248,"(0.8191933240611962, 0.9457579972183588)",0.892122,0.000268,-0.093724,1.696552
GAN ctrl,1000,"(0.7858136300417247, 0.9624478442280946)",0.872316,0.000642,0.042301,0.522135
Original case,148,"(0.8219749652294854, 0.9429763560500696)",0.887813,0.000424,0.309246,0.427984
GAN case,1000,"(0.7649513212795549, 0.9707927677329624)",0.866573,0.00124,0.059911,-0.264223


Shannon Entropy


  return -np.sum(np.where(x > tol, x * np.log(x), 0), axis=-1)
  return -np.sum(np.where(x > tol, x * np.log(x), 0), axis=-1)


Unnamed: 0,nobs,minmax,mean,variance,skewness,kurtosis
Original ctrl,248,"(1.4803902227728476, 3.8584069736393483)",2.97272,0.156112,-1.144235,1.845062
GAN ctrl,1000,"(0.67656565, 3.8749592)",2.972613,0.185524,-1.010242,1.546127
Original case,148,"(1.7056071992903896, 3.8471567050900752)",3.077932,0.17017,-1.038462,0.986683
GAN case,1000,"(0.91057855, 3.9654443)",3.084653,0.217414,-1.323819,2.07317


In [31]:
## Save simlated data
data_g_case

array([[2.1831600e-03, 4.3323366e-06, 1.9520120e-01, ..., 1.0500869e-09,
        9.8865605e-10, 2.6257508e-08],
       [4.5983076e-01, 1.8790088e-10, 1.8271598e-08, ..., 4.2099310e-13,
        9.0345765e-13, 2.5478007e-12],
       [3.3938554e-07, 1.6933550e-04, 4.9903840e-02, ..., 1.3689702e-11,
        3.7363584e-12, 1.7654847e-10],
       ...,
       [4.7019008e-12, 6.3298846e-09, 2.6364378e-03, ..., 8.7805333e-13,
        2.2206721e-13, 1.5106872e-11],
       [8.5995406e-09, 1.5247123e-05, 9.3870014e-03, ..., 6.1989268e-11,
        8.8506008e-12, 6.1393440e-10],
       [1.9499138e-11, 1.6391055e-13, 2.5953893e-06, ..., 2.7525822e-13,
        1.2412178e-14, 1.7378946e-12]], dtype=float32)