In [1]:
# dashapp_v7.py
import json, threading, traceback
from pathlib import Path
from datetime import datetime

import dash
from dash import dcc, html, Input, Output, State, callback_context, no_update
import plotly.graph_objs as go

# 导入训练逻辑（与上一个文件同目录）
import importlib.util, sys
spec = importlib.util.spec_from_file_location("ppgpipe", "./pttppg_pipeline_v7.py")
ppgpipe = importlib.util.module_from_spec(spec); sys.modules["ppgpipe"]=ppgpipe; spec.loader.exec_module(ppgpipe)

app = dash.Dash(__name__)
app.title = "PTT-PPG v7 Trainer"

SyntaxError: keyword argument repeated: groups (pttppg_pipeline_v7.py, line 278)

In [2]:
app = dash.Dash(__name__)
app.title = "PTT-PPG v7 Trainer"

def metric_bar(title, setup1, setup2, key, invert=False):
    y1 = setup1.get(key) if setup1 else None
    y2 = setup2.get(key) if setup2 else None
    return dcc.Graph(figure=go.Figure(data=[
        go.Bar(name="Setup1", x=["Setup1"], y=[y1]),
        go.Bar(name="Setup2", x=["Setup2"], y=[y2]),
    ]).update_layout(barmode='group', title=title + (" (↓更好)" if invert else " (↑更好)")))

app.layout = html.Div(style={"padding":"16px", "maxWidth":"1100px", "margin":"auto"}, children=[
    html.H2("PTT-PPG AE + Denoiser 统一训练（v7）"),
    html.Div([
        html.Label("数据路径（PhysioNet 1.1.0 根目录或其上层）"),
        dcc.Input(id="data-root", type="text", placeholder="/path/to/physionet.org/files/pulse-transit-time-ppg/1.1.0",
                  style={"width":"100%"})
    ], style={"marginBottom":"12px"}),

    html.Div([
        html.Div([
            html.Label("采样率 fs (Hz)"), dcc.Input(id="fs", type="number", value=500, step=1),
            html.Br(), html.Label("窗口 win (s)"), dcc.Input(id="win", type="number", value=6, step=0.5),
            html.Br(), html.Label("步长 hop (s)"), dcc.Input(id="hop", type="number", value=1, step=0.5),
        ], style={"width":"24%", "display":"inline-block", "verticalAlign":"top", "paddingRight":"10px"}),

        html.Div([
            html.Label("划分模式"),
            dcc.Dropdown(id="split-mode", options=[{"label":"GroupKFold","value":"kfold"},
                                                   {"label":"GroupHoldout(按比例)","value":"holdout"}],
                         value="kfold"),
            html.Br(), html.Label("KFold 折数 / Holdout 训练占比"),
            dcc.Slider(id="split-param", min=2, max=10, step=1, value=5,
                       marks={i:str(i) for i in range(2,11)}),
            html.Div(id="split-param-tip", style={"fontSize":"12px","color":"#555"}),
        ], style={"width":"24%", "display":"inline-block", "verticalAlign":"top", "paddingRight":"10px"}),

        html.Div([
            html.Label("训练参数"),
            html.Br(), html.Label("epochs_ae"), dcc.Input(id="epochs-ae", type="number", value=20, step=1),
            html.Br(), html.Label("epochs_denoise"), dcc.Input(id="epochs-denoise", type="number", value=20, step=1),
            html.Br(), html.Label("学习率 lr"), dcc.Input(id="lr", type="number", value=1e-3, step=1e-4),
        ], style={"width":"24%", "display":"inline-block", "verticalAlign":"top", "paddingRight":"10px"}),

        html.Div([
            html.Label("其他超参"),
            html.Br(), html.Label("运动阈值 motion_thresh"), dcc.Input(id="motion-thresh", type="number", value=0.8, step=0.1),
            html.Br(), html.Label("AE 阈值 detector_threshold（可空）"), dcc.Input(id="det-thr", type="number", value=None),
            html.Br(), html.Label("损失系数：spec/hr/cons"),
            html.Div([
                dcc.Input(id="lam-spec", type="number", value=0.5, step=0.1, style={"width":"30%"}),
                dcc.Input(id="lam-hr", type="number", value=0.1, step=0.05, style={"width":"30%", "marginLeft":"2%"}),
                dcc.Input(id="lam-cons", type="number", value=0.2, step=0.05, style={"width":"30%", "marginLeft":"2%"}),
            ]),
        ], style={"width":"28%", "display":"inline-block", "verticalAlign":"top"}),
    ], style={"marginBottom":"12px"}),

    html.Div([
        html.Button("▶ 开始训练（跑 Setup1 & Setup2）", id="btn-start", n_clicks=0, style={"fontSize":"16px","padding":"10px 16px"}),
        html.Span(id="status", style={"marginLeft":"12px", "color":"#0070f3"}),
    ], style={"marginBottom":"16px"}),

    html.Hr(),
    html.H3("训练结果（每次运行会覆盖 results/ 下的文件）"),
    html.Div(id="results-summary", style={"whiteSpace":"pre-wrap", "fontFamily":"monospace", "background":"#fafafa", "padding":"8px"}),

    html.Hr(),
    html.H3("Holdout 指标对比"),
    html.Div(id="comparison-charts"),
])

