In [1]:
import pandas as pd 
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import os, sys, time
sys.path.append("..")
from all_funcs import util
from model import Generator, Discriminator, train_discriminator, train_generator
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras import layers
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

tf.keras.backend.set_floatx('float64')
pd.set_option('display.max_rows',500)
pd.set_option('display.max_columns',500)

In [2]:
from numba import cuda
os.environ['CUDA_VISIBLE_DEVICES']="1"
print(tf.test.is_gpu_available())

True


In [3]:
df=pd.read_csv("../dataset/df_noOutliner_ana.csv",index_col=0)
df, imp_mode, imp_mean=util.FeatureArrange(df)

In [4]:
## reduce redundant features which can be assembled
dataset=df.drop(['NIHTotal','THD_ID','cortical_CT', 'subcortical_CT',
              'circulation_CT', 'CT_find', 'watershed_CT', 'Hemorrhagic_infarct_CT',
              'CT_left', 'CT_right',],axis=1)

In [5]:
## prepare for inverse tensor values from range(0,1) to original values
params=dict()
params['max']=dataset.max().to_numpy()
params['min']=dataset.min().to_numpy()

In [6]:
sc = MinMaxScaler()
dataset.loc[:,dataset.columns] = sc.fit_transform(dataset.loc[:,dataset.columns])

In [7]:
## setting hyperparameter
latent_dim = dataset.shape[1]
epochs = 15000
batch_size= 128
buffer_size = 6000
# save_interval = 50
n_critic = 5
checkpoint_dir = './training_checkpoints'


In [8]:
generator = Generator(latent_dim)
discriminator = Discriminator()

In [9]:
## create Cross Entropy
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [10]:
gen_opt = tf.keras.optimizers.Adam(0.0001,)
disc_opt = tf.keras.optimizers.Adam(0.00001,)

In [11]:
# save checkpoints
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=gen_opt,
                                 discriminator_optimizer=disc_opt,
                                 generator=generator,
                                 discriminator=discriminator)

In [12]:
X_train, X_test = train_test_split(dataset, test_size=0.2,shuffle=True,
                                   stratify=dataset['elapsed_class'],
                                   random_state=None)

In [None]:
# separate to 0,1 dataset
data_1=X_train.loc[X_train['elapsed_class']==1]
data_0=X_train.loc[X_train['elapsed_class']==0]
## store losses
### generator losses
losses_gen = np.array([])
best_loss_gen = np.inf
### discriminator losses
losses_dis = np.array([])
best_loss_dis = np.inf

for epoch in range(epochs):
    start = time.time()
    disc_loss = 0
    gen_loss = 0

    # resample the dataset
    data1_shape_0 = data_1.sample(data_0.shape[0])
    df_same_shape = pd.concat([data1_shape_0, data_0]).to_numpy()

#     # slices to data and labels
#     df_same_ = df_same_shape.iloc[:, :-1].to_numpy()
#     org_label = df_same_shape.iloc[:, -1].to_numpy().reshape(-1, 1)

    # create batch dataset
    training_dataset = tf.data.Dataset.from_tensor_slices(df_same_shape)\
        .shuffle(buffer_size).batch(batch_size, drop_remainder=True)

    for data in training_dataset:
        for _ in range(n_critic): # 5*discriminator times, 1*generator times
            disc_loss += train_discriminator(data, generator,
                                             discriminator, disc_opt, latent_dim)
#         if disc_opt.iterations.numpy() % n_critic == 0: ### using samples
        gen_loss+= train_generator(data, generator,
                                    discriminator, gen_opt, params, batch_size, latent_dim)
    
    losses_gen= np.append(losses_gen, gen_loss / batch_size)
    losses_dis= np.append(losses_dis, disc_loss / (batch_size*n_critic))
    
    print('Time for epoch {} is {} sec - gen_loss = {}, disc_loss = {}'.format(epoch + 1, time.time() - start,
                                                                               gen_loss / batch_size,
                                                                               disc_loss / (batch_size*n_critic)))
    # save best discriminator or generator
    if abs(best_loss_gen) > abs((gen_loss / batch_size)):
        best_loss_gen = (gen_loss / batch_size)
        generator.save_weights(checkpoint_prefix+"gen", save_format='tf')
        
    if abs(best_loss_dis) > abs((disc_loss / (batch_size*n_critic))):
        best_loss_dis = (disc_loss / (batch_size*n_critic))
        discriminator.save_weights(checkpoint_prefix+"dis", save_format='tf')
    



To change all layers to have dtype float32 by default, call `tf.keras.backend.set_floatx('float32')`. To change just this layer, pass dtype='float32' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Time for epoch 1 is 48.269712924957275 sec - gen_loss = 77.87105334501982, disc_loss = 152.46211179744972
Time for epoch 2 is 1.1840739250183105 sec - gen_loss = 80.0898007909963, disc_loss = 102.51627272855497
Time for epoch 3 is -17.621154308319092 sec - gen_loss = 77.24401800871306, disc_loss = 71.22179209832171
Time for epoch 4 is 1.1960747241973877 sec - gen_loss = 76.85333435713474, disc_loss = 50.087021515986976
Time for epoch 5 is 1.1760737895965576 sec - gen_loss = 76.2474400400428, disc_loss = 36.099991554261216
Time for epoch 6 is 1.2560787200927734 sec - gen_loss = 72.669381443203, disc_loss = 26.17834107558905
Time for epoch 7 is 1.2840800285339355 sec - gen_loss = 66.576756413

Time for epoch 71 is 1.2000749111175537 sec - gen_loss = 32.26939839295901, disc_loss = -0.008636119395674583
Time for epoch 72 is -16.232255220413208 sec - gen_loss = 31.76943990107089, disc_loss = -0.009285941196897147
Time for epoch 73 is 1.2720794677734375 sec - gen_loss = 31.909030906477174, disc_loss = -0.00955500528686326
Time for epoch 74 is 1.2720794677734375 sec - gen_loss = 31.196410780951012, disc_loss = -0.010456479912634506
Time for epoch 75 is 1.3400835990905762 sec - gen_loss = 32.07054481543521, disc_loss = -0.009374314318603565
Time for epoch 76 is 1.3640851974487305 sec - gen_loss = 32.033077938998225, disc_loss = -0.010772017526340591
Time for epoch 77 is 1.3680856227874756 sec - gen_loss = 31.99375692682328, disc_loss = -0.011063887136002412
Time for epoch 78 is 1.4880931377410889 sec - gen_loss = 31.608920605168667, disc_loss = -0.01124273364075041
Time for epoch 79 is 1.5560970306396484 sec - gen_loss = 31.606000814053214, disc_loss = -0.0115709792064447
Time for

Time for epoch 145 is 1.4520907402038574 sec - gen_loss = 29.56665706106186, disc_loss = -0.002313377108760059
Time for epoch 146 is 1.5400962829589844 sec - gen_loss = 29.373630126158357, disc_loss = -0.001843240136597869
Time for epoch 147 is 1.5280954837799072 sec - gen_loss = 29.029634231911352, disc_loss = -0.00242583906302805
Time for epoch 148 is 1.6521029472351074 sec - gen_loss = 28.9215141170825, disc_loss = -0.0029390543622379154
Time for epoch 149 is 1.700106143951416 sec - gen_loss = 28.980461311894565, disc_loss = -0.0029664919939364318
Time for epoch 150 is 1.8041129112243652 sec - gen_loss = 29.06386924683502, disc_loss = -0.003410712293834711
Time for epoch 151 is 1.9121193885803223 sec - gen_loss = 28.430284002307708, disc_loss = -0.002951543712353636
Time for epoch 152 is 2.0241265296936035 sec - gen_loss = 29.37932304774968, disc_loss = -0.003262410721653671
Time for epoch 153 is 2.2281394004821777 sec - gen_loss = 29.03403900204717, disc_loss = -0.00292916571428755

Time for epoch 219 is 2.0161259174346924 sec - gen_loss = 25.900179714817362, disc_loss = -0.0022883254279051
Time for epoch 220 is 2.172136068344116 sec - gen_loss = 25.90899039222061, disc_loss = -0.0023663786529572817
Time for epoch 221 is 2.3281455039978027 sec - gen_loss = 26.041381935264432, disc_loss = -0.002469112343864463
Time for epoch 222 is 2.4681544303894043 sec - gen_loss = 25.697062879281788, disc_loss = -0.0023339371502028297
Time for epoch 223 is 2.6801676750183105 sec - gen_loss = 25.79112262128537, disc_loss = -0.0024943334763070767
Time for epoch 224 is 3.284205198287964 sec - gen_loss = 25.657840525345147, disc_loss = -0.0022630994434698357
Time for epoch 225 is 3.680230140686035 sec - gen_loss = 25.918674380174153, disc_loss = -0.002419549965102759
Time for epoch 226 is -13.910802125930786 sec - gen_loss = 25.928305987732628, disc_loss = -0.002252227552678405
Time for epoch 227 is 8.032502174377441 sec - gen_loss = 25.935484204611807, disc_loss = -0.00231697612314

Time for epoch 293 is 4.420276165008545 sec - gen_loss = 21.643693578520097, disc_loss = -0.002423152077781871
Time for epoch 294 is 6.776423692703247 sec - gen_loss = 21.711692532217093, disc_loss = -0.002277857914513222
Time for epoch 295 is -18.266247272491455 sec - gen_loss = 22.107326474757414, disc_loss = -0.002385292814867165
Time for epoch 296 is 0.9920618534088135 sec - gen_loss = 22.009515515304507, disc_loss = -0.0024175788600375543
Time for epoch 297 is 0.9680607318878174 sec - gen_loss = 22.248579093852715, disc_loss = -0.002554579170695808
Time for epoch 298 is 1.0080630779266357 sec - gen_loss = 21.5846476417597, disc_loss = -0.002156582498558623
Time for epoch 299 is 0.9960622787475586 sec - gen_loss = 22.06966155903646, disc_loss = -0.0024275622209432675
Time for epoch 300 is 1.0600662231445312 sec - gen_loss = 21.96393344700256, disc_loss = -0.002469383562701969
Time for epoch 301 is 1.0640664100646973 sec - gen_loss = 21.974112009163388, disc_loss = -0.00216881991756

Time for epoch 367 is 1.048065423965454 sec - gen_loss = 19.307622477217695, disc_loss = -0.0021954404226649416
Time for epoch 368 is 1.040065050125122 sec - gen_loss = 19.356023344884512, disc_loss = -0.0020158600521999965
Time for epoch 369 is 1.0960681438446045 sec - gen_loss = 19.49930935631235, disc_loss = -0.0021440857530694423
Time for epoch 370 is 1.1440718173980713 sec - gen_loss = 19.240249941391816, disc_loss = -0.0019376279866390018
Time for epoch 371 is 1.1480717658996582 sec - gen_loss = 19.338645360482086, disc_loss = -0.0019684903801172886
Time for epoch 372 is 1.1720731258392334 sec - gen_loss = 19.474665305759416, disc_loss = -0.002057795700169546
Time for epoch 373 is 1.1960747241973877 sec - gen_loss = 19.220072552238282, disc_loss = -0.002017679048715788
Time for epoch 374 is 1.2440779209136963 sec - gen_loss = 19.399868487654885, disc_loss = -0.0019872821330011907
Time for epoch 375 is 1.2520780563354492 sec - gen_loss = 19.301707003619068, disc_loss = -0.00210798

Time for epoch 441 is 1.1520719528198242 sec - gen_loss = 16.60543215653441, disc_loss = -0.002116293447698117
Time for epoch 442 is 1.2080755233764648 sec - gen_loss = 16.687113048564555, disc_loss = -0.001854188428262208
Time for epoch 443 is 1.2440779209136963 sec - gen_loss = 16.831103310084174, disc_loss = -0.0018557124268301802
Time for epoch 444 is 1.2520780563354492 sec - gen_loss = 16.846761558529494, disc_loss = -0.0019054919405955284
Time for epoch 445 is 1.296081304550171 sec - gen_loss = 16.488117955362405, disc_loss = -0.0015537803970071384
Time for epoch 446 is 1.3120818138122559 sec - gen_loss = 16.885176306228562, disc_loss = -0.0019116845781084607
Time for epoch 447 is 1.3520846366882324 sec - gen_loss = 16.230669048469043, disc_loss = -0.0018198937904472657
Time for epoch 448 is 1.3880867958068848 sec - gen_loss = 16.48795048165087, disc_loss = -0.0017221314466719369
Time for epoch 449 is 1.4800925254821777 sec - gen_loss = 16.360589315272374, disc_loss = -0.00202579

Time for epoch 514 is 1.320082426071167 sec - gen_loss = 13.93265691141479, disc_loss = -0.0016028202595849074
Time for epoch 515 is 1.4320895671844482 sec - gen_loss = 13.75634786473575, disc_loss = -0.001525205649817482
Time for epoch 516 is 1.4000873565673828 sec - gen_loss = 13.837993208552192, disc_loss = -0.0015496809496602002
Time for epoch 517 is 1.4680919647216797 sec - gen_loss = 13.833504156659716, disc_loss = -0.0016613225710651349
Time for epoch 518 is 1.5240952968597412 sec - gen_loss = 13.637084560668473, disc_loss = -0.0014889690159418322
Time for epoch 519 is 1.5760986804962158 sec - gen_loss = 13.59628060712874, disc_loss = -0.0014525215219124297
Time for epoch 520 is 1.684105396270752 sec - gen_loss = 13.885524399337339, disc_loss = -0.0016113380442591182
Time for epoch 521 is 1.7481091022491455 sec - gen_loss = 13.65235245957745, disc_loss = -0.0015133914802067233
Time for epoch 522 is 1.8281145095825195 sec - gen_loss = 13.622804694908211, disc_loss = -0.0014584630

Time for epoch 587 is 1.512094497680664 sec - gen_loss = 11.657614554889179, disc_loss = -0.001263587089765017
Time for epoch 588 is 1.6001002788543701 sec - gen_loss = 11.918541543304267, disc_loss = -0.0012531655495375874
Time for epoch 589 is 1.6241014003753662 sec - gen_loss = 11.892325233405701, disc_loss = -0.0014653636959235735
Time for epoch 590 is 1.6681044101715088 sec - gen_loss = 12.098237411502247, disc_loss = -0.0012968944662643724
Time for epoch 591 is 1.804112434387207 sec - gen_loss = 11.899526379850029, disc_loss = -0.0012917588457854999
Time for epoch 592 is -6.809260368347168 sec - gen_loss = 12.063157025737004, disc_loss = -0.0013189994727841273
Time for epoch 593 is 2.080130100250244 sec - gen_loss = 11.6571169990162, disc_loss = -0.001348642696617586
Time for epoch 594 is 1.9641225337982178 sec - gen_loss = 11.635097444334628, disc_loss = -0.001391153223953171
Time for epoch 595 is 1.4680919647216797 sec - gen_loss = 11.658232930460928, disc_loss = -0.00131682187

Time for epoch 660 is 0.8000502586364746 sec - gen_loss = 10.696548294296171, disc_loss = -0.0012852254474953007
Time for epoch 661 is 0.7480466365814209 sec - gen_loss = 10.342199869666123, disc_loss = -0.001200428812478598
Time for epoch 662 is 0.7240452766418457 sec - gen_loss = 10.556507009175084, disc_loss = -0.0011289810507569124
Time for epoch 663 is 0.7200450897216797 sec - gen_loss = 10.495028527314103, disc_loss = -0.0012965710922057474
Time for epoch 664 is 0.7680478096008301 sec - gen_loss = 10.501722460178197, disc_loss = -0.0012448691273026966
Time for epoch 665 is 0.8080506324768066 sec - gen_loss = 10.414769975944276, disc_loss = -0.0011938242799539825
Time for epoch 666 is 0.7840492725372314 sec - gen_loss = 10.319581681798624, disc_loss = -0.0012197211776021478
Time for epoch 667 is 0.7840487957000732 sec - gen_loss = 10.263613536861612, disc_loss = -0.0011431789684133882
Time for epoch 668 is 0.7880492210388184 sec - gen_loss = 10.287989667303618, disc_loss = -0.0012

Time for epoch 734 is 1.1520721912384033 sec - gen_loss = 9.279634745743559, disc_loss = -0.0014897192586629088
Time for epoch 735 is 1.23207688331604 sec - gen_loss = 9.408134169120725, disc_loss = -0.0012095470040580426
Time for epoch 736 is 1.2360773086547852 sec - gen_loss = 9.32994348655991, disc_loss = -0.0014676415466131128
Time for epoch 737 is 1.316082239151001 sec - gen_loss = 9.562423498944547, disc_loss = -0.0013393551869087421
Time for epoch 738 is 1.2920808792114258 sec - gen_loss = 9.333198569372277, disc_loss = -0.0013910938670533518
Time for epoch 739 is 1.3840863704681396 sec - gen_loss = 9.453805941229609, disc_loss = -0.0013008468651701795
Time for epoch 740 is 1.4080877304077148 sec - gen_loss = 9.252037259333857, disc_loss = -0.0014420365456461856
Time for epoch 741 is -9.281580924987793 sec - gen_loss = 9.23061719888962, disc_loss = -0.0014262175935551632
Time for epoch 742 is 1.6361021995544434 sec - gen_loss = 9.388051680903263, disc_loss = -0.00137686992637235

Time for epoch 808 is 0.7640476226806641 sec - gen_loss = 8.600256178275705, disc_loss = -0.0014431270996596807
Time for epoch 809 is -19.943038940429688 sec - gen_loss = 8.558320383775687, disc_loss = -0.0014138784651189027
Time for epoch 810 is 0.8080503940582275 sec - gen_loss = 8.50731053964372, disc_loss = -0.0012939497451601339
Time for epoch 811 is 0.848052978515625 sec - gen_loss = 8.720851381469789, disc_loss = -0.0014347687249977787
Time for epoch 812 is 0.7960498332977295 sec - gen_loss = 8.553758170147578, disc_loss = -0.001393136728633716
Time for epoch 813 is 0.7480466365814209 sec - gen_loss = 8.73938500914286, disc_loss = -0.0014274237932855054
Time for epoch 814 is 0.8480532169342041 sec - gen_loss = 8.365620662290791, disc_loss = -0.0013032118895343068
Time for epoch 815 is 0.8640539646148682 sec - gen_loss = 8.551328941699712, disc_loss = -0.0015588761047284235
Time for epoch 816 is 0.852053165435791 sec - gen_loss = 8.590651740123844, disc_loss = -0.0014532234655404

In [None]:
## Record the 40000 gen_loss = , disc_loss = 

## show the training results

In [None]:
plt.title("ADS-GAN-constraint training Loss")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.grid()
plt.plot(losses_gen, label='Generator')
plt.plot(losses_dis, label='Discriminator')
plt.legend(loc='best')
plt.savefig("./ADS-GAN-constraint_LOSS.png",dpi=300)
plt.show()

In [None]:
## create matrix 0 row*latent_dim columns
arr=np.empty((0,latent_dim))
noise = tf.random.normal([128, latent_dim])

## batch testing data
testing_dataset = tf.data.Dataset.from_tensor_slices(X_test.to_numpy())\
        .shuffle(buffer_size).batch(batch_size, drop_remainder=True)

## generate dataset
for data in testing_dataset:
    gen_=generator(noise, data).numpy()
    arr=np.append(arr,gen_,axis=0)
arr.shape

