In [222]:
import os
import pandas as pd
import numpy as np
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import time
import tensorflow as tf
from tensorflow import keras
from keras import layers, models, optimizers
from keras.layers import Dense, LSTM, Embedding, LeakyReLU, BatchNormalization, Dropout, Input, ReLU
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam


### Load the dataset

In [223]:
# Load the data
data_path = 'original_data/GSE158508_normalized_counts.tsv'
data = pd.read_csv(data_path, sep='\t', index_col=0)

data_shape = data.shape[1]  # 69 columns

# Save columns names for later use
col_names = data.columns.values

print(data.shape)

(57736, 69)


### Create the models

#### The Generator

In [225]:
def build_generator(latent_dim, data_shape):
    model = Sequential()
    
    model.add(Dense(64, input_dim=latent_dim))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(64, input_dim=latent_dim))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(128))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(128))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(256))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(256))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(512))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(2048))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(4096))
    model.add(LeakyReLU())
    model.add(BatchNormalization(momentum=0.8))
    
    model.add(Dense(data_shape, activation='relu'))
    
    noise = Input(shape=(latent_dim,))
    generated_data = model(noise)
    
    return Model(noise, generated_data)


#### The Discriminator

In [227]:
def build_discriminator(data_shape):
    model = Sequential()
    
    model.add(Dense(1024, input_dim=data_shape))
    model.add(ReLU())
    model.add(Dropout(0.4))
    
    model.add(Dense(1024))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(512))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(256))
    model.add(ReLU())
    model.add(Dropout(0.4))

    model.add(Dense(128))
    model.add(ReLU())
    model.add(Dropout(0.4))
    
    model.add(Dense(1, activation='sigmoid'))
    
    data = Input(shape=(data_shape,))
    validity = model(data)
    
    return Model(data, validity)


### Compile the models

In [228]:
# Wymiary przestrzeni ukrytej
latent_dim = 64

