In [85]:
import pandas as pd
import numpy as np

RANDOM_STATE = 404

In [86]:
df = pd.read_csv('../../data/HAD.csv')
df.head()

Unnamed: 0,AGE,AGE_MISSING,SEX_F,NIHSS_BL,NIHSS_BL_MISSING,SYS_BLOOD_PRESSURE,SYS_BLOOD_PRESSURE_MISSING,PREV_MRS,PREV_MRS_MISSING,ORAL_ANTICOAGULANT,...,ONSET_TO_ADMISSION,ONSET_TO_ADMISSION_MISSING,ONSET_TO_IMAGING,ONSET_TO_IMAGING_MISSING,ONSET_TO_TPA,ONSET_TO_TPA_MISSING,ONSET_TO_GROIN,ONSET_TO_GROIN_MISSING,MRS_90,MRS_90_DICHO
0,70,0,0,3,0,-1,1,0,0,0,...,64,0,96,0,180,0,-1,1,0,0
1,55,0,0,6,0,142,0,3,0,0,...,38,0,104,0,165,0,-1,1,4,1
2,73,0,0,3,0,170,0,0,0,1,...,-1,1,-1,1,-1,1,-1,1,2,0
3,81,0,0,10,0,-1,1,0,0,0,...,69,0,90,0,115,0,-1,1,3,1
4,81,0,1,11,0,-1,1,0,0,0,...,98,0,110,0,120,0,-1,1,0,0


#### Data Preprocessing

In [87]:
from sklearn.preprocessing import LabelEncoder

target_feature = 'MRS_90'
y_raw = df.filter([target_feature])
y = LabelEncoder().fit_transform(y_raw.values.ravel())
df = df.drop(columns=['MRS_90', 'MRS_90_DICHO'], axis=1)

In [88]:
import sys
from dill import load

# Load the scaler object
scalerFile = "..\predictive_models\HAD_scaler.pkl"
with open(scalerFile, "rb") as f:
    scaler = load(f)

# Now your code should be able to load the scaler object without encountering ModuleNotFoundError for ADT
df_scaled = scaler.preprocess_clinical_data(np.asarray(df, dtype=float))
X = pd.DataFrame(df_scaled, columns=df.columns)

X.head(5)

Unnamed: 0,AGE,AGE_MISSING,SEX_F,NIHSS_BL,NIHSS_BL_MISSING,SYS_BLOOD_PRESSURE,SYS_BLOOD_PRESSURE_MISSING,PREV_MRS,PREV_MRS_MISSING,ORAL_ANTICOAGULANT,...,CTA_CS,CTA_CS_MISSING,ONSET_TO_ADMISSION,ONSET_TO_ADMISSION_MISSING,ONSET_TO_IMAGING,ONSET_TO_IMAGING_MISSING,ONSET_TO_TPA,ONSET_TO_TPA_MISSING,ONSET_TO_GROIN,ONSET_TO_GROIN_MISSING
0,0.509804,0.0,0.0,0.071429,0.0,-1.0,1.0,0.0,0.0,0.0,...,-1.0,1.0,0.044444,0.0,0.066667,0.0,0.125,0.0,-1.0,1.0
1,0.362745,0.0,0.0,0.142857,0.0,0.368,0.0,0.6,0.0,0.0,...,-1.0,1.0,0.026389,0.0,0.072222,0.0,0.114583,0.0,-1.0,1.0
2,0.539216,0.0,0.0,0.071429,0.0,0.48,0.0,0.0,0.0,1.0,...,-1.0,1.0,-1.0,1.0,-1.0,1.0,-1.0,1.0,-1.0,1.0
3,0.617647,0.0,0.0,0.238095,0.0,-1.0,1.0,0.0,0.0,0.0,...,-1.0,1.0,0.047917,0.0,0.0625,0.0,0.079861,0.0,-1.0,1.0
4,0.617647,0.0,1.0,0.261905,0.0,-1.0,1.0,0.0,0.0,0.0,...,-1.0,1.0,0.068056,0.0,0.076389,0.0,0.083333,0.0,-1.0,1.0


