## BME 544 Assignment 3

In [None]:
import os
import librosa
import librosa.display 
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import pandas as pd
from scipy.signal import hilbert
import re

np.random.seed(42)
show_gt = False

# Directory containing the .wav files
wav_directory = "DC1_Recordings/"

# List of time slices (start, stop) in seconds
time_slices = [
    [20, 35], [36, 55], [1, 16], [11, 20], [0, 33], [10, 15], [9, 31], [9, 26],
    [9, 26], [3, 16], [7, 19], [3, 15], [18, 50], [8, 25], [13, 18], [6, 15],
    [10, 23], [1, 29], [13, 32], [0, 12], [3, 17], [14, 27], [0, 8], [14, 28],
    [0, 9], [9, 30], [0, 9], [0, 13], [11, 22], [2, 22]
]

# Normalize the audio signal
def normalize_audio(audio):
    """Normalize audio signal to the range [-1, 1]."""
    return audio / np.max(np.abs(audio))

# Apply Rolling Median Filter (Window-based)
def rolling_median_filter(audio, window_size):
    """Apply a median filter with a sliding window to the audio signal."""
    return signal.medfilt(audio, kernel_size=window_size)

# Function to calculate signal envelope
def calculate_envelope(signal):
    """Calculate the envelope of a signal using Hilbert transform."""
    analytic_signal = hilbert(signal)
    return np.abs(analytic_signal)

def percussive_enhance(audio):
    return librosa.effects.percussive(audio)

def normalize_array(arr):
    """
    Normalize an array to the range [0, 1].
    
    Parameters:
    - arr (array-like): The array to normalize.
    
    Returns:
    - normalized_arr (ndarray): The normalized array.
    """
    return (arr - np.min(arr)) / (np.max(arr) - np.min(arr) + 1e-6)

def weigh_peaks(peaks, heights, horizontal_gap_weight=1.75, vertical_gap_weight=0.5, horizontal_center_dist_weight=1.0, vertical_center_dist_weight=0.5):
    peaks = np.asarray(peaks)
    heights = np.asarray(heights)
    
    if len(peaks) < 2:
        return np.ones_like(peaks, dtype=float)  # Default weight if no neighbors exist
    
    # Compute median offsets for distances to neighbors
    median_horizontal_gap = np.median(np.diff(peaks)) if len(peaks) > 1 else 0
    median_vertical_gap = np.median(np.abs(np.diff(heights))) if len(heights) > 1 else 0
    
    # Compute horizontal and vertical distances to neighbors
    horizontal_gaps = np.zeros_like(peaks, dtype=float)
    vertical_gaps = np.zeros_like(peaks, dtype=float)
    
    for i in range(len(peaks)):
        if i == 0:
            horizontal_gaps[i] = abs(peaks[i + 1] - peaks[i] - median_horizontal_gap)
            vertical_gaps[i] = abs(heights[i + 1] - heights[i] - median_vertical_gap)
        elif i == len(peaks) - 1:
            horizontal_gaps[i] = abs(peaks[i] - peaks[i - 1] - median_horizontal_gap)
            vertical_gaps[i] = abs(heights[i] - heights[i - 1] - median_vertical_gap)
        else:
            horizontal_gaps[i] = abs((peaks[i + 1] - peaks[i - 1]) / 2 - median_horizontal_gap)
            vertical_gaps[i] = abs((abs(heights[i + 1] - heights[i]) + abs(heights[i] - heights[i - 1])) / 2 - median_vertical_gap)
    
    # Normalize horizontal and vertical distances before applying the weights
    horizontal_gaps = normalize_array(horizontal_gaps)
    vertical_gaps = normalize_array(vertical_gaps)
    
    # Calculate the distance of each peak_x (in peaks) and peak_y (in heights) to their respective medians
    median_peak_x = np.median(peaks)
    median_peak_y = np.median(heights)
    
    horizontal_center_dists = normalize_array(np.abs(peaks - median_peak_x))
    vertical_center_dists = normalize_array(np.abs(heights - median_peak_y))
    
    # Compute weighted score including the new median-based weights
    weights = (horizontal_gap_weight / (horizontal_gaps + 1e-2) + 
               vertical_gap_weight / (vertical_gaps + 1e-2) + 
               horizontal_center_dist_weight / (horizontal_center_dists + 1e-2) + 
               vertical_center_dist_weight / (vertical_center_dists + 1e-2))

    # Normalize the final weights
    weights = normalize_array(weights)
    
    return weights

def extract_info(filename):
    pattern = r"(\w+)_(\d+)_(\d+)\.wav"
    match = re.match(pattern, filename)
    
    if match:
        key, start_pressure, end_pressure = match.groups()
        return {
            "key": key,
            "start_pressure": int(start_pressure),
            "end_pressure": int(end_pressure)
        }
    else:
        return None
    
downsample_factor = 3
median_filter_window_size = 15
percentile_lower = 97
percentile_upper = 99
max_bpm = 120
peak_prominence = 0.01
convolution_window_length_seconds = 8
crossing_threshold = 0.3

