In [222]:
import os
import pandas as pd
import numpy as np
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import time
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers
from keras.layers import Dense, LSTM, Embedding, LeakyReLU, BatchNormalization, Dropout, Input, ReLU
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam


### Load the dataset

In [223]:
# Load the data
data_path = 'original_data/GSE158508_normalized_counts.tsv'
data = pd.read_csv(data_path, sep='\t', index_col=0)

data_shape = data.shape[1]  # 69 columns

# Save columns names for later use
col_names = data.columns.values

print(data.shape)

(57736, 69)


### Create the models

#### The Generator

In [225]:
def build_generator(latent_dim, data_shape):
    model = Sequential()
    
    model.add(Dense(64, input_dim=latent_dim))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(64, input_dim=latent_dim))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(128))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(128))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(256))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(256))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(512))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(2048))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(4096))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(data_shape, activation='relu'))
    
    noise = Input(shape=(latent_dim,))
    generated_data = model(noise)
    
    return Model(noise, generated_data)


#### The Discriminator

In [227]:
def build_discriminator(data_shape):
    model = Sequential()
    
    model.add(Dense(1024, input_dim=data_shape))
    model.add(ReLU())
    model.add(Dropout(0.4))
    
    model.add(Dense(1024))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(512))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(256))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(128))
    model.add(ReLU())
    model.add(Dropout(0.4))
    
    model.add(Dense(1, activation='sigmoid'))
    
    data = Input(shape=(data_shape,))
    validity = model(data)
    
    return Model(data, validity)


### Compile the models

In [228]:
# Wymiary przestrzeni ukrytej
latent_dim = 64

