In [1]:

import matplotlib
print(matplotlib.get_backend())


module://matplotlib_inline.backend_inline


In [2]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import SpanSelector

EP_RE = re.compile(r"episode_(\d{8})_(\d{6})_.*\.npz$")

def parse_time_key(path: str):
    fn = os.path.basename(path)
    m = EP_RE.match(fn)
    if not m:
        return None
    ymd, hms = m.group(1), m.group(2)
    return int(ymd + hms)

def list_day_episodes(data_dir: str, date_yyyymmdd: str):
    files = []
    for fn in os.listdir(data_dir):
        if not fn.endswith(".npz"):
            continue
        m = EP_RE.match(fn)
        if not m:
            continue
        if m.group(1) != date_yyyymmdd:
            continue
        files.append(os.path.join(data_dir, fn))
    files.sort(key=lambda p: parse_time_key(p) or 0)
    return files

def load_npz(npz_path: str):
    data = np.load(npz_path, allow_pickle=True)
    return data

def get_t_and_ft(data):
    if "force_torque" not in data:
        raise KeyError("missing key: force_torque")
    ft = data["force_torque"]
    if ft.ndim != 2 or ft.shape[1] != 6:
        raise ValueError(f"bad force_torque shape: {ft.shape} (expect (N,6))")

    # prefer t_ref -> t -> index
    if "t_ref" in data:
        t = data["t_ref"]
    elif "t" in data:
        t = data["t"]
    else:
        t = np.arange(ft.shape[0])

    N = min(len(t), ft.shape[0])
    return t[:N], ft[:N].copy()

def short_title(npz_path: str):
    base = os.path.basename(npz_path).replace(".npz", "")
    m = EP_RE.match(os.path.basename(npz_path))
    if m:
        hhmmss = m.group(2)
        suffix = base.split("_")[-1]
        return f"{hhmmss}_{suffix}"
    return base

def plot_ft_pair(t, ft, title_prefix="BEFORE"):
    """
    returns (fig_force, ax_force), (fig_torque, ax_torque)
    """
    fig1, ax1 = plt.subplots(1, 1, figsize=(12, 3.2))
    fig2, ax2 = plt.subplots(1, 1, figsize=(12, 3.2))

    fig1.suptitle(f"{title_prefix} | FORCE", fontsize=12)
    ax1.set_title(title_prefix, fontsize=10)
    ax1.plot(t, ft[:, 0], label="fx")
    ax1.plot(t, ft[:, 1], label="fy")
    ax1.plot(t, ft[:, 2], label="fz")
    ax1.legend(loc="upper right", fontsize=8)
    ax1.grid(True)
    ax1.set_xlabel("t")
    ax1.set_ylabel("force")

    fig2.suptitle(f"{title_prefix} | TORQUE", fontsize=12)
    ax2.set_title(title_prefix, fontsize=10)
    ax2.plot(t, ft[:, 3], label="tx")
    ax2.plot(t, ft[:, 4], label="ty")
    ax2.plot(t, ft[:, 5], label="tz")
    ax2.legend(loc="upper right", fontsize=8)
    ax2.grid(True)
    ax2.set_xlabel("t")
    ax2.set_ylabel("torque")

    plt.show()
    return (fig1, ax1), (fig2, ax2)

def choose_span_by_two_clicks(ax):
    ax.text(0.01, 0.98, "Click START then END (2 clicks).",
            transform=ax.transAxes, va="top", fontsize=9,
            bbox=dict(boxstyle="round", alpha=0.1))
    plt.show()
    pts = plt.ginput(2, timeout=0)  # 2 clicks
    if len(pts) < 2:
        return None
    x0, x1 = pts[0][0], pts[1][0]
    if x1 < x0:
        x0, x1 = x1, x0
    ax.axvspan(x0, x1, alpha=0.15)
    ax.figure.canvas.draw_idle()
    return float(x0), float(x1)


def time_to_index(t, x):
    # 找到最接近 x 的索引
    return int(np.argmin(np.abs(t - x)))

