In [5]:
import os
import pandas as pd
import numpy as np
import glob
from tensorflow.keras.models import load_model
import sys
sys.path.append("../../common_tools/peaks_confusion_matrix")
from confusion_matrix import compute_confusion_matrix
from plot_false_peaks import plot_fp_fn
sys.path.append("../../common_tools/signal_processing")
from normalize_signal import normalize_ppg

In [6]:
# Function to filter out close peaks
def filter_close_peaks(peaks, signal, tolerance):
    """
    Retains only the highest peak in each group of close peaks.

    Args:
        peaks (list): List of detected peak indices.
        signal (array): PPG signal data.
        tolerance (int): Minimum distance between peaks to consider them separate.

    Returns:
        list: List of filtered peaks with only the highest peak in each group.
    """
    if not peaks:
        return []

    filtered_peaks = []
    current_group = [peaks[0]]  # Initialize the first group
    
    for peak in peaks[1:]:
        if peak - current_group[-1] <= tolerance:
            current_group.append(peak)  # Add to the current group if within tolerance
        else:
            # Select the peak with the highest value in the group
            valid_peaks = [p for p in current_group if p < len(signal)]  # Ensure valid indices
            if valid_peaks:  # Avoid empty list
                best_peak = max(valid_peaks, key=lambda x: signal[x])
                filtered_peaks.append(best_peak)
            current_group = [peak]  # Start a new group
    
    # Process the last group
    valid_peaks = [p for p in current_group if p < len(signal)]
    if valid_peaks:
        best_peak = max(valid_peaks, key=lambda x: signal[x])
        filtered_peaks.append(best_peak)

    return filtered_peaks

# Function to detect peaks using a trained model
def detect_peaks_with_model(ppg_data, model, segment_length=256, threshold=0.4, fs=100):
    """
    Detect PPG peaks using a trained model and filter out close peaks.
    
    Args:
        ppg_data (array): Input PPG signal data.
        model (keras.Model): Pre-trained model.
        segment_length (int): Length of each processing segment.
        threshold (float): Probability threshold for peak detection.
        fs (int): Sampling frequency, default is 100 Hz.
    
    Returns:
        list: List of filtered peak indices.
    """
    detected_peaks = []
    tolerance = int(0.3 * fs)
    
    for i in range(0, len(ppg_data), segment_length):
        segment = ppg_data[i:i+segment_length]
        if len(segment) < segment_length:
            segment = np.pad(segment, (0, segment_length - len(segment)), 'constant', constant_values=0)
        segment = np.expand_dims(np.expand_dims(segment, axis=0), axis=-1)  
        
        predictions = model.predict(segment, verbose=0)
        predictions = predictions.flatten()  
        
        for idx, value in enumerate(predictions):
            if value > threshold:  
                detected_peaks.append(i + idx)
    
    # Filter out close peaks
    detected_peaks = filter_close_peaks(detected_peaks, ppg_data, tolerance)
    
    return detected_peaks

In [7]:
# Paths
test_data_dir = 'E:/dilated_cnn_peak_detection_model_data/test/test_data/data'
label_dir = 'E:/dilated_cnn_peak_detection_model_data/test/test_data/label'
output_dir = 'E:/dilated_cnn_peak_detection_model_data/test/model_detected_peaks/'
plot_output_dir = 'E:/dilated_cnn_peak_detection_model_data/test/model_plot_results/'

# Create directory to save results
os.makedirs(output_dir, exist_ok=True)
os.makedirs(plot_output_dir, exist_ok=True)

for file in glob.glob(os.path.join(output_dir, "*.csv")):
    os.remove(file)

for file in glob.glob(os.path.join(plot_output_dir, "*.png")):
    os.remove(file)

In [8]:
# Load the model
model = load_model("../dilated_cnn_peak_detection_model.h5")

# Test and save results
for file_name in os.listdir(test_data_dir):
    input_file = os.path.join(test_data_dir, file_name)
    label_file = os.path.join(label_dir, file_name.replace(".csv", "_labeled_peaks.csv"))
    if os.path.exists(label_file):
        print(f"Processing: {file_name}")
        
        # Read and normalize data
        ppg_data = pd.read_csv(input_file, header=None).squeeze("columns").values
        ppg_data = normalize_ppg(ppg_data)
        labeled_peaks = pd.read_csv(label_file, header=None).squeeze("columns").values

        # Detect peaks
        detected_peaks = detect_peaks_with_model(ppg_data, model)

        # Save detected peaks
        detected_file = os.path.join(output_dir, file_name.replace(".csv", "_detected_peaks.csv"))
        pd.DataFrame(detected_peaks).to_csv(detected_file, header=False, index=False)

        # Compute confusion matrix
        results = compute_confusion_matrix(detected_peaks, labeled_peaks)
        print(f"File: {file_name}")
        print(f"True Positives (TP): {results['tp']}")
        print(f"False Positives (FP): {results['fp']}")
        print(f"False Negatives (FN): {results['fn']}")
        print(f"Accuracy: {results['accuracy']:.2f}")
        print(f"Precision: {results['precision']:.2f}")
        print(f"Recall: {results['recall']:.2f}")
        print(f"Detection Error Rate: {results['der']:.2f}")
        print("False Positives (FP) at indices:", results['fp_list'])
        print("False Negatives (FN) at indices:", results['fn_list'])
        print("")

        # Plot False Positives and False Negatives
        #plot_fp_fn(ppg_data, results['fp_list'], results['fn_list'], detected_peaks, file_name, plot_output_dir)



Processing: golden_ppg_data_100hz.csv
File: golden_ppg_data_100hz.csv
True Positives (TP): 272
False Positives (FP): 0
False Negatives (FN): 0
Accuracy: 1.00
Precision: 1.00
Recall: 1.00
Detection Error Rate: 0.00
False Positives (FP) at indices: []
False Negatives (FN) at indices: []

Processing: p000946-2120-05-14-08-08.csv
File: p000946-2120-05-14-08-08.csv
True Positives (TP): 31298
False Positives (FP): 249
False Negatives (FN): 27
Accuracy: 0.99
Precision: 0.99
Recall: 1.00
Detection Error Rate: 0.01
False Positives (FP) at indices: [66, 466, 8355, 31818, 36433, 83143, 84989, 88999, 92395, 105904, 108751, 138463, 139135, 139490, 145435, 145668, 149916, 158884, 159262, 159533, 161205, 163293, 163527, 166798, 167681, 231170, 236192, 256001, 257403, 272897, 283517, 317559, 319742, 332198, 342017, 342458, 468564, 491521, 498728, 498866, 502833, 512510, 512635, 526122, 582732, 586754, 590212, 595280, 655438, 656055, 668749, 674890, 683005, 700152, 702779, 707148, 763137, 770817, 79727