# Budowa i kompilacja dyskryminatora
discriminator = build_discriminator(data_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# Budowa generatora
generator = build_generator(latent_dim, data_shape)

# Generator bierze szum jako wejście i generuje dane
z = Input(shape=(latent_dim,))
generated_data = generator(z)

# Tylko generator jest trenowany
discriminator.trainable = False

# Dyskryminator bierze wygenerowane dane jako wejście i określa ich prawdziwość
validity = discriminator(generated_data)

# Połączony model (stacked generator and discriminator)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))


### Save checkpoints

In [124]:
# checkpoint_dir = './training_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# checkpoint = tf.train.Checkpoint(generator=generator,
#                                  discriminator=discriminator)

### Define the training loop

In [229]:
def scale_output(generated_data):
    min_val = 0.873159932419581
    return generated_data + min_val

In [230]:
def train(generator, discriminator, combined, data, latent_dim, epochs, batch_size=128, save_interval=50):
    # Ładowanie i skalowanie danych
    X_train = data.values

    # Pętla po epokach
    for epoch in range(epochs):
        start = time.time()
        # ---------------------
        #  Trenowanie dyskryminatora
        # ---------------------

        # Wybieranie losowych próbek
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_data = X_train[idx]

        # Generowanie nowego szumu
        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Generowanie nowych danych
        generated_data = generator.predict(noise)
        generated_data = scale_output(generated_data)

        # Trenowanie dyskryminatora
        d_loss_real = discriminator.train_on_batch(real_data, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(generated_data, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Trenowanie generatora
        # ---------------------

        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Chcemy, aby dyskryminator uznał wygenerowane dane za prawdziwe
        valid_y = np.array([1] * batch_size)

        # Trenowanie generatora
        g_loss = combined.train_on_batch(noise, valid_y)

        # Zapisywanie postępów
        if epoch % save_interval == 0:
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
            #checkpoint.save(file_prefix = checkpoint_prefix)


### Train the model

In [221]:
# Trenowanie modelu
train(generator, discriminator, combined, data, latent_dim, epochs=3000, batch_size=128)

0 [D loss: 0.852849, acc.: 26.93%] [G loss: 0.629460]
Time for epoch 1 is 4.526929616928101 sec
50 [D loss: 0.697878, acc.: 49.32%] [G loss: 0.721198]
Time for epoch 51 is 0.8264892101287842 sec
100 [D loss: 0.695164, acc.: 50.98%] [G loss: 0.698274]
Time for epoch 101 is 0.8019261360168457 sec
150 [D loss: 0.692809, acc.: 52.12%] [G loss: 0.693293]
Time for epoch 151 is 0.797844409942627 sec
200 [D loss: 0.696308, acc.: 51.81%] [G loss: 0.694075]
Time for epoch 201 is 1.1123497486114502 sec
250 [D loss: 0.692505, acc.: 53.22%] [G loss: 0.694788]
Time for epoch 251 is 1.3819751739501953 sec
300 [D loss: 0.694990, acc.: 51.39%] [G loss: 0.692912]
Time for epoch 301 is 1.0467572212219238 sec
350 [D loss: 0.693592, acc.: 50.95%] [G loss: 0.696139]
Time for epoch 351 is 1.0929861068725586 sec
400 [D loss: 0.694290, acc.: 47.71%] [G loss: 0.693498]
Time for epoch 401 is 1.134448528289795 sec
450 [D loss: 0.693870, acc.: 48.34%] [G loss: 0.690891]
Time for epoch 451 is 1.0806210041046143 sec

KeyboardInterrupt: 

### Generate after the final epoch

In [208]:
def generate_data(generator, n_samples):
    noise = np.random.normal(0, 1, size=(n_samples, latent_dim))
    generated_data = generator.predict(noise)
    generated_data = scale_output(generated_data)
    return generated_data

In [209]:
synthetic_data = generate_data(generator, n_samples=57736)

### Save the generated data to file

In [210]:
df = pd.DataFrame(synthetic_data)
# append column names to the data 
df.columns = col_names
df.to_csv('synthetic_data/generated_data.tsv', sep='\t', index=False, header=True)

In [211]:
data.describe()

Unnamed: 0,TR4190,TR4184,TR4193,TR4186,TR4185,TR4219,TR4268,TR4269,TR4271,TR4273,...,TR4191,TR4017,TR4329,TR4215,TR4189,TR4011,TR4044,TR4267,TR4149,TR4188
count,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,...,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0
mean,1.194206,1.338851,1.171084,1.233508,1.238979,1.281862,1.404603,1.240871,1.188165,1.385383,...,1.175039,1.33924,1.198407,1.205923,1.196342,1.547039,1.202872,1.28315,1.265179,1.345267
std,1.145634,1.292237,1.097246,1.157664,1.184453,1.203927,1.447374,1.170842,1.123312,1.437643,...,1.139818,1.238471,1.196844,1.118109,1.111701,1.591664,1.143378,1.242899,1.337912,1.370505
min,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
25%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
50%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
75%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
max,14.566485,16.289445,15.247593,14.408506,14.97066,14.526786,14.430398,14.516758,14.286098,13.278705,...,15.028946,14.327433,14.590482,14.083975,14.62942,13.412346,15.252446,14.718518,14.142212,13.527193


In [212]:
df.describe()

Unnamed: 0,TR4190,TR4184,TR4193,TR4186,TR4185,TR4219,TR4268,TR4269,TR4271,TR4273,...,TR4191,TR4017,TR4329,TR4215,TR4189,TR4011,TR4044,TR4267,TR4149,TR4188
count,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,...,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0
mean,2.051405,1.597122,0.873176,0.980953,0.897371,1.964139,1.400752,1.173359,0.877007,1.8514,...,0.87316,1.046476,1.55599,1.626892,0.87316,1.925425,1.038716,0.886786,0.873191,0.962128
std,0.775961,0.699174,0.001403,0.263535,0.103911,0.705737,0.779746,0.543532,0.025109,0.951902,...,0.00058,0.341178,0.693108,0.558726,0.000575,0.741452,0.325553,0.074302,0.002596,0.266894
min,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
25%,1.477926,0.87316,0.87316,0.87316,0.87316,1.341336,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.903075,0.87316,0.981285,0.87316,0.87316,0.87316,0.87316
50%,2.186714,1.504515,0.87316,0.87316,0.87316,2.139694,0.887716,0.87316,0.87316,1.728293,...,0.87316,0.87316,1.408911,1.693339,0.87316,2.05071,0.87316,0.87316,0.87316,0.87316
75%,2.632104,2.042601,0.87316,0.87316,0.87316,2.492109,1.748798,1.279757,0.87316,2.401282,...,0.87316,1.078574,2.005128,2.009491,0.87316,2.455536,1.046527,0.87316,0.87316,0.87316
max,4.092374,4.765678,1.014402,2.889914,2.222526,3.466793,6.070132,5.337439,1.462217,7.334809,...,0.891643,4.602888,4.616714,3.945812,0.87316,4.092665,3.023783,2.006119,1.246375,3.235507
