In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from scipy.stats import mode
from scipy import stats
from multiprocessing import Pool
import neurokit2 as nk
from biosppy.signals import ecg
import biosppy
import os

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from scipy import stats

import matplotlib.pyplot as plt

In [None]:
def load_ecg_from_id(ecg_id, base_path='/home/ngsci/project/NEJM_benchmark/waveforms_12lead_10sec'):
    """
    Load a 12-lead ECG from the specified directory based on ECG ID.
    
    Parameters:
    -----------
    ecg_id : str
        Unique identifier for the ECG file
    base_path : str, optional
        Base directory containing ECG files
    
    Returns:
    --------
    numpy.ndarray
        Array containing 12-lead ECG data
    """
    # Construct the full file path
    file_path = os.path.join(base_path, f'{ecg_id}.npy')
    
    # Load the numpy array
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"ECG file for ID {ecg_id} not found at {file_path}")
    
    return np.load(file_path)

def plot_12lead_ecg(arr):
    """
    Plot 12-lead ECG with each lead in a separate row.
    
    Parameters:
    -----------
    arr : numpy.ndarray
        12-lead ECG data array
    """
    # Lead labels in standard order
    lead_labels = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                   'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    # Create a figure with 12 rows (one for each lead)
    fig, axs = plt.subplots(12, 1, figsize=(10, 15))
    
    # Process and plot each lead
    for i, (lead, label) in enumerate(zip(arr, lead_labels)):
        # Remove NaNs
        lead_without_nan = lead[~np.isnan(lead)]
        if len(lead_without_nan) == 0:
            print(f"No valid data for lead {label}")
            continue
        
        # Baseline correction
        mode_value = stats.mode(lead_without_nan, nan_policy='omit')[0][0]
        lead_corrected = lead - mode_value
        lead_corrected = np.nan_to_num(lead_corrected, nan=0)
        
        # Remove leading/trailing zeros that represent missing data
        nonzero_indices = np.where(lead_corrected != 0)[0]
        if len(nonzero_indices) > 0:
            lead_corrected = lead_corrected[nonzero_indices[0]:nonzero_indices[-1]+1]
        else:
            print(f"All zeros for lead {label}")
            continue
        
        # Plot the lead
        axs[i].plot(lead_corrected, color='black',linewidth=0.75)
        
        # Add lead label
        axs[i].text(len(lead_corrected)//2, 1.2, label, color='black', 
                    ha='center', va='bottom', fontsize=10)
        
        # Customize ECG-like grid
        axs[i].set_ylim(-1.6, 1.6)
        axs[i].xaxis.set_major_locator(MultipleLocator(100))
        axs[i].yaxis.set_major_locator(MultipleLocator(0.5))
        axs[i].xaxis.set_minor_locator(MultipleLocator(20))
        axs[i].yaxis.set_minor_locator(MultipleLocator(0.1))
        axs[i].grid(which='major', color='red', linestyle='-', linewidth=0.5)
        axs[i].grid(which='minor', color='red', linestyle='-', linewidth=0.2)
        axs[i].set_facecolor('white')
        
        # Hide x-axis labels
        axs[i].tick_params(axis='x', which='major', labelbottom=False)
        axs[i].tick_params(axis='x', which='minor', labelbottom=False)
    
    plt.tight_layout()
    plt.show()

def process_ecg(ecg_id):
    """
    Comprehensive function to load and plot ECG data.
    
    Parameters:
    -----------
    ecg_id : str
        Unique identifier for the ECG file
    """
    # Load ECG data
    ecg_data = load_ecg_from_id(ecg_id)
    
    # Plot the ECG
    plot_12lead_ecg(ecg_data)

In [None]:
test_df = pd.read_csv("/home/ngsci/project/NEJM_benchmark/all_ids_labels_untested_with_covars_all_final.csv")

In [None]:
test_ids = test_df['ecg_id_new'][:7].tolist()

# Example usage
for idx in test_ids:
    process_ecg(idx)
    print()
    print('################################################################')
    print()