In [1]:
# ============================================================
# 01 FEATURE EXTRACTION
# ============================================================
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import signal, stats
from scipy.signal import find_peaks, welch
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)


In [2]:
# ============================================================
# 1 Configuration
# ============================================================

DATA_ROOT = "/Volumes/blue_nateck/WESAD"     # path to WESAD dataset root
OUTPUT_PATH = "features_wesad_raw.csv"         # output CSV file


In [3]:
# Windowing parameters (10 s window, 5 s stride)
WINDOW_SEC = 7
STRIDE_SEC = 3.5
PURITY_THRESHOLD = 0.9  # ≥90% jednej klasy w oknie

In [4]:
# Sampling frequencies per modality (from WESAD documentation)
ECG_FS  = 700
BVP_FS  = 64
EDA_FS  = 4
RESP_FS = 700
TEMP_FS = 700
ACC_FS  = 32

In [5]:
CLASS_MAP = {0: "baseline", 1: "amusement", 2: "stress"}
INV_CLASS_MAP = {v: k for k, v in CLASS_MAP.items()}

In [6]:
# ============================================================
# FUNKCJE POMOCNICZE
# ============================================================

def load_wesad_subject(pkl_path: str):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")
    return data

def to_seconds(idx: int, fs: int) -> float:
    return idx / float(fs)

def to_index(t: float, fs: int) -> int:
    return int(round(t * fs))

def label_spans_from_array(y: np.ndarray, fs_label: int):
    spans = {}
    i = 0
    n = len(y)
    while i < n:
        cls = y[i]; j = i + 1
        while j < n and y[j] == cls:
            j += 1
        if cls in CLASS_MAP:
            t1, t2 = to_seconds(i, fs_label), to_seconds(j, fs_label)
            spans.setdefault(CLASS_MAP[cls], []).append((t1, t2))
        i = j
    return spans

def assign_label_for_window(t_start: float, t_end: float, spans, purity: float):
    ts = np.arange(t_start, t_end, 0.1)
    if len(ts) == 0: return None
    cover = {c:0 for c in spans.keys()}
    for c, lst in spans.items():
        for a,b in lst:
            cover[c] += np.sum((ts >= a) & (ts < b))
    total = sum(cover.values())
    if total == 0: return None
    cls, cnt = max(cover.items(), key=lambda kv: kv[1])
    return cls if (cnt/total) >= purity else None


In [7]:
# ============================================================
# FUNKCJE CECH SYGNAŁOWYCH
# ============================================================

