# PCG Concatenation and Spike Removal Pipeline

This notebook replicates the MATLAB script `concatenate_and_save.m`.

## Functionality:
1. Read all `.wav` files per subject
2. Resample signals to 2000 Hz
3. Perform spike removal on each PCG channel
4. Concatenate signals channel-wise
5. Save one `.wav` file per subject

⚠️ Spike removal function will be added later.


In [1]:
import os
import numpy as np
import pandas as pd
import soundfile as sf
import scipy.signal as ssg

def resample(signal: np.ndarray, fs_old: float, fs_new: float) -> np.ndarray:
    return ssg.resample_poly(signal, fs_new, fs_old)

In [None]:
# Output directory
outdir = "/home/sparc/Desktop/TH_data_joined_PYTHON_SpRe/"
outdir = "/home/sparc/Desktop/testing/"
os.makedirs(outdir, exist_ok=True)

# Input directory
data_dir = "/home/sparc/Desktop/TH_alldata_fil"

# Target sampling rate
fs_new = 2000

# Channel orders (MATLAB equivalent)
order_1  = [0,1,3,4,5,6,7,10]   # MATLAB 1-based → Python 0-based
order_23 = [1,3,4,5,7,6,8,10]


In [5]:
import numpy as np

def schmidt_spike_removal(signal, fs):
    """
    Python version of schmidt_spike_removal.m
    Faithfully reproduces MATLAB logic from:
    Schmidt et al., Physiol. Meas., 2010

    Parameters
    ----------
    signal : 1D numpy array
        Original PCG signal
    fs : float
        Sampling frequency (Hz)

    Returns
    -------
    despiked_signal : 1D numpy array
        Signal with spikes removed
    """

    signal = np.asarray(signal).flatten()
    original_len = len(signal)

    # --- Window size (500 ms) ---
    windowsize = int(round(fs / 2))

    # --- Trailing samples ---
    trailingsamples = original_len % windowsize

    # --- Reshape into windows (MATLAB column-wise) ---
    sampleframes = signal[:original_len - trailingsamples] \
        .reshape((windowsize, -1), order='F')

    # --- Maximum Absolute Amplitudes ---
    MAAs = np.max(np.abs(sampleframes), axis=0)

    # --- Main spike removal loop (MATLAB-style) ---
    while np.any(MAAs > np.median(MAAs) * 3):

        # Window with maximum MAA
        window_num = np.where(MAAs == np.max(MAAs))[0][0]

        # Spike position inside window
        spike_position = np.where(
            np.abs(sampleframes[:, window_num]) ==
            np.max(np.abs(sampleframes[:, window_num]))
        )[0][0]

        # Zero crossings
        zero_crossings = np.concatenate([
            np.abs(np.diff(np.sign(sampleframes[:, window_num]))) > 1,
            [0]
        ])

        # Spike start
        before = np.where(zero_crossings[:spike_position])[0]
        if before.size == 0:
            spike_start = 0
        else:
            spike_start = before[-1] + 1

        # Spike end
        zero_crossings[:spike_position] = 0
        after = np.where(zero_crossings)[0]
        if after.size == 0:
            spike_end = windowsize - 1
        else:
            spike_end = after[0]

        # Replace spike with small value
        sampleframes[spike_start:spike_end + 1, window_num] = 0.0001

        # Recalculate MAAs
        MAAs = np.max(np.abs(sampleframes), axis=0)

    # --- Reshape back to 1D ---
    despiked_signal = sampleframes.flatten(order='F')

    # --- Append trailing samples ---
    despiked_signal = np.concatenate([
        despiked_signal,
        signal[len(despiked_signal):]
    ])

    return despiked_signal


In [6]:
ref_csv = "REFERENCE_ALLROUNDS_ExclusionCriteria.csv"
df_ref = pd.read_csv(ref_csv, header=None)

subjects = df_ref.iloc[:, 0].astype(str).tolist()


In [7]:
all_files = [f for f in os.listdir(data_dir) if f.endswith(".wav")]


In [None]:
#main processing
for sub in subjects[0:1]:
    print(f"Processing subject: {sub}")

    # Select subject-specific files
    sub_files = [f for f in all_files if sub in f]
    sub_files.sort()
    data_concat = []

    for file in sub_files:
        print(file)
        file_path = os.path.join(data_dir, file)

        # Select channel order
        if len(file) > 1 and data_wav, fs = sf.read(file_path, dtype="float64")file[1] in ['c', 'v']:
            order = order_23
        else:
            order = order_1

        # Read WAV
        data_wav, fs = sf.read(file_path, dtype="float64")

        # Resample
        # num_samples = int(len(data_wav) * fs_new / fs)
        data_wav = resample(data_wav,fs,fs_new)

        # Spike removal per PCG channel
        for chan in range(7):
            col = order[chan]
            temp_chan = data_wav[:, col]

            start = fs_new - 200
            end = len(temp_chan) - 1000

            temp_chan[start:end] = schmidt_spike_removal(
                temp_chan[start:end], fs_new
            )

            data_wav[:, col] = temp_chan

        # Concatenate
        data_concat.append(data_wav)

#     if len(data_concat) == 0:
#         continue

    data_concat = np.vstack(data_concat)

#     # Save WAV
    out_file = os.path.join(outdir, f"{sub}_Sit.wav")
    sf.write(out_file, data_concat, fs_new, subtype='FLOAT')

#     print(f"Saved: {out_file}")
# print("All subjects processed successfully.")



Processing subject: i002
i002_Dia1_Fing_60s.wav
i002_Dia1_Vest_60s.wav


In [None]:
import scipy.io as sio
from typing import Optional, Tuple

def read_signal_wav(filename: str) -> Tuple[np.ndarray, int]:
    """
    Reads in a signal from a wav file then converts it into the same format that matlab would output.
    Outputs the sampling freq as well as the signal
    """
    if ".wav" not in filename:
        filename += ".wav"

    Fs, signal = sio.wavfile.read(filename)

    if signal.dtype == np.int16:
        max_val = np.iinfo(np.int16).max
    elif signal.dtype == np.int32:
        max_val = np.iinfo(np.int32).max
    elif signal.dtype == np.int64:
        max_val = np.iinfo(np.int64).max
    elif signal.dtype == np.float32 or signal.dtype == np.float64:
        # print('matt')
        # input(signal.dtype)
        return signal.astype(np.float32), Fs
    else:
        raise ValueError("Unsupported data type")

    # Convert to float 32
    signal = (signal / max_val).astype(np.float32)

    return signal, Fs


fname = '/home/sparc/Desktop/TH_data_joined_PYTHON_SpRe/i002_Sit.wav'
data=read_signal_wav(fname)
data_sam = data[0]


In [None]:
fname1 = '/home/sparc/Desktop/TH_data_joined_indicies_SpikeRemoval/i002_Sit.wav'
data1=read_signal_wav(fname1)
data_sam1 = data1[0]

  Fs, signal = sio.wavfile.read(filename)
