In [1]:
import pandas as pd 
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import os, sys, time
sys.path.append("..")
from all_funcs import util
from model import Generator, Discriminator
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras import layers
from sklearn.impute import SimpleImputer
tf.keras.backend.set_floatx('float32')


  import pandas.util.testing as tm


In [21]:
# y_true = [3.]
# y_pred = [0.]

# bce = tf.keras.losses.MeanSquaredError()
# bce(y_true, y_pred).numpy()


9.0

In [2]:
from numba import cuda
os.environ['CUDA_VISIBLE_DEVICES']="0"
print(tf.test.is_gpu_available())

True


In [2]:
df=pd.read_csv("../dataset/df_noOutliner_ana.csv",index_col=0)

In [3]:
df, imp_mode, imp_mean=util.FeatureArrange(df)
sc = MinMaxScaler()
df = sc.fit_transform(df)
df.shape

(4778, 73)

In [5]:
## setting hyperparameter
latent_dim = 72
epochs = 1000
batch_size= 128
buffer_size = 6000
# save_interval = 50
n_critic = 5

In [6]:
generator = Generator()
discriminator = Discriminator()

In [7]:
gen_opt = tf.keras.optimizers.Adam(0.0001, 0.5, 0.9)
disc_opt = tf.keras.optimizers.Adam(0.0001, 0.5, 0.9)

In [8]:
training_dataset=tf.data.Dataset.from_tensor_slices(df.astype('float32'))\
    .shuffle(buffer_size).batch(batch_size, drop_remainder=True)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [10]:
@tf.function
def train_discriminator(x, labels):

    noise = tf.random.normal([x.shape[0], latent_dim])

    with tf.GradientTape() as dis_tape:
        gen_data = generator(noise, labels, )

        dis_output, labels_fake = discriminator(gen_data, label_out=True)

        real_output, labels_real = discriminator(x, label_out=True)
        
        # calculated labels Loss
        labels_loss = cross_entropy(labels, labels_fake).numpy() +\
            cross_entropy(labels, labels_real).numpy()
        
        x_hat = util.random_weight_average(x, gen_data)
        d_hat = discriminator(x_hat, label_out=False)

        disc_loss = util.discriminator_loss(real_output, dis_output, d_hat, x_hat)+labels_loss
        
    grad_disc = dis_tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(grad_disc, discriminator.trainable_variables))

    return disc_loss


@tf.function
def train_generator(labels):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape:
        gen_data = generator(noise, labels,)
        dis_output, judge_labels = discriminator(gen_data, label_out=True)

        gen_loss = util.generator_loss(dis_output)

        # sum all loss
        sum_loss = gen_loss+cross_entropy(labels, judge_labels).numpy()

    grad_gen = gen_tape.gradient(sum_loss, generator.trainable_variables)
    gen_opt.apply_gradients(zip(grad_gen, generator.trainable_variables))

    return gen_loss

In [11]:
for epoch in range(epochs):
    start = time.time()
    disc_loss = 0
    gen_loss = 0
    
    training_dataset=tf.data.Dataset.from_tensor_slices(df.astype('float32'))\
    .shuffle(buffer_size).batch(batch_size, drop_remainder=True)
    
    for data in training_dataset:
        disc_loss += train_discriminator(data)
    
        if disc_opt.iterations.numpy() % n_critic == 0:
            gen_loss += train_generator()
    print('Time for epoch {} is {} sec - gen_loss = {}, disc_loss = {}'.format(epoch + 1,time.time() - start, 
                                 gen_loss / batch_size, 
                                 disc_loss / (batch_size*n_critic)))

Time for epoch 1 is 2.240140199661255 sec - gen_loss = 0.02712870016694069, disc_loss = 0.33137020468711853
Time for epoch 2 is 0.18801188468933105 sec - gen_loss = 0.018824489787220955, disc_loss = 0.03301863744854927
Time for epoch 3 is 0.18401145935058594 sec - gen_loss = 0.008542127907276154, disc_loss = -0.03147453814744949
Time for epoch 4 is 0.18401122093200684 sec - gen_loss = 0.004004958085715771, disc_loss = -0.04073657840490341
Time for epoch 5 is 0.19201207160949707 sec - gen_loss = 0.0023481126409024, disc_loss = -0.04780250042676926
Time for epoch 6 is 0.19201207160949707 sec - gen_loss = 0.0010885894298553467, disc_loss = -0.04963129013776779
Time for epoch 7 is 0.19601225852966309 sec - gen_loss = 0.000538758235052228, disc_loss = -0.05157894641160965
Time for epoch 8 is 0.18801188468933105 sec - gen_loss = 0.0003601380449254066, disc_loss = -0.05277132987976074
Time for epoch 9 is 0.18401145935058594 sec - gen_loss = 0.00019111733126919717, disc_loss = -0.0517725162208

