In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from numpy import genfromtxt
import matplotlib.pyplot as plt
import time
import os
import numpy as np
import pandas as pd
import seaborn as sns
from numpy import array
from tensorflow.keras import layers
from tensorflow import keras
from sklearn.preprocessing import StandardScaler
from scipy import stats
from sklearn.manifold import TSNE
from sklearn import metrics
from sklearn.cluster import AffinityPropagation

In [None]:
tf.__version__

In [None]:
use_gpu = True 
if use_gpu :
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))
    os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    

# Agumentation

In [None]:
def gene_filtering(genepath):    

    # Ensembl ID
    gnames = np.array(pd.read_csv(genepath))[:,0]
    
    # Gene symbol : gene name
    glists = np.array(pd.read_csv(genepath))[:,1]

    return gnames, glists

In [None]:
def extract_filtered_gene(datapath, gnames):
    
    #Load dataset
    rld = np.genfromtxt(datapath, delimiter = ',', dtype = 'str')
    label, gene, value = rld[0][1:], rld[1:,0], rld[1:,1:].T
    label = np.array([x.replace('"', '') for x in label])
    gene = np.array([x.replace('"', '') for x in gene])
    value = value.astype(np.float32)

    #Extract filtered genes
    eidx = []
    for g in range(len(gnames)):
        idx = np.where(gnames[g] == gene)[0][0]
        eidx.append(idx)

    evalue, egene = value[:, eidx], gnames
    return evalue, egene, label

In [None]:
def data_augmentation(evalue, egene, elabel, save=True):

    aug_values = []
    aug_labels = []    
    sp_size = 5    
    n_age = 2
    n_treatment = 2
        
    for ti in range(n_age): #time
        
        for tr in range(n_treatment): #treatment
            #
            for s in range(sp_size): #samples
                s1 = ti*10+tr*5+s
                for s_ in range(s, sp_size):
                    s2 = ti*10+tr*5+s_
                    tmp1, tmp2 = evalue[s1], evalue[s2]
                    if(s==s_): #rawdata
                        aug_values.append(tmp1)
                        aug_labels.append(elabel[s1])
                    else: #augmentation
                        for r in range(9):
                            augx = tmp1*(r+1)*0.1+tmp2*(9-r)*0.1
                            aug_values.append(augx)
                            aug_labels.append(elabel[s1]+'+'+elabel[s2]+'+'+str(r+1))
                            
    aug_values, aug_labels = np.array(aug_values), np.array(aug_labels),

    #save
    if save==True:
        np.savez('augmented_input_Tau_Union_LFC03_3767G_re.npz', genelist=egene, values=aug_values, 
                 labels=aug_labels)
        print ('Augmented input is saved!')
    
    return [egene, aug_values, aug_labels]

In [None]:
#Rerun Augmentation?
re_data_augmentation = True
is_training=True
is_interpolation = True

# 1. Preprocess dataset

In [None]:
##Gene filtering and augmentation



if re_data_augmentation : 
    gene_path = 'Tau_Union_rld_LFC03_remove_all_case_3767genes.csv'
    ex_gnames, ex_glists = gene_filtering(gene_path)
    ex_value, ex_gene, ex_label = extract_filtered_gene(gene_path, ex_gnames)
    augmented_data = data_augmentation(ex_value, ex_glists, ex_label, True)
    input_ = 'augmented_input_Tau_Union_LFC03_3767G_re.npz'
else : 
    input_ = 'augmented_input_Tau_Union_LFC03_3767G.npz'
    
dat = np.load(input_,allow_pickle=True)
rlds, genelist, augsample = dat['values'], dat['genelist'], dat['labels']
realidx = [0, 95, 190, 285]

print (rlds.shape)

In [None]:
# Label Indexing

WT3M = np.array([0,37,65,84,94])
AD3M = np.add(WT3M, 95)
WT6M = np.add(WT3M, 2*95)
AD6M = np.add(WT3M, 3*95)
Aug_WT3M = np.arange(95)
Aug_AD3M = np.add(Aug_WT3M,95)
Aug_WT6M = np.add(Aug_WT3M,2*95)
Aug_AD6M = np.add(Aug_WT3M,3*95)

# 2. Rescaling(normalization)

In [None]:
cond_list = [Aug_WT3M, Aug_WT6M, Aug_AD3M, Aug_AD6M]

rld_mean = np.average(rlds, axis=0)
rld_median = np.median(rlds, axis=0)
rld_std = np.zeros((len(cond_list),rlds.shape[1]))

for c in range(len(cond_list)):
    rld_std[c] = np.std(rlds[cond_list[c]], axis=0)

max_rld_std = np.max(rld_std, axis=0)
avg_rld_std = np.average(rld_std, axis=0)
all_rld_std = np.std(rlds,axis=0)

In [None]:
##IQR filter 

num_q1, num_q3 = 25, 75 

all_q1 = stats.scoreatpercentile(all_rld_std, num_q1)
max_q1 = stats.scoreatpercentile(max_rld_std, num_q1)

all_q3 = stats.scoreatpercentile(all_rld_std, num_q3)
max_q3 = stats.scoreatpercentile(max_rld_std, num_q3)

all_rld_std_filter = np.where(all_rld_std<all_q1,all_q1,all_rld_std)
all_rld_std_filter_f = np.where(all_rld_std>all_q3,all_q3,all_rld_std_filter)