def synthesize_noise_fill(
    ft_pre,
    L,
    seed=0,
    keep_level="last_mean",
    level_window=20,
    noise_scale=1.0,      # 噪声幅度旋钮
):
    """
    ft_pre: (M,6) baseline (0:start_idx)
    L: length to fill
    noise_scale:
        1.0 = 和 baseline 一样抖
        0.5 = baseline 的一半
        0.2 = 很平
        2.0 = 比 baseline 抖两倍
    """
    M, C = ft_pre.shape
    if M < max(5, level_window):
        raise ValueError(f"baseline too short ({M}). Need >= {max(5, level_window)}")

    rng = np.random.default_rng(seed)

    # ===== baseline level（趋势 / 偏置）=====
    if keep_level == "zero":
        level = np.zeros(C)
    elif keep_level == "mean":
        level = ft_pre.mean(axis=0)
    elif keep_level == "last_mean":
        level = ft_pre[-level_window:].mean(axis=0)
    else:
        raise ValueError("keep_level must be one of: 'zero','mean','last_mean'")

    # ===== residual noise（零均值）=====
    mu = ft_pre.mean(axis=0)
    residual = ft_pre - mu[None, :]   # (M,6)

    # ===== fill =====
    fill = np.zeros((L, C), dtype=ft_pre.dtype)
    for c in range(C):
        idx = rng.integers(0, M, size=L)
        fill[:, c] = level[c] + noise_scale * residual[idx, c]

    return fill




def apply_trigger_artifact_replacement(ft, start_idx, end_idx, baseline_min_frames=10, fade=8, seed=0):
    """
    用 start_idx 前的数据作为噪声源，把 [start_idx, end_idx] 替换成噪声填充。
    fade: 边界交叉淡入淡出帧数（避免硬切）
    """
    N = ft.shape[0]
    start_idx = int(np.clip(start_idx, 0, N-1))
    end_idx = int(np.clip(end_idx, 0, N-1))
    if end_idx <= start_idx:
        raise ValueError("end_idx must be > start_idx")

    # baseline: 影响开始前的数据
    base_end = start_idx
    base_start = 0
    ft_pre = ft[base_start:base_end]

    if ft_pre.shape[0] < baseline_min_frames:
        raise ValueError(
            f"baseline too short ({ft_pre.shape[0]} frames). "
            f"Need >= {baseline_min_frames}. Move start_idx later."
        )

    L = end_idx - start_idx + 1
    # fill = synthesize_noise_fill(ft_pre, L=L, seed=seed, per_channel=True)
    fill = synthesize_noise_fill(
    ft_pre,
    L=L,
    seed=seed,
    keep_level="last_mean",
    level_window=10,
    noise_scale=0.8,   # ⭐ 你想调就调这里
)




    ft_new = ft.copy()
    ft_new[start_idx:end_idx+1] = fill

    # 交叉淡入淡出：让替换段与原段在边界更平滑
    f = int(max(0, fade))
    if f > 0:
        # left boundary blend: [start_idx, start_idx+f)
        l0 = start_idx
        l1 = min(start_idx + f, end_idx + 1)
        if l1 > l0:
            alpha = np.linspace(0, 1, l1 - l0, endpoint=False)[:, None]  # 0->1
            ft_new[l0:l1] = (1 - alpha) * ft[l0:l1] + alpha * ft_new[l0:l1]

        # right boundary blend: (end_idx-f, end_idx]
        r1 = end_idx + 1
        r0 = max(start_idx, end_idx + 1 - f)
        if r1 > r0:
            alpha = np.linspace(1, 0, r1 - r0, endpoint=False)[:, None]  # 1->0
            ft_new[r0:r1] = (1 - alpha) * ft_new[r0:r1] + alpha * ft[r0:r1]

    return ft_new

def save_npz_like(original_npz_path, out_dir, ft_new, suffix="_pre_processing"):
    os.makedirs(out_dir, exist_ok=True)
    base = os.path.basename(original_npz_path)
    if base.endswith(".npz"):
        base2 = base[:-4] + f"{suffix}.npz"
    else:
        base2 = base + f"{suffix}.npz"
    out_path = os.path.join(out_dir, base2)

    # 读取原 npz 所有 key，保持一致，只替换 force_torque
    data = np.load(original_npz_path, allow_pickle=True)
    out_dict = {}
    for k in data.files:
        if k == "force_torque":
            out_dict[k] = ft_new
        else:
            out_dict[k] = data[k]
    np.savez(out_path, **out_dict)
    return out_path


