# WaveformAnalysis 高级示例：插件开发 + 多 run 处理 + 自定义特征

本 notebook 是一个**独立完整**的高级示例，涵盖：
- 自定义插件开发（基于 `st_waveforms` 计算特征）
- 多 run 批处理（并行执行）
- 自定义特征整合与可视化


## 1. 环境与依赖

请确保已安装 WaveformAnalysis（推荐在项目根目录执行 `./install.sh`）。


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from waveform_analysis import DAQAnalyzer
from waveform_analysis.core.context import Context
from waveform_analysis.core.data.export import BatchProcessor
from waveform_analysis.core.plugins.builtin.cpu import (
    BasicFeaturesPlugin,
    RawFileNamesPlugin,
    StWaveformsPlugin,
    WaveformsPlugin,
)
from waveform_analysis.core.plugins.core.base import Option, Plugin

# Matplotlib 中文显示（如有需要）
plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False


## 2. 扫描 DAQ runs

扫描 `DAQ` 目录下的 runs，并挑选少量 runs 做示例。


In [None]:
daq_root = "/mnt/data/Run3/DAQ"
da = DAQAnalyzer(daq_root=daq_root, daq_adapter="vx2730")
da.scan_all_runs()

run_ids = list(da.runs.keys())
print(f"Found {len(run_ids)} runs")
run_ids[:5]


## 3. 自定义插件：脉冲宽度特征

基于 `st_waveforms` 的波形数据，计算每个事件的脉冲宽度：
- `width_samples`：超过阈值的采样点数量
- `width_ns`：转换为纳秒（基于 `dt_ns`）


In [None]:
PULSE_WIDTH_DTYPE = np.dtype([
    ("width_samples", "f4"),
    ("width_ns", "f4"),
    ("timestamp", "i8"),
    ("channel", "i2"),
])


class PulseWidthPlugin(Plugin):
    """Compute pulse width above baseline + threshold."""

    provides = "pulse_width"
    depends_on = ["st_waveforms"]
    save_when = "always"
    output_dtype = PULSE_WIDTH_DTYPE
    version = "0.2.0"
    options = {
        "threshold": Option(default=5.0, type=float, help="超过 baseline 的阈值"),
        "dt_ns": Option(default=2.0, type=float, help="采样间隔 (ns)"),
        "min_width": Option(default=1, type=int, help="最小宽度 (采样点)"),
    }

    def compute(self, context, run_id, **kwargs):
        st_waveforms = context.get_data(run_id, "st_waveforms")
        threshold = context.get_config(self, "threshold")
        dt_ns = context.get_config(self, "dt_ns")
        min_width = context.get_config(self, "min_width")

        n_events = len(st_waveforms)
        out = np.zeros(n_events, dtype=PULSE_WIDTH_DTYPE)
        if n_events == 0:
            return out

        waves = st_waveforms["wave"]
        baselines = st_waveforms["baseline"].astype(np.float32)

        if waves.ndim != 2:
            return out

        mask = waves > (baselines[:, None] + threshold)
        widths = mask.sum(axis=1).astype(np.float32)
        widths = np.where(widths < min_width, 0.0, widths)

        out["width_samples"] = widths
        out["width_ns"] = widths * float(dt_ns)
        out["timestamp"] = st_waveforms["timestamp"]
        out["channel"] = st_waveforms["channel"]
        return out

## 4. 创建 Context 并注册插件

注册基础插件 + 自定义插件，并设置统一配置（推荐全局设置 `daq_adapter`）。


In [None]:
ctx = Context(storage_dir="./strax_data")
ctx.register(
    RawFileNamesPlugin(),
    WaveformsPlugin(),
    StWaveformsPlugin(),
    BasicFeaturesPlugin(),
    PulseWidthPlugin(),
)
ctx.set_config({
    "data_root": daq_root,
    "daq_adapter": "vx2730",
    "n_channels": 2,
    "show_progress": True,
})
# 自定义插件参数
ctx.set_config(
    {
        "threshold": 6.0,
        "dt_ns": 2.0,
        "min_width": 2,
    },
    plugin_name="pulse_width",
)
print("Context ready.")


## 5. 单个 run 试跑与特征整合

选择一个 run，计算基础特征、脉冲宽度，并构建自定义 DataFrame。


In [None]:
if len(run_ids) == 0:
    raise RuntimeError("No runs found under DAQ root.")

run_id = run_ids[0]
print(f"Using run: {run_id}")

st_waveforms = ctx.get_data(run_id, "st_waveforms")
basic_features = ctx.get_data(run_id, "basic_features")
pulse_width = ctx.get_data(run_id, "pulse_width")

print(f"总事件数: {len(st_waveforms)}")
channels = np.unique(st_waveforms["channel"])
for ch in channels:
    ch_mask = st_waveforms["channel"] == ch
    print(f"  ch{ch}: {ch_mask.sum()} events")

In [None]:
# 自定义特征整合为 DataFrame（数据已是单个扁平数组）
n_events = min(len(st_waveforms), len(basic_features), len(pulse_width))

if n_events > 0:
    df_custom = pd.DataFrame({
        "timestamp": st_waveforms["timestamp"][:n_events],
        "channel": st_waveforms["channel"][:n_events],
        "height": basic_features["height"][:n_events],
        "area": basic_features["area"][:n_events],
        "width_ns": pulse_width["width_ns"][:n_events],
    }).sort_values("timestamp")
else:
    df_custom = pd.DataFrame()

df_custom.head()

In [None]:
# 示例：派生自定义特征并可视化
if len(df_custom) > 0:
    df_custom["area_over_width"] = df_custom["area"] / (df_custom["width_ns"] + 1e-6)
    sample = df_custom.sample(min(len(df_custom), 2000), random_state=42)

    plt.figure(figsize=(6, 4))
    plt.scatter(sample["width_ns"], sample["area"], s=6, alpha=0.4)
    plt.xlabel("Width (ns)")
    plt.ylabel("Area")
    plt.title("Area vs Width (sampled)")
    plt.tight_layout()
    plt.show()
else:
    print("No events to visualize.")


## 6. 多 run 批处理

使用 `BatchProcessor` 并行处理多个 runs，统计每个 run 的脉冲宽度分布。


In [None]:
if len(run_ids) == 0:
    raise RuntimeError("No runs found for batch processing.")

selected_runs = run_ids[:3]
print("Selected runs:", selected_runs)

processor = BatchProcessor(ctx)
batch_result = processor.process_runs(
    run_ids=selected_runs,
    data_name="pulse_width",
    max_workers=2,
    executor_type="thread",
)

print("Errors:", batch_result["errors"])


In [None]:
# 汇总每个 run 的脉冲宽度统计
summaries = []
for run_id, data in batch_result["results"].items():
    if data is None or len(data) == 0:
        continue
    all_widths = data["width_ns"]
    if len(all_widths) == 0:
        continue
    summaries.append({
        "run_id": run_id,
        "n_events": int(len(all_widths)),
        "median_width_ns": float(np.median(all_widths)),
    })

summary_df = pd.DataFrame(summaries)
summary_df.sort_values("median_width_ns") if len(summary_df) > 0 else summary_df

## 7. 下一步建议

- 将自定义插件拆分到独立模块，并使用 `PluginSpec` 做严格校验
