In [21]:
"""
svm_dataset_train.py


One-click dataset builder + lightweight labeling + quick SVM trainer (Dash)
for motion-state classification from PPG + IMU.


This revision implements your requested changes:
1) Fix: import kurtosis/skew from scipy.stats (not scipy.signal).
2) Segment default range = full data duration.
3) Labels merged/split: Sit&Stand → one class; Transit and StrongMotion → two classes.
4) PSD preview x-range limited to 0-8 Hz.
5) Left control panel width doubled.
6) Training option (default ON): exclude Gyro & Jerk features; OFF = use all features.
7) Dataset source selector: in-memory (Store) by default, or pick a saved CSV from ./datasets.
8) In-memory dataset preview: first 3 rows + total count.
9) Spectral-shape features (entropy & main peak) added for IMU (AccMag, GyroMag, JerkMag) in addition to PPG.
"""

'\nsvm_dataset_train.py\n\n\nOne-click dataset builder + lightweight labeling + quick SVM trainer (Dash)\nfor motion-state classification from PPG + IMU.\n\n\nThis revision implements your requested changes:\n1) Fix: import kurtosis/skew from scipy.stats (not scipy.signal).\n2) Segment default range = full data duration.\n3) Labels merged/split: Sit&Stand → one class; Transit and StrongMotion → two classes.\n4) PSD preview x-range limited to 0-8 Hz.\n5) Left control panel width doubled.\n6) Training option (default ON): exclude Gyro & Jerk features; OFF = use all features.\n7) Dataset source selector: in-memory (Store) by default, or pick a saved CSV from ./datasets.\n8) In-memory dataset preview: first 3 rows + total count.\n9) Spectral-shape features (entropy & main peak) added for IMU (AccMag, GyroMag, JerkMag) in addition to PPG.\n'

In [None]:
'''
first split holdout test data 
then moving window
''' 

In [22]:
import webbrowser
from __future__ import annotations
import os
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from scipy import signal
from scipy.stats import kurtosis, skew # (1) FIX: stats module, not signal

import dash
from dash import dcc, html, Input, Output, State, dash_table
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC, LinearSVC
from sklearn.model_selection import (StratifiedKFold, GroupKFold, train_test_split,
                                     GroupShuffleSplit, cross_validate)
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,           
    balanced_accuracy_score,    
    precision_recall_fscore_support
)

import joblib



In [23]:
# ========================
# Global config
# ========================
FS_DEFAULT = 400  # Sampling frequency in Hz
MIN_BPM = 40  # Minimum expected heart rate for artifact rejection
MAX_BPM = 180  # Maximum expected heart rate for artifact rejection
DATASET_DIR = Path("datasets")     # Feature CSV snapshots
MODEL_DIR = Path("models")         # Saved models (pkl)
PORT = 8051
G = 9.81
for p in (DATASET_DIR, MODEL_DIR):
    p.mkdir(parents=True, exist_ok=True)

In [24]:
envi = 1

windows_address_1 = ["/mnt/d/Tubcloud/Shared/PPG/Test Data",
                   "/mnt/d/Tubcloud/Shared/PPG/Test Data/25July25"]

ubuntu_address_0 = ["/home/trinker/only_view/Test Data", 
                  "/home/trinker/only_view/Test Data/25July25"]


if envi:
    DEFAULT_FOLDER_MAIN = windows_address_1[0]
    DEFAULT_FOLDER = windows_address_1[1]
else:
    DEFAULT_FOLDER_MAIN = ubuntu_address_0[0]
    DEFAULT_FOLDER = ubuntu_address_0[1]

In [25]:
# --------------------------------------
# Try to import your project core funcs
# --------------------------------------
try:
    from funcs import preprocess_ppg_min, imu_preprocess_with_kf
    print("import success")
    HAVE_CORE = True
except Exception:
    HAVE_CORE = False
    print("import fail")
    def preprocess_ppg_min(ppg, fs=FS_DEFAULT, hp_cut=0.2, mains=None):
        """Fallback: high-pass 0.2 Hz + optional mains notch (50/60 Hz)."""
        ppg = np.asarray(ppg, float).ravel()
        b, a = signal.butter(2, hp_cut/(0.5*fs), 'high')
        y = signal.filtfilt(b, a, ppg)
        if mains in (50, 60):
            b, a = signal.iirnotch(w0=float(mains), Q=30.0, fs=fs)
            y = signal.filtfilt(b, a, y)
        return y

    def imu_preprocess_with_kf(df: pd.DataFrame, fs=FS_DEFAULT, acc_fc=20, gyro_fc=40, static_secs=2.0):
        """Fallback: LP accel/gyro → magnitudes + jerk. (No EKF here.)"""
        G = 9.81
        acc = df[['AX','AY','AZ']].to_numpy(float) * G
        gyr = np.deg2rad(df[['GX','GY','GZ']].to_numpy(float))
        def lp(x, fc):
            b, a = signal.butter(4, fc/(0.5*fs), 'low')
            return signal.filtfilt(b, a, x, axis=0)
        acc_f = lp(acc, acc_fc)
        gyr_f = lp(gyr, gyro_fc)
        a_dyn = acc_f
        acc_mag  = np.linalg.norm(a_dyn, axis=1)
        gyro_mag = np.linalg.norm(gyr_f, axis=1)
        jerk     = np.diff(a_dyn, axis=0, prepend=a_dyn[:1]) * fs
        jerk_mag = np.linalg.norm(jerk, axis=1)
        return dict(acc_f=acc_f, gyr_f=gyr_f, a_dyn=a_dyn,
                    AccMag=acc_mag, GyroMag=gyro_mag, JerkMag=jerk_mag)

import success


In [26]:
# ========================
# Feature engineering
# ========================



