Training

In [5]:
import torch.optim as optim
import torch
import torch.nn as nn

# Initialize models
generator = Pulse2PulseGenerator()
discriminator = Pulse2PulseDiscriminator()

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
loss_fn = nn.BCEWithLogitsLoss()

# Ensure the data is correctly shaped: [batch_size, channels, sequence_length]
input_data = torch.tensor(input_data, dtype=torch.float32).permute(0, 2, 1)  # Shape: (batch_size, 2, 5000)
output_data = torch.tensor(output_data, dtype=torch.float32).permute(0, 2, 1)  # Shape: (batch_size, 2, 5000)

# Gradient penalty function (for stabilizing discriminator)
def gradient_penalty(discriminator, real_data, fake_data):
    batch_size = real_data.size(0)
    epsilon = torch.rand(batch_size, 1, 1).to(real_data.device)
    epsilon = epsilon.expand_as(real_data)

    # Interpolate between real and fake data
    interpolated = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated.requires_grad_(True)

    # Discriminator's prediction on interpolated data
    interpolated_preds = discriminator(interpolated)

    # Calculate gradients of the discriminator's prediction wrt interpolated data
    gradients = torch.autograd.grad(
        outputs=interpolated_preds,
        inputs=interpolated,
        grad_outputs=torch.ones_like(interpolated_preds),
        create_graph=True,  # Necessary to retain the graph for second backward pass
        retain_graph=True,  # Retain graph for future use
        only_inputs=True,
    )[0]

    # Flatten the gradients to compute their norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)

    # Gradient penalty
    return ((gradient_norm - 1) ** 2).mean()

# Training loop
epochs = 5
batch_size = 32  # Adjust batch size based on memory constraints
lambda_gp = 5  # Weight for the gradient penalty

for epoch in range(epochs):
    for i in range(0, len(input_data), batch_size):  # Loop through the dataset in batches

        # Get real and input data for the batch
        real_lead_II = output_data[i:i + batch_size]  # Shape: [batch_size, 2, 5000]
        input_lead_I = input_data[i:i + batch_size]   # Shape: [batch_size, 2, 5000]

        # Discriminator Training
        d_optimizer.zero_grad()

        # Real data loss
        real_preds = discriminator(real_lead_II)  # Forward pass with real {Lead II, Lead II}
        real_loss = loss_fn(real_preds, torch.full_like(real_preds, 0.8))  # Real labels are 1

        # Fake data loss
        generated_lead_II = generator(input_lead_I)  # Generator creates fake {Lead II, Lead II}
        fake_preds = discriminator(generated_lead_II.detach())  # Forward pass with fake data
        fake_loss = loss_fn(fake_preds, torch.full_like(fake_preds, 0.2))  # Fake labels are 0

        # Gradient penalty
        gp = gradient_penalty(discriminator, real_lead_II, generated_lead_II)

        # Total discriminator loss with gradient penalty
        d_loss = real_loss + fake_loss + lambda_gp * gp
        d_loss.backward(retain_graph=True)  # Use retain_graph=True here to avoid freeing the graph prematurely
        if epoch % 2 == 0:
            d_optimizer.step()  # Update discriminator

        # Generator Training
        g_optimizer.zero_grad()

        # Generator tries to fool discriminator
        fake_preds = discriminator(generated_lead_II)  # Discriminator evaluates fake data
        g_loss = loss_fn(fake_preds, torch.ones_like(fake_preds))  # Generator wants fake data to be labeled as real (1)
        g_loss.backward()  # Backpropagate
        g_optimizer.step()  # Update generator

    # Logging
    print(f'Epoch [{epoch + 1}/{epochs}], Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}')


Epoch [1/5], Generator Loss: 0.3132617175579071, Discriminator Loss: 6.626523017883301
Epoch [2/5], Generator Loss: 0.3132617175579071, Discriminator Loss: 6.626523017883301


KeyboardInterrupt: 

convert

In [None]:
import os
import wfdb
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# Define the path to your original dataset directory (.dat and .hea files)
DATA_DIR = "C:/Users/M2-Winterfell/Downloads/electrocardiography-dataset-1.0.3/records500/02000"

# Define the destination directory for the converted .asc files
DEST_DIR = "C:/Users/M2-Winterfell/Downloads/converted_asc"
if not os.path.exists(DEST_DIR):
    os.makedirs(DEST_DIR)

# Function to convert a single record into .asc format without modifying its content
def convert_ecg_record_to_asc(file_path, out_filepath):
    try:
        # Load the record using wfdb (the file_path should be provided without extension)
        record = wfdb.rdrecord(file_path)
        # Save the complete signal data (all channels) into an .asc file
        np.savetxt(out_filepath, record.p_signal, fmt="%.6f")
    except Exception as e:
        print(f"Error converting {file_path}: {e}")

# Function to convert all ECG records in the folder to .asc format
def convert_all_ecg_to_asc(data_dir, dest_dir):
    for root, dirs, files in os.walk(data_dir):
        for file in tqdm(files, desc="Converting ECG files"):
            if file.endswith("_hr.dat"):
                # Construct the base filename (removing the '_hr.dat' extension)
                base_filename = file.replace("_hr.dat", "_hr")
                record_path = os.path.join(root, base_filename)
                # Define the output .asc file path
                out_filename = base_filename + ".asc"
                out_filepath = os.path.join(dest_dir, out_filename)
                convert_ecg_record_to_asc(record_path, out_filepath)

# Convert all records in the DATA_DIR to .asc format
convert_all_ecg_to_asc(DATA_DIR, DEST_DIR)