In [6]:
import os
import pandas as pd
import numpy as np
from tensorflow.keras.models import load_model
import sys
sys.path.append("../../common_tools/peaks_confusion_matrix")
from confusion_matrix import compute_confusion_matrix
sys.path.append("../../common_tools/signal_processing")
from normalize_signal import normalize_ppg

# Function to filter out close peaks
def filter_close_peaks(peaks, tolerance):
    """
    Remove peaks that are too close to each other within a given tolerance.
    Only the first peak in a group of close peaks is retained.
    """
    if not peaks:
        return []
    
    filtered_peaks = [peaks[0]]
    for peak in peaks[1:]:
        if peak - filtered_peaks[-1] > tolerance:
            filtered_peaks.append(peak)
    
    return filtered_peaks

# Function to detect peaks using a trained model
def detect_peaks_with_model(ppg_data, model, segment_length=256, threshold=0.5, fs=100):
    """
    Detect PPG peaks using a trained model and filter out close peaks.
    
    Args:
        ppg_data (array): Input PPG signal data.
        model (keras.Model): Pre-trained model.
        segment_length (int): Length of each processing segment.
        threshold (float): Probability threshold for peak detection.
        fs (int): Sampling frequency, default is 100 Hz.
    
    Returns:
        list: List of filtered peak indices.
    """
    detected_peaks = []
    tolerance = int(0.075 * fs)  # Convert 0.075s to number of samples
    
    for i in range(0, len(ppg_data), segment_length):
        segment = ppg_data[i:i+segment_length]
        if len(segment) < segment_length:
            segment = np.pad(segment, (0, segment_length - len(segment)), 'constant', constant_values=0)
        segment = np.expand_dims(np.expand_dims(segment, axis=0), axis=-1)  # Reshape to (1, 256, 1)
        
        predictions = model.predict(segment, verbose=0)
        predictions = predictions.flatten()  # Convert to 1D array
        
        for idx, value in enumerate(predictions):
            if value > threshold:  # Detect peak based on threshold
                detected_peaks.append(i + idx)
    
    # Filter out close peaks
    detected_peaks = filter_close_peaks(detected_peaks, tolerance)
    
    return detected_peaks

# Paths
test_data_dir = 'E:/dilated_cnn_peak_detection_model_data/test/test_data/data'
label_dir = 'E:/dilated_cnn_peak_detection_model_data/test/test_data/label'
output_dir = 'E:/dilated_cnn_peak_detection_model_data/test/model_detected_peaks/'

# Create directory to save results
os.makedirs(output_dir, exist_ok=True)

# Load the model
model = load_model("../dilated_cnn_peak_detection_model.h5")

# Test and save results
for file_name in os.listdir(test_data_dir):
    input_file = os.path.join(test_data_dir, file_name)
    label_file = os.path.join(label_dir, file_name.replace(".csv", "_labeled_peaks.csv"))
    if os.path.exists(label_file):
        print(f"Processing: {file_name}")
        
        # Read and normalize data
        ppg_data = pd.read_csv(input_file, header=None).squeeze("columns").values
        ppg_data = normalize_ppg(ppg_data)
        labeled_peaks = pd.read_csv(label_file, header=None).squeeze("columns").values

        # Detect peaks
        detected_peaks = detect_peaks_with_model(ppg_data, model)

        # Save detected peaks
        detected_file = os.path.join(output_dir, file_name.replace(".csv", "_detected_peaks.csv"))
        pd.DataFrame(detected_peaks).to_csv(detected_file, header=False, index=False)

        # Compute confusion matrix
        results = compute_confusion_matrix(detected_peaks, labeled_peaks)
        print(f"File: {file_name}")
        print(f"True Positives (TP): {results['tp']}")
        print(f"False Positives (FP): {results['fp']}")
        print(f"False Negatives (FN): {results['fn']}")
        print(f"Accuracy: {results['accuracy']:.2f}")
        print(f"Precision: {results['precision']:.2f}")
        print(f"Recall: {results['recall']:.2f}")
        print(f"Detection Error Rate: {results['der']:.2f}")
        #print("\nFalse Positives (FP) at indices:", results['fp_list'])
        #print("False Negatives (FN) at indices:", results['fn_list'])




Processing: golden_ppg_data_100hz.csv
File: golden_ppg_data_100hz.csv
True Positives (TP): 272
False Positives (FP): 1
False Negatives (FN): 1
Accuracy: 0.99
Precision: 1.00
Recall: 1.00
Detection Error Rate: 0.01
Processing: mimic_perform_af_001_data.csv
File: mimic_perform_af_001_data.csv
True Positives (TP): 1779
False Positives (FP): 7
False Negatives (FN): 51
Accuracy: 0.97
Precision: 1.00
Recall: 0.97
Detection Error Rate: 0.03
Processing: mimic_perform_af_002_data.csv
File: mimic_perform_af_002_data.csv
True Positives (TP): 1336
False Positives (FP): 4
False Negatives (FN): 24
Accuracy: 0.98
Precision: 1.00
Recall: 0.98
Detection Error Rate: 0.02
Processing: mimic_perform_af_003_data.csv
File: mimic_perform_af_003_data.csv
True Positives (TP): 1434
False Positives (FP): 2
False Negatives (FN): 21
Accuracy: 0.98
Precision: 1.00
Recall: 0.99
Detection Error Rate: 0.02
Processing: mimic_perform_non_af_001_data.csv
File: mimic_perform_non_af_001_data.csv
True Positives (TP): 1185
Fa