-
Notifications
You must be signed in to change notification settings - Fork 2
/
ppg_sqa.py
362 lines (289 loc) · 12 KB
/
ppg_sqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# -*- coding: utf-8 -*-
import pickle
import os
from typing import Tuple, List
from scipy import stats, signal
import more_itertools as mit
import joblib
from utils import normalize_data, get_data, bandpass_filter, find_peaks, resample_signal
import warnings
import numpy as np
warnings.filterwarnings("ignore")
MODEL_PATH = "models"
SCALER_FILE_NAME = "Train_data_scaler.save"
SQA_MODEL_FILE_NAME = 'OneClassSVM_model.sav'
SQA_MODEL_SAMPLING_FREQUENCY = 20
SEGMENT_SIZE = 30
SHIFTING_SIZE = 2
def segmentation(
sig: np.ndarray,
sig_indices: np.ndarray,
sampling_rate: int,
method: str = 'shifting',
segment_size: int = 30,
shift_size: int = 2,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Segments the signals (PPG) and their indices into fixed-size segments.
Args:
sig: Input signal (e.g., PPG).
sig_indices: Corresponding indices for the input signal.
sampling_rate: Sampling rate of the PPG signal.
method: Segmentation method. Options: 'standard' or 'shifting'.
Segments do not overlap for 'standard' and overlap with the
size of (segment_size - shift_size) for 'shifting'.
segment_size: Size of the segment (in second).
shift_size: Size of the shift (in seconds) in segmentation
in case method is 'shifting'.
Return:
segments_sig: List of segments (PPG).
segments_indices: List of segments (indices).
"""
signal_length = len(sig)
segment_length = int(segment_size*sampling_rate)
shift_length = int(shift_size*sampling_rate)
if method == 'standard':
# Non-overlapping segments
segments_sig = [sig[i:i+segment_length] for i in range(
0, signal_length, segment_length
) if i + segment_length <= signal_length]
segments_indices = [sig_indices[i:i+segment_length] for i in range(
0, signal_length, segment_length
) if i + segment_length <= signal_length]
elif method == 'shifting':
# Overlapping segments
segments_sig = [sig[i:i+segment_length] for i in range(
0, signal_length - segment_length + 1, shift_length
) if i + segment_length <= signal_length]
segments_indices = [sig_indices[i:i+segment_length] for i in range(
0, signal_length - segment_length + 1, shift_length
) if i + segment_length <= signal_length]
else:
raise ValueError("Invalid method. Use 'standard' or 'shifting'.")
return segments_sig, segments_indices
def heart_cycle_detection(
ppg: np.ndarray,
sampling_rate: int,
) -> list:
"""
Extract heart cycles from the PPG signal
Args:
ppg: Input PPG signal.
sampling_rate: Sampling rate of the PPG signal.
Return:
hc: List of heart cycles
"""
# Normalization
ppg_normalized = normalize_data(ppg)
# Upsampling signal by 2
sampling_rate = sampling_rate*2
ppg_upsampled = signal.resample(ppg_normalized, len(ppg_normalized)*2)
# Systolic peak detection
peaks, ppg_cleaned = find_peaks(
ppg=ppg_upsampled, sampling_rate=sampling_rate, return_sig=True)
# Heart cycle detection based on the peaks and fixed intervals
hc = []
if len(peaks) < 2:
return hc
# Define a fixed interval in PPG signal to detect heart cycles
beat_bound = round((len(ppg_upsampled)/len(peaks))/2)
# Ignore the first and last beat to prevent boundary error
for i in range(1, len(peaks) - 1):
# Select beat from the signal and add it to the list
beat_start = peaks[i] - beat_bound
beat_end = peaks[i] + beat_bound
if beat_start >= 0 and beat_end < len(ppg_cleaned):
beat = ppg_cleaned[beat_start:beat_end]
if len(beat) >= beat_bound*2:
hc.append(beat)
return hc
def energy_hc(hc: list) -> float:
"""
Extract energy of heart cycle
Args:
hc: List of heart cycles
Return:
var_energy: Variation of heart cycles energy
"""
energy = []
for beat in hc:
energy.append(np.sum(beat*beat))
if not energy:
var_energy = 0
else:
# Calculate variation
var_energy = max(energy) - min(energy)
return var_energy
def template_matching_features(hc: list) -> Tuple[float, float]:
"""
Extract template matching features from heart cycles
Args:
hc: List of heart cycles
Return:
tm_ave_eu: Average of Euclidean distance with the template
tm_ave_corr: Average of correlation with the template
"""
hc = np.array([np.array(xi) for xi in hc if len(xi) != 0])
# Calculate the template by averaging all heart cycles
template = np.mean(hc, axis=0)
# Euclidean distance and correlation
distances = []
corrs = []
for beat in hc:
distances.append(np.linalg.norm(template-beat))
corr_matrix = np.corrcoef(template, beat)
corrs.append(corr_matrix[0, 1])
tm_ave_eu = np.mean(distances)
tm_ave_corr = np.mean(corrs)
return tm_ave_eu, tm_ave_corr
def feature_extraction(
ppg: np.ndarray,
sampling_rate: int,
) -> List[float]:
"""
Extract features from PPG signal
Args:
ppg: Input PPG signal.
sampling_rate: Sampling rate of the PPG signal.
Return:
features: List of features
"""
# feature 1: Interquartile range
iqr_rate = stats.iqr(ppg, interpolation='midpoint')
# feature 2: STD of power spectral density
_, pxx_den = signal.periodogram(ppg, sampling_rate)
std_p_spec = np.std(pxx_den)
# Heart cycle detection
hc = heart_cycle_detection(ppg=ppg, sampling_rate=sampling_rate)
if hc:
# feature 3: variation in energy of heart cycles
var_energy = energy_hc(hc)
# features 4, 5: average Euclidean and Correlation in template matching
tm_ave_eu, tm_ave_corr = template_matching_features(hc)
else:
var_energy = np.nan
tm_ave_eu = np.nan
tm_ave_corr = np.nan
features = [iqr_rate, std_p_spec, var_energy, tm_ave_eu, tm_ave_corr]
return features
def sqa(
sig: np.ndarray,
sampling_rate: int,
filter_signal: bool = True,
) -> Tuple[list, list]:
"""
Perform PPG Signal Quality Assessment (SQA).
This function assesses the quality of a PPG signal by classifying its segments
as reliable (clean) or unrelaible (noisy) using a pre-trained model.
The clean indices represent parts of the PPG signal that are deemed reliable,
while the noisy indices indicate parts that may be affected by noise or artifacts.
Args:
sig (np.ndarray): PPG signal.
sampling_rate (int): Sampling rate of the PPG signal.
filter_signal (bool): True if the signal has not filtered using
a bandpass filter.
Return:
clean_indices: A list of clean indices.
noisy_indices: A list of noisy indices.
Reference:
Feli, M., Azimi, I., Anzanpour, A., Rahmani, A. M., & Liljeberg, P. (2023).
An energy-efficient semi-supervised approach for on-device photoplethysmogram signal quality assessment.
Smart Health, 28, 100390.
"""
signal_length = len(sig)
signal_indices = list(range(signal_length))
# Load pre-trained model and normalization scaler
scaler = joblib.load(os.path.join(MODEL_PATH, SCALER_FILE_NAME))
model = pickle.load(
open(os.path.join(MODEL_PATH, SQA_MODEL_FILE_NAME), 'rb'))
resampling_flag = False
# Check if resampling is needed and perform resampling if necessary
if sampling_rate != SQA_MODEL_SAMPLING_FREQUENCY:
sig = resample_signal(
sig=sig, fs_origin=sampling_rate, fs_target=SQA_MODEL_SAMPLING_FREQUENCY)
resampling_flag = True
resampling_rate = sampling_rate/SQA_MODEL_SAMPLING_FREQUENCY
sampling_rate = SQA_MODEL_SAMPLING_FREQUENCY
# Apply bandpass filter if needed
if filter_signal:
sig = bandpass_filter(
sig=sig, fs=sampling_rate, lowcut=0.5, highcut=3)
# Generate indices for the PPG signal
sig_indices = np.arange(len(sig))
# Segment the PPG signal into
segments, segments_indices = segmentation(
sig=sig,
sig_indices=sig_indices,
sampling_rate=sampling_rate,
method='shifting',
segment_size=SEGMENT_SIZE,
shift_size=SHIFTING_SIZE,
)
# Initialize lists to store all reliable and unreliable segments
reliable_segments_all = []
unreliable_segments_all = []
reliable_indices_all = []
unreliable_indices_all = []
# Loop through the segments for feature extraction and classification
for idx, segment in enumerate(segments):
# Feature extraction
features = feature_extraction(segment, sampling_rate)
# Classification
if np.isnan(np.array(features)).any():
pred = 1
else:
features_norm = scaler.transform([features])
pred = model.predict(features_norm)
# Categorize segments based on classification result
if pred == 0:
reliable_segments_all.append(segment)
reliable_indices_all.append(segments_indices[idx])
else:
unreliable_segments_all.append(segment)
unreliable_indices_all.append(segments_indices[idx])
# Generate a flatten list of clean indices by aggregating all reliable segments' indices
clean_indices_flat = sorted(set([item for segment in reliable_indices_all for item in segment]))
# If resampling performed, update indices according to the original sampling rate
if resampling_flag:
# Unflat the clean indices list to create a list of list of clean indices
clean_indices = []
for group in mit.consecutive_groups(clean_indices_flat):
clean_indices.append(list(group))
# Update clean indices according to the original sampling rate
clean_indices = [list(range(int(sublist[0]*resampling_rate), int(sublist[-1]*resampling_rate)+1)) for sublist in clean_indices]
# Flatten the clean indices
clean_indices_flat = [item for sublist in clean_indices for item in sublist]
# The indices that dont exist in the flat list of clean indices indicate noisy indices
noisy_indices_flat = [item for item in signal_indices if item not in clean_indices_flat]
# Unflat the clean indices list to separte clean parts
clean_indices = []
for group in mit.consecutive_groups(clean_indices_flat):
clean_indices.append(list(group))
# Unflat the noisy indices list to separte noisy parts
noisy_indices = []
for group in mit.consecutive_groups(noisy_indices_flat):
noisy_indices.append(list(group))
# Discard the indices that have not been quality assessed due to be less than shifting size (indices at the end of the signal)
noisy_indices = [noisy_indices[i] for i in range(
len(noisy_indices)) if len(noisy_indices[i]) > SHIFTING_SIZE]
return clean_indices, noisy_indices
if __name__ == "__main__":
# Import a sample data
file_name = "201902020222_Data.csv"
input_sig = get_data(file_name=file_name)
input_sampling_rate = 20
# Run PPG signal quality assessment.
clean_ind, noisy_ind = sqa(sig=input_sig, sampling_rate=input_sampling_rate)
# Display results
print("Analysis Results:")
print("------------------")
print(f"Number of clean parts in the signal: {len(clean_ind)}")
if clean_ind:
print("Length of each clean part in the signal (in seconds):")
for clean_seg in clean_ind:
print(f" - {len(clean_seg)/input_sampling_rate:.2f}")
print(f"Number of noisy parts in the signal: {len(noisy_ind)}")
if noisy_ind:
print("Length of each noise in the signal (in seconds):")
for noise in noisy_ind:
print(f" - {len(noise)/input_sampling_rate:.2f}")