def welch_bandpower(x, fs, fmin, fmax, nperseg=None, noverlap=0.5):
    """Band power via Welch (Hann). Returns a scalar."""
    x = np.asarray(x, float).ravel()
    if nperseg is None:
        nperseg = int(min(len(x), 2*fs))
    nlap = int(noverlap * nperseg)
    f, P = signal.welch(x, fs=fs, window="hann", nperseg=nperseg, noverlap=nlap)
    m = (f >= fmin) & (f <= fmax)
    return float(np.trapz(P[m], f[m])) if np.any(m) else 0.0

def spectral_entropy(x, fs, fmax=10.0):
    """Shannon entropy of normalized Welch PSD up to fmax."""
    x = np.asarray(x, float).ravel()
    nperseg = int(min(len(x), 2*fs)) if len(x) else 256
    f, P = signal.welch(x, fs=fs, window="hann", nperseg=max(64, nperseg))
    m = f <= fmax
    p = P[m] + 1e-18
    p /= np.sum(p)
    return float(-(p * np.log(p)).sum())


def spectral_main_peak(x, fs, fmin=0.3, fmax=8.0):
    """Dominant frequency (Hz) within [fmin,fmax] from Welch PSD."""
    x = np.asarray(x, float).ravel()
    if x.size < 8:
        return 0.0
    nperseg = int(min(len(x), 2*fs))
    f, P = signal.welch(x, fs=fs, window="hann", nperseg=max(64, nperseg))
    band = (f >= fmin) & (f <= fmax)
    return float(f[band][np.argmax(P[band])]) if np.any(band) else 0.0


def compute_time_stats(x, prefix):
    """Basic time-domain stats; names carry a prefix (signal role)."""
    x = np.asarray(x, float).ravel()
    names = [f"{prefix}_{k}" for k in ("mean","std","rms","iqr","kurt","skew")]
    vals = [
            float(np.mean(x)),
            float(np.std(x)),
            float(np.sqrt(np.mean(x**2))),
            float(np.percentile(x, 75) - np.percentile(x, 25)),
            float(kurtosis(x, fisher=False)), # Pearson definition (normal=3)
            float(skew(x))
            ]
    return names, vals

def extract_features_window(ppg_seg: np.ndarray, imu_mag: Dict[str,np.ndarray], fs: float) -> Tuple[List[str], np.ndarray]:
    """
    Features per window. Includes PPG + IMU:
    • Time stats: PPG, AccMag, GyroMag, JerkMag
    • Bandpowers: 0.1–0.5 / 0.5–3 / 3–8 Hz for PPG, AccMag, GyroMag
    • Spectral shape: spectral entropy + dominant peak for PPG, AccMag, GyroMag, JerkMag
    """
    # Minimal PPG preprocessing to remove drift & mains
    ppgm = preprocess_ppg_min(ppg_seg, fs=fs, hp_cut=0.2, mains=50)
    acc_mag, gyro_mag, jerk_mag = imu_mag['AccMag'], imu_mag['GyroMag'], imu_mag['JerkMag']


    feat_names, feat_vals = [], []
    # Time-domain stats for each channel
    for sig, pfx in [
    (ppgm, "ppg"), (acc_mag, "accmag"), (gyro_mag, "gyromag"), (jerk_mag, "jerkmag")
    ]:
        n, v = compute_time_stats(sig, pfx)
        feat_names += n; feat_vals += v


    # Bandpowers for selected channels
    bands = [(0.1,0.5),(0.5,3.0),(3.0,8.0)]
    for lo, hi in bands:
        feat_names += [f"ppg_bp_{lo}-{hi}", f"acc_bp_{lo}-{hi}", f"gyro_bp_{lo}-{hi}"]
        feat_vals += [
        welch_bandpower(ppgm, fs, lo, hi),
        welch_bandpower(acc_mag, fs, lo, hi),
        welch_bandpower(gyro_mag, fs, lo, hi)
        ]


    # Spectral-shape features for PPG + IMU (9): entropy & main peak freq
    for sig, pfx in [
    (ppgm, "ppg"), (acc_mag, "accmag"), (gyro_mag, "gyromag"), (jerk_mag, "jerkmag")
    ]:
        feat_names += [f"{pfx}_spec_entropy", f"{pfx}_main_freq"]
        feat_vals += [spectral_entropy(sig, fs, fmax=10.0), spectral_main_peak(sig, fs, fmin=0.1, fmax=8.0)]


    return feat_names, np.array(feat_vals, float)


# Sliding window helper
def make_windows(N: int, fs: float, win_sec: float, hop_sec: float):
    W, H = int(win_sec*fs), int(hop_sec*fs)
    for s in range(0, max(1, N-W+1), H):
        yield s, s+W


In [27]:
#-------------------------Dash Func-------------------
def get_folder_options():
    """遍历 DEFAULT_FOLDER_MAIN 下的子文件夹，生成 Dropdown 选项；确保包含 DEFAULT_FOLDER。"""
    paths = []
    if os.path.isdir(DEFAULT_FOLDER_MAIN):
        for name in sorted(os.listdir(DEFAULT_FOLDER_MAIN)):
            p = os.path.join(DEFAULT_FOLDER_MAIN, name)
            if os.path.isdir(p):
                paths.append(p)
    # 确保 DEFAULT_FOLDER 在选项里（即使不在 DEFAULT_FOLDER_MAIN 下，也加入）
    if DEFAULT_FOLDER and os.path.exists(DEFAULT_FOLDER) and DEFAULT_FOLDER not in paths:
        paths.insert(0, DEFAULT_FOLDER)
    # label 显示目录名，value 为完整路径
    return [{'label': os.path.basename(p) or p, 'value': p} for p in paths]

In [28]:

# ========================
# Labels & colors
# ========================
LABEL_MAP = {
0: "Rest",
1: "Sit/Stand",
2: "Walk",
3: "Transition",
4: "StrongMotion",
}
LABEL_OPTIONS = [{"label": f"{k} - {v}", "value": k} for k,v in LABEL_MAP.items()]
LABEL_COLORS = {
0: "#2ecc71", # Rest – green
1: "#3498db", # Sit/Stand – blue
2: "#e67e22", # Walk – orange
3: "#f1c40f", # Transition – yellow
4: "#e74c3c", # StrongMotion – red
}

