## chengshuai — Preprocessing + HFO Detector + Group-Event Visualization

只做三件事（每件都可视化）：
- **Preprocessing**：读 EDF → 显式参考（Bipolar/CAR）→ 画波形（按 shaft 上色；波形避开红色系）
- **HFO Detector**：严格复用 `src/utils/bqk_utils.py` 的检测逻辑（Hilbert envelope + 双阈值 + merge + min_last）→ 画事件统计 + tick overlay（不遮波形）
- **Group Event Visualization**：读取/生成 GPU 中间结果（整段 crop 的 bandpassed + envelope 缓存），把 packedTimes 群体事件窗口拼接起来，画两张图：
  - 图1：80–250Hz bandpassed raster（横轴=拼接事件，纵轴=通道）
  - 图2：STFT 归一化功率背景 + 每事件质心点 + 质心路径（按最早→最晚连接）

注意：Detector 只依赖 `PreprocessingResult` 的 `data/sfreq/ch_names`，不做任何“推断”。

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

sys.path.insert(0, '/home/honglab/leijiaxin/HFOsp/')

from src.preprocessing import SEEGPreprocessor
from src.hfo_detector import HFODetector, HFODetectionConfig
from src.group_event_analysis import precompute_envelope_cache
from src.visualization import (
    plot_from_result,
    plot_shaft_channels,
    plot_event_counts,
    detections_to_events,
    plot_raw_filtered_envelope,
    plot_group_events_band_raster,
    plot_group_events_tf_centroid_paths,
    plot_group_events_tf_centroids_per_channel,
)

EDF_PATH = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q.edf')
GPU_NPZ = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q_gpu.npz')
PACKED_TIMES = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q_packedTimes.npy')

START_SEC = 0
DURATION_SEC = 100
SHAFT = 'K'

# Detector config (STRICT: match src/utils/bqk_utils.py)
DET_CFG = HFODetectionConfig(
    algorithm='bqk',
    band='ripple',
    # chunking keeps memory bounded for long recordings
    chunk_sec=100.0,
    chunk_overlap_sec=1.0,
    # bqk thresholds (tune if needed)
    rel_thresh=2.0,
    abs_thresh=2.0,
    min_gap_ms=20.0,
    min_last_ms=20.0,
    use_gpu=False,
)

%matplotlib inline
plt.rcParams['figure.dpi'] = 120

## 1) Bipolar 全通道（100s）

In [None]:
bip = SEEGPreprocessor(reference='bipolar', crop_seconds=START_SEC + DURATION_SEC + 1)
bip_res = bip.run(EDF_PATH)

fig = plot_from_result(bip_res, start_sec=START_SEC, duration_sec=DURATION_SEC, channels='all')
plt.show()

## 2) CAR 全通道（100s）

In [None]:
# car = SEEGPreprocessor(reference='car', crop_seconds=START_SEC + DURATION_SEC + 1)
# car_res = car.run(EDF_PATH)

# fig = plot_from_result(car_res, start_sec=START_SEC, duration_sec=DURATION_SEC, channels='all')
# plt.show()

## 3) 单电极串（例：K，Bipolar）

In [None]:
# fig = plot_shaft_channels(
#     bip_res.data, bip_res.sfreq, bip_res.ch_names,
#     shaft=SHAFT,
#     start_sec=START_SEC,
#     duration_sec=min(30, DURATION_SEC),
#     reference_type=bip_res.reference_type
# )
# plt.show()

## 4) HFO Detector（Ripple）+ 可视化（counts + overlay）

In [None]:
# Run detector on the bipolar-preprocessed data
# NOTE: This can be slow on long segments. We’re using crop_seconds=101 above.
det = HFODetector(DET_CFG)
det_res = det.detect(bip_res)

print('Detections:', det_res.events_count.sum(), 'total')
print('Top channels:', [det_res.ch_names[i] for i in det_res.events_count.argsort()[::-1][:10]])

In [None]:
# 4.1 事件数统计（快速发现伪迹通道）
fig = plot_event_counts(det_res, top_k=30, title='Ripple detections per channel (top 30)')
plt.show()

In [None]:
# 4.2 Overlay detections onto waveform
# Use per-channel thin ticks so the waveform stays visible.
VIEW_START = 0
VIEW_DUR = 10
VIEW_CHANNELS = ['K1-K2', 'K2-K3', 'K3-K4']

overlay_events = detections_to_events(
    det_res,
    color='#e94560',
    alpha=0.35,
    max_events_per_channel=200,
    style='tick',
    linewidth=2.0,
)

fig = plot_from_result(
    bip_res,
    start_sec=VIEW_START,
    duration_sec=VIEW_START+10,
    channels='all',
    events=overlay_events,
    title=f'Bipolar waveform + Ripple detections ({VIEW_START}-{VIEW_START+VIEW_DUR}s)'
)
plt.show()