Time for epoch 78 is 0.3000190258026123 sec - gen_loss = 0.0, disc_loss = -0.0477827712893486
Time for epoch 79 is 0.29201817512512207 sec - gen_loss = 0.0, disc_loss = -0.050231099128723145
Time for epoch 80 is 0.30801939964294434 sec - gen_loss = 0.0, disc_loss = -0.04736940190196037
Time for epoch 81 is 0.3000185489654541 sec - gen_loss = 0.0, disc_loss = -0.04416896402835846
Time for epoch 82 is 0.3040192127227783 sec - gen_loss = 0.0, disc_loss = -0.048702776432037354
Time for epoch 83 is 0.2960186004638672 sec - gen_loss = 0.0, disc_loss = -0.047127362340688705
Time for epoch 84 is 0.3200199604034424 sec - gen_loss = 0.0, disc_loss = -0.046310193836688995
Time for epoch 85 is 0.3440215587615967 sec - gen_loss = 0.0, disc_loss = -0.046308163553476334
Time for epoch 86 is 0.3280203342437744 sec - gen_loss = 0.0, disc_loss = -0.04830324649810791
Time for epoch 87 is 0.33602094650268555 sec - gen_loss = 0.0, disc_loss = -0.0443916954100132
Time for epoch 88 is 0.3280203342437744 sec 

Time for epoch 164 is 0.1480093002319336 sec - gen_loss = 0.0, disc_loss = -0.049048908054828644
Time for epoch 165 is 0.1480093002319336 sec - gen_loss = 0.0, disc_loss = -0.04879320040345192
Time for epoch 166 is 0.1480093002319336 sec - gen_loss = 0.0, disc_loss = -0.04471539705991745
Time for epoch 167 is 0.1480093002319336 sec - gen_loss = 0.0, disc_loss = -0.049226921051740646
Time for epoch 168 is 0.14000868797302246 sec - gen_loss = 0.0, disc_loss = -0.050024550408124924
Time for epoch 169 is 0.14400887489318848 sec - gen_loss = 0.0, disc_loss = -0.050862155854701996
Time for epoch 170 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.04899051412940025
Time for epoch 171 is 0.1480090618133545 sec - gen_loss = 0.0, disc_loss = -0.046679966151714325
Time for epoch 172 is 0.1480093002319336 sec - gen_loss = 0.0, disc_loss = -0.04756016284227371
Time for epoch 173 is 0.14400887489318848 sec - gen_loss = 0.0, disc_loss = -0.05044664815068245
Time for epoch 174 is 0.15200948

Time for epoch 248 is 0.20801305770874023 sec - gen_loss = 0.0, disc_loss = -0.047789353877305984
Time for epoch 249 is 0.2000124454498291 sec - gen_loss = 6.548361852765083e-11, disc_loss = -0.04603543505072594
Time for epoch 250 is 0.20801305770874023 sec - gen_loss = 5.638867150992155e-11, disc_loss = -0.04681510850787163
Time for epoch 251 is 0.20401263236999512 sec - gen_loss = 0.0, disc_loss = -0.04052365943789482
Time for epoch 252 is 0.2000124454498291 sec - gen_loss = 0.0, disc_loss = -0.046783145517110825
Time for epoch 253 is 0.19601225852966309 sec - gen_loss = 0.0, disc_loss = -0.048003584146499634
Time for epoch 254 is 0.2000124454498291 sec - gen_loss = 0.0, disc_loss = -0.04710625112056732
Time for epoch 255 is 0.19601225852966309 sec - gen_loss = 2.9103830456733704e-11, disc_loss = -0.048904627561569214
Time for epoch 256 is 0.19601225852966309 sec - gen_loss = 0.0, disc_loss = -0.04805698245763779
Time for epoch 257 is 0.21201300621032715 sec - gen_loss = 0.0, disc_lo