@app.callback(
    Output("split-param-tip","children"),
    Output("split-param","min"),
    Output("split-param","max"),
    Output("split-param","marks"),
    Input("split-mode","value"),
)
def _update_split_tip(mode):
    if mode=="holdout":
        marks={i/10:f"{i/10:.1f}" for i in range(5,10)}
        return "Holdout：训练集比例（0.5~0.9）", 0.5, 0.9, marks
    else:
        marks={i:str(i) for i in range(2,11)}
        return "KFold：折数（2~10）", 2, 10, marks

def run_training_async(kwargs, outdict):
    try:
        res1, res2, comp = ppgpipe.run_both_setups(**kwargs)
        outdict["ok"]=True
        outdict["res1"]=res1; outdict["res2"]=res2; outdict["comp"]=comp
    except Exception as e:
        outdict["ok"]=False; outdict["err"]=traceback.format_exc()

@app.callback(
    Output("status","children"),
    Output("results-summary","children"),
    Output("comparison-charts","children"),
    Input("btn-start","n_clicks"),
    State("data-root","value"),
    State("fs","value"),
    State("win","value"),
    State("hop","value"),
    State("split-mode","value"),
    State("split-param","value"),
    State("epochs-ae","value"),
    State("epochs-denoise","value"),
    State("lr","value"),
    State("motion-thresh","value"),
    State("det-thr","value"),
    State("lam-spec","value"),
    State("lam-hr","value"),
    State("lam-cons","value"),
    prevent_initial_call=True
)
def on_start(n, data_root, fs, win, hop, split_mode, split_param, epochs_ae, epochs_den, lr, motion_thresh, det_thr, lam_spec, lam_hr, lam_cons):
    if not data_root: return "请先填写数据路径。", no_update, no_update
    kwargs = dict(
        data_root=data_root, fs=float(fs), win=float(win), hop=float(hop),
        motion_thresh=float(motion_thresh), split_mode=split_mode,
        n_splits=int(split_param) if split_mode=="kfold" else 5,
        train_size=float(split_param) if split_mode=="holdout" else 0.8,
        epochs_ae=int(epochs_ae), epochs_denoise=int(epochs_den), lr=float(lr),
        detector_threshold=(None if (det_thr in [None,""]) else float(det_thr)),
        outdir="results", lam_spec=float(lam_spec), lam_hr=float(lam_hr), lam_cons=float(lam_cons)
    )
    out = {}
    run_training_async(kwargs, out)  # 同步执行；如需异步可改线程，但这里直接执行

    if not out.get("ok"):
        return "训练失败", f"[Error]\n{out.get('err')}", no_update

    comp = out["comp"]
    det1 = comp["detector"]["setup1_holdout"]; det2 = comp["detector"]["setup2_holdout"]
    den1 = comp["denoiser"]["setup1_holdout"]; den2 = comp["denoiser"]["setup2_holdout"]

    # summary
    summary = {
        "detector": {"setup1": det1, "setup2": det2},
        "denoiser": {"setup1": den1, "setup2": den2},
        "saved": {
            "setup1/detector_results.json": str(Path("results/setup1/detector_results.json").resolve()),
            "setup1/denoiser_results.json": str(Path("results/setup1/denoiser_results.json").resolve()),
            "setup2/detector_results.json": str(Path("results/setup2/detector_results.json").resolve()),
            "setup2/denoiser_results.json": str(Path("results/setup2/denoiser_results.json").resolve()),
            "compare.json": str(Path("results/compare.json").resolve()),
        }
    }
    txt = json.dumps(summary, ensure_ascii=False, indent=2)

    # charts
    charts = html.Div(children=[
        html.H4("Detector — Holdout 对比"),
        metric_bar("PR-AUC", det1, det2, "pr_auc"),
        metric_bar("ROC-AUC", det1, det2, "roc_auc"),
        metric_bar("F1", det1, det2, "f1"),
        metric_bar("Balanced Accuracy", det1, det2, "bal_acc"),

        html.Hr(),
        html.H4("Denoiser — Holdout 对比"),
        metric_bar("L1", den1, den2, "l1", invert=True),
        metric_bar("SNR Improvement (dB)", den1, den2, "snr_improvement_db"),
        metric_bar("HR-band Loss", den1, den2, "hr_band_loss", invert=True),
    ])

    return "训练完成 ✅", txt, charts



In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)