def discrete_colorscale_from_map(map_k2hex: Dict[int,str]):
    ks = sorted(map_k2hex.keys())
    if not ks:
        return "Viridis"
    vmin, vmax = ks[0], ks[-1]
    scale = []
    for k in ks:
        v = 0.0 if vmax==vmin else (k - vmin) / (vmax - vmin)
        scale.append([v, map_k2hex[k]])
        scale.append([min(v+1e-6,1.0), map_k2hex[k]])
    return scale

In [29]:
# ========================
# Dash app layout
# ========================
external_stylesheets: List[str] = []
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
webbrowser.open(f"http://localhost:{PORT}")
app.title = "SVM Dataset Builder + Trainer (PPG+IMU)"
folder_options = get_folder_options()

left_panel_style = {"border":"1px solid #eee","borderRadius":"8px","padding":"10px"}
right_panel_style = {"border":"1px solid #eee","borderRadius":"8px","padding":"10px"}

table_style_table = {"maxHeight":"260px","overflowY":"auto","overflowX":"auto","maxWidth":"100%","minWidth":"100%"}
table_style_cell  = {"minWidth":"110px","width":"140px","maxWidth":"240px","whiteSpace":"normal","textAlign":"left"}

app.layout = html.Div(
    style={"backgroundColor": "white", "padding": "12px"},
    children=[
        html.H2("SVM Dataset Builder + Quick Trainer (PPG + IMU)", style={"marginBottom": "6px"}),
        html.Div(
            style={"display": "grid", "gridTemplateColumns": "1000px 1fr", "gap": "12px"},
            children=[
                # ------------------ Left control panel ------------------
                html.Div([
                    html.H4("1) File & Columns", style={"marginTop": "8px"}),
                    dcc.Dropdown(id="input-folder",             
                                 options=folder_options,
                                    value=DEFAULT_FOLDER,
                                    clearable=False,
                                    placeholder='Select data folder',
                                    style={'width': '80%'}
                                ),
                    #html.Button("Scan Folder", id="btn-scan", n_clicks=0, style={"width": "100%", "marginTop": 6}),
                    dcc.Dropdown(id="ddl-files", options=[], placeholder="Select a CSV file...", style={"marginTop":6}),
                    html.Div([
                        dcc.Input(id="input-ppg-col", type="text", value="IR", placeholder="PPG column (IR/RED/custom)", style={"width":"48%"}),
                        dcc.Input(id="input-fs", type="number", value=FS_DEFAULT, step=1, placeholder="Fs (Hz)", style={"width":"48%","float":"right"})
                    ], style={"marginTop":6}),
                    html.Div(id="div-head-preview", style={"marginTop":8}),

                    html.H4("2) Segment & Label"),
                    html.Div([
                        dcc.RangeSlider(id="rs-seg", min=0, max=10, step=0.5, value=[0,10], tooltip={"always_visible":True}),
                        html.Div(id="txt-seg", style={"marginTop":4})
                    ], style={"marginTop":6}),
                    dcc.Dropdown(id="ddl-label", options=LABEL_OPTIONS, value=0, style={"marginTop":6}),
                    html.Button("Add Labeled Segment → Feature rows", id="btn-add-label", n_clicks=0,
                                style={"width":"100%","marginTop":6,"backgroundColor":"#2ecc71","color":"white"}),
                    html.Button("Save Current Dataset CSV", id="btn-save-ds", n_clicks=0,
                                style={"width":"100%","marginTop":6}),
                    html.Div(id="txt-save-status", style={"marginTop":6,"color":"#2c3e50"}),

                    html.H4("3) Windowing for Features"),
                    html.Div([
                        dcc.Input(id="input-win", type="number", value=3.0, step=0.5, placeholder="win_sec", style={"width":"48%"}),
                        dcc.Input(id="input-hop", type="number", value=1.0, step=0.5, placeholder="hop_sec", style={"width":"48%","float":"right"}),
                    ], style={"marginTop":6}),

                    html.Hr(),
                    html.H4("4) Train SVM (quick)"),
                    dcc.RadioItems(id="ri-kernel", options=[{"label":"RBF","value":"rbf"},{"label":"Linear","value":"linear"}], value="rbf", inline=True),
                    html.Div([
                        dcc.Input(id="input-C", type="number", value=10.0, step=0.5, placeholder="C", style={"width":"48%"}),
                        dcc.Input(id="input-gamma", type="text", value="scale", placeholder="gamma (scale/auto/float)", style={"width":"48%","float":"right"}),
                    ], style={"marginTop":6}),
                    dcc.Checklist(id="chk-pca", options=[{"label":"Use PCA (var)", "value":"use"}], value=["use"], inline=True),
                    dcc.Slider(id="sl-pca-var", min=0.80, max=0.99, step=0.01, value=0.95, marks=None, tooltip={"always_visible":True}),
                    dcc.Checklist(id="chk-group-file", options=[{"label":"Group by file (CV)", "value":"group"}], value=["group"], inline=True),
                    dcc.Checklist(id="chk-excl-gj", options=[{"label":"Exclude Gyro & Jerk features (default ON)", "value":"excl"}], value=["excl"], inline=False, style={"marginTop":6}),
                    html.Div([
                        dcc.Input(id="input-cv", type="number", value=5, step=1, placeholder="CV folds", style={"width":"48%"}),
                        dcc.Input(id="input-test", type="number", value=0.2, step=0.05, placeholder="Holdout ratio", style={"width":"48%","float":"right"}),
                    ], style={"marginTop":6}),
                    html.Button("Train Now", id="btn-train", n_clicks=0,
                                style={"width":"100%","marginTop":6,"backgroundColor":"#34495e","color":"white"}),
                    html.Div(id="txt-train-status", style={"marginTop":6,"color":"#2c3e50"}),

                    html.Hr(),
                    html.H4("5) Training Dataset Source"),
                    dcc.RadioItems(id="ri-ds-source", options=[{"label":"In-memory (Store)","value":"mem"},{"label":"From saved CSV","value":"file"}], value="mem", inline=True),
                    html.Button("Scan Datasets", id="btn-scan-ds", n_clicks=0, style={"marginTop":6}),
                    dcc.Dropdown(id="ddl-ds-file", options=[], placeholder="Select a dataset CSV from ./datasets", style={"marginTop":6}),
                    html.Div(id="div-ds-preview", style={"marginTop":6}),

                    html.Hr(),
                    html.H4("6) Parameter Glossary / 参数说明"),
                    html.Details([
                        html.Summary("Click to expand / 展开查看"),
                        html.Ul([
                            html.Li("Folder, Scan Folder: choose data location / 选择数据目录并扫描 CSV"),
                            html.Li("File dropdown: select a CSV file / 选择 CSV 文件"),
                            html.Li("PPG column: IR/RED/custom / 选择或输入 PPG 列名"),
                            html.Li("Fs (Hz): sampling rate / 采样率"),
                            html.Li("Segment slider: time range for preview & labeling / 选择预览与标注的时间段"),
                            html.Li("Label: motion class for the selected segment / 给选定片段贴标签"),
                            html.Li("Add Labeled Segment: extract windowed features and append to dataset / 提取特征并追加到训练集"),
                            html.Li("Save Dataset CSV: export current dataset to ./datasets / 导出当前训练集"),
                            html.Li("win_sec & hop_sec: sliding window and hop in seconds / 滑窗长度与步进（秒）"),
                            html.Li("Kernel (RBF/Linear): SVM kernel / 核函数"),
                            html.Li("C: regularization strength / 正则强度（越大越贴合训练集）"),
                            html.Li("gamma: RBF scale (scale/auto/float) / RBF 尺度（可填数值）"),
                            html.Li("Use PCA (var): keep given variance ratio / PCA 保留方差比例"),
                            html.Li("Group by file (CV): use GroupKFold / 以文件分组做交叉验证，防泄漏"),
                            html.Li("Exclude Gyro & Jerk: train without those features / 训练时排除陀螺与冲击特征"),
                            html.Li("CV folds & Holdout ratio: k-folds and test split / 折数与留出比例"),
                            html.Li("Training Dataset Source: choose from Store or saved CSV / 训练集来源：内存或CSV")
                        ])
                    ])
                ], style=left_panel_style),

                # ------------------ Right visualization panel ------------------
                html.Div([
                    html.H4("Preview: Raw PPG & IMU (selected file / segment)"),
                    dcc.Graph(id="fig-raw"),

                    html.H4("Welch Spectra (PPG, AccMag, GyroMag)"),
                    dcc.Graph(id="fig-psd"),

                    html.H4("Feature Table (from labeled segments → sliding windows)"),
                    dash_table.DataTable(id="table-feats", page_size=10,
                                        style_table=table_style_table, style_cell=table_style_cell),

                    html.H4("Training Quick Results"),
                    html.Div(id="div-train-metrics"),
                    html.Div(style={"display":"grid","gridTemplateColumns":"1fr 1fr","gap":"8px"}, children=[
                        dcc.Graph(id="fig-cm"),
                        dcc.Graph(id="fig-f1")
                    ]),

                    html.H4("Inference Preview on Selected Segment (predicted labels)"),
                    dcc.Graph(id="fig-infer"),
                ], style=right_panel_style),
            ]),
            dcc.Store(id="store-folder"),
            dcc.Store(id="store-file-path"),
            dcc.Store(id="store-file-meta"),
            dcc.Store(id="store-labels", data=[]),
            dcc.Store(id="store-dataset", data=None),
            dcc.Store(id="store-model-info", data=None),
        ])