In [None]:
# 4.3 Raw vs bandpass vs envelope (RP/FR) for the same channels
# Note: FR plot only works if sfreq >= 1200Hz (ideally 2000Hz).
# For ripple analysis at 1000Hz, set show_fast_ripple=False or use band parameter.
fig = plot_raw_filtered_envelope(
    bip_res.data,
    bip_res.sfreq,
    bip_res.ch_names,
    channels=VIEW_CHANNELS,
    start_sec=VIEW_START,
    duration_sec=2.0,
    show_fast_ripple=(bip_res.sfreq >= 1200),  # Auto-detect if FR is feasible
    band=DET_CFG.band,  # Match detector band ('ripple' or 'fast_ripple')
)
plt.show()

In [None]:
# --- 5) Group event visualization (packedTimes) ---
# We generate a GPU cache for the crop (bandpassed + envelope), then plot Fig1/Fig2.

CROP_SEC = 60.0
CACHE_PATH = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q_envCache_ripple_bipolar_alias_crop60s_viz.npz')

# Generate cache if missing
if not CACHE_PATH.exists():
    _ = precompute_envelope_cache(
        edf_path=str(EDF_PATH),
        out_npz_path=str(CACHE_PATH),
        band='ripple',
        crop_seconds=CROP_SEC,
        reference='bipolar',
        alias_bipolar_to_left=True,
        alias_filter_using_gpu_npz=str(GPU_NPZ),
        use_gpu=True,
        save_bandpass=True,
        dtype='float32',
    )

# Pick first N packedTimes events for a quick plot
EVENTS = list(range(15))

fig1 = plot_group_events_band_raster(
    cache_npz_path=str(CACHE_PATH),
    packed_times_path=str(PACKED_TIMES),
    event_indices=EVENTS,
    max_events=len(EVENTS),
    mode='bandpassed',
    plot_style='trace',
    value_transform='none',
    downsample_ms=None,
    x_axis='seconds',
    figsize=(12, 8),
)
plt.show()

fig2 = plot_group_events_tf_centroids_per_channel(
    cache_npz_path=str(CACHE_PATH),
    packed_times_path=str(PACKED_TIMES),
    detections_npz_path=str(GPU_NPZ),
    event_indices=EVENTS,
    max_events=len(EVENTS),
    freq_band=(80.0, 250.0),
    nperseg=256,
    noverlap=192,
    mask_by_detections=False,  # avoid hiding all dots due to name mismatch
    show_colorbar=True,
    font_size=14,
    title_font_size=16,
    tick_font_size=12,
    hspace=0.02,
    figsize=(12, 8),
)
plt.show()

In [None]:
# --- 6) Lag heatmaps (channels × events): packedTimes+lagPat vs our bqk windows ---
# For one EDF file, one band (ripple), we produce 3 heatmaps:
#   (a) event energy (envelope)
#   (b) centroid rank (0=earliest)
#   (c) lag (ms) aligned to first centroid

from src.group_event_analysis import (
    load_envelope_cache,
    build_windows_from_packed_times,
    compute_centroid_matrix_from_envelope_cache,
    lag_rank_from_centroids,
    build_windows_from_detections,
)
from src.visualization import plot_lag_heatmaps

# Load cache
cache = load_envelope_cache(str(CACHE_PATH))
assert cache['x_band'] is not None

# Use core channels ordering (same as lagPat)
ch_order = CORE_CH

# A) Baseline: packedTimes + lagPat
packed = np.load(PACKED_TIMES, allow_pickle=True)
windows = build_windows_from_packed_times(packed)
windows = [w for w in windows if w.start < CROP_SEC]

# detections mask from GPU (for events_bool)
gpu = np.load(GPU_NPZ, allow_pickle=True)
name_to_idx = {str(n): i for i, n in enumerate(gpu['chns_names'].tolist())}
dets = {ch: np.asarray(gpu['whole_dets'][name_to_idx[ch]], dtype=np.float64) for ch in ch_order if ch in name_to_idx}

cent, eb = compute_centroid_matrix_from_envelope_cache(
    windows=windows,
    detections=dets,
    ch_names=ch_order,
    env=cache['env'],
    sfreq=cache['sfreq'],
    centroid_power=2.0,
)
lag_rel, rank_rel = lag_rank_from_centroids(cent, eb, align='first_centroid', tie_tol_ms=0.0)

# energy per event: sum(env^2) within window
sf = cache['sfreq']
energy = np.full_like(lag_rel, np.nan, dtype=np.float64)
for ei, w in enumerate(windows):
    i0 = int(round(w.start * sf))
    i1 = int(round(w.end * sf))
    for ci, ch in enumerate(ch_order):
        if not eb[ci, ei]:
            continue
        x = cache['env'][ci, i0:i1].astype(np.float64)
        energy[ci, ei] = float(np.sum(x * x))

# Plot first N events for readability
N_SHOW = min(200, len(windows))
figE, figR, figL = plot_lag_heatmaps(
    energy=energy[:, :N_SHOW],
    lag_ms=lag_rel[:, :N_SHOW] * 1000.0,
    rank=rank_rel[:, :N_SHOW],
    ch_names=ch_order,
    event_ids=list(range(N_SHOW)),
    figsize=(14, 4),
)
plt.show()

# B) Our pipeline: bqk detections -> build_windows_from_detections (same crop)
# For notebook speed, we reuse the already-loaded bip_res if available; otherwise preprocess quickly.