# Budowa i kompilacja dyskryminatora
discriminator = build_discriminator(data_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# Budowa generatora
generator = build_generator(latent_dim, data_shape)

# Generator bierze szum jako wejście i generuje dane
z = Input(shape=(latent_dim,))
generated_data = generator(z)

# Tylko generator jest trenowany
discriminator.trainable = False

# Dyskryminator bierze wygenerowane dane jako wejście i określa ich prawdziwość
validity = discriminator(generated_data)

# Połączony model (stacked generator and discriminator)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))


### Save checkpoints

In [124]:
# checkpoint_dir = './training_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# checkpoint = tf.train.Checkpoint(generator=generator,
#                                  discriminator=discriminator)

### Define the training loop

In [229]:
def scale_output(generated_data):
    min_val = 0.873159932419581
    return generated_data + min_val

In [230]:
def train(generator, discriminator, combined, data, latent_dim, epochs, batch_size=128, save_interval=50):
    # Ładowanie i skalowanie danych
    X_train = data.values

    # Pętla po epokach
    for epoch in range(epochs):
        start = time.time()
        # ---------------------
        #  Trenowanie dyskryminatora
        # ---------------------

        # Wybieranie losowych próbek
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_data = X_train[idx]

        # Generowanie nowego szumu
        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Generowanie nowych danych
        generated_data = generator.predict(noise)
        generated_data = scale_output(generated_data)

        # Trenowanie dyskryminatora
        d_loss_real = discriminator.train_on_batch(real_data, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(generated_data, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Trenowanie generatora
        # ---------------------

        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Chcemy, aby dyskryminator uznał wygenerowane dane za prawdziwe
        valid_y = np.array([1] * batch_size)

        # Trenowanie generatora
        g_loss = combined.train_on_batch(noise, valid_y)

        # Zapisywanie postępów
        if epoch % save_interval == 0:
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
            #checkpoint.save(file_prefix = checkpoint_prefix)


### Train the model

In [231]:
# Trenowanie modelu
train(generator, discriminator, combined, data, latent_dim, epochs=3000, batch_size=128)

0 [D loss: 0.838802, acc.: 30.47%] [G loss: 0.708507]
Time for epoch 1 is 3.23549747467041 sec
50 [D loss: 0.682237, acc.: 51.56%] [G loss: 0.979710]
Time for epoch 51 is 0.17813944816589355 sec
100 [D loss: 0.707888, acc.: 47.66%] [G loss: 0.749582]
Time for epoch 101 is 0.1814873218536377 sec
150 [D loss: 0.702724, acc.: 50.78%] [G loss: 0.708737]
Time for epoch 151 is 0.17972993850708008 sec
200 [D loss: 0.695175, acc.: 51.56%] [G loss: 0.701854]
Time for epoch 201 is 0.17410039901733398 sec
250 [D loss: 0.697000, acc.: 51.17%] [G loss: 0.694157]
Time for epoch 251 is 0.17612695693969727 sec
300 [D loss: 0.693497, acc.: 52.73%] [G loss: 0.699870]
Time for epoch 301 is 0.26914119720458984 sec
350 [D loss: 0.694649, acc.: 52.34%] [G loss: 0.691001]
Time for epoch 351 is 0.2644312381744385 sec
400 [D loss: 0.691035, acc.: 51.95%] [G loss: 0.694297]
Time for epoch 401 is 0.32695579528808594 sec
450 [D loss: 0.687431, acc.: 55.86%] [G loss: 0.694836]
Time for epoch 451 is 0.2956519126892

### Generate after the final epoch

In [232]:
def generate_data(generator, n_samples):
    noise = np.random.normal(0, 1, size=(n_samples, latent_dim))
    generated_data = generator.predict(noise)
    generated_data = scale_output(generated_data)
    return generated_data

In [233]:
synthetic_data = generate_data(generator, n_samples=57736)

### Save the generated data to file

In [234]:
df = pd.DataFrame(synthetic_data)
# append column names to the data 
df.columns = col_names
df.to_csv('synthetic_data/generated_data.tsv', sep='\t', index=False, header=True)

### Data distribution

In [235]:
data.describe()

Unnamed: 0,TR4190,TR4184,TR4193,TR4186,TR4185,TR4219,TR4268,TR4269,TR4271,TR4273,...,TR4191,TR4017,TR4329,TR4215,TR4189,TR4011,TR4044,TR4267,TR4149,TR4188
count,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,...,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0
mean,1.194206,1.338851,1.171084,1.233508,1.238979,1.281862,1.404603,1.240871,1.188165,1.385383,...,1.175039,1.33924,1.198407,1.205923,1.196342,1.547039,1.202872,1.28315,1.265179,1.345267
std,1.145634,1.292237,1.097246,1.157664,1.184453,1.203927,1.447374,1.170842,1.123312,1.437643,...,1.139818,1.238471,1.196844,1.118109,1.111701,1.591664,1.143378,1.242899,1.337912,1.370505
min,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
25%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
50%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
75%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
max,14.566485,16.289445,15.247593,14.408506,14.97066,14.526786,14.430398,14.516758,14.286098,13.278705,...,15.028946,14.327433,14.590482,14.083975,14.62942,13.412346,15.252446,14.718518,14.142212,13.527193


In [236]:
df.describe()

Unnamed: 0,TR4190,TR4184,TR4193,TR4186,TR4185,TR4219,TR4268,TR4269,TR4271,TR4273,...,TR4191,TR4017,TR4329,TR4215,TR4189,TR4011,TR4044,TR4267,TR4149,TR4188
count,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,...,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0,57736.0
mean,1.074218,1.466389,1.293475,1.369223,1.608798,1.138521,1.082749,1.142271,1.268538,1.430429,...,1.3708,1.365978,1.314876,1.129716,1.355922,1.350276,1.146059,1.470701,1.484422,1.444118
std,0.788013,2.307459,1.576355,1.940696,2.646473,1.083133,0.956349,1.055396,1.484436,2.332082,...,1.88422,1.804007,1.765342,1.041269,1.836795,1.966122,1.107575,2.298373,2.357991,2.293154
min,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
25%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
50%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
75%,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,...,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316,0.87316
max,6.690738,19.344109,12.99997,17.029339,22.010338,9.640545,9.076724,10.537092,13.403953,19.13648,...,15.758722,13.027126,14.900488,9.546752,14.753994,17.16638,11.719974,18.790796,18.779106,23.107731