In [30]:
# ========================
# Callbacks: scan & load
# ========================
@app.callback(
    Output("ddl-files", "options"),
    Output("store-folder", "data"),
    #Input("btn-scan", "n_clicks"),
    Input("input-folder", "value"),
    prevent_initial_call=True
)
def update_file_list(folder):
    """List CSV filenames in the selected directory for dropdown display."""
    if not folder or not os.path.exists(folder):
        print("Wrong Path")
        return [], None
    files = sorted([f for f in os.listdir(folder) if f.endswith('.csv')])
    opts = [{'label': Path(f).name, 'value': os.path.join(folder, f)} for f in files]
    default_val = files[0] if files else None
    return opts, default_val

@app.callback(
    Output("store-file-path", "data"),
    Output("div-head-preview", "children"),
    Output("store-file-meta", "data"),
    Output("rs-seg", "max"),
    Output("rs-seg", "value"),
    Input("ddl-files", "value"),
    State("input-fs", "value"),
)
def load_file(file_path, fs):
    if not file_path:
        return None, html.Div(""), None, 10, [0,10]
    try:
        df = pd.read_csv(file_path)
        head = df.head(3)
        table = dash_table.DataTable(
            data=head.to_dict("records"),
            columns=[{"name": c, "id": c} for c in head.columns],
            page_size=3,
            style_table={"overflowX":"auto","maxWidth":"100%"},
            style_cell=table_style_cell
        )
        N = len(df); fs = float(fs or FS_DEFAULT)
        dur = float(N/fs)
        meta = {"N": N, "duration": dur}
        new_max = max(5.0, round(dur, 2))
        return file_path, table, meta, new_max, [0.0, new_max]
    except Exception as e:
        return None, html.Div(f"Failed to load: {e}"), None, 10, [0,10]

@app.callback(Output("txt-seg", "children"), Input("rs-seg", "value"))
def seg_text(value):
    if not value:
        return "No segment selected."
    return f"Segment: {value[0]:.2f} s → {value[1]:.2f} s"