In [4]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

EP_RE = re.compile(r"episode_(\d{8})_(\d{6})_.*\.npz$")

def parse_time_key(path: str):
    fn = os.path.basename(path)
    m = EP_RE.match(fn)
    if not m:
        return None
    ymd, hms = m.group(1), m.group(2)
    return int(ymd + hms)

def list_day_episodes(data_dir: str, date_yyyymmdd: str):
    files = []
    for fn in os.listdir(data_dir):
        if not fn.endswith(".npz"):
            continue
        m = EP_RE.match(fn)
        if not m:
            continue
        if m.group(1) != date_yyyymmdd:
            continue
        files.append(os.path.join(data_dir, fn))
    files.sort(key=lambda p: parse_time_key(p) or 0)
    return files

def load_t_ft(npz_path):
    data = np.load(npz_path, allow_pickle=True)
    ft = data["force_torque"]
    if "t_ref" in data:
        t = data["t_ref"]
    elif "t" in data:
        t = data["t"]
    else:
        t = np.arange(ft.shape[0])
    N = min(len(t), ft.shape[0])
    return t[:N], ft[:N].copy()

def save_npz_like_processed(original_npz_path, out_dir, ft_new, suffix="_processed"):
    os.makedirs(out_dir, exist_ok=True)
    base = os.path.basename(original_npz_path)
    out_name = base[:-4] + f"{suffix}.npz" if base.endswith(".npz") else base + f"{suffix}.npz"
    out_path = os.path.join(out_dir, out_name)

    data = np.load(original_npz_path, allow_pickle=True)
    out_dict = {k: data[k] for k in data.files}
    out_dict["force_torque"] = ft_new
    np.savez(out_path, **out_dict)
    return out_path

def plot_before_after(t, ft_before, ft_after, start_idx, end_idx, title=""):
    fig, axes = plt.subplots(2, 1, figsize=(13, 6), sharex=True)
    axes[0].plot(t, ft_before[:,0], label="fx")
    axes[0].plot(t, ft_before[:,1], label="fy")
    axes[0].plot(t, ft_before[:,2], label="fz")
    axes[0].axvspan(t[start_idx], t[end_idx], alpha=0.15)
    axes[0].set_ylabel("force"); axes[0].grid(True)
    axes[0].legend(loc="upper right", fontsize=8)
    axes[0].set_title(f"BEFORE (force) {title}")

    axes[1].plot(t, ft_before[:,3], label="tx")
    axes[1].plot(t, ft_before[:,4], label="ty")
    axes[1].plot(t, ft_before[:,5], label="tz")
    axes[1].axvspan(t[start_idx], t[end_idx], alpha=0.15)
    axes[1].set_ylabel("torque"); axes[1].set_xlabel("t"); axes[1].grid(True)
    axes[1].legend(loc="upper right", fontsize=8)
    plt.show()

    fig, axes = plt.subplots(2, 1, figsize=(13, 6), sharex=True)
    axes[0].plot(t, ft_after[:,0], label="fx")
    axes[0].plot(t, ft_after[:,1], label="fy")
    axes[0].plot(t, ft_after[:,2], label="fz")
    axes[0].axvspan(t[start_idx], t[end_idx], alpha=0.15)
    axes[0].set_ylabel("force"); axes[0].grid(True)
    axes[0].legend(loc="upper right", fontsize=8)
    axes[0].set_title(f"AFTER (force) {title}")

    axes[1].plot(t, ft_after[:,3], label="tx")
    axes[1].plot(t, ft_after[:,4], label="ty")
    axes[1].plot(t, ft_after[:,5], label="tz")
    axes[1].axvspan(t[start_idx], t[end_idx], alpha=0.15)
    axes[1].set_ylabel("torque"); axes[1].set_xlabel("t"); axes[1].grid(True)
    axes[1].legend(loc="upper right", fontsize=8)
    plt.show()