In [89]:
# Warning is given while imputing missing values in 'SERUM_GLUCOSE' and 'VALV_HEART' columns due to missing all values thus they are removed, only to be readded for classification
columns_names_to_add_back_for_classification = ['SERUM_GLUCOSE', 'SERUM_GLUCOSE_MISSING', 'VALV_HEART']
for col in list(X.filter(regex='MISSING')):
    columns_names_to_add_back_for_classification.append(col)

X = X.drop(columns=columns_names_to_add_back_for_classification, axis=1)
X.shape[1]

27

In [90]:
X.describe()

Unnamed: 0,AGE,SEX_F,NIHSS_BL,SYS_BLOOD_PRESSURE,PREV_MRS,ORAL_ANTICOAGULANT,HYPERTENSION,HYPERCHOL,ISCH_HEART,SMOKING,...,OCCLUSION_M2,OCCLUSION_ICA,OCCLUSION_ACA,OCCLUSION_PCA,OCCLUSION_VB,CTA_CS,ONSET_TO_ADMISSION,ONSET_TO_IMAGING,ONSET_TO_TPA,ONSET_TO_GROIN
count,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0,...,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0,944.0
mean,0.503282,0.452331,0.204802,-0.42111,0.135381,0.113347,0.685381,0.501059,0.304025,0.220339,...,-0.276483,-0.205508,-0.389831,-0.370763,-0.380297,-0.439972,-0.294675,-0.270553,-0.648606,-0.577454
std,0.170349,0.497986,0.296881,0.693467,0.32814,0.336648,0.466887,0.504486,0.460237,0.419778,...,0.677496,0.748368,0.515448,0.547077,0.531583,0.734238,0.537091,0.543417,0.5223,0.591049
min,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
25%,0.421569,0.0,0.071429,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0
50%,0.519608,0.0,0.166667,-1.0,0.0,0.0,1.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,-1.0,0.036111,0.050694,-1.0,-1.0
75%,0.617647,1.0,0.404762,0.368,0.4,0.0,1.0,1.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.333333,0.076563,0.097222,0.084201,0.104167
max,0.803922,1.0,0.857143,0.716,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,0.814583,0.89375,0.618056,0.941667


#### Data splitting

In [91]:
from sklearn.model_selection import train_test_split

# Split into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE)
X_train.shape, X_test.shape

((755, 27), (189, 27))

#### Model definition

In [92]:
import tensorflow as tf

# This script defines the generator and discriminator models for a Generative Adversarial Imputation Network (GAIN)
# using the Keras API in TensorFlow 2.x.

def build_generator(data_dim, hidden_dim):
    """
    Builds the generator model for a Generative Adversarial Imputation Network (GAIN).

    Args:
        data_dim (int): The dimensionality of the input data.
        hidden_dim (int): The number of hidden units in the encoder and decoder.

    Returns:
        tf.keras.Model: The generator model.
    """

    # Define the model architecture
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(data_dim,)),  # Input layer
        tf.keras.layers.Dense(hidden_dim, activation='relu'),  # Hidden layers
        tf.keras.layers.Dense(hidden_dim, activation='relu'),
        tf.keras.layers.Dense(data_dim, activation='sigmoid')  # Output layer
    ])

    return model


def build_discriminator(data_dim, hidden_dim):
    """
    Builds the discriminator model for a Generative Adversarial Imputation Network (GAIN).

    Args:
        data_dim (int): The dimensionality of the input data.
        hidden_dim (int): The number of hidden units in the encoder and decoder.

    Returns:
        tf.keras.Model: The discriminator model.
    """

    # Define the model architecture
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(data_dim,)),  # Input layer
        tf.keras.layers.Dense(hidden_dim, activation='relu'),  # Hidden layers
        tf.keras.layers.Dense(hidden_dim, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')  # Output layer
    ])

    return model

In [93]:
# Setup models
data_dim = X_train.shape[1]
hidden_dim = 128

generator = build_generator(data_dim, hidden_dim)
discriminator = build_discriminator(data_dim, hidden_dim)