def rr_interval_features(rr_seconds: np.ndarray):
    feats = {}
    if len(rr_seconds) < 3:
        return {"RR_mean": np.nan, "RR_std": np.nan, "IBI_median": np.nan, "IBI_IQR": np.nan, "TINN": np.nan}

    RR_mean = float(np.mean(rr_seconds))
    RR_std  = float(np.std(rr_seconds))
    IBI_median = float(np.median(rr_seconds))
    q75, q25 = np.percentile(rr_seconds, [75,25])
    IBI_IQR = float(q75 - q25)

    try:
        hist, bin_edges = np.histogram(rr_seconds, bins=min(64, max(16, len(rr_seconds)//2)))
        peak_idx = np.argmax(hist)
        left = peak_idx
        while left > 0 and hist[left] > 0: left -= 1
        right = peak_idx
        while right < len(hist)-1 and hist[right] > 0: right += 1
        left_val = (bin_edges[left] + bin_edges[left+1]) / 2.0
        right_val = (bin_edges[right] + bin_edges[right+1]) / 2.0
        TINN = float(right_val - left_val)
    except Exception:
        TINN = np.nan

    return {"RR_mean": RR_mean, "RR_std": RR_std, "IBI_median": IBI_median, "IBI_IQR": IBI_IQR, "TINN": TINN}

def resp_band_energy_and_var(sig: np.ndarray, fs: int):
    if len(sig) < fs:
        return {"Resp_band_energy": np.nan, "Resp_rate_var": np.nan}
    f, Pxx = welch(sig, fs=fs, nperseg=min(256, len(sig)))
    band = (f >= 0.1) & (f <= 0.35)
    band_energy = float(np.trapezoid(Pxx[band], f[band])) if np.any(band) else np.nan

    peaks, _ = find_peaks(sig, distance=max(1, int(fs/2)))
    if len(peaks) >= 3:
        ibi = np.diff(peaks) / fs
        rate_series = 60.0 / np.clip(ibi, 1e-6, None)
        rate_var = float(np.var(rate_series))
    else:
        rate_var = np.nan
    return {"Resp_band_energy": band_energy, "Resp_rate_var": rate_var}

def acc_axis_stats(acc_xyz: np.ndarray):
    if acc_xyz.ndim != 2 or acc_xyz.shape[1] != 3:
        return {"ACC_rms_x": np.nan, "ACC_rms_y": np.nan, "ACC_rms_z": np.nan,
                "ACC_mean_x": np.nan, "ACC_mean_y": np.nan, "ACC_mean_z": np.nan}
    rms = np.sqrt(np.mean(acc_xyz**2, axis=0))
    mean_axes = np.mean(acc_xyz, axis=0)
    return {
        "ACC_rms_x": float(rms[0]), "ACC_rms_y": float(rms[1]), "ACC_rms_z": float(rms[2]),
        "ACC_mean_x": float(mean_axes[0]), "ACC_mean_y": float(mean_axes[1]), "ACC_mean_z": float(mean_axes[2]),
    }

def eda_tonic_phasic_features(sig: np.ndarray, fs: int):
    if len(sig) < fs:
        return {"SCL_tonic_mean": np.nan, "SCR_rise_time_mean": np.nan}
    win = max(1, int(5*fs))
    scl = pd.Series(sig).rolling(window=win, center=True, min_periods=max(1, win//2)).median().to_numpy()
    scl = np.nan_to_num(scl, nan=np.nanmedian(scl))
    phasic = sig - scl

    peaks, _ = find_peaks(phasic, distance=max(1, int(0.5*fs)), prominence=0.02)
    rise_times = []
    for p in peaks:
        start = max(0, p - int(3*fs))
        segment = phasic[start:p+1]
        if len(segment) > 1:
            st_idx = np.argmin(segment)
            rt = (p - (start + st_idx)) / fs
            rise_times.append(rt)
    return {
        "SCL_tonic_mean": float(np.nanmean(scl)),
        "SCR_rise_time_mean": float(np.mean(rise_times)) if len(rise_times) else np.nan
    }


In [8]:
# ============================================================
# KLASY FEATURÓW
# ============================================================

class ECG_BVP_Features:
    """Extract heart rate and HRV-related features from ECG or BVP signals."""

    def __init__(self, fs):
        self.fs = fs

    def compute(self, sig):
        # --- Zabezpieczenie pasma filtra ---
        nyq = 0.5 * self.fs
        low = 0.5 / nyq
        high = 40 / nyq

        # Korekta, żeby 0 < low < high < 1
        high = min(high, 0.99)
        low = max(low, 0.001)

        # Filtrowanie sygnału (z obsługą błędów)
        try:
            b, a = signal.butter(3, [low, high], btype='band')
            sig_f = signal.filtfilt(b, a, sig)
        except ValueError:
            return {k: np.nan for k in [
                "HR_mean","HR_std","HRV_SDNN","HRV_RMSSD","pNN50","LF_HF_ratio",
                "RR_mean","RR_std","IBI_median","IBI_IQR","TINN"
            ]}

        # --- Detekcja pików i obliczenia HR/HRV ---
        peaks, _ = find_peaks(sig_f, distance=0.3 * self.fs, prominence=np.std(sig_f) * 0.3)
        if len(peaks) < 3:
            return {k: np.nan for k in [
                "HR_mean","HR_std","HRV_SDNN","HRV_RMSSD","pNN50","LF_HF_ratio",
                "RR_mean","RR_std","IBI_median","IBI_IQR","TINN"
            ]}

        rr = np.diff(peaks) / self.fs
        hr = 60.0 / np.clip(rr, 1e-6, None)

        HR_mean = np.mean(hr)
        HR_std = np.std(hr)
        HRV_SDNN = np.std(rr)
        HRV_RMSSD = np.sqrt(np.mean(np.square(np.diff(rr)))) if len(rr) > 2 else np.nan
        pNN50 = np.sum(np.abs(np.diff(rr)) > 0.05) / len(rr) if len(rr) > 2 else np.nan

        # --- Analiza widmowa RR (LF/HF) ---
        rr_d = rr - np.mean(rr)
        f, pxx = welch(rr_d, fs=4.0, nperseg=min(256, len(rr_d)))
        LF_mask = (f >= 0.04) & (f < 0.15)
        HF_mask = (f >= 0.15) & (f < 0.40)
        LF = np.trapezoid(pxx[LF_mask], f[LF_mask]) if np.any(LF_mask) else np.nan
        HF = np.trapezoid(pxx[HF_mask], f[HF_mask]) if np.any(HF_mask) else np.nan
        LF_HF_ratio = float(LF / HF) if (isinstance(HF, float) and HF > 0) else np.nan

        # --- Zbiór wszystkich cech ---
        feats = {
            "HR_mean": HR_mean,
            "HR_std": HR_std,
            "HRV_SDNN": HRV_SDNN,
            "HRV_RMSSD": HRV_RMSSD,
            "pNN50": pNN50,
            "LF_HF_ratio": LF_HF_ratio
        }

        # --- Dodatkowe cechy RR (IBI, TINN itd.) ---
        feats.update(rr_interval_features(rr))
        return feats

class EDA_Features:
    def __init__(self, fs): self.fs = fs
    def compute(self, sig):
        mean = np.mean(sig)
        std  = np.std(sig)
        slope = stats.linregress(np.arange(len(sig)), sig).slope if len(sig) > 1 else np.nan
        peaks, props = find_peaks(sig, distance=max(1,int(0.5*self.fs)), prominence=0.02)
        scr_rate = len(peaks) / (len(sig)/self.fs/60) if len(sig) else np.nan
        mean_amp = np.mean(props["prominences"]) if len(peaks)>0 else 0.0
        median_val = np.median(sig)
        eda_extra = eda_tonic_phasic_features(sig, self.fs)
        return {
            "EDA_mean": mean, "EDA_std": std, "EDA_median": median_val,
            "EDA_slope": slope, "SCR_peaks_per_min": scr_rate, "SCR_mean_amp": mean_amp,
            **eda_extra
        }


class RESP_Features:
    def __init__(self, fs): self.fs = fs
    def compute(self, sig):
        peaks, _ = find_peaks(sig, distance=max(1,int(self.fs/2)))
        resp_rate = len(peaks) / (len(sig)/self.fs/60) if len(sig) else np.nan
        amp_mean = np.mean(sig)
        amp_std  = np.std(sig)
        extras = resp_band_energy_and_var(sig, self.fs)
        return {"Resp_rate": resp_rate, "Resp_amp_mean": amp_mean, "Resp_amp_std": amp_std, **extras}


class TEMP_Features:
    def compute(self, sig):
        mean = np.mean(sig)
        std = np.std(sig)
        slope = stats.linregress(np.arange(len(sig)), sig).slope if len(sig) > 1 else np.nan
        return {"TEMP_mean": mean, "TEMP_std": std, "TEMP_slope": slope}


class ACC_Features:
    def compute(self, acc_xyz):
        mag = np.linalg.norm(acc_xyz, axis=1)
        mean_mag = np.mean(mag)
        std = np.std(mag)
        energy = np.sum(mag**2)/len(mag) if len(mag) else np.nan
        axis_feats = acc_axis_stats(acc_xyz)
        return {"ACC_mean_magnitude": mean_mag, "ACC_std": std, "ACC_energy": energy, **axis_feats}



In [13]:
# ============================================================
# GŁÓWNA PĘTLA EKSTRAKCJI
# ============================================================

records = []

for subject in tqdm(["S2","S3","S4","S5","S6","S7","S8","S9","S10","S11","S13","S14","S15","S16","S17"], desc="Processing subjects"):
    pkl_path = os.path.join(DATA_ROOT, subject, f"{subject}.pkl")
    if not os.path.exists(pkl_path):
        print(f"{pkl_path} not found, skipping...")
        continue

    data = load_wesad_subject(pkl_path)
    y = data['label'].astype(int).squeeze()
    spans = label_spans_from_array(y, ECG_FS)

    ecg = data['signal']['chest']['ECG'].squeeze()
    bvp = data['signal']['wrist']['BVP'].squeeze()
    eda = data['signal']['wrist']['EDA'].squeeze()
    resp = data['signal']['chest']['Resp'].squeeze()
    temp = data['signal']['wrist']['TEMP'].squeeze()
    acc = data['signal']['wrist']['ACC']

    ecg_ex = ECG_BVP_Features(ECG_FS)
    bvp_ex = ECG_BVP_Features(BVP_FS)
    eda_ex = EDA_Features(EDA_FS)
    resp_ex = RESP_Features(RESP_FS)
    temp_ex = TEMP_Features()
    acc_ex = ACC_Features()

    win = int(WINDOW_SEC * ECG_FS)
    step = int(STRIDE_SEC * ECG_FS)

    for start in range(0, len(ecg) - win, step):
        t1, t2 = start/ECG_FS, (start+win)/ECG_FS
        lbl_name = assign_label_for_window(t1, t2, spans, PURITY_THRESHOLD)
        if lbl_name is None:
            continue
        label = INV_CLASS_MAP[lbl_name]

        ecg_win = ecg[start:start+win]
        idx_eda = slice(int(start*EDA_FS/ECG_FS), int((start+win)*EDA_FS/ECG_FS))
        idx_bvp = slice(int(start*BVP_FS/ECG_FS), int((start+win)*BVP_FS/ECG_FS))
        idx_resp = slice(int(start*RESP_FS/ECG_FS), int((start+win)*RESP_FS/ECG_FS))
        idx_temp = slice(int(start*TEMP_FS/ECG_FS), int((start+win)*TEMP_FS/ECG_FS))
        idx_acc = slice(int(start*ACC_FS/ECG_FS), int((start+win)*ACC_FS/ECG_FS))

        feats = {"subject": subject, "label": label}

        feats.update(ecg_ex.compute(ecg_win))
        feats.update(bvp_ex.compute(bvp[idx_bvp]))
        feats.update(eda_ex.compute(eda[idx_eda]))
        feats.update(resp_ex.compute(resp[idx_resp]))
        feats.update(temp_ex.compute(temp[idx_temp]))
        feats.update(acc_ex.compute(acc[idx_acc]))

        records.append(feats)

Processing subjects: 100%|██████████| 15/15 [02:02<00:00,  8.19s/it]


In [14]:
# ============================================================
# ZAPIS DANYCH
# ============================================================

df = pd.DataFrame(records)
df.to_csv(OUTPUT_PATH, index=False)
print(f"\nZapisano cechy: {OUTPUT_PATH}")
print(f"Kształt: {df.shape}")
display(df.head())


Zapisano cechy: features_wesad_raw.csv
Kształt: (19218, 38)


Unnamed: 0,subject,label,HR_mean,HR_std,HRV_SDNN,HRV_RMSSD,pNN50,LF_HF_ratio,RR_mean,RR_std,...,TEMP_slope,ACC_mean_magnitude,ACC_std,ACC_energy,ACC_rms_x,ACC_rms_y,ACC_rms_z,ACC_mean_x,ACC_mean_y,ACC_mean_z
0,S2,0,82.972527,32.136162,0.289972,0.472769,0.857143,,0.830357,0.289972,...,4.9e-05,64.834768,9.564265,4295.022321,54.175887,20.424513,30.705615,53.785714,19.1875,26.433036
1,S2,0,71.26784,4.871916,0.059895,0.082186,0.428571,,0.845982,0.059895,...,-0.000146,63.281737,0.794698,4005.209821,53.768477,20.13692,26.620766,53.763393,20.129464,26.611607
2,S2,0,71.769454,6.391315,0.087197,0.134563,0.428571,,0.84375,0.087197,...,-0.000397,63.481564,4.885925,4053.78125,43.470433,41.739627,20.540357,38.919643,-14.383929,17.236607
3,S2,0,88.117261,34.317224,0.24569,0.264309,0.875,,0.773438,0.24569,...,-0.00064,63.397065,5.220551,4046.441964,23.278209,57.685841,13.300779,18.946429,-54.290179,10.9375
4,S2,0,100.569534,31.057347,0.193543,0.274273,0.666667,,0.654514,0.193543,...,-9.3e-05,62.879988,2.035378,3958.035714,15.359676,59.200967,14.743188,15.125,-59.15625,14.40625