max_rld_std_filter = np.where(max_rld_std<max_q1,max_q1,max_rld_std) 
max_rld_std_filter_f=np.where(max_rld_std>max_q3,max_q3,max_rld_std_filter) 

two_rld_std_f = []
two_rld_std_f.append(all_rld_std_filter_f)
two_rld_std_f.append(max_rld_std_filter_f)
two_rld_std_f = np.array(two_rld_std_f)

gmean_rld_std_f = stats.gmean(two_rld_std_f,axis=0)

re_rld_f = (rlds-rld_mean)/gmean_rld_std_f

plt.figure()
plt.hist(re_rld_f.flatten(), bins=1000)
plt.xlim(-5, 5)
plt.plot()  

std_re_rld_f = np.std(re_rld_f.flatten())

### Within 95% 
re_rld_norm_f = re_rld_f/(2*1.959*std_re_rld_f)+0.5

rld = re_rld_norm_f

plt.figure()
plt.hist(re_rld_norm_f.flatten(), bins=1000)
plt.xlim(-0.1, 1.1)
plt.plot()  

# 3. Training networks setting

In [None]:
#split training and test dataset

nrld = rld.flatten()
te_ratio = 0.1

x_idx = np.arange(len(rld))
te_idx = np.random.choice(len(rld), int(len(rld)*te_ratio), False)
tr_idx = np.setdiff1d(x_idx, te_idx)
xtr, xte = rld[tr_idx], rld[te_idx]
tr_max = np.max(xtr)
n_tr, n_te = len(xtr), len(xte)

print ('xtr:', xtr.shape, ', xte:', xte.shape)
n, p = xtr.shape
print(n ,p)

In [None]:
print(rld)

In [None]:
#Hyperparameter setting

lr = 1e-5                        # Learning rate
EPOCHS = 300000                  # Epochs to train
critic = 5                       # Critic updates per generator update 
gradient_penalty_weight = 10.0
BUFFER_SIZE = 20
BATCH_SIZE = 32

# Number of generator hidden layer units & Number of discriminator hidden layer units
gen_size, disc_size = 450, 270
rld = rld.astype('float32')
rld_max = np.max(rld)
train_dataset = tf.data.Dataset.from_tensor_slices(rld).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
# optimizer setting

generator_optimizer = tf.keras.optimizers.Adam(1e-5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)

In [None]:
# Setting the training data directory

log_dir="training_data/"
sub_path = log_dir + "Tau_union_LFC03_G3767_G450D270_300k/" + "20210309_1"

if is_training :
    # Set the directory to save the training data 
    sub_path = log_dir + "Tau_union_LFC03_G3767_G450D270_300k/" + "yours_dir"
    summary_writer = tf.summary.create_file_writer(sub_path)

In [None]:
noise_dim = 100
nrld = rld.flatten()

def get_noise(batch_size):
    seed = nrld[np.random.randint(0,len(nrld),size=(batch_size, noise_dim))]
    seed = tf.convert_to_tensor(seed)
    return seed

idx_ = np.random.choice(n_tr, BATCH_SIZE, replace=False)

In [None]:
gen_init=0.3
def make_generator_model():
    model = keras.Sequential([
        layers.Dense(gen_size, input_shape=(noise_dim, ), 
                     kernel_initializer=tf.random_uniform_initializer(-gen_init,gen_init)),
        layers.LeakyReLU(),
        layers.Dense(gen_size, kernel_initializer=tf.random_uniform_initializer(-gen_init,gen_init)),
        layers.LeakyReLU(),
        layers.Dense(p, kernel_initializer=tf.random_uniform_initializer(-gen_init,gen_init))
    ])
    return model

generator = make_generator_model()
generator.summary()

In [None]:
noise = get_noise(1)
generated_rlds = generator(noise, training=False)

In [None]:
def make_discriminator_model():
    model = keras.Sequential([
        layers.Dense(disc_size, input_shape=(p, )),
        layers.LeakyReLU(),
        layers.Dense(disc_size),
        layers.LeakyReLU(),
        layers.Dense(1)
    ])   
    return model

discriminator = make_discriminator_model()
discriminator.summary()

In [None]:
decision = discriminator(generated_rlds)

In [None]:
checkpoint_dir = sub_path + "/training_checkpoints"
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

print(checkpoint_dir)

In [None]:
@tf.function
def train_gen(rlds, noise):
    #batch_size = rlds.get_shape().as_list()[0]
    
    with tf.GradientTape() as t:
        #noise = get_noise(batch_size)
        generated_rlds = generator(noise, training=True)
        fake_output = discriminator(generated_rlds, training=True)
        
        loss = generator_loss(fake_output)
    gradients = t.gradient(loss, generator.trainable_variables)
   
    generator_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    return loss

In [None]:
@tf.function
def train_disc(rlds, noise):
    #batch_size = rlds.get_shape().as_list()[0]
    
        
    with tf.GradientTape() as t:
        
        generated_rlds = generator(noise, training=True)
        fake_output = discriminator(generated_rlds, training=True)
        real_output = discriminator(rlds, training=True)
        
        gp = gradient_penalty(rlds, generated_rlds)
        loss = discriminator_loss(fake_output, real_output) + gp
        gradients = t.gradient(loss, discriminator.trainable_variables)
           
    discriminator_optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))
    return loss 