# Print the model summary
generator.summary(), discriminator.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_24 (Dense)            (None, 128)               3584      
                                                                 
 dense_25 (Dense)            (None, 128)               16512     
                                                                 
 dense_26 (Dense)            (None, 27)                3483      
                                                                 
Total params: 23579 (92.11 KB)
Trainable params: 23579 (92.11 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_27 (Dense)            (None, 128)               3584      
                                                              

(None, None)

#### Loss functions and optimizers

In [94]:
# Loss function for the discriminator
def discriminator_loss(D_prob, M, X, G_sample):
    D_prob = tf.cast(D_prob, dtype=tf.float32)  # Cast to float32
    return -tf.reduce_mean(M * tf.math.log(D_prob + 1e-8) + (1 - M) * tf.math.log(1. - D_prob + 1e-8))

# Loss function for the generator
def generator_loss(D_prob, G_sample, M, X):
    # Cast all inputs to float32 to ensure consistent data types for operations
    D_prob = tf.cast(D_prob, dtype=tf.float32)
    G_sample = tf.cast(G_sample, dtype=tf.float32)
    M = tf.cast(M, dtype=tf.float32)
    X = tf.cast(X, dtype=tf.float32)
    
    # Compute the binary cross-entropy loss part
    BCE_loss = -tf.reduce_mean((1 - M) * tf.math.log(D_prob + tf.constant(1e-8, dtype=tf.float32)))

    # Compute the mean squared error loss part
    MSE_loss = tf.reduce_mean(M * tf.square(X - G_sample))

    # Weighting factor for the losses
    alpha = 0.5

    # Combine the losses
    total_loss = alpha * BCE_loss + (1 - alpha) * MSE_loss

    return total_loss

# Adam optimizer is a stochastic gradient descent method
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

#### Training function definition

In [95]:
@tf.function
def train_step(generator, discriminator, data, batch_size):
    noise = tf.random.normal([batch_size, data.shape[1]])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_data = generator(noise, training=True)
        real_output = discriminator(data, training=True)
        fake_output = discriminator(generated_data, training=True)
        gen_loss = generator_loss(fake_output, generated_data, data, noise)
        disc_loss = discriminator_loss(real_output, fake_output, data, noise)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    return gen_loss, disc_loss

def train_gan(generator, discriminator, df, iterations, batch_size):
    for iteration in range(iterations):
        idx = np.random.choice(len(df), batch_size, replace=False)
        data_batch = df.iloc[idx]
        gen_loss, disc_loss = train_step(generator, discriminator, data_batch, batch_size)
        if iteration % 1000 == 0:
            print(f"Iteration {iteration}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}")

#### Training loop

In [96]:
# Start training
train_gan(generator, discriminator, X_train, iterations=10000, batch_size=128)

Iteration 0, Generator Loss: 0.36589857935905457, Discriminator Loss: 0.7027698755264282
Iteration 1000, Generator Loss: 3.19404935836792, Discriminator Loss: 0.012191230431199074
Iteration 2000, Generator Loss: 4.395117282867432, Discriminator Loss: 0.0012393519282341003
Iteration 3000, Generator Loss: 5.17070198059082, Discriminator Loss: 0.0003341997798997909
Iteration 4000, Generator Loss: 5.684363842010498, Discriminator Loss: 0.00011812297452706844
Iteration 5000, Generator Loss: 6.119815826416016, Discriminator Loss: 5.725143273593858e-05
Iteration 6000, Generator Loss: 6.472872734069824, Discriminator Loss: 2.247413431177847e-05
Iteration 7000, Generator Loss: 6.956284999847412, Discriminator Loss: 1.0969339200528339e-05
Iteration 8000, Generator Loss: 7.046768665313721, Discriminator Loss: 5.6455382946296595e-06
Iteration 9000, Generator Loss: 7.546322822570801, Discriminator Loss: 2.9745519896096084e-06


In [97]:
from joblib import dump

dump(generator, 'had_gain_generator.h5')

['had_gain_generator.h5']