# Preprocessing steps for the project

In [3]:
# Import libraries required
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt 
import os
import wfdb
import pickle
import sys
import glob
from scipy.signal import butter, lfilter


In [4]:
file_path = 'G:/Datasets/mit-bih-arrhythmia-database-1.0.0/'

## Check for GPU 
background checks, used for making sure CUDA is setup and linked to pytorch. See below for setup:

<a href="https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exe_local
" >download cuda</a>

```pip uninstall torch torchvision torchaudio```

```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118```

then to verify it downloaded, run the below in terminal:

```python -c "import torch; print('CUDA available:', torch.cuda.is_available(), 'Number of GPUs:', torch.cuda.device_count(), 'Current device:', torch.cuda.current_device(), 'Device name:', torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'No GPU found')"```

In [None]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

if device.type == 'cuda':
    print('CUDA Device Name:', torch.cuda.get_device_name(0))
    print('CUDA Version:', torch.version.cuda)
    print('PyTorch Version:', torch.__version__)
    
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0) / 1024**3, 1), 'GB')
    
    # Total memory info
    total_memory = torch.cuda.get_device_properties(0).total_memory
    print('Total Memory:', round(total_memory / 1024**3, 1), 'GB')
    
    free_memory = total_memory - torch.cuda.memory_reserved(0)
    print('Free Memory:', round(free_memory / 1024**3, 1), 'GB')

    # tensor check, making sure cuda and torch are setup correctly
    x = torch.rand(2, 3)
    print('Test Tensor:', x)
else:
    print('Using the CPU, no GPU found')


## Load Dataset, simple data exploration
To load the dataset, we will use The `WFDB` (Waveform Database) library. It is a part of the PhysioNet project and is a software package designed for reading, writing, and processing physiological signals, primarily ECG (electrocardiogram) signals.

In [6]:
def get_records(path):
    """ Get paths for ECG records in the specified directory. """ #each pt has 4 files. .atr, .dat, .hea, .xws
    # Pattern to find *.atr files
    path_source = f'{path}*.atr'
    paths = glob.glob(path_source) 

    # Remove the extensions and sort
    records = sorted(path[:-4] for path in paths)
    records = [record for record in records if not record.endswith('\\102-0')] # 7/06/2018: File 102.atr has been edited. Annotation number 1991 (0 indexed) has been shifted from sample 590296 to 590262 <-- from physionet, orig. @ 102-0.atr
    return records


In [7]:
def show_patient_info(record_path):
    """Show patient metadata for a single record."""
    try:
        record = wfdb.rdsamp(record_path)
        # NOTE FROM DATASET WEBSITE / LIBRARY - PT ECG DATA SAMPLES IN RECORD[0], PATIENT DATA IN RECORD[1]

        metadata = record[1]
        print(metadata)
        # print("Patient Information:")
        # print(f"Sampling rate (fs): {metadata['fs']}")
        # print(f"Total number of samples (sig_len): {metadata['sig_len']}")
        # print(f"Total number of channels (n_sig): {metadata['n_sig']}")
        # print(f"Base date: {metadata['base_date']}")
        # print(f"Base time: {metadata['base_time']}")
        # print(f"Units: {metadata['units']}")
        # print(f"Channel names: {metadata['sig_name']}")
        # print(f"Comments: {metadata['comments']}")
        print("")
        
    except FileNotFoundError:
        print(f"File not found: {record_path}, check file name + path xdd")
    except Exception as e:
        print(f"couldn't load file, error: {e}")




In [None]:
# identify all paths / files, print on separate line.
record_paths = get_records(file_path)

for record in record_paths:
    print(record)

# print 1 pt's data
record_name = file_path+'100'  #record + path to folder
show_patient_info(record_name)

In [12]:
def load_record(file_path):
    """
    Load ECG signal and annotations for a given record.

    Returns:
        signal: 2D array (samples, channels)
        annotations: Object containing arrhythmia annotations
    """
    record = wfdb.rdrecord(file_path)
    signal = record.p_signal  # 2D array with 2 channels
    annotations = wfdb.rdann(file_path, 'atr')
    return signal, annotations

def butter_bandpass(lowcut, highcut, fs, order=5):
    """
    Design a bandpass filter to remove noise from the ECG signal.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

# Apply the bandpass filter to the ECG signal
def bandpass_filter(signal, lowcut=0.5, highcut=50.0, fs=360, order=3):
    """
    Apply a bandpass filter to the ECG signal.
    """
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return lfilter(b, a, signal)

def preprocess_signal(signal):
    """
   normalize the signal, apply the filter
    """
    filtered_signal = bandpass_filter(signal)
    normalized_signal = (filtered_signal - np.mean(filtered_signal)) / np.std(filtered_signal)
    return normalized_signal

def segment_signal(signal, segment_size=900):
    """
    Segment the signal into non-overlapping segments.
    """
    num_segments = len(signal) // segment_size
    return np.array([signal[i * segment_size:(i + 1) * segment_size] for i in range(num_segments)])

def label_segments(segments, annotations, threshold=5):
    """
    Label each segment based on proximity to arrhythmia annotations.
    """
    labels = np.zeros(len(segments))
    for ann in annotations.sample:
        segment_idx = ann // len(segments[0])
        if segment_idx < len(labels):
            labels[segment_idx] = 1  # 1 for irregular
    return labels

# Processing All Records and Saving
This section processes each record using the defined functions, then saves each patient's processed segments and labels into separate files for efficient loading later.

In [None]:
preprocessed_dir = './preprocessed_data'
os.makedirs(preprocessed_dir, exist_ok=True)

# Specify data directory
data_dir = file_path[:-1]
record_files = [f[:-4] for f in os.listdir(data_dir) if f.endswith('.dat')]

# Process and save each record's data
for record_name in record_files:
    # Load the ECG signal and annotations
    signal, annotations = load_record(os.path.join(data_dir, record_name))
    
    # Preprocess each channel
    processed_ch1 = preprocess_signal(signal[:, 0])
    processed_ch2 = preprocess_signal(signal[:, 1])
    
    # Segment each channel
    segments_ch1 = segment_signal(processed_ch1)
    segments_ch2 = segment_signal(processed_ch2)
    
    # Stack channels along the last dimension
    segments = np.stack((segments_ch1, segments_ch2), axis=-1)
    
    # Label the segments
    labels = label_segments(segments_ch1, annotations)
    
    # Save preprocessed segments and labels
    np.save(os.path.join(preprocessed_dir, f"{record_name}_segments.npy"), segments)
    np.save(os.path.join(preprocessed_dir, f"{record_name}_labels.npy"), labels)