Time for epoch 331 is 0.3800239562988281 sec - gen_loss = 3.637978807091713e-12, disc_loss = -0.04995020106434822
Time for epoch 332 is 0.38402390480041504 sec - gen_loss = 0.0, disc_loss = -0.049038924276828766
Time for epoch 333 is 0.4040253162384033 sec - gen_loss = 0.0, disc_loss = -0.046439047902822495
Time for epoch 334 is 0.38802433013916016 sec - gen_loss = 0.0, disc_loss = -0.051014356315135956
Time for epoch 335 is 0.4040250778198242 sec - gen_loss = 0.0, disc_loss = -0.050205670297145844
Time for epoch 336 is 0.4240264892578125 sec - gen_loss = 0.0, disc_loss = -0.04929213598370552
Time for epoch 337 is 0.4000248908996582 sec - gen_loss = 0.0, disc_loss = -0.04939039796590805
Time for epoch 338 is 0.4680294990539551 sec - gen_loss = 8.330971468240023e-10, disc_loss = -0.05101697891950607
Time for epoch 339 is 0.468029260635376 sec - gen_loss = 0.0, disc_loss = -0.047633200883865356
Time for epoch 340 is 0.45602846145629883 sec - gen_loss = 0.0, disc_loss = -0.050223153084516

Time for epoch 412 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.04388440027832985
Time for epoch 413 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.046438977122306824
Time for epoch 414 is 0.15600967407226562 sec - gen_loss = 5.4569682106375694e-11, disc_loss = -0.04662328585982323
Time for epoch 415 is 0.15600967407226562 sec - gen_loss = 1.8189894035458565e-12, disc_loss = -0.047884419560432434
Time for epoch 416 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.05013594776391983
Time for epoch 417 is 0.15600991249084473 sec - gen_loss = 0.0, disc_loss = -0.04720147326588631
Time for epoch 418 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.04313565790653229
Time for epoch 419 is 0.15600967407226562 sec - gen_loss = 0.0, disc_loss = -0.0467359833419323
Time for epoch 420 is 0.1520097255706787 sec - gen_loss = 0.0, disc_loss = -0.044952403753995895
Time for epoch 421 is 0.16001009941101074 sec - gen_loss = 0.0, disc_loss = -0.0472142100334

Time for epoch 492 is 0.20401287078857422 sec - gen_loss = 9.094947017729282e-12, disc_loss = -0.04845592379570007
Time for epoch 493 is 0.20401263236999512 sec - gen_loss = 3.092281986027956e-11, disc_loss = -0.05073506385087967
Time for epoch 494 is 0.20801281929016113 sec - gen_loss = 5.4569682106375694e-12, disc_loss = -0.0458202138543129
Time for epoch 495 is 0.21601366996765137 sec - gen_loss = 0.0, disc_loss = -0.050015706568956375
Time for epoch 496 is 0.21201324462890625 sec - gen_loss = 0.0, disc_loss = -0.04728097468614578
Time for epoch 497 is 0.20801305770874023 sec - gen_loss = 0.0, disc_loss = -0.044407330453395844
Time for epoch 498 is 0.19601202011108398 sec - gen_loss = 0.0, disc_loss = -0.05030781775712967
Time for epoch 499 is 0.21601343154907227 sec - gen_loss = 0.0, disc_loss = -0.04947090148925781
Time for epoch 500 is 0.20801305770874023 sec - gen_loss = 3.637978807091713e-12, disc_loss = -0.047043655067682266
Time for epoch 501 is 0.21201324462890625 sec - gen_

Time for epoch 572 is 0.4040253162384033 sec - gen_loss = 0.0, disc_loss = -0.04913688078522682
Time for epoch 573 is 0.4480280876159668 sec - gen_loss = 0.0, disc_loss = -0.045836083590984344
Time for epoch 574 is 0.46002864837646484 sec - gen_loss = 0.0, disc_loss = -0.04574421048164368
Time for epoch 575 is 0.4680294990539551 sec - gen_loss = 0.0, disc_loss = -0.05109833925962448
Time for epoch 576 is 0.5040316581726074 sec - gen_loss = 0.0, disc_loss = -0.05001339316368103
Time for epoch 577 is 0.5000312328338623 sec - gen_loss = 0.0, disc_loss = -0.04645590856671333
Time for epoch 578 is 0.5000312328338623 sec - gen_loss = 1.0913936421275139e-11, disc_loss = -0.04592360928654671
Time for epoch 579 is 0.5080316066741943 sec - gen_loss = 0.0, disc_loss = -0.04865119978785515
Time for epoch 580 is 0.568035364151001 sec - gen_loss = 0.0, disc_loss = -0.045967958867549896
Time for epoch 581 is 0.5720357894897461 sec - gen_loss = 0.0, disc_loss = -0.04819583147764206
Time for epoch 582 

