In [3]:
from __future__ import annotations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt


# ---------- Helpers ----------
def infer_fs_from_time(time_s: np.ndarray) -> float:
    dt = np.diff(time_s)
    dt = dt[np.isfinite(dt) & (dt > 0)]
    if dt.size == 0:
        raise ValueError("Cannot infer fs.")
    return float(1.0 / np.median(dt))


def bandpass_filter(x: np.ndarray, fs: float,
                    low_hz: float = 20.0,
                    high_hz: float = 45.0,
                    order: int = 4) -> np.ndarray:
    nyq = 0.5 * fs
    low = low_hz / nyq
    high = high_hz / nyq
    b, a = butter(order, [low, high], btype="bandpass")
    return filtfilt(b, a, x, axis=0)


def zscore_per_channel(x: np.ndarray) -> np.ndarray:
    mu = np.mean(x, axis=0, keepdims=True)
    sigma = np.std(x, axis=0, keepdims=True)
    return (x - mu) / (sigma + 1e-8)


# ---------- Main ----------
if __name__ == "__main__":

    df = pd.read_csv("S1_A1_E1_export.csv")

    time_col = "Time"
    emg_cols = [c for c in df.columns if c.startswith("EMG_")]

    if time_col not in df.columns:
        raise ValueError("Time column not found.")

    if not emg_cols:
        raise ValueError("No EMG columns found.")

    # Infer sampling rate
    fs = infer_fs_from_time(df[time_col].to_numpy(float))
    print("fs =", fs)

    # Extract signal
    x = df[emg_cols].to_numpy(float)

    # Process
    x = x - np.mean(x, axis=0, keepdims=True)
    x = bandpass_filter(x, fs, 20.0, 45.0)
    x = zscore_per_channel(x)

    # Downsample for plotting performance
    t = df[time_col].to_numpy(float)
    max_points = 50_000
    stride = max(1, len(t) // max_points)

    t_plot = t[::stride]
    x_plot = x[::stride, :]

    # ---------- Plot each channel separately ----------
    for i in range(x_plot.shape[1]):

        plt.figure(figsize=(10, 4))
        plt.plot(t_plot, x_plot[:, i], linewidth=0.8)

        plt.xlabel("Time (s)")
        plt.ylabel(f"EMG_{i+1} (z)")
        plt.title(f"Processed EMG Channel {i+1}")
        plt.tight_layout()

        filename = f"processed_emg_ch{i+1}.png"
        plt.savefig(filename, dpi=200)
        plt.close()

        print(f"Saved {filename}")

    print("All 10 channel plots saved.")

fs = 100.00000000009095
Saved processed_emg_ch1.png
Saved processed_emg_ch2.png
Saved processed_emg_ch3.png
Saved processed_emg_ch4.png
Saved processed_emg_ch5.png
Saved processed_emg_ch6.png
Saved processed_emg_ch7.png
Saved processed_emg_ch8.png
Saved processed_emg_ch9.png
Saved processed_emg_ch10.png
All 10 channel plots saved.
