In [9]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [18]:
import pandas as pd
import wfdb
import os

crfs_df = pd.read_csv('CRFs.csv')

def load_physionet_data(directory):
    ecg_data = []
    for file in os.listdir(directory):
        if file.endswith(".hea"):
            # Extract the record name (without the .hea extension)
            record_name = file.split(".")[0]
            # Read the signal and metadata
            record = wfdb.rdrecord(os.path.join(directory, record_name), sampfrom=20000, sampto=20000+(128*60))
            # Get the Sample ID (assume record_name is the ID)
            sample_id = record_name.lstrip('0')
            # Get the signals (columns represent leads)
            signals = record.p_signal
            # Separate into individual leads (assuming 3 leads)
            lead_II = signals[:, 0]
            lead_V3 = signals[:, 1]
            lead_V5 = signals[:, 2] if signals.shape[1] > 2 else None
            ecg_data.append({
                "Record": int(sample_id),
                "Lead II": lead_II,
                "Lead V3": lead_V3,
                "Lead V5": lead_V5
            })
    return pd.DataFrame(ecg_data)

ecg_df = load_physionet_data("dataset/")
combined_df = pd.merge(crfs_df, ecg_df, on='Record', how='inner')

ecg_signals = np.array([np.concatenate([row['Lead II'], row['Lead V3'], row['Lead V5']]) for _, row in combined_df.iterrows()])
crfs = combined_df[['Gender', 'Age', 'Weight', 'Height', 'BSA', 'BMI', 'Smoker', 'SBP', 'DBP']].values

ecg_signal_length = ecg_signals.shape[1]
real_ecg_signals = ecg_signals.reshape(-1, ecg_signal_length)

real_data = np.hstack((real_ecg_signals, crfs))
real_data

array([[0.11428571428571428, 0.17142857142857143, 0.11428571428571428,
        ..., 'yes', 140.0, 80.0],
       [0.11428571428571428, 0.11428571428571428, 0.08571428571428572,
        ..., 'no', 130.0, 75.0],
       [0.0, 0.0, 0.05128205128205128, ..., 'no', 177.0, 75.0],
       ...,
       [-0.027777777777777776, -0.08333333333333333, 0.1388888888888889,
        ..., 'no', 125.0, 65.0],
       [0.027777777777777776, 0.05555555555555555, 0.027777777777777776,
        ..., 'no', 120.0, 80.0],
       [-0.16666666666666666, -0.16666666666666666, -0.05555555555555555,
        ..., 'yes', 120.0, 80.0]], dtype=object)

In [21]:
from keras import metrics, optimizers, Sequential, layers

# Define the generator model
def build_generator(latent_dim, crf_dim, ecg_signal_length):
    model = Sequential()
    model.add(layers.Dense(128, activation='relu', input_dim=latent_dim))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.Dense(ecg_signal_length + crf_dim, activation='tanh'))  # Output both ECG signal and CRFs
    return model

# Define the discriminator model
def build_discriminator(ecg_signal_length, crf_dim):
    model = Sequential()
    model.add(layers.Dense(512, activation='relu', input_shape=(ecg_signal_length + crf_dim,)))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

# Define the GAN model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

# Set the dimensions
latent_dim = 100
crf_dim = crfs.shape[1]

# Build and compile the discriminator
discriminator = build_discriminator(ecg_signal_length, crf_dim)
discriminator.compile(optimizer=optimizers.Adam(), loss='binary_crossentropy', metrics=['accuracy'])

# Build the generator
generator = build_generator(latent_dim, crf_dim, ecg_signal_length)

# Build and compile the GAN
gan = build_gan(generator, discriminator)
gan.compile(optimizer=optimizers.Adam(), loss='binary_crossentropy', metrics=['accuracy'])

# Function to generate synthetic data
def generate_synthetic_data(generator, latent_dim, num_samples):
    noise = np.random.normal(0, 1, (num_samples, latent_dim))
    synthetic_data = generator.predict(noise)
    return synthetic_data

# Training the GAN
def train_gan(generator, discriminator, gan, latent_dim, epochs, batch_size, real_data):
    for epoch in range(epochs):
        # Generate random noise
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        
        # Generate synthetic data (ECG signals and CRFs)
        generated_data = generator.predict(noise)
        
        # Sample real data from the dataset
        idx = np.random.randint(0, real_data.shape[0], batch_size)
        real_batch = real_data[idx]
        print(real_batch)
        
        # Create labels
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))
        
        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(real_batch, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_data, fake_labels)
        
        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, real_labels)
        
        # Print the progress
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        print(f"Epoch {epoch + 1}/{epochs} [D loss: {d_loss}] [G loss: {g_loss}]")

# Set parameters
epochs = 10000
batch_size = 32

# Train the GAN with the provided dataset
train_gan(generator, discriminator, gan, latent_dim, epochs, batch_size, real_data)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 477ms/step
[[0.14285714285714285 0.14285714285714285 0.11428571428571428 ... 'no'
  135.0 75.0]
 [0.08333333333333333 0.08333333333333333 0.1111111111111111 ... 'no'
  125.0 85.0]
 [0.0 0.02857142857142857 0.05714285714285714 ... 'no' 140.0 70.0]
 ...
 [0.4166666666666667 0.4444444444444444 0.3888888888888889 ... 'yes'
  150.0 86.0]
 [0.17142857142857143 0.17142857142857143 0.08571428571428572 ... 'no'
  105.0 70.0]
 [0.05555555555555555 0.05555555555555555 0.08333333333333333 ... 'no'
  120.0 75.0]]


AttributeError: 'NoneType' object has no attribute 'update_state'