In [None]:
def generator_loss(fake_output):
    fake_loss = -tf.reduce_mean(fake_output)
    return fake_loss

def discriminator_loss(fake_output, real_output):
    fake_loss = tf.reduce_mean(fake_output)
    real_loss = tf.reduce_mean(real_output)
    return fake_loss - real_loss

In [None]:
# gradient penalty 
def gradient_penalty(rlds, generated_rlds):
    batch_size = rlds.get_shape().as_list()[0]
    epsilon = tf.random.uniform (shape=[batch_size, 1], minval=0., maxval=rld_max)
    x_hat = (epsilon * rlds) + ((1 - epsilon) * generated_rlds)
    #x_hat = (epsilon * generated_rlds) + ((1 - epsilon) * rlds)
    with tf.GradientTape() as gp:
        gp.watch(x_hat)
        d_hat = discriminator(x_hat, training=True)
    gradients = gp.gradient(d_hat, [x_hat])[0]
    d_gradient = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
    d_regularizer = gradient_penalty_weight * tf.reduce_mean(tf.square(d_gradient-1.))
    return d_regularizer 

In [None]:
if is_training :
    gen_loss_results = []
    disc_loss_results = []
    gen_rlds = []

    start = time.time()
    num_to_generate = 84

    for epoch in range(EPOCHS):
        epoch_gen_loss_avg = tf.keras.metrics.Mean()
        epoch_disc_loss_avg = tf.keras.metrics.Mean()

        for batch in train_dataset:
            noise = get_noise(batch.shape[0])
            for _ in range (critic):
                #train_disc(batch)
                d_loss = train_disc(batch, noise)           
                epoch_disc_loss_avg(d_loss)
            g_loss = train_gen(batch, noise)
            epoch_gen_loss_avg(g_loss)
            #train_gen(batch)

        gen_loss_results.append(epoch_gen_loss_avg.result())
        disc_loss_results.append(epoch_disc_loss_avg.result())

        if (epoch + 1) % 500 == 0:
            checkpoint_prefix = checkpoint_dir + "/cp-" + str(epoch+1) + '.ckpt'
            checkpoint.save(file_prefix = checkpoint_prefix) 
            print("Saving checkpoint for epoch{} at {}".format(epoch+1, checkpoint_prefix))

            predictions = generator(get_noise(num_to_generate), training=False)
            gen_rlds.append(predictions.numpy())
            print("Epoch {:03d}: Gen_Loss: {:.3f}, Disc_Loss: {:.3f}". format(epoch+1, 
                                                                      epoch_gen_loss_avg.result(),
                                                                      epoch_disc_loss_avg.result()))

        #epoch_gen_loss_avg.reset_states()
        #epoch_disc_loss_avg.reset_states()

        with summary_writer.as_default():
            tf.summary.scalar('gen_loss', epoch_gen_loss_avg.result(), step=epoch)
            tf.summary.scalar('disc_loss', epoch_disc_loss_avg.result(), step=epoch)

    elapsed_time = time.time() - start
    hours, rem = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(rem, 60)
    print ("Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))

    predictions = generator(get_noise(num_to_generate), training=False)
    gen_rlds.append(predictions.numpy())
    gen_x_84 = np.array(gen_rlds)
    np.savez(sub_path+'/Loss_GenX84.npz', gen_x_84=gen_x_84,disc_loss=disc_loss_results,gen_loss=gen_loss_results) 

In [None]:
Reload_LossGenX = np.load(sub_path+'/Loss_GenX84.npz')
Rel_gen_x_84 = Reload_LossGenX['gen_x_84'] 
Rel_disc_loss = Reload_LossGenX['disc_loss']
Rel_gen_loss = Reload_LossGenX['gen_loss']
print(Rel_gen_x_84.shape, np.shape(Rel_disc_loss), np.shape(Rel_gen_loss))

In [None]:
gen_x, disc_loss_results, gen_loss_results = Rel_gen_x_84, Rel_disc_loss, Rel_gen_loss

In [None]:
# plot training loss graph
plt.figure(figsize=(7,5))
plt.plot(disc_loss_results, c='blue', label='discriminator')
plt.plot(gen_loss_results, c='red', label='generator')
plt.legend(loc=1, fontsize=14, bbox_to_anchor=(1.01,1.01))
plt.ylim(-6.0,10.0)
plt.yticks([-6,-2,0,2,6,10], ['-6', '-2', '0', '2', '6', '10'], fontsize=16)
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Loss', fontsize=16)
plt.xticks(np.arange(0, 350000, 50000), ['0', '50K','100K','150K', '200K', '250K', '300K'], 
           fontsize=16)
plt.plot()

In [None]:
# Plot tSNE
initx = np.concatenate((rld, gen_x[100]), axis=0)
med1x = np.concatenate((rld, gen_x[200]), axis=0) #100K
med2x = np.concatenate((rld, gen_x[400]), axis=0) #200K
endx = np.concatenate((rld, gen_x[600]), axis=0)  #300K
print (initx.shape, med1x.shape, med2x.shape, endx.shape)

tsne = TSNE()
tsne_init = tsne.fit_transform(initx)
tsne_med1 = tsne.fit_transform(med1x)
tsne_med2 = tsne.fit_transform(med2x)
tsne_end = tsne.fit_transform(endx)
print (tsne_init.shape, tsne_med1.shape, tsne_med2.shape, tsne_end.shape)
print (len(rld), len(gen_x[0]))
n_rld = len(rld)
n_gen = len(gen_x)

In [None]:
plt.figure(figsize=(16,3))

plt.subplot(141)
plt.title("50K epoch")
plt.scatter(tsne_init[:n_rld,0], tsne_init[:n_rld,1], alpha=0.5, c='red')
plt.scatter(tsne_init[n_rld:n_rld+n_gen,0], tsne_init[n_rld:n_rld+n_gen,1], alpha=0.5, c='blue')
plt.scatter(tsne_init[WT3M,0], tsne_init[WT3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_init[AD3M,0], tsne_init[AD3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_init[WT6M,0], tsne_init[WT6M,1], alpha=0.7, c='gold')
plt.scatter(tsne_init[AD6M,0], tsne_init[AD6M,1], alpha=0.7, c='gold')
plt.xticks([])
plt.yticks([])

plt.subplot(142)
plt.title("100K epoch")
plt.scatter(tsne_med1[:n_rld,0], tsne_med1[:n_rld,1], alpha=0.5, c='red')
plt.scatter(tsne_med1[n_rld:n_rld+n_gen,0], tsne_med1[n_rld:n_rld+n_gen,1], alpha=0.5, c='blue')
plt.scatter(tsne_med1[WT3M,0], tsne_med1[WT3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med1[AD3M,0], tsne_med1[AD3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med1[WT6M,0], tsne_med1[WT6M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med1[AD6M,0], tsne_med1[AD6M,1], alpha=0.7, c='gold')
plt.xticks([])
plt.yticks([])

plt.subplot(143)
plt.title("200K epoch")
plt.scatter(tsne_med2[:n_rld,0], tsne_med2[:n_rld,1], alpha=0.5, c='red')
plt.scatter(tsne_med2[n_rld:n_rld+n_gen,0], tsne_med2[n_rld:n_rld+n_gen,1], alpha=0.5, c='blue')
plt.scatter(tsne_med2[WT3M,0], tsne_med2[WT3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med2[AD3M,0], tsne_med2[AD3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med2[WT6M,0], tsne_med2[WT6M,1], alpha=0.7, c='gold')
plt.scatter(tsne_med2[AD6M,0], tsne_med2[AD6M,1], alpha=0.7, c='gold')
plt.xticks([])
plt.yticks([])

plt.subplot(144)
plt.title("300K epoch")
plt.scatter(tsne_end[:n_rld,0], tsne_end[:n_rld,1], alpha=0.5, c='red')
plt.scatter(tsne_end[n_rld:n_rld+n_gen,0], tsne_end[n_rld:n_rld+n_gen,1], alpha=0.5, c='blue')
plt.scatter(tsne_end[WT3M,0], tsne_end[WT3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_end[AD3M,0], tsne_end[AD3M,1], alpha=0.7, c='gold')
plt.scatter(tsne_end[WT6M,0], tsne_end[WT6M,1], alpha=0.7, c='gold')
plt.scatter(tsne_end[AD6M,0], tsne_end[AD6M,1], alpha=0.7, c='gold')
plt.legend(['Training','Generated', 'Original Real'], ncol=3, bbox_to_anchor=(-.22, -.2), fontsize=14)
plt.xticks([])
plt.yticks([])

In [None]:
n_noise = 10000
  
Fake10000z = get_noise(n_noise)
Fake10000 = generator(Fake10000z, training=False)
print(generator.summary())

In [None]:
#if is_training :
#generate augx for all samples
#caluclate Corr(fake_x, xte)
fake_real_corr = np.zeros((n_rld,))
Fakeallz, Fakeallx = np.zeros((n_rld,Fake10000z.shape[1])), np.zeros((n_rld,Fake10000.shape[1]))
print (fake_real_corr.shape)

genaugx = []
for i in range(n_rld):
    scorr_ary = np.zeros((n_noise,))
    tmp1 = rld[i]
    for j in range(n_noise):
        tmp2 = Fake10000[j]
        scorr_ary[j] = stats.pearsonr(tmp1, tmp2)[0]
    smaxidx = np.argsort(scorr_ary)[-10:]

    Fakeallz[i] = np.average(tf.gather(Fake10000z, smaxidx), axis=0)
    Fakeallx[i] = generator(np.reshape(Fakeallz[i], (1, 100)), training=False)
    fake_real_corr[i] = stats.pearsonr(tmp1, Fakeallx[i])[0]
    genaugx.append(Fakeallx[i])
genaugx = np.array(genaugx)
genaugx = genaugx.reshape((rld.shape))
print(genaugx.shape)
#save value
np.savez(sub_path+'/FakeAll_atEnd.npz', fake_real_corr=fake_real_corr,Fakeallx=Fakeallx,genaugx=genaugx) 

In [None]:
Reload_FakeA_atE = np.load(sub_path+'/FakeAll_atEnd.npz')
Rel_fake_real_corr, Rel_Fakeallx, Rel_genaugx = Reload_FakeA_atE['fake_real_corr'],Reload_FakeA_atE['Fakeallx'],Reload_FakeA_atE['genaugx']
print(Rel_fake_real_corr.shape, np.shape(Rel_Fakeallx), np.shape(Rel_genaugx))

In [None]:
fake_real_corr, Fakeallx, genaugx = Rel_fake_real_corr, Rel_Fakeallx, Rel_genaugx

In [None]:
print(np.min(fake_real_corr,axis=0))
real_min = np.argmin(fake_real_corr)
plt.scatter(Fakeallx[real_min],rld[real_min])
plt.plot([0,1],color='red')
plt.plot()

In [None]:
print(np.max(fake_real_corr,axis=0))
real_max = np.argmax(fake_real_corr)
plt.scatter(Fakeallx[real_max],rld[real_max])
plt.plot([0,1],color='red')
plt.plot()

In [None]:
def plot_hist_re_rld(real, fake, bins):
    plt.figure(figsize=(20,7))
    plt.suptitle('Histogram of rescaled RLD', fontsize=36, y=1.03)
    plt.subplot(121)
    plt.hist(real.flatten(), bins=bins)
    plt.xlim(-0.1, 1.1)
    plt.ylim(0, 10000)
    plt.xticks(fontsize=34)
    plt.yticks([])
    plt.xlabel('Rescaled RLD', fontsize=35)
    plt.ylabel('Counts', fontsize=35)
    plt.title('Real', fontsize=35)

    plt.subplot(122)
    plt.subplots_adjust(wspace=0.05)
    plt.hist(fake.flatten(), bins=bins, color='red')
    plt.xlim(-0.1, 1.1)
    plt.ylim(0, 10000)
    plt.xticks(fontsize=34)
    plt.yticks([])
    plt.xlabel('Rescaled RLD', fontsize=35)
    plt.title('Generated', fontsize=35)
    publish_save_dir = 'publish_data/'
    file_name = 'histogram_RLD.jpeg'

    plt.savefig(publish_save_dir+file_name, bbox_inches='tight',  pad_inches=0, dpi=300)

In [None]:
numBin = int((np.max(rld.flatten())-np.min(rld.flatten()))/0.005)
plot_hist_re_rld(rld, genaugx, 3000)

In [None]:
#Validate performance of network by comparing distribution of real and fake

real_corr = np.zeros((rld.shape[0], rld.shape[0]))
for i in range(rld.shape[0]):
    for j in range(rld.shape[0]):
        tmp1, tmp2 = rld[i], rld[j]
        real_corr[i,j] = stats.pearsonr(tmp1, tmp2)[0]

aug_corr = np.zeros((rld.shape[0], rld.shape[0]))
for i in range(len(aug_corr)):
    for j in range(len(aug_corr)):
        tmp1, tmp2 = rld[i], genaugx[j]
        aug_corr[i,j] = stats.pearsonr(tmp1,tmp2)[0]

In [None]:
def plot_hist_corr(real, fake, bins=100):
    plt.figure(figsize=(20,7))
    plt.suptitle('Histogram of correlation coefficient', fontsize=36, y=1.03)
    plt.subplot(1,2,1)
    plt.hist(real.flatten(), bins=bins)
    plt.xlim(-1,1)
    plt.xticks([-1, -0.5, 0.0, 0.5,  1], fontsize=24)
    plt.yticks([])
    plt.xlabel('Correlation', fontsize=35)
    plt.ylabel('Counts', fontsize=35)
    plt.title('Real vs. Real', fontsize=35)

    plt.subplot(1,2,2)
    plt.subplots_adjust(wspace=0.12)
    plt.hist(fake.flatten(), bins=bins, color='red')
    plt.xlim(-1,1)
    plt.xticks([-1, -0.5, 0.0, 0.5,  1], fontsize=24)
    plt.yticks([])
    plt.xlabel('Correlation', fontsize=35)
    plt.title('Real vs. Generated', fontsize=35)
    plt.plot()

In [None]:
plot_hist_corr(real_corr, aug_corr)

In [None]:
# Label Indexing

WT3M = np.array([0,37,65,84,94])
AD3M = np.add(WT3M, 95)
WT6M = np.add(WT3M, 2*95)
AD6M = np.add(WT3M, 3*95)
Aug_WT3M = np.arange(95)
Aug_AD3M = np.add(Aug_WT3M,95)
Aug_WT6M = np.add(Aug_WT3M,2*95)
Aug_AD6M = np.add(Aug_WT3M,3*95)

OnlyAug_WT3M = [e for e in Aug_WT3M if e not in WT3M ]
OnlyAug_WT6M = [e for e in Aug_WT6M if e not in WT6M ]
OnlyAug_AD3M = [e for e in Aug_AD3M if e not in AD3M ]
OnlyAug_AD6M = [e for e in Aug_AD6M if e not in AD6M ]

x_WT3M = rld[WT3M,:]
x_WT6M = rld[WT6M,:]
x_AD3M = rld[AD3M,:]
x_AD6M = rld[AD6M,:]
x_OnlyAug_WT3M = rld[OnlyAug_WT3M,:]
x_OnlyAug_WT6M = rld[OnlyAug_WT6M,:]
x_OnlyAug_AD3M = rld[OnlyAug_AD3M,:]
x_OnlyAug_AD6M = rld[OnlyAug_AD6M,:]
x_real20=np.concatenate((x_WT3M,x_WT6M,x_AD3M,x_AD6M), axis=0)
x_OnlyAug360=np.concatenate((x_OnlyAug_WT3M,x_OnlyAug_WT6M,x_OnlyAug_AD3M,x_OnlyAug_AD6M),axis=0)
print(x_real20.shape,x_OnlyAug360.shape)

In [None]:
n_noise = 10000
  
Fake10000z = get_noise(n_noise)
Fake10000 = generator(Fake10000z, training=False)
print(Fake10000z.shape, Fake10000.shape)

In [None]:
from scipy import stats
n_real = 20
fake_real_corr = np.zeros((n_real,))
Fake20z, Fake20x = np.zeros((n_real,Fake10000z.shape[1])), np.zeros((n_real,Fake10000.shape[1]))


for i in range(1):
    scorr_ary = np.zeros((n_noise,))
    tmp1 = x_real20[i]
    for j in range(n_noise):
        tmp2 = Fake10000[j]
        scorr_ary[j] = stats.pearsonr(tmp1, tmp2)[0]
    smaxidx = np.argsort(scorr_ary)[-10:]
    Fake20z[i]=np.average(tf.gather(Fake10000z, smaxidx), axis=0)
    Fake20x[i]=generator(np.reshape(Fake20z[i],(1,100)), training=False)
    fake_real_corr[i] = stats.pearsonr(tmp1, Fake20x[i])[0]


In [None]:
epochs = np.arange(500, 300001, 500)
Pheno_index = [WT3M,WT6M,AD3M,AD6M]
Pheno_rld = [x_WT3M,x_WT6M,x_AD3M,x_AD6M]

In [None]:
if is_training :
    AllE_Fake20z, AllE_Fake20x, AllE_Fake20_corr = [], [], []
    for ep in range(len(epochs)):
        ckpath = checkpoint_dir + "/cp-" + str(epochs[ep]) + '.ckpt-' + str(ep+1)
        print("restoring from " + ckpath)
        checkpoint.restore(ckpath)


        noise_tmp = get_noise(n_noise)
        gen_tmp = generator(noise_tmp, training=False)
        print(noise_tmp.shape, gen_tmp.shape)
        OneE_Fake20z, OneE_Fake20x, OneE_Fake20_corr = [], [], []
        for pheno in range (len(Pheno_index)):
            tmpz_rld, tmpx_rld, tmp_corr = [], [], []
            for i in range (len(Pheno_index[pheno])):
                scorr_arr = np.zeros((n_noise,))
                pheno1_rld = Pheno_rld[pheno][i]
                for j in range (n_noise):
                    scorr_arr[j] = stats.pearsonr(pheno1_rld, gen_tmp[j])[0]
                smaxidx = np.argsort(scorr_arr)[-10:]
                tmpz_avg = np.average(tf.gather(noise_tmp, smaxidx), axis=0)
                tmpx_avg = generator(np.reshape(tmpz_avg, (1, 100)), training=False)
                tmpz_rld.append(tmpz_avg)
                tmpx_rld.append(tmpx_avg)
                scorr_avg = stats.pearsonr(pheno1_rld, tmpx_avg[0])[0]
                tmp_corr.append(scorr_avg)

            OneE_Fake20z.append(tmpz_rld)
            OneE_Fake20x.append(tmpx_rld)
            OneE_Fake20_corr.append(tmp_corr)
        AllE_Fake20z.append(OneE_Fake20z)
        AllE_Fake20x.append(OneE_Fake20x)
        AllE_Fake20_corr.append(OneE_Fake20_corr)

    AllE_Fake20z, AllE_Fake20x, AllE_Fake20_corr = np.array(AllE_Fake20z), np.array(AllE_Fake20x), np.array(AllE_Fake20_corr)
    print(AllE_Fake20z.shape, AllE_Fake20x.shape, AllE_Fake20_corr.shape)
    #save value
    np.savez(sub_path+'/AllE_Fake20.npz', AllE_Fake20z=AllE_Fake20z,AllE_Fake20x=AllE_Fake20x,
             AllE_Fake20_corr=AllE_Fake20_corr) 

In [None]:
Reload = np.load(sub_path+'/AllE_Fake20.npz')
ReAllE_Fake20z, ReAllE_Fake20x, ReAllE_Fake20_corr = Reload['AllE_Fake20z'],Reload['AllE_Fake20x'],Reload['AllE_Fake20_corr']
print(ReAllE_Fake20z.shape, ReAllE_Fake20x.shape, ReAllE_Fake20_corr.shape)

In [None]:
suptitle=['[WT3M]','[WT6M]','[AD3M]','[AD6M]']
plt.figure(figsize=(20,28))
for i in range(4):
    plt.subplot(5,1,i+1)
    plt.subplots_adjust(hspace=0.35)
    plt.title(suptitle[i], fontsize=20, loc='left', fontweight= 'bold')
    plt.ylim(0.7,1.0)
    plt.xlim(0,600)
    plt.yticks(fontweight= 'bold',fontsize=24)
    plt.xticks([0,50,100,150,200,250,300,350,400,450,500,550,600],
               ['0','25K','50K','75K','100K','125K','150K','175K','200K','225K','250K','275K','300K'], 
               fontweight= 'bold',fontsize=24)
    plt.grid(alpha=0.3)
    for j in range(4):
        plt.plot(ReAllE_Fake20_corr[:,i,j], linewidth=3.0)

plt.subplot(5,1,5)
plt.title('[Average]', fontsize=20, loc='left', fontweight= 'bold')
plt.ylim(0.7,1.0)
plt.xlim(0,600)
plt.yticks(fontweight= 'bold',fontsize=24)
plt.xticks([0,50,100,150,200,250,300,350,400,450,500,550,600],
           ['0','25K','50K','75K','100K','125K','150K','175K','200K','225K','250K','275K','300K'], 
           fontweight= 'bold',fontsize=24)
plt.grid(alpha=0.3)
plt.plot(np.average(np.average(ReAllE_Fake20_corr,axis=2),axis=1), linewidth=3.0)


In [None]:
#Interpolation setting 
cond_list = [Aug_WT3M,Aug_WT6M,Aug_AD3M,Aug_AD6M]
Latent_inlist = [[cond_list[0], cond_list[1]],[cond_list[2], cond_list[3]],
                 [cond_list[0], cond_list[2]],[cond_list[1], cond_list[3]]]


union_condition = ['WT3MtoWT6M','AD3MtoAD6M','WT3MtoAD3M','WT6MtoAD6M']

start_ep = 225000
ep_ary = np.arange(0,50500,500)+start_ep
start_ckpt = int(start_ep/500)
LatInt_epoch_txt= str(start_ep)[0:3]+'K'+str(start_ep+50000)[0:3]+'K'
print(LatInt_epoch_txt)

In [None]:
#Interpolation processing
if is_interpolation :
    from multiprocessing import Pool

    first_start = time.time()
    #your cpu core setting
    num_cores = (os.cpu_count()*2)-10

    inlist = Latent_inlist
    def work_func(j):

        scorr_ary = stats.pearsonr(rld[s_[s]], tmp_genx_as_array[j])[0]
        ecorr_ary = stats.pearsonr(rld[e_[s]], tmp_genx_as_array[j])[0]

        return [scorr_ary,ecorr_ary]

    for i in range(len(inlist)):
        LatInt_z, LatInt_avgz, LatInt_genx  = [], [], []
        print("Start " + union_condition[i])

        for ep in range(len(ep_ary)):
            #if ep==1: break;
            ckpath = checkpoint_dir + "/cp-" + str(ep_ary[ep]) + '.ckpt-' + str(ep+start_ckpt)
            print("restoring from " + ckpath)
            checkpoint.restore(ckpath)

            tmpz = get_noise(n_noise)
            tmp_genx = generator(tmpz, training=False)
            #Logic to avoid conflicts between TensorFlow variables and parallel processing for parallel processing
            tmp_genx_as_array=np.array(tmp_genx,np.float32)
            gens, gene = [], []

            s_, e_ = inlist[i][0], inlist[i][1]

            tmps, tmpe = [], []
            for s in range(len(s_)):
                #if s==1: break;
                scorr_ary, ecorr_ary = np.zeros((n_noise,)), np.zeros((n_noise,)) 
                #parallel processing
                if __name__ == "__main__":             

                    pool = Pool(num_cores)
                    work_result=np.array(pool.map(work_func,range(10000)))
                    pool.terminate()

                smaxidx = np.argsort(work_result.T[0])[-10:]
                emaxidx = np.argsort(work_result.T[1])[-10:]
                tmps.append(np.average(tf.gather(tmpz, smaxidx),axis=0))
                tmpe.append(np.average(tf.gather(tmpz, emaxidx),axis=0))
            gens.append(tmps)
            gene.append(tmpe)

            gens, gene = np.array(gens), np.array(gene)
       
            delta_z = np.average(gene,axis=1)-np.average(gens,axis=1)
            inter_ = np.linspace(0., 1., num=101)

            LatInt_z.append([gens, gene])
            LatInt_avgz.append(delta_z)

            #generate fake by interpolating latent vector
            genx = []
            for ag in range(gens.shape[0]): #age
                tmpgenx = []
                for aug_sp in range(gens.shape[1]): #aug_samples
                    ttmpgenx = []
                    #interpolate
                    for z_ in inter_: 
                        tmpz = gens[ag][aug_sp] + (delta_z[ag]*z_)
                        tmpx = generator(tmpz.reshape((1,100)), training=False) 
                        ttmpgenx.append(tmpx)
                    tmpgenx.append(ttmpgenx)
                genx.append(tmpgenx)
            genx = np.array(genx)
            genx = genx.reshape((len(delta_z), np.shape(gens)[1],len(inter_), p))

            LatInt_genx.append(genx)

        LatInt_z, LatInt_avgz = np.array(LatInt_z), np.array(LatInt_avgz)
        LatInt_genx = np.array(LatInt_genx) 
      
        file_name = '/LatInt_'+LatInt_epoch_txt+'_'+union_condition[i]+'.npz'
        np.savez(sub_path+file_name, LatInt_genx=LatInt_genx) 

    elapsed_time = time.time() - first_start
    hours, rem = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(rem, 60)
    print ("Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))

In [None]:
genx_A, genx_S=[], []
for i in range(len(union_condition)) :
    #if i==1 : break
    file_name = '/LatInt_'+LatInt_epoch_txt+'_'+union_condition[i]+'.npz'
    Reload_LatInt = np.load(sub_path+file_name)
    Rel_LatInt_genx = Reload_LatInt['LatInt_genx']
    genx_A_ind = np.average(np.average(Rel_LatInt_genx, axis=2), axis=0)
    genx_S_ind = np.std(np.average(Rel_LatInt_genx, axis=2), axis=0)
    genx_A.append(genx_A_ind[0])
    genx_S.append(genx_S_ind[0])
genx_A=np.array(genx_A)
genx_S=np.array(genx_S)
print(genx_A.shape,genx_S.shape)
smaple_list={'WT3M' : WT3M, 'WT6M' : WT6M, 'AD3M' : AD3M,'AD6M' : AD6M}

# 4. 4way Transition curves

In [None]:
genx_A_reserved = np.array([np.flip(genx_A[3],axis=0),np.flip(genx_A[0],axis=0),genx_A[2],genx_A[1]])
genx_A_sw1 = genx_A_reserved.swapaxes(1,2)
genx_A_sw2 = genx_A_sw1.swapaxes(0,1)
genx_A_4way = genx_A_sw2.reshape(3767,-1)
print(union_condition)
print(genx_A_4way)

In [None]:
#APC(Affinity Propagation Clustering)
cluster = AffinityPropagation(random_state=0).fit(genx_A_4way)
cluster_id = pd.DataFrame(cluster.labels_)
cluster_id.columns = ['mapping']

In [None]:
#pattern setting
pattern_marking_data= [[3, 10, 15, 16, 18, 22, 24, 28, 30, 32],
                        [4, 7, 8, 9, 11, 13, 14, 17, 20, 21, 34],
                        [2, 5, 6, 19, 23, 27],
                        [50,51,42,45,48],
                        [54,55],
                        [35, 37, 41, 43, 44],
                        [12,25,29,39],
                        [26,33,36],
                       [0,1,31,38,40,46,47,49,52,53]]

In [None]:
#gene mapping index data 
gene_index_data = pd.DataFrame(data=genelist,columns=['gene_symbol'])
gene_index_data["gene_id"]=np.arange(len(gene_index_data))
gene_DKres_data = pd.read_csv('Tau_Union_res_LFC03_remove_all_case_3767genes.csv', delimiter=',')

gene_DKres_data = gene_DKres_data.rename(columns={'X':'ensemble_id'})

for i in range(len(gene_DKres_data)):
    get_compare_data = gene_DKres_data.iloc[i]
    get_comparison_value = get_compare_data['log2FoldChange']
    get_compare_name = get_compare_data['gene_name']
    get_mapping_id = gene_index_data[gene_index_data['gene_symbol']==get_compare_name].index  
    
gene_index_data

In [None]:
def marking_mapping_plot(genx_A_reshape_cluster, cluster_id, cluster_centers_indices, mapping_pattern, title,
                         marking_pattern_fl, index):
    
    n_clusters_ = len(mapping_pattern)
    s_n_clusters_= len(mapping_pattern)
    height = n_clusters_//11  
    if n_clusters_%11 : height=height+1
    else : height=height+1
    width=100 
    if n_clusters_ < 11 :
        height=height+1
        s_n_clusters_=11
    plt.clf()
    plt.figure(figsize=(width,15))
    print(title)
    for i in range(s_n_clusters_):
        
        plt.subplot((s_n_clusters_//11)+1,11,i+1)
        plt.subplots_adjust(wspace=0.2, hspace=0.2)
        
        #if i>n_clusters_-1 : continue
        if n_clusters_>i :
            num_pattern = marking_pattern_fl.index(mapping_pattern[i])+1
            cluster_mapping_index=np.array(cluster_id[cluster_id['mapping']==mapping_pattern[i]].index)
            plt.title('APC '+str(num_pattern)+'('+str(len(cluster_mapping_index))+')', 
                      fontsize=40, fontweight= 'bold')

        
            #print(cluster_mapping_index)
            for j in range(len(cluster_mapping_index)):

                plt.plot(genx_A_reshape_cluster[cluster_mapping_index[j]], c='gray', alpha=0.5)  

                #plt.plot(genx_A[3][:,cluster_mapping_index[j]], c='gray', alpha=0.5)  

            plt.plot(genx_A_reshape_cluster[cluster_centers_indices[mapping_pattern[i]]], c='red', linewidth=3)

            plt.grid(True, axis='x',linewidth=3)
            
            plt.gca().spines['right'].set_visible(False) #오른쪽 테두리 제거
            plt.gca().spines['top'].set_visible(False) 
            plt.gca().spines['left'].set_linewidth(4) 
            plt.gca().spines['bottom'].set_linewidth(4)
      
            plt.xticks([0,100,202,303,404], ('AD6M', 'WT6M','WT3M','AD3M','AD6M'),fontsize=20,
                       fontweight= 'bold')
            plt.yticks(fontsize=30)
            
            if i==0 : plt.ylabel('Rescaled RLD', fontsize=40,fontweight= 'bold')
        else :
            plt.xticks([])
            plt.yticks([])
            plt.gca().spines['right'].set_visible(False) #오른쪽 테두리 제거
            plt.gca().spines['top'].set_visible(False) 
            plt.gca().spines['left'].set_linewidth(False) 
            plt.gca().spines['bottom'].set_linewidth(False)
    text_title = 'P'+str(index+1)

    plt.show()

In [None]:
pattern_marking_data_flatten=[element for array in pattern_marking_data for element in array]
pattern_marking_name=['P1','P2','P3','P4','P5','P6','P7','P8','Others']

for i in range(len(pattern_marking_data)) :
    marking_mapping_plot(genx_A_4way, cluster_id,
                         cluster.cluster_centers_indices_,pattern_marking_data[i],pattern_marking_name[i],
                         pattern_marking_data_flatten,i)