if 'bip_res' not in globals():
    bip_res = SEEGPreprocessor(reference='bipolar', crop_seconds=CROP_SEC + 1).run(EDF_PATH)

# Run bqk detector on bipolar data
cfg = HFODetectionConfig(
    algorithm='bqk',
    band='ripple',
    chunk_sec=30.0,
    chunk_overlap_sec=1.0,
    rel_thresh=2.5,
    abs_thresh=2.5,
    min_gap_ms=20.0,
    min_last_ms=20.0,
    use_gpu=False,
)
res = HFODetector(cfg).detect(bip_res)

# Alias bipolar names to left contact (A1-A2 -> A1) and restrict to core
bqk_dets = {}
for ch_name, ev in zip(list(res.ch_names), list(res.events_by_channel)):
    if '-' not in ch_name:
        continue
    left = ch_name.split('-', 1)[0].strip().upper()
    if left not in ch_order:
        continue
    bqk_dets[left] = np.asarray(ev, dtype=np.float64)

# Use packedTimes window length as reference for window_sec
win_sec = float(np.median(packed[:, 1] - packed[:, 0]))
windows2 = build_windows_from_detections(bqk_dets, window_sec=win_sec)
windows2 = [w for w in windows2 if w.start < CROP_SEC]

cent2, eb2 = compute_centroid_matrix_from_envelope_cache(
    windows=windows2,
    detections=bqk_dets,
    ch_names=ch_order,
    env=cache['env'],
    sfreq=cache['sfreq'],
    centroid_power=2.0,
)
lag2, rank2 = lag_rank_from_centroids(cent2, eb2, align='first_centroid', tie_tol_ms=0.0)

energy2 = np.full_like(lag2, np.nan, dtype=np.float64)
for ei, w in enumerate(windows2):
    i0 = int(round(w.start * sf))
    i1 = int(round(w.end * sf))
    for ci, ch in enumerate(ch_order):
        if not eb2[ci, ei]:
            continue
        x = cache['env'][ci, i0:i1].astype(np.float64)
        energy2[ci, ei] = float(np.sum(x * x))

N_SHOW2 = min(200, len(windows2))
figE2, figR2, figL2 = plot_lag_heatmaps(
    energy=energy2[:, :N_SHOW2],
    lag_ms=lag2[:, :N_SHOW2] * 1000.0,
    rank=rank2[:, :N_SHOW2],
    ch_names=ch_order,
    event_ids=list(range(N_SHOW2)),
    figsize=(14, 4),
)
plt.show()

In [None]:
# --- 5) Group event visualization (packedTimes) ---
# We generate a GPU cache for the crop (bandpassed + envelope), then plot Fig1/Fig2.
# Fixes vs previous version:
# - x-axis is in seconds (not samples)
# - Fig1 uses envelope by default (less "TV static")
# - You can select channels + events

import numpy as np

CROP_SEC = 60.0
CACHE_PATH = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q_envCache_ripple_bipolar_alias_crop60s_viz.npz')

# Generate cache if missing
if not CACHE_PATH.exists():
    _ = precompute_envelope_cache(
        edf_path=str(EDF_PATH),
        out_npz_path=str(CACHE_PATH),
        band='ripple',
        crop_seconds=CROP_SEC,
        reference='bipolar',
        alias_bipolar_to_left=True,
        alias_filter_using_gpu_npz=str(GPU_NPZ),
        use_gpu=True,
        save_bandpass=True,
        dtype='float32',
    )

# Core channels from lagPat (core-only view)
LAGPAT = Path('/mnt/yuquan_data/yuquan_24h_edf/chengshuai/FC10477Q_lagPat.npz')
lag = np.load(LAGPAT, allow_pickle=True)
CORE_CH = [str(x) for x in lag['chnNames'].tolist()]

# Pick events (event ids)
EVENTS = list(range(20))

fig1 = plot_group_events_band_raster(
    cache_npz_path=str(CACHE_PATH),
    packed_times_path=str(PACKED_TIMES),
    channel_order=CORE_CH,
    event_indices=EVENTS,
    max_events=len(EVENTS),
    mode='bandpassed',
    plot_style='trace',
    value_transform='none',
    downsample_ms=None,
    x_axis='seconds',
    trace_spacing=6.0,
    figsize=(24, 10),
)
plt.show()

fig2 = plot_group_events_tf_centroids_per_channel(
    cache_npz_path=str(CACHE_PATH),
    packed_times_path=str(PACKED_TIMES),
    detections_npz_path=str(GPU_NPZ),
    channel_order=CORE_CH,
    event_indices=EVENTS,
    max_events=len(EVENTS),
    freq_band=(80.0, 250.0),
    nperseg=256,
    noverlap=192,
    mask_by_detections=False,  # avoid hiding all dots due to name mismatch
    show_colorbar=True,
    font_size=16,
    title_font_size=18,
    y_label_font_size=14,
    tick_font_size=12,
    hspace=0.02,
    centroid_power='power2',
    figsize=(24, 10),
)
plt.show()