def interactive_process_batch(npz_paths, out_dir, baseline_min_frames=20, fade=8, seed0=0):
    """
    一次处理一个 episode：
      - slider 选 [start,end]
      - Apply 预览
      - Save + Next 保存并跳到下一个
      - Skip 跳过不保存
    """
    assert len(npz_paths) > 0, "npz_paths is empty"

    idx_state = {"i": 0}
    state = {"ft_new": None, "t": None, "ft": None, "path": None}

    header = widgets.HTML()
    info = widgets.HTML()
    range_slider = widgets.IntRangeSlider(description="idx", continuous_update=False,
                                          layout=widgets.Layout(width="800px"))

    btn_apply = widgets.Button(description="Apply + Preview", button_style="primary")
    btn_save_next = widgets.Button(description="Save + Next", button_style="success")
    btn_skip = widgets.Button(description="Skip", button_style="")
    btn_prev = widgets.Button(description="Prev", button_style="")
    out = widgets.Output()

    def load_current():
        i = idx_state["i"]
        path = npz_paths[i]
        t, ft = load_t_ft(path)
        N = ft.shape[0]

        state["t"], state["ft"], state["path"] = t, ft, path
        state["ft_new"] = None

        # slider reset
        s0 = min(50, N-2)
        e0 = min(150, N-1)
        range_slider.min = 0
        range_slider.max = N-1
        range_slider.value = (s0, e0)

        header.value = f"<b>({i+1}/{len(npz_paths)})</b> {os.path.basename(path)}"
        refresh_info()

    def refresh_info(*_):
        s, e = range_slider.value
        N = range_slider.max + 1
        info.value = f"N={N} | selected idx: <b>{s}</b> → <b>{e}</b> (len={e-s+1})"

    def on_apply(_):
        with out:
            clear_output(wait=True)
            s, e = range_slider.value
            try:
                # 你已经有最终满意的 apply_trigger_artifact_replacement()
                ft_new = apply_trigger_artifact_replacement(
                    state["ft"], s, e,
                    baseline_min_frames=baseline_min_frames,
                    fade=fade,
                    seed=seed0 + idx_state["i"]
                )
                state["ft_new"] = ft_new
                plot_before_after(state["t"], state["ft"], ft_new, s, e,
                                  title=f"| {os.path.basename(state['path'])}")
            except Exception as ex:
                state["ft_new"] = None
                print("❌ apply failed:", ex)

    def goto(i_new):
        i_new = int(np.clip(i_new, 0, len(npz_paths)-1))
        idx_state["i"] = i_new
        with out:
            clear_output(wait=True)
        load_current()

    def on_save_next(_):
        with out:
            if state["ft_new"] is None:
                print("❌ nothing to save. click Apply + Preview first.")
                return
            out_path = save_npz_like_processed(state["path"], out_dir, state["ft_new"], suffix="_processed")
            print("✅ saved:", out_path)
        goto(idx_state["i"] + 1)

    def on_skip(_):
        with out:
            clear_output(wait=True)
            print("⚠️ skipped:", os.path.basename(state["path"]))
        goto(idx_state["i"] + 1)

    def on_prev(_):
        goto(idx_state["i"] - 1)

    range_slider.observe(refresh_info, names="value")
    btn_apply.on_click(on_apply)
    btn_save_next.on_click(on_save_next)
    btn_skip.on_click(on_skip)
    btn_prev.on_click(on_prev)

    display(widgets.VBox([
        header,
        info,
        range_slider,
        widgets.HBox([btn_prev, btn_apply, btn_save_next, btn_skip]),
        out
    ]))

    load_current()


In [5]:
data_dir = "/mnt/WDC10T/tailai_ws/dataset/one_clip_mounting/raw_data"
date = "20260227"
out_dir = "/mnt/WDC10T/tailai_ws/dataset/one_clip_mounting/data_processed"

npz_paths = list_day_episodes(data_dir, date)
print("found:", len(npz_paths))
interactive_process_batch(npz_paths, out_dir, baseline_min_frames=10, fade=8, seed0=0)


found: 90


VBox(children=(HTML(value=''), HTML(value=''), IntRangeSlider(value=(25, 75), continuous_update=False, descrip…