In [1]:
import h5py
import pandas as pd
import numpy as np
from Utilities import *
from tqdm import tqdm
from torch.utils.data import DataLoader

In [2]:
import numpy as np
from torch.utils.data import Dataset
import h5py
from scipy.signal import resample_poly
import torch
import pandas as pd


class CustomHD5Dataset(Dataset):
	"""
	This class has methods to read the hd5 file and dump the data into csv according to segment provided in csv_file_path.
	Since reading hd5 file is too slow, dumping data into csv removes this bottleneck.
	"""
	def __init__(self, hd5_path, csv_file_path, sampling_freq=125, sample_len_sec=3, original_freq=125, data_group_name='Waveforms/ART_na', timestamp_group_name='Waveforms/ART_na_Timestamps' ):
		"""
		
		Args:
			hd5_path:        Filepath to the hdf5 file
			csv_file_path:   This csv contains each interval to consider
		
		"""
		
		self.csv_file_path = csv_file_path
		self.hd5_path = hd5_path
		self.sample_len = sample_len_sec * sampling_freq
		self.sampling_freq = sampling_freq
		self.sample_len_sec = sample_len_sec
		
		with h5py.File(hd5_path, 'r') as file:
			dataset = file[data_group_name]
			timestamp = file[timestamp_group_name]
	
			data = dataset[:]
			timestamp = timestamp[:]
		
		self.intervals = np.genfromtxt(csv_file_path, delimiter=',', skip_header=1)
		
		if sampling_freq!=original_freq:
			down_val = int(original_freq/sampling_freq)
			# Resampling the data for efficiency
			self.data = resample_poly(data, up=1, down=down_val)
			self.timestamp = timestamp[::down_val]
		else:
			self.data = data
			self.timestamp = timestamp

		
		
		
	def __len__(self):
		return len(self.intervals)
	
	def __getitem__(self, idx):
		
		start_time, end_time, label = self.intervals[idx]
		
		# Efficiently find the start and end indices (assuming sorted timestamps)
		# This is a placeholder for a more efficient search method
		
		################ BOTTLE NECK #####################
		start_idx = np.searchsorted(self.timestamp, start_time, side='left')
		end_idx = start_idx + (self.sampling_freq *  self.sample_len_sec)
		# For 125Hz, the sample_len is 1250 for 10sec sample
		##################################################

		interval_data = self.data[start_idx:end_idx]

		# Define a fixed sequence length
		fixed_length = self.sample_len  # Example fixed length

		# Pad or truncate the sequence to the fixed length
		if len(interval_data) > fixed_length:
			interval_data = interval_data[:fixed_length]  # Truncate

		elif len(interval_data) < fixed_length:
			padding = fixed_length - len(interval_data)  # Calculate padding size
			interval_data = np.pad(interval_data, (0, padding), 'constant', constant_values=0)  # Pad

		# Convert to tensor
		interval_data = torch.tensor(interval_data, dtype=torch.float)

		return idx, interval_data, label
	


class CSVSignalDataset(Dataset):
	def __init__(self, csv_data_file, train_data_file):
		"""
		Args:
			csv_data_file: Path to the CSV file containing signals.
			train_data_file: To get the mean and std to standardize
		"""
		  
		self.csv_data_file = csv_data_file
		self.train_data_file = train_data_file

		self.data_frame = pd.read_csv(csv_data_file, skiprows=1)
		self.labels = torch.tensor(self.data_frame.iloc[:, 1].values).long()
		self.data = torch.tensor(self.data_frame.iloc[:, 2:].values).float()
		self.mean, self.std = self.get_mean_std()
		
	
	def get_mean_std(self):
		#################
		csv_path = self.train_data_file
		data = np.loadtxt(csv_path, delimiter=',',skiprows=1)[:,2:]
		
		# Calculate the mean and standard deviation
		mean = np.mean(data, axis=0)
		std = np.std(data, axis=0)

		# Ensure std is not zero to avoid division by zero
		std = np.where(std == 0, 1, std)
		return mean, std

	def __len__(self):
		# Return the number of rows in the DataFrame
		return len(self.data_frame)
	
	def __getitem__(self, idx):
		"""
		Args:
			idx: Index of the data sample.
			
		Returns:
			A tuple (ecg_sample, label) where ecg_sample is the ECG data as a tensor
			and label is the corresponding label as a tensor.
		"""
		sample = (self.data[idx] - self.mean) / self.std
		
		label = self.labels[idx]
		return idx, sample, label

In [13]:
hdf5_file_path = '/home/ms5267@drexel.edu/moberg-precicecap/data/Patient_2021-12-21_04_16.h5'
annotation_file = '/home/ms5267@drexel.edu/moberg-precicecap/data/20240207-annotations-export-workspace=precicecap-patient=7-annotation_group=90.csv'
annotation_metadata = {
	'modality':'ART'
	,'location':'na'
	,'scale_wrt_hd5':1e3
}
segment_length_sec = 10

data_group_name='Waveforms/ART_na'
timestamp_group_name = 'Waveforms/ART_na_Timestamps'
segment_length_sec = segment_length_sec
sampling_frequency = 125
signal_type = 'ABP'
num_segments = 10000



