In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load BPM sequences from CSV
file_path = "/home/Gurshan.R/Documents/GitHub/SYSC4907_Capstone/GAN_Bradycardia/Scaled_CSV_DATA/BPM_8hr_5min_Neonatal.csv"
df = pd.read_csv(file_path, header=None)  # Assuming no column headers

df.head()  # Display the first few rows to check the data structure

In [None]:
# Convert data into NumPy array
data = df.to_numpy(dtype=np.float32)  # Shape: (num_samples, 96)

In [None]:
# Normalize the BPM data for GAN training (0 to 1 range)
min_val, max_val = data.min(), data.max()
data_normalized = (data - min_val) / (max_val - min_val)

In [None]:
# Convert to PyTorch tensor
data_normalized = torch.tensor(data_normalized, dtype=torch.float32).to(device)

In [None]:
# Define the GAN model with 96-dimensional input/output
class Generator(nn.Module):
    def __init__(self, noise_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Sigmoid()
        )

    def forward(self, noise):
        return self.model(noise)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Set dimensions
noise_dim = 10  # Random noise input size
output_dim = 96  # Each output is a 96-point sequence

# Initialize GAN models
generator = Generator(noise_dim, output_dim).to(device)
discriminator = Discriminator(output_dim).to(device)

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

In [None]:
# Training loop
epochs = 5000
batch_size = 16
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

In [None]:
for epoch in range(epochs):
    # Train Discriminator
    discriminator.zero_grad()
    idx = torch.randint(0, data_normalized.size(0), (batch_size,))
    real_data = data_normalized[idx].to(device)
    real_loss = criterion(discriminator(real_data), real_labels)

    noise = torch.randn(batch_size, noise_dim).to(device)
    fake_data = generator(noise).detach()
    fake_loss = criterion(discriminator(fake_data), fake_labels)
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()

    # Train Generator
    generator.zero_grad()
    noise = torch.randn(batch_size, noise_dim).to(device)
    fake_data = generator(noise)
    g_loss = criterion(discriminator(fake_data), real_labels)
    g_loss.backward()
    optimizer_G.step()

    if epoch % 500 == 0:
        print(f"Epoch [{epoch}/{epochs}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

# Generate synthetic BPM sequences
with torch.no_grad():
    noise = torch.randn(10, noise_dim).to(device)  # Generate 10 synthetic 96-point sequences
    synthetic_data = generator(noise).cpu().numpy()

In [None]:
# Denormalize the synthetic data back to original BPM range
synthetic_data_denormalized = synthetic_data * (max_val - min_val) + min_val

In [None]:
# Save synthetic data
synthetic_output_file = "/home/Gurshan.R/Documents/GitHub/SYSC4907_Capstone/GAN_Bradycardia/Synthetic_CSV_DATA/BPM_8hr_5min_Synthetic.csv"
np.savetxt(synthetic_output_file, synthetic_data_denormalized, delimiter=",")

print(f"Synthetic BPM data saved: {synthetic_output_file}")

In [None]:
# Plot real vs synthetic data for comparison
plt.figure(figsize=(10, 6))
plt.plot(synthetic_data_denormalized[0], label="Synthetic BPM", linestyle="--", color="orange", alpha=0.8)
plt.plot(data[:10].mean(axis=0), label="Real BPM (Average)", linestyle="-", color="blue", alpha=0.6)
plt.xlabel("Time Steps (5-min intervals)")
plt.ylabel("BPM")
plt.title("Comparison of Real and Synthetic BPM Data")
plt.legend()
plt.grid()
plt.show()

In [None]:
# Save the trained generator model
torch.save(generator.state_dict(), "/home/Gurshan.R/Documents/GitHub/SYSC4907_Capstone/GAN_Bradycardia/trained_generator_Bradycardia.pth")

In [None]:
# ----- Bradycardia Episode Evaluation Code -----

from scipy.stats import ks_2samp

# Define a bradycardia threshold (adjust this value as needed)
brady_threshold = 100

def extract_episodes(sequence, threshold):
    """
    Extract contiguous episodes where the BPM is below the threshold.
    Each episode is a list of BPM values.
    """
    episodes = []
    current_episode = []
    for bpm in sequence:
        if bpm < threshold:
            current_episode.append(bpm)
        else:
            if current_episode:
                episodes.append(current_episode)
                current_episode = []
    if current_episode:
        episodes.append(current_episode)
    return episodes

# --- Evaluate Episodes in Real Data ---
# 'data' is your original real data loaded from the CSV (shape: [num_samples, 96])
real_episode_metrics = []  # Each element: (duration in timesteps, minimum BPM)
for sequence in data:  # assuming 'data' is still in memory as loaded from the CSV
    episodes = extract_episodes(sequence, brady_threshold)
    for ep in episodes:
        duration = len(ep)              # Duration in number of 5-min timesteps
        min_bpm = min(ep)               # Minimum BPM during the episode
        real_episode_metrics.append((duration, min_bpm))

if real_episode_metrics:
    real_episode_metrics = np.array(real_episode_metrics)
    real_durations = real_episode_metrics[:, 0]
    real_min_bpm = real_episode_metrics[:, 1]
    print(f"Real Bradycardia Episodes: Count = {len(real_durations)}, "
          f"Mean Duration = {np.mean(real_durations):.2f} timesteps, "
          f"Mean Min BPM = {np.mean(real_min_bpm):.2f}")
else:
    print("No bradycardia episodes found in real data.")

# --- Evaluate Episodes in Synthetic Data ---
synthetic_episode_metrics = []  # Each element: (duration, minimum BPM)
for sequence in synthetic_data_denormalized:
    episodes = extract_episodes(sequence, brady_threshold)
    for ep in episodes:
        duration = len(ep)
        min_bpm = min(ep)
        synthetic_episode_metrics.append((duration, min_bpm))

if synthetic_episode_metrics:
    synthetic_episode_metrics = np.array(synthetic_episode_metrics)
    synthetic_durations = synthetic_episode_metrics[:, 0]
    synthetic_min_bpm = synthetic_episode_metrics[:, 1]
    print(f"Synthetic Bradycardia Episodes: Count = {len(synthetic_durations)}, "
          f"Mean Duration = {np.mean(synthetic_durations):.2f} timesteps, "
          f"Mean Min BPM = {np.mean(synthetic_min_bpm):.2f}")
else:
    print("No bradycardia episodes found in synthetic data.")

# --- Quantitative Comparison Using KS Tests ---
if real_episode_metrics.size > 0 and synthetic_episode_metrics.size > 0:
    # Compare episode durations
    duration_ks_stat, duration_ks_p = ks_2samp(real_durations, synthetic_durations)
    # Compare minimum BPM during episodes
    min_bpm_ks_stat, min_bpm_ks_p = ks_2samp(real_min_bpm, synthetic_min_bpm)
    
    print(f"KS test on episode durations: Statistic = {duration_ks_stat:.4f}, p-value = {duration_ks_p:.4f}")
    print(f"KS test on episode minimum BPM: Statistic = {min_bpm_ks_stat:.4f}, p-value = {min_bpm_ks_p:.4f}")