Time for epoch 653 is 0.16401028633117676 sec - gen_loss = 0.0, disc_loss = -0.04969382658600807
Time for epoch 654 is 0.16001009941101074 sec - gen_loss = 0.0, disc_loss = -0.04613451659679413
Time for epoch 655 is 0.16401004791259766 sec - gen_loss = 0.0, disc_loss = -0.04971695318818092
Time for epoch 656 is 0.16001009941101074 sec - gen_loss = 0.0, disc_loss = -0.04742036014795303
Time for epoch 657 is 0.16401052474975586 sec - gen_loss = 0.0, disc_loss = -0.04484429210424423
Time for epoch 658 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.0439843088388443
Time for epoch 659 is 0.15600967407226562 sec - gen_loss = 3.637978807091713e-12, disc_loss = -0.046131011098623276
Time for epoch 660 is 0.16000986099243164 sec - gen_loss = 0.0, disc_loss = -0.0472024530172348
Time for epoch 661 is 0.16001009941101074 sec - gen_loss = 0.0, disc_loss = -0.04860015958547592
Time for epoch 662 is 0.16000986099243164 sec - gen_loss = 0.0, disc_loss = -0.04934098571538925
Time for epoch

Time for epoch 735 is 0.21601343154907227 sec - gen_loss = 0.0, disc_loss = -0.04705057293176651
Time for epoch 736 is 0.21601343154907227 sec - gen_loss = 0.0, disc_loss = -0.04806330054998398
Time for epoch 737 is 0.2000124454498291 sec - gen_loss = 0.0, disc_loss = -0.049327023327350616
Time for epoch 738 is 0.2240140438079834 sec - gen_loss = 0.0, disc_loss = -0.048112113028764725
Time for epoch 739 is 0.20801305770874023 sec - gen_loss = 0.0, disc_loss = -0.046945057809352875
Time for epoch 740 is 0.22001361846923828 sec - gen_loss = 0.0, disc_loss = -0.04957412928342819
Time for epoch 741 is 0.22001385688781738 sec - gen_loss = 0.0, disc_loss = -0.046789031475782394
Time for epoch 742 is 0.22001385688781738 sec - gen_loss = 0.0, disc_loss = -0.045274026691913605
Time for epoch 743 is 0.23201441764831543 sec - gen_loss = 0.0, disc_loss = -0.04957934096455574
Time for epoch 744 is 0.23601484298706055 sec - gen_loss = 0.0, disc_loss = -0.04725588113069534
Time for epoch 745 is 0.220

Time for epoch 819 is 0.660041332244873 sec - gen_loss = 0.0, disc_loss = -0.04945690557360649
Time for epoch 820 is 0.6960434913635254 sec - gen_loss = 0.0, disc_loss = -0.041458211839199066
Time for epoch 821 is 0.6760423183441162 sec - gen_loss = 0.0, disc_loss = -0.047358714044094086
Time for epoch 822 is 0.6680417060852051 sec - gen_loss = 7.275957614183426e-12, disc_loss = -0.044910892844200134
Time for epoch 823 is 0.8160510063171387 sec - gen_loss = 0.0, disc_loss = -0.047465600073337555
Time for epoch 824 is 0.856053352355957 sec - gen_loss = 0.0, disc_loss = -0.04800380766391754
Time for epoch 825 is 0.9640600681304932 sec - gen_loss = 1.8189894035458565e-12, disc_loss = -0.04651516303420067
Time for epoch 826 is 1.1480717658996582 sec - gen_loss = 0.0, disc_loss = -0.04800712317228317
Time for epoch 827 is 1.3520843982696533 sec - gen_loss = 0.0, disc_loss = -0.048527706414461136
Time for epoch 828 is 1.6121008396148682 sec - gen_loss = 0.0, disc_loss = -0.04809989780187607