In [4]:
df_annotation = pd.read_csv(annotation_file)
df_annotation_filtered = df_annotation[(df_annotation['modality']==annotation_metadata['modality']) & (df_annotation['location']==annotation_metadata['location'])]
artifacts = df_annotation_filtered[["start_time","end_time"]].to_numpy() * int(annotation_metadata['scale_wrt_hd5'])		 

In [5]:
with h5py.File(hdf5_file_path, 'r') as file:
	dataset = file[data_group_name][:]
	timestamp = file[timestamp_group_name][:]

In [21]:
segment_length = segment_length_sec * sampling_frequency

artifact_raw = []
for art in artifacts:
    start_idx = np.searchsorted(timestamp, art[0], side='left')
    end_idx = np.searchsorted(timestamp, art[1], side='left')

    interval_data = dataset[start_idx:end_idx]
    
    if len(interval_data) == segment_length:
        artifact_raw.append(interval_data)
    elif len(interval_data) < segment_length:
        # Calculate the amount of indices needed to reach the segment_length
        diff = segment_length - len(interval_data)
        new_start_idx = max(0, start_idx - diff // 2)  # Ensure new_start_idx is not negative
        new_end_idx = new_start_idx + segment_length  # Set the new end index

        # If the new_end_idx goes beyond the length of the dataset, adjust new_start_idx back
        if new_end_idx > len(dataset):
            new_end_idx = len(dataset)
            new_start_idx = max(0, new_end_idx - segment_length)  # Ensure new_start_idx is not negative
        
        # Update interval_data with the new indices
        interval_data = dataset[new_start_idx:new_end_idx]
        artifact_raw.append(interval_data)
    else:
        # Break down interval_data into multiple segments of segment_length
        num_segments = len(interval_data) // segment_length
        for i in range(num_segments):
            start_segment_idx = start_idx + i * segment_length
            end_segment_idx = start_segment_idx + segment_length
            segment_data = dataset[start_segment_idx:end_segment_idx]
            artifact_raw.append(segment_data)

        # Handle any remaining data that doesn't fit into a full segment
        remaining_data_start = start_idx + num_segments * segment_length
        if remaining_data_start < end_idx:
            remaining_data_end = min(remaining_data_start + segment_length, len(dataset))
            remaining_data = dataset[remaining_data_start:remaining_data_end]
            artifact_raw.append(remaining_data)

In [38]:
num_positive_samples = len(artifact_raw)

# Randomly get a segment that is of length given as segment_length_sec*sampling_frequency
# If has artifact, then append to artifact list else append to non-artifact list

reduced_range = int(len(timestamp)/segment_length)

# Generate num_positive_samples*2 unique random values from 0 to 58360000 without replacement
random_values = np.random.choice(range(reduced_range), num_positive_samples*2, replace=False) * segment_length
len(random_values)

count_negative, i = 0, 0

non_artifact_raw=[]
while count_negative<num_positive_samples:
    start_idx = random_values[i] + segment_length
    temp_ts = timestamp[start_idx:start_idx + segment_length]
    if not has_artifact(temp_ts, artifacts):
        non_artifact_raw.append(dataset[start_idx:start_idx + segment_length])
        count_negative+=1
    i+=1
    
print(f"There are total of {len(artifact_raw)} positive samples and {len(non_artifact_raw)} negative samples.")

There are total of 1477 positive samples and 1477 negative samples.


In [41]:
# Append a label to each sublist
artifact_labeled = [sample + [1] for sample in artifact_raw]
non_artifact_labeled = [sample + [0] for sample in non_artifact_raw]

# Now combine them into a single list
combined_data = artifact_labeled + non_artifact_labeled


In [42]:
combined_data

[array([155.5625, 158.0625, 159.9375, ..., 100.5625,  96.1875,  98.    ]),
 array([104.125 , 112.3125, 120.9375, ...,  70.25  ,  78.5   ,  87.6875]),
 array([58.9375, 59.    , 59.625 , ..., 82.375 , 81.5   , 80.625 ]),
 array([126.1875, 124.5625, 122.625 , ...,  61.0625,  60.875 ,  60.8125]),
 array([100.375 ,  99.375 ,  98.375 , ...,  88.0625,  87.3125,  86.5625]),
 array([128.8125, 127.625 , 126.5625, ..., 128.625 , 127.375 , 126.125 ]),
 array([ 93.375 ,  91.875 ,  90.0625, ..., -49.    , -49.    , -49.    ]),
 array([-49.    , -49.    , -49.    , ..., 120.625 , 110.0625, 102.25  ]),
 array([96.6875, 91.375 , 86.    , ..., 65.25  , 64.75  , 64.25  ]),
 array([63.75  , 63.25  , 62.75  , ..., 20.5   , 21.    , 21.4375]),
 array([ 21.8125,  22.1875,  22.4375, ..., 462.9375, 462.9375, 462.9375]),
 array([462.9375, 462.9375, 462.9375, ..., 219.75  , 219.5625, 219.375 ]),
 array([219.4375, 219.6875, 220.0625, ..., 424.3125, 408.9375, 394.8125]),
 array([381.8125, 369.9375, 358.9375, ..., 