# ========================
# Previews
# ========================
@app.callback(
    Output("fig-raw", "figure"),
    Output("fig-psd", "figure"),
    Input("ddl-files", "value"),
    Input("rs-seg", "value"),
    State("input-ppg-col", "value"),
    State("input-fs", "value"),
)
def preview_signals(file_path, seg, ppg_col, fs):
    fig_empty = go.Figure().update_layout(height=240, paper_bgcolor="white", plot_bgcolor="white")
    if not file_path or not seg:
        return fig_empty, fig_empty
    fs = float(fs or FS_DEFAULT)
    try:
        df = pd.read_csv(file_path)
        if ppg_col not in df.columns:
            ppg_col = "IR" if "IR" in df.columns else df.columns[0]
        t = np.arange(len(df)) / fs
        m = (t >= seg[0]) & (t <= seg[1])
        ppg = df[ppg_col].to_numpy(float)
        imu = imu_preprocess_with_kf(df, fs=fs)
        acc_mag, gyro_mag = imu['AccMag'], imu['GyroMag']

        fig_t = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.04,
                              subplot_titles=(f"PPG ({ppg_col})", "AccMag (m/s²)", "GyroMag (rad/s)"))
        fig_t.add_trace(go.Scatter(x=t[m], y=ppg[m], name="PPG", line=dict(color="purple")), row=1, col=1)
        fig_t.add_trace(go.Scatter(x=t[m], y=acc_mag[m], name="AccMag", line=dict(color="crimson")), row=2, col=1)
        fig_t.add_trace(go.Scatter(x=t[m], y=gyro_mag[m], name="GyroMag", line=dict(color="royalblue")), row=3, col=1)
        fig_t.update_layout(height=520, margin=dict(l=50,r=30,t=60,b=40), paper_bgcolor="white", plot_bgcolor="white")
        for r in (1,2,3):
            fig_t.update_yaxes(showgrid=True, gridcolor="rgba(0,0,0,0.08)", row=r, col=1)
        fig_t.update_xaxes(title_text="Time (s)", row=3, col=1)

        def psd_curve(x, fs):
            x = x[m]
            if len(x) < int(2*fs):
                nperseg = max(128, int(len(x)//2))
            else:
                nperseg = int(2*fs)
            f, P = signal.welch(x, fs=fs, window="hann", nperseg=nperseg, noverlap=int(0.5*nperseg))
            return f, P
        f1,P1 = psd_curve(ppg, fs)
        f2,P2 = psd_curve(acc_mag, fs)
        f3,P3 = psd_curve(gyro_mag, fs)

        fig_f = go.Figure()
        fig_f.add_trace(go.Scatter(x=f1, y=P1, name="PPG", line=dict(color="purple")))
        fig_f.add_trace(go.Scatter(x=f2, y=P2, name="AccMag", line=dict(color="crimson")))
        fig_f.add_trace(go.Scatter(x=f3, y=P3, name="GyroMag", line=dict(color="royalblue")))
        fig_f.update_layout(height=320, margin=dict(l=50,r=30,t=50,b=40), paper_bgcolor="white", plot_bgcolor="white",
                            xaxis_title="Frequency (Hz)", yaxis_title="PSD", xaxis=dict(range=[0,8]))
        fig_f.update_xaxes(showgrid=True, gridcolor="rgba(0,0,0,0.08)")
        fig_f.update_yaxes(showgrid=True, gridcolor="rgba(0,0,0,0.08)")
        return fig_t, fig_f
    except Exception as e:
        fig = go.Figure().update_layout(title=f"Error: {e}", height=240, paper_bgcolor="white", plot_bgcolor="white")
        return fig, fig

# ========================
# Add labeled segment → features accumulation (writes back to Store)
# ========================
@app.callback(
    Output("store-labels", "data"),
    Output("table-feats", "columns"),
    Output("table-feats", "data"),
    Output("txt-save-status", "children"),
    Output("store-dataset", "data", allow_duplicate=True),
    Input("btn-add-label", "n_clicks"),
    State("ddl-files", "value"),
    State("input-ppg-col", "value"),
    State("input-fs", "value"),
    State("rs-seg", "value"),
    State("ddl-label", "value"),
    State("input-win", "value"),
    State("input-hop", "value"),
    State("store-labels", "data"),
    State("store-dataset", "data"),
    prevent_initial_call=True
)
def add_labeled_segment(nc, file_path, ppg_col, fs, seg, label_id, win_sec, hop_sec, labels_data, ds_data):
    if not file_path or not seg or win_sec is None or hop_sec is None:
        return (labels_data or [], no_update, no_update, "No file/segment or window params.", ds_data)
    fs = float(fs or FS_DEFAULT)
    t0, t1 = float(seg[0]), float(seg[1])
    if t1 <= t0:
        return (labels_data or [], no_update, no_update, "Invalid segment: end ≤ start.", ds_data)
    try:
        df = pd.read_csv(file_path)
        if ppg_col not in df.columns:
            ppg_col = "IR" if "IR" in df.columns else df.columns[0]
        N = len(df)
        s_idx = int(max(0, min(N-1, np.floor(t0*fs))))
        e_idx = int(max(1, min(N,     np.ceil (t1*fs))))
        imu_all = imu_preprocess_with_kf(df, fs=fs)
        acc_mag_all, gyro_mag_all, jerk_mag_all = imu_all['AccMag'], imu_all['GyroMag'], imu_all['JerkMag']
        rows = []
        feat_names_ref = None
        ppg_arr = df[ppg_col].to_numpy(float)
        for s, e in make_windows(N, fs, float(win_sec), float(hop_sec)):
            if s >= s_idx and e <= e_idx:
                ppg_seg = ppg_arr[s:e]
                imu_win = {"AccMag":acc_mag_all[s:e], "GyroMag":gyro_mag_all[s:e], "JerkMag":jerk_mag_all[s:e]}
                feat_names, feats = extract_features_window(ppg_seg, imu_win, fs)
                if feat_names_ref is None:
                    feat_names_ref = feat_names
                rows.append(feats.tolist() + [int(label_id), float(s/fs), float(e/fs), os.path.basename(file_path)])
        existing = ds_data or getattr(dash.get_app(), "_cache_dataset", None)
        if not rows:
            return (labels_data or [], no_update, no_update, "No window fully inside the selected interval.", existing or ds_data)
        cols = feat_names_ref + ["label","t_start","t_end","file"]
        new_df = pd.DataFrame(rows, columns=cols)
        if existing:
            base_df = pd.DataFrame(existing["data"], columns=existing["columns"])
            if list(base_df.columns) == cols:
                ds_df = pd.concat([base_df, new_df], ignore_index=True)
            else:
                ds_df = pd.concat([base_df.reindex(columns=cols), new_df.reindex(columns=cols)], ignore_index=True)
        else:
            ds_df = new_df
        ds_payload = {"columns": cols, "data": ds_df.values.tolist()}
        dash.get_app()._cache_dataset = ds_payload
        labels_data = (labels_data or []) + [{"file": os.path.basename(file_path), "t0": t0, "t1": t1, "label": int(label_id)}]
        feat_cols = [{"name": c, "id": c} for c in cols]
        feat_rows = ds_df.tail(200).to_dict("records")
        msg = f"Added {len(rows)} windows. Dataset size: {len(ds_df)} rows."
        return labels_data, feat_cols, feat_rows, msg, ds_payload
    except Exception as e:
        prev = ds_data or getattr(dash.get_app(), "_cache_dataset", None)
        return (labels_data or [], no_update, no_update, f"Error: {e}", prev)

# ========================
# Dataset source & preview
# ========================
@app.callback(Output("ddl-ds-file", "options"), Input("btn-scan-ds", "n_clicks"), prevent_initial_call=True)
def scan_datasets(_):
    files = sorted(DATASET_DIR.glob("*.csv"))
    return [{"label": f.name, "value": str(f)} for f in files]

@app.callback(
    Output("div-ds-preview", "children"),
    Input("ri-ds-source", "value"),
    Input("store-dataset", "data"),
    Input("ddl-ds-file", "value"),
)
def preview_dataset(src, ds_mem, ds_file):
    try:
        if src == "file" and ds_file:
            df = pd.read_csv(ds_file)
            head = df.head(3)
            return html.Div([
                html.Div(f"Loaded CSV: {Path(ds_file).name} | Rows: {len(df)}"),
                dash_table.DataTable(data=head.to_dict("records"), columns=[{"name":c, "id":c} for c in head.columns],
                                     page_size=3, style_table={"overflowX":"auto","maxWidth":"100%"}, style_cell=table_style_cell)
            ])
        elif src == "mem":
            payload = getattr(dash.get_app(), "_cache_dataset", None) or ds_mem
            if not payload:
                return html.Div("In-memory dataset is empty.")
            df = pd.DataFrame(payload["data"], columns=payload["columns"]) if isinstance(payload, dict) else pd.DataFrame()
            head = df.head(3)
            return html.Div([
                html.Div(f"In-memory dataset | Rows: {len(df)}"),
                dash_table.DataTable(data=head.to_dict("records"), columns=[{"name":c, "id":c} for c in head.columns],
                                     page_size=3, style_table={"overflowX":"auto","maxWidth":"100%"}, style_cell=table_style_cell)
            ])
        else:
            return html.Div("Select a dataset source.")
    except Exception as e:
        return html.Div(f"Preview failed: {e}")

# ========================
# Save dataset
# ========================
@app.callback(
    Output("store-dataset", "data", allow_duplicate=True),
    Output("txt-save-status", "children", allow_duplicate=True),
    Input("btn-save-ds", "n_clicks"),
    State("store-dataset", "data"),
    prevent_initial_call=True
)
def save_dataset(nc, ds_data):
    try:
        payload = getattr(dash.get_app(), "_cache_dataset", None) or ds_data
        if payload is None:
            return ds_data, "Nothing to save yet."
        cols, data = payload["columns"], payload["data"]
        df = pd.DataFrame(data, columns=cols)
        tag = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
        out = DATASET_DIR / f"motion_dataset_{tag}.csv"
        df.to_csv(out, index=False)
        return payload, f"Saved dataset → {out.as_posix()} ({len(df)} rows)."
    except Exception as e:
        return ds_data, f"Save failed: {e}"

# ========================
# Training helpers
# ========================

def build_Xy_from_df(df: pd.DataFrame):
    meta_cols = {"label","t_start","t_end","file"}
    feat_cols = [c for c in df.columns if c not in meta_cols]
    X = df[feat_cols].to_numpy(float)
    y = df["label"].to_numpy(int)
    groups = df["file"].to_numpy(object) if "file" in df.columns else None
    return X, y, groups, feat_cols

def select_feat_cols(feat_cols: List[str], exclude_gj: bool) -> List[str]:
    if not exclude_gj:
        return feat_cols
    drop_prefixes = ("gyromag_", "gyro_bp_", "gyromag_spec_entropy", "gyromag_main_freq",
                     "jerkmag_", "jerk_bp_", "jerkmag_spec_entropy", "jerkmag_main_freq")
    keep = [c for c in feat_cols if not c.startswith(drop_prefixes)]
    return keep

def train_pipeline_quick(X, y, groups, feat_cols_used: List[str], kernel: str, C_val: float, gamma_val: str|float,
                         use_pca: bool, pca_var: float, cv_folds: int, test_ratio: float,
                         group_by_file: bool, random_state: int = 42):
    steps = [("scaler", StandardScaler())]
    if use_pca:
        steps.append(("pca", PCA(n_components=pca_var, svd_solver="full", whiten=False, random_state=random_state)))
    try:
        gval = float(gamma_val)
    except Exception:
        gval = gamma_val
    clf = SVC(kernel=kernel, C=C_val, gamma=gval if kernel=="rbf" else "scale",
              class_weight='balanced', probability=True, random_state=random_state)
    steps.append(("clf", clf))
    pipe = Pipeline(steps)

    scoring = ["accuracy","balanced_accuracy","precision_macro","recall_macro","f1_macro","f1_weighted"]
    if group_by_file and groups is not None:
        cv = GroupKFold(n_splits=max(2, int(cv_folds)))
        scores = cross_validate(pipe, X, y, groups=groups, cv=cv, scoring=scoring, n_jobs=-1, return_train_score=False)
    else:
        scores = cross_validate(pipe, X, y, cv=max(2, int(cv_folds)), scoring=scoring, n_jobs=-1, return_train_score=False)

    # Holdout split (group-aware if requested)
    if group_by_file and groups is not None:
        gss = GroupShuffleSplit(n_splits=1, train_size=1.0-float(test_ratio), random_state=random_state)
        idx_tr, idx_va = next(gss.split(X, y, groups))
        X_tr, X_va, y_tr, y_va = X[idx_tr], X[idx_va], y[idx_tr], y[idx_va]
    else:
        X_tr, X_va, y_tr, y_va = train_test_split(X, y, test_size=float(test_ratio), random_state=random_state, stratify=y)

    pipe.fit(X_tr, y_tr)
    y_pred = pipe.predict(X_va)

    # Holdout metrics
    acc_h   = accuracy_score(y_va, y_pred)
    balc_h  = balanced_accuracy_score(y_va, y_pred)
    prec_h, rec_h, f1_h, _ = precision_recall_fscore_support(y_va, y_pred, average='macro', zero_division=0)
    f1w_h = precision_recall_fscore_support(y_va, y_pred, average='weighted', zero_division=0)[2]

    report = classification_report(y_va, y_pred, output_dict=True, zero_division=0)
    cm = confusion_matrix(y_va, y_pred, labels=sorted(np.unique(y)))

    tag = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    model_path = MODEL_DIR / f"svm_motion_{kernel}_{tag}.pkl"
    joblib.dump(dict(pipeline=pipe, meta=dict(tag=tag, kernel=kernel, C=C_val, gamma=str(gamma_val), pca=use_pca,
                                              pca_var=pca_var, cv_folds=cv_folds, test_ratio=test_ratio,
                                              feat_cols_used=feat_cols_used)), model_path)

    # Aggregate CV metrics (mean ± std)
    cv_summary = {k.replace('test_',''): (float(np.mean(v)), float(np.std(v))) for k,v in scores.items() if k.startswith('test_')}

    return pipe, dict(
        cv=cv_summary, holdout=dict(acc=acc_h, balc=balc_h, prec=prec_h, rec=rec_h, f1=f1_h, f1w=f1w_h,
                                    n=int(len(y_va))),
        report=report, cm=cm, model_path=str(model_path), classes=sorted(np.unique(y)), feat_cols_used=feat_cols_used
    )

# ========================
# Train now
# ========================
@app.callback(
    Output("txt-train-status", "children"),
    Output("div-train-metrics", "children"),
    Output("fig-cm", "figure"),
    Output("fig-f1", "figure"),
    Output("store-model-info", "data"),
    Input("btn-train", "n_clicks"),
    State("ri-ds-source", "value"),
    State("ddl-ds-file", "value"),
    State("store-dataset", "data"),
    State("ri-kernel", "value"),
    State("input-C", "value"),
    State("input-gamma", "value"),
    State("chk-pca", "value"),
    State("sl-pca-var", "value"),
    State("input-cv", "value"),
    State("input-test", "value"),
    State("chk-group-file", "value"),
    State("chk-excl-gj", "value"),
    prevent_initial_call=True
)
def train_now(nc, ds_src, ds_file, ds_mem, kernel, C_val, gamma_val, chk_pca, pca_var, cv_folds, test_ratio, chk_group, chk_excl):
    try:
        if ds_src == "file" and ds_file:
            df = pd.read_csv(ds_file)
        else:
            payload = getattr(dash.get_app(), "_cache_dataset", None) or ds_mem
            if payload is None:
                return "No dataset. Add labeled windows or choose a saved CSV.", "", go.Figure(), go.Figure(), None
            df = pd.DataFrame(payload["data"], columns=payload["columns"]) if isinstance(payload, dict) else pd.DataFrame()
        if df.empty:
            return "Dataset is empty.", "", go.Figure(), go.Figure(), None

        X_all, y, groups, feat_cols_all = build_Xy_from_df(df)
        exclude_gj = (chk_excl is not None and "excl" in chk_excl)
        feat_cols_used = select_feat_cols(feat_cols_all, exclude_gj)
        X = df[feat_cols_used].to_numpy(float)

        use_pca = (chk_pca is not None and "use" in chk_pca)
        group_by_file = (chk_group is not None and "group" in chk_group)
        pipe, info = train_pipeline_quick(X, y, groups, feat_cols_used, kernel, float(C_val), gamma_val,
                                          use_pca, float(pca_var), int(cv_folds), float(test_ratio), group_by_file)

        # Training quick results: richer metrics
        cv = info['cv']; ho = info['holdout']
        def fmt(mu_std):
            mu, sd = mu_std
            return f"{mu:.3f}±{sd:.3f}"
        txt = f"Model saved: {info['model_path']}"
        metrics_div = html.Div([
            html.Div([html.Strong("Cross-Validation (mean±std): "),
                      html.Span(f"Acc {fmt(cv['accuracy'])} | BalAcc {fmt(cv['balanced_accuracy'])} | "+
                                f"Prec_macro {fmt(cv['precision_macro'])} | Rec_macro {fmt(cv['recall_macro'])} | "+
                                f"F1_macro {fmt(cv['f1_macro'])} | F1_weighted {fmt(cv['f1_weighted'])}")]),
            html.Div([html.Strong("Holdout: "),
                      html.Span(f"Acc {ho['acc']:.3f} | BalAcc {ho['balc']:.3f} | Prec_macro {ho['prec']:.3f} | "+
                                f"Rec_macro {ho['rec']:.3f} | F1_macro {ho['f1']:.3f} | F1_weighted {ho['f1w']:.3f} | "+
                                f"Samples {ho['n']}")])
        ])

        cm = info['cm']
        classes = [f"{c}:{LABEL_MAP.get(c,'c'+str(c))}" for c in info['classes']]
        fig_cm = go.Figure(data=go.Heatmap(z=cm, x=classes, y=classes, colorscale="Blues", showscale=True))
        fig_cm.update_layout(height=320, margin=dict(l=50,r=30,t=40,b=40), paper_bgcolor="white", plot_bgcolor="white",
                             title="Confusion Matrix (holdout)")
        f1_vals = [info['report'].get(str(c),{}).get('f1-score', 0.0) for c in info['classes']]
        fig_f1 = go.Figure(data=go.Bar(x=classes, y=f1_vals, name="F1 (holdout)"))
        fig_f1.update_layout(height=320, margin=dict(l=50,r=30,t=40,b=40), paper_bgcolor="white", plot_bgcolor="white",
                             yaxis_title="F1-score", title="Per-class F1 (holdout)")

        model_info = dict(path=info['model_path'], classes=info['classes'], feat_cols_used=feat_cols_used)
        return txt, metrics_div, fig_cm, fig_f1, model_info
    except Exception as e:
        return f"Training failed: {e}", "", go.Figure(), go.Figure(), None

# ========================
# Inference preview (with legend)
# ========================
@app.callback(
    Output("fig-infer", "figure"),
    Input("store-model-info", "data"),
    Input("ddl-files", "value"),
    Input("rs-seg", "value"),
    State("input-ppg-col", "value"),
    State("input-fs", "value"),
    State("input-win", "value"),
    State("input-hop", "value"),
)
def infer_preview(model_info, file_path, seg, ppg_col, fs, win_sec, hop_sec):
    fig_empty = go.Figure().update_layout(height=220, paper_bgcolor="white", plot_bgcolor="white",
                                          title="(Train a model to see predictions)")
    if not (model_info and file_path and seg):
        return fig_empty
    try:
        bundle = joblib.load(model_info['path'])
        pipe = bundle["pipeline"]
        feat_cols_used = bundle["meta"].get("feat_cols_used", model_info.get("feat_cols_used", []))
        fs = float(fs or FS_DEFAULT)
        df = pd.read_csv(file_path)
        if ppg_col not in df.columns:
            ppg_col = "IR" if "IR" in df.columns else df.columns[0]
        N = len(df)
        t = np.arange(N)/fs
        msel = (t >= seg[0]) & (t <= seg[1])
        imu_all = imu_preprocess_with_kf(df, fs=fs)
        acc_mag, gyro_mag, jerk_mag = imu_all['AccMag'], imu_all['GyroMag'], imu_all['JerkMag']
        ids = np.zeros(N, dtype=int)
        for s, e in make_windows(N, fs, float(win_sec), float(hop_sec)):
            if e <= int(seg[0]*fs) or s >= int(seg[1]*fs):
                continue
            ppg_seg = df[ppg_col].to_numpy(float)[s:e]
            imu_win = {'AccMag':acc_mag[s:e], 'GyroMag':gyro_mag[s:e], 'JerkMag':jerk_mag[s:e]}
            feat_names, fv = extract_features_window(ppg_seg, imu_win, fs)
            idx_map = {name:i for i,name in enumerate(feat_names)}
            sel = [fv[idx_map[c]] for c in feat_cols_used]
            y_hat = int(pipe.predict(np.asarray(sel)[None,:])[0])
            ids[s:e] = y_hat
        scale = discrete_colorscale_from_map(LABEL_COLORS)
        fig = go.Figure()
        # Heatmap
        fig.add_trace(go.Heatmap(z=ids[np.newaxis,:], x=t, y=["State"], colorscale=scale, showscale=False,
                                 zmin=min(LABEL_COLORS.keys()), zmax=max(LABEL_COLORS.keys()), name="Predicted"))
        # Legend proxies (dummy scatter for legend only)
        for k, name in LABEL_MAP.items():
            fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color=LABEL_COLORS[k], size=10),
                                     name=f"{k}:{name}", showlegend=True))
        fig.update_xaxes(range=[seg[0], seg[1]])
        fig.update_layout(height=220, margin=dict(l=50,r=30,t=40,b=40), paper_bgcolor="white", plot_bgcolor="white",
                          title="Predicted motion labels (windowed → expanded)")
        return fig
    except Exception as e:
        return go.Figure().update_layout(height=220, title=f"Inference failed: {e}", paper_bgcolor="white", plot_bgcolor="white")


/usr/bin/xdg-open: 882: x-www-browser: not found
/usr/bin/xdg-open: 882: firefox: not found


In [None]:
# ========================
# Entrypoint
# ========================

def main():
    app.run(debug=True, port=8051)

if __name__ == "__main__":
    main()

/usr/bin/xdg-open: 882: iceweasel: not found
/usr/bin/xdg-open: 882: seamonkey: not found
/usr/bin/xdg-open: 882: mozilla: not found
/usr/bin/xdg-open: 882: epiphany: not found
/usr/bin/xdg-open: 882: konqueror: not found
/usr/bin/xdg-open: 882: chromium: not found
/usr/bin/xdg-open: 882: chromium-browser: not found
/usr/bin/xdg-open: 882: google-chrome: not found
/usr/bin/xdg-open: 882: www-browser: not found
/usr/bin/xdg-open: 882: links2: not found
/usr/bin/xdg-open: 882: elinks: not found
/usr/bin/xdg-open: 882: links: not found
/usr/bin/xdg-open: 882: lynx: not found
/usr/bin/xdg-open: 882: w3m: not found
xdg-open: no method available for opening 'http://localhost:8051'


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr