In [53]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.stats import wasserstein_distance
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

In [54]:
# Data Preprocessing

# Load the .npy file
data = np.load('./eICU_age.npy')

# Flatten
data = np.array([x[0] for x in data])
len_data = len(data)

# Normalize data
max_age = max(data)
min_age = min(data)
scaled_data = 2 * ((data - min_age) / (max_age - min_age)) - 1

In [67]:
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation=None)
        self.leaky_relu1 = tf.keras.layers.LeakyReLU(0.2)
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
        self.dense2 = tf.keras.layers.Dense(128, activation=None)
        self.leaky_relu2 = tf.keras.layers.LeakyReLU(0.2)
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
        self.dense_out = tf.keras.layers.Dense(1, activation='tanh')

    def call(self, z):
        x = self.dense1(z)
        x = self.leaky_relu1(x)
        x = self.batch_norm1(x)
        x = self.dense2(x)
        x = self.leaky_relu2(x)
        x = self.batch_norm2(x)
        x = self.dense_out(x)
        return x

class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation=None)
        self.leaky_relu1 = tf.keras.layers.LeakyReLU(0.2)
        self.dense2 = tf.keras.layers.Dense(128, activation=None)
        self.leaky_relu2 = tf.keras.layers.LeakyReLU(0.2)
        self.dense_out = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, img):
        x = self.dense1(img)
        x = self.leaky_relu1(x)
        x = self.dense2(x)
        x = self.leaky_relu2(x)
        x = self.dense_out(x)
        return x

In [56]:
# Initialize and compile the discriminator
discriminator = Discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Initialize the generator
generator = Generator()

# Build the models
generator.build((None, 100))
discriminator.build((None, 1))

# Define input for the combined model
z = Input(shape=(100,))
img = generator(z)

# Set discriminator as non-trainable for the combined model
discriminator.trainable = False
validity = discriminator(img)

# Define the combined model
combined = Model(z, validity)
combined.compile(optimizer='adam', loss='binary_crossentropy')

In [65]:
# Plotting functions

def plot_generated_data(epoch):
    noise = np.random.normal(0, 1, (len_data, 100))
    generated_ages = generator.predict(noise)
    generated_ages = (generated_ages + 1) * (max_age - min_age) / 2 + min_age
    int_generated_ages = [int(round(x[0])) for x in generated_ages]
    earth_mover = wasserstein_distance(data, int_generated_ages)
    plt.hist(generated_ages, bins=30, alpha=0.6, label='Generated Data')
    plt.hist(data, bins=30, alpha=0.6, label='Real Data')
    plt.legend()
    plt.title(f"Epoch: {epoch}")
    print(f"Wasserstein Distance: {earth_mover}")
    name = "plot_" + str(epoch) + '.png'
    plt.savefig(name, dpi=300)
    plt.close() 

def plot_training_metrics(epoch, d_losses, g_losses, d_accuracies):
    epochs = range(len(d_losses))

    plt.figure(figsize=(15, 5))

    # Plotting Discriminator and Generator Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, d_losses, label='Discriminator Loss')
    plt.plot(epochs, g_losses, label='Generator Loss')
    plt.title('Discriminator and Generator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plotting Discriminator Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, d_accuracies, label='Discriminator Accuracy')
    plt.title('Discriminator Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    name = "loss_plot_" + str(epoch) + '.png'
    plt.savefig(name, dpi=300)
    plt.close() 

In [68]:
def train(epochs, batch_size, sample_interval):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # Lists to save metrics
    d_losses = []
    g_losses = []
    d_accuracies = []
    
    # log_file="training_log.txt"
    # with open(log_file, "w") as file:
    for epoch in range(epochs):
        # Train discriminator
        idx = np.random.randint(0, scaled_data.shape[0], batch_size)
        real_ages = scaled_data[idx]
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_ages = generator.predict(noise)
        
        real_ages = real_ages.reshape(batch_size, -1)
        gen_ages = gen_ages.reshape(batch_size, -1)

        discriminator.trainable = True
        d_loss_real = discriminator.train_on_batch(real_ages, valid)
        d_loss_fake = discriminator.train_on_batch(gen_ages, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train generator
        discriminator.trainable = False
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = combined.train_on_batch(noise, valid)
        
        
        d_losses.append(d_loss[0])
        g_losses.append(g_loss)
        d_accuracies.append(d_loss[1])
        
        # file.write(f"{epoch}/{epochs} [D loss: {d_loss[0]} | D accuracy: {d_loss[1]}] [G loss: {g_loss}]" + "\n")
        
        print(f"{epoch}/{epochs} [D loss: {d_loss[0]} | D accuracy: {d_loss[1]}] [G loss: {g_loss}]")
        
        if epoch % sample_interval == 0:
            plot_generated_data(epoch)
            plot_training_metrics(epoch, d_losses, g_losses, d_accuracies)
            filename = "g_weights_" + str(epoch)
            generator.save_weights(filename)

In [69]:
# Call the train function
train(epochs=5000, batch_size=64, sample_interval=500)



0/5000 [D loss: 0.6931722462177277 | D accuracy: 0.4921875] [G loss: 0.715831458568573]
Wasserstein Distance: 2.682142857142857
1/5000 [D loss: 0.6958407163619995 | D accuracy: 0.4296875] [G loss: 0.7154825329780579]
2/5000 [D loss: 0.6963367462158203 | D accuracy: 0.46875] [G loss: 0.7151400446891785]
3/5000 [D loss: 0.6905187368392944 | D accuracy: 0.53125] [G loss: 0.7140082120895386]
4/5000 [D loss: 0.6965427398681641 | D accuracy: 0.453125] [G loss: 0.712769627571106]
5/5000 [D loss: 0.6952250897884369 | D accuracy: 0.4921875] [G loss: 0.7131716012954712]
6/5000 [D loss: 0.6973019540309906 | D accuracy: 0.4375] [G loss: 0.7122743129730225]
7/5000 [D loss: 0.6932879984378815 | D accuracy: 0.5078125] [G loss: 0.7109990119934082]
8/5000 [D loss: 0.6958081722259521 | D accuracy: 0.46875] [G loss: 0.7086180448532104]
9/5000 [D loss: 0.6968276798725128 | D accuracy: 0.4296875] [G loss: 0.7093995213508606]
10/5000 [D loss: 0.6944271326065063 | D accuracy: 0.515625] [G loss: 0.70783603191