output = []
# Iterate through .wav files in the directory
for i, filename in enumerate(sorted(os.listdir(wav_directory))):
    if filename.endswith(".wav"):
        try:
            info = extract_info(filename)
            file_path = os.path.join(wav_directory, filename)
            y, sr = librosa.load(file_path, sr=None)  # Load audio file with its original sampling rate

            # Get the time slice for this file from the list
            start_time, end_time = time_slices[i]

            # Convert start and end times to sample indices
            start_sample = int(start_time * sr)
            end_sample = int(end_time * sr)
            y_percussive = normalize_audio(percussive_enhance(y))

            y_envelope = calculate_envelope(y_percussive)
            y_downsampled = normalize_array(y_envelope[::downsample_factor])
            y_median = rolling_median_filter(y_downsampled, median_filter_window_size)

            percentile_lower_thresh = np.percentile(y_median, percentile_lower)
            percentile_upper_thresh = np.percentile(y_median, percentile_upper)
            masked_signal = np.where((y_median >= percentile_lower_thresh), y_median, 0)

            # Find peaks in the masked signal
            y_peaks, _ = signal.find_peaks(masked_signal, distance=int((sr * 60 / max_bpm) / downsample_factor), height=peak_prominence)
            y_peaks = y_peaks[1:-1] # Ignore the first and last peaks
            weights = weigh_peaks(y_peaks, y_median[y_peaks])

            window_length_samples = int(convolution_window_length_seconds * sr / downsample_factor)
            square_window = np.ones(window_length_samples)
            weights_with_zeros = np.zeros_like(y_median)

            weights_with_zeros[y_peaks] = weights

            convolved_weights = normalize_array(signal.convolve(weights_with_zeros, square_window, mode='same'))

            # Find the index of the max value in the convolved signal
            max_idx = np.argmax(convolved_weights)

            # Compute the derivative (finite differences)
            derivative = np.diff(convolved_weights, prepend=0)

            # Find the leftmost crossing point before max_idx with a positive derivative
            left_crossing_idx = None
            for i in range(max_idx):
                if convolved_weights[i] >= crossing_threshold and derivative[i] > 0:
                    left_crossing_idx = i
                    break  # Stop at the first (leftmost) crossing

            # Find the rightmost crossing point after max_idx with a negative derivative
            right_crossing_idx = None
            for i in range(len(convolved_weights) - 1, max_idx, -1):
                if convolved_weights[i] >= crossing_threshold and derivative[i] < 0:
                    right_crossing_idx = i
                    break  # Stop at the first (rightmost) crossing
            
            systolic_pressure_est = info["start_pressure"] + (left_crossing_idx / (len(y) - 1)) * (info["end_pressure"] - info["start_pressure"])
            diastolic_pressure_est = info["start_pressure"] + (right_crossing_idx / (len(y) - 1)) * (info["end_pressure"] - info["start_pressure"])

            output.append({
                "filename": filename,
                "systolic_pressure_est": systolic_pressure_est,
                "diastolic_pressure_est": diastolic_pressure_est
            })
            
            plt.figure(figsize=(32, 14))

            # 1. Original Signal
            plt.subplot(3, 3, 1)
            if show_gt:
                plt.axvline(x=start_sample, color='r', linestyle='--', label='Start')
                plt.axvline(x=end_sample, color='g', linestyle='--', label='End')
            plt.plot(normalize_audio(y))
            plt.title("Original Signal")
            plt.legend()
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 2. Percussive Enhanced (Normalized)
            plt.subplot(3, 3, 2)
            if show_gt:
                plt.axvline(x=start_sample, color='r', linestyle='--')
                plt.axvline(x=end_sample, color='g', linestyle='--')
            plt.plot(y_percussive)
            plt.title("Percussive Enhanced")
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 3. Envelope Downsampled
            plt.subplot(3, 3, 3)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(y_downsampled)
            plt.title("Enveloped + Rectified, Downsampled")
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 4. Median Filtered
            plt.subplot(3, 3, 4)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(y_median)
            plt.title("Median Filtered")
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 5. Masked Signal
            plt.subplot(3, 3, 5)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(masked_signal)
            plt.title("Percentile Masked")
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 6. Median Filtered with Peaks
            plt.subplot(3, 3, 6)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(y_median, label="Filtered Signal")
            plt.plot(y_peaks, y_median[y_peaks], "rx", label="Detected Peaks")
            plt.title("Median Filtered with Peaks")
            plt.legend()
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 7. Median Filtered with Weighted Peaks
            plt.subplot(3, 3, 7)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(y_median, label="Filtered Signal")
            plt.plot(y_peaks, weights, "rx", label="Weighted Peaks")
            plt.title("Median Filtered with Weighted Peaks")
            plt.legend()
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 8. Convolved
            plt.subplot(3, 3, 8)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.plot(convolved_weights, "g--", label="Convolved Weights")
            plt.title("Convolved Weighted Peaks")
            plt.legend()
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # 9. Normalized Convolved with Left & Right Points Found
            plt.subplot(3, 3, 9)
            if show_gt:
                plt.axvline(x=start_sample // downsample_factor, color='r', linestyle='--')
                plt.axvline(x=end_sample // downsample_factor, color='g', linestyle='--')
            plt.axhline(y=crossing_threshold, color='b', linestyle='--', label='Threshold')
            plt.plot(convolved_weights, "g--")
            if left_crossing_idx is not None:
                plt.axvline(left_crossing_idx, color='purple', linestyle='-.', label="Estimated Systolic Pressure")
            if right_crossing_idx is not None:
                plt.axvline(right_crossing_idx, color='orange', linestyle='-.', label="Estimated Diastolic Pressure")
            plt.title("Convolved Weighted Peaks with Systolic & Diastolic Sample Index")
            plt.legend()
            plt.xlabel("Samples")
            plt.ylabel("Amplitude")

            # Layout & Title
            plt.tight_layout()
            if show_gt:
                plt.suptitle(f'File {filename} | Interval: {start_time}s - {end_time}s', fontsize=16, y=1.05)
            else:
                plt.suptitle(f'File {filename}', fontsize=16, y=1.05)
            plt.show()

        except ValueError:
            print(f"Skipping {filename}: Unable to parse pressure values")

df = pd.DataFrame(output)
df.to_csv("Assignment_2_output.csv", index=False)