Time for epoch 904 is 0.16401028633117676 sec - gen_loss = 0.0, disc_loss = -0.04815516620874405
Time for epoch 905 is 0.16401004791259766 sec - gen_loss = 0.0, disc_loss = -0.04877431318163872
Time for epoch 906 is 0.16000986099243164 sec - gen_loss = 0.0, disc_loss = -0.04753304272890091
Time for epoch 907 is 0.15600967407226562 sec - gen_loss = 0.0, disc_loss = -0.048138950020074844
Time for epoch 908 is 0.16801047325134277 sec - gen_loss = 0.0, disc_loss = -0.04662309214472771
Time for epoch 909 is 0.16801047325134277 sec - gen_loss = 0.0, disc_loss = -0.04740063101053238
Time for epoch 910 is 0.1720106601715088 sec - gen_loss = 0.0, disc_loss = -0.04600663483142853
Time for epoch 911 is 0.1520094871520996 sec - gen_loss = 0.0, disc_loss = -0.04580608755350113
Time for epoch 912 is 0.16801047325134277 sec - gen_loss = 0.0, disc_loss = -0.047373220324516296
Time for epoch 913 is 0.16801047325134277 sec - gen_loss = 0.0, disc_loss = -0.05020096153020859
Time for epoch 914 is 0.168010

Time for epoch 989 is 0.24001479148864746 sec - gen_loss = 0.0, disc_loss = -0.050554316490888596
Time for epoch 990 is 0.24401545524597168 sec - gen_loss = 0.0, disc_loss = -0.04830825328826904
Time for epoch 991 is 0.23601460456848145 sec - gen_loss = 0.0, disc_loss = -0.04873283952474594
Time for epoch 992 is 0.2480156421661377 sec - gen_loss = 0.0, disc_loss = -0.04954100772738457
Time for epoch 993 is 0.2560157775878906 sec - gen_loss = 0.0, disc_loss = -0.04567306861281395
Time for epoch 994 is 0.23601484298706055 sec - gen_loss = 0.0, disc_loss = -0.047731664031744
Time for epoch 995 is 0.24001526832580566 sec - gen_loss = 0.0, disc_loss = -0.04690162092447281
Time for epoch 996 is 0.2480154037475586 sec - gen_loss = 0.0, disc_loss = -0.046274732798337936
Time for epoch 997 is 0.23601484298706055 sec - gen_loss = 0.0, disc_loss = -0.04937570542097092
Time for epoch 998 is 0.23601460456848145 sec - gen_loss = 0.0, disc_loss = -0.048450272530317307
Time for epoch 999 is 0.25601601

In [12]:
# disc_opt.iterations.numpy()

In [16]:
noise = tf.random.normal([128, latent_dim])
a=generator(noise).numpy()
(sc.inverse_transform(a))
for i in sc.inverse_transform(a):
    print(np.round(i))

[ 40.  63.   9.   6.  23. 818.   2.  16. 170.   4. 138.   9.   1.  16.
  94.   3.  32.   0.   0.   0.   1.   0.   1.   0.   0.   1.   1.   1.
   1.   0.   1.   1.   1.   0.   0.   1.   1.   1.   1.   1.   1.   0.
   1.   1.   1.   0.   1.   0.   1.   0.   0.   1.   0.   1.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   4.   0.   4.   4.   0.   2.   3.
   0.   0.   0.]
[ 40.  65.   9.   6.  25. 825.   2.  14. 178.   3. 111.  11.   1.  16.
  95.   7.  27.   0.   0.   0.   1.   0.   1.   0.   0.   1.   1.   1.
   1.   0.   1.   1.   1.   0.   0.   1.   1.   1.   1.   1.   1.   0.
   1.   1.   1.   0.   1.   0.   1.   0.   0.   1.   0.   1.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   3.   0.   4.   4.   0.   2.   3.
   0.   0.   0.]
[ 39.  74.  10.   7.  20. 766.   3.  20. 166.   3. 143.  16.   1.  17.
  89.   7.  33.   0.   0.   0.   1.   0.   1.   0.   0.   1.   1.   1.
   1.   0.   1.   1.   1.   0.   0.   1.   1.   1.   1.   1.   1.   0.
   1.   1.   1.   0.   1.   0.   1.   0.   