In [None]:
output_dataset = pd.DataFrame(np.round(sc.inverse_transform(arr)), columns=[
    'BT_NM', 'HR_NM', 'RR_NM', 'HB_NM', 'HCT_NM', 'PLATELET_NM', 'WBC_NM',
    'PTT1_NM', 'PTT2_NM', 'PTINR_NM', 'ER_NM', 'BUN_NM', 'CRE_NM', 'BMI',
    'age', 'PPD', 'THDA_FL', 'THDH_FL', 'THDI_FL', 'THDAM_FL', 'THDV_FL',
    'THDE_FL', 'THDM_FL', 'THDR_FL', 'THDP_FL', 'THDOO_FL', 'Gender',
    'cortical_ACA_ctr', 'cortical_MCA_ctr', 'subcortical_ACA_ctr',
    'subcortical_MCA_ctr', 'PCA_cortex_ctr', 'thalamus_ctr',
    'brainstem_ctr', 'cerebellum_ctr', 'Watershed_ctr',
    'Hemorrhagic_infarct_ctr', 'cortical_ACA_ctl', 'cortical_MCA_ctl',
    'subcortical_ACA_ctl', 'subcortical_MCA_ctl', 'PCA_cortex_ctl',
    'thalamus_ctl', 'brainstem_ctl', 'cerebellum_ctl', 'Watershed_ctl',
    'Hemorrhagic_infarct_ctl', 'NIHS_1a_in', 'NIHS_1b_in', 'NIHS_1c_in',
    'NIHS_2_in', 'NIHS_3_in', 'NIHS_4_in', 'NIHS_5aL_in', 'NIHS_5bR_in',
    'NIHS_6aL_in', 'NIHS_6bR_in', 'NIHS_7_in', 'NIHS_8_in', 'NIHS_9_in',
    'NIHS_10_in', 'NIHS_11_in','elapsed_class'
])
output_dataset

In [None]:
output_dataset.to_csv("../dataset/output_dataset/ADS-GAN-constraint_models.csv",encoding='utf_8_sig')

In [None]:
X_TEST_dataset = pd.DataFrame(np.round(sc.inverse_transform(X_test)), columns=[
    'BT_NM', 'HR_NM', 'RR_NM', 'HB_NM', 'HCT_NM', 'PLATELET_NM', 'WBC_NM',
    'PTT1_NM', 'PTT2_NM', 'PTINR_NM', 'ER_NM', 'BUN_NM', 'CRE_NM', 'BMI',
    'age', 'PPD', 'THDA_FL', 'THDH_FL', 'THDI_FL', 'THDAM_FL', 'THDV_FL',
    'THDE_FL', 'THDM_FL', 'THDR_FL', 'THDP_FL', 'THDOO_FL', 'Gender',
    'cortical_ACA_ctr', 'cortical_MCA_ctr', 'subcortical_ACA_ctr',
    'subcortical_MCA_ctr', 'PCA_cortex_ctr', 'thalamus_ctr',
    'brainstem_ctr', 'cerebellum_ctr', 'Watershed_ctr',
    'Hemorrhagic_infarct_ctr', 'cortical_ACA_ctl', 'cortical_MCA_ctl',
    'subcortical_ACA_ctl', 'subcortical_MCA_ctl', 'PCA_cortex_ctl',
    'thalamus_ctl', 'brainstem_ctl', 'cerebellum_ctl', 'Watershed_ctl',
    'Hemorrhagic_infarct_ctl', 'NIHS_1a_in', 'NIHS_1b_in', 'NIHS_1c_in',
    'NIHS_2_in', 'NIHS_3_in', 'NIHS_4_in', 'NIHS_5aL_in', 'NIHS_5bR_in',
    'NIHS_6aL_in', 'NIHS_6bR_in', 'NIHS_7_in', 'NIHS_8_in', 'NIHS_9_in',
    'NIHS_10_in', 'NIHS_11_in','elapsed_class'
])
X_TEST_dataset.to_csv("../dataset/output_dataset/ADS-GAN-constraint_XTEST.csv",encoding='utf_8_sig')