In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")) + '/lib/')
import cv2
import time
import torch
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from lib.blob_extraction import img_preprocess, blob_detect, get_nodes_pos
from lib.graph_generate import Delaunay_graph_generate
from lib.voronoi_generate import TransformVoronoi_331

try:
    from scipy.optimize import linear_sum_assignment
    from scipy.spatial.distance import cdist
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False


  from .autonotebook import tqdm as notebook_tqdm


## 1. Define Voronoi Graph Dataset

In [2]:
def voronoi_graph_dataset_generate_with_temporal_edges(
    raw_data_dir, train_data_dir, test_data_dir,
    tip_num=331, window_size=20, stride=20
):
    """
    生成含“逐边门控”时序边的图数据：
    - 帧内（空间）边：基于重排后的 nodes_pos 做 Delaunay；
    - 帧间（时序）边：先做帧间匹配，逐节点计算匹配距离 dmatch，动态阈值筛选后才连边；
    - 额外保存 temporal_edge_index（和 temporal_conf），便于 SeqGAT 两路注意力使用。
    """
    import os, time
    import numpy as np
    import torch
    import pandas as pd
    from torch_geometric.data import Data
    import cv2

    # 尝试用匈牙利匹配（更稳）；没有 scipy 就退化到最近邻
    try:
        from scipy.optimize import linear_sum_assignment
        from scipy.spatial.distance import cdist
        _HAS_SCIPY = True
    except Exception:
        _HAS_SCIPY = False

    os.makedirs(train_data_dir, exist_ok=True)
    os.makedirs(test_data_dir, exist_ok=True)

    # Read label CSV
    label_df = pd.read_csv(os.path.join(os.path.abspath(os.path.join(raw_data_dir, "..")), 'targets.csv'))

    # Group by object ID
    grouped = label_df.groupby('obj_id')
    data_list = []
    start = time.time()

    for obj_id, group in grouped:
        group = group.sort_values(by='pose_id').reset_index(drop=True)
        total_frames = len(group)

        for start_idx in range(0, total_frames - window_size + 1, stride):
            sub_group = group.iloc[start_idx:start_idx + window_size]

            x_list = []
            edge_list = []

            # 额外：单独收集时序边及其置信度（供 SeqGAT 使用）
            temporal_edges_seq = []
            temporal_conf_seq  = []

            label = [sub_group.iloc[-1]['pose_2'], sub_group.iloc[-1]['pose_6']]

            node_offset = 0

            # —— 本序列对齐状态 —— #
            anchor_XY_canon = None    # 第一帧锚点
            prev_XY_canon   = None    # 上一帧对齐后的 XY
            TAU_NODE_BASE   = None    # 动态阈值基准（由第一帧的最近邻尺度估计）

            for row in sub_group.itertuples():
                image_path = os.path.join(raw_data_dir, row.image_name)
                img = cv2.imread(image_path)
                if img is None:
                    print(f"[Skipped Frame] ❌ Image not found: {row.image_name}")
                    continue

                processed = img_preprocess(
                    img, erosion=False, kernel_size=1,
                    resize_x=300, resize_y=300,
                    binary_threshold=100,
                    circle_x_bias=-1, circle_y_bias=-2, circle_radius_bias=-14
                )

                keypoints = blob_detect(
                    processed, minArea=10, blobColor=255,
                    minCircularity=0.01, minConvexity=0.01,
                    minInertiaRatio=0.01, thresholdStep=5,
                    minDistBetweenBlobs=3.0, minRepeatability=3
                )

                nodes_pos = get_nodes_pos(keypoints)
                if nodes_pos.shape[0] != tip_num:
                    print(f"[Skipped Frame] ❌ Blob detection failed: {row.image_name}, detected tips = {nodes_pos.shape[0]}")
                    continue

                Axx_canon, Cxx_canon, Cyy_canon, XY_canon = TransformVoronoi_331(borderScale=1.1).transform(nodes_pos)
                if len(Axx_canon) != tip_num:
                    print(f"[Skipped Frame] ❌ Voronoi transform error: {row.image_name}, Axx_canon size = {len(Axx_canon)}")
                    continue

                # —— 统一 numpy —— #
                XY_canon  = np.asarray(XY_canon,  dtype=np.float32)      # (N,2)
                Axx_canon = np.asarray(Axx_canon, dtype=np.float32).reshape(-1)
                nodes_pos = np.asarray(nodes_pos, dtype=np.float32)

                # —— 帧间匹配并重排 —— #
                if prev_XY_canon is None:
                    # 第一帧：设定锚点与尺度
                    anchor_XY_canon = XY_canon.copy()
                    prev_XY_canon   = XY_canon.copy()
                    # 估计最近邻间距中位数，作为动态阈值的尺度
                    d_nn = ((anchor_XY_canon[:,None,:]-anchor_XY_canon[None,:,:])**2).sum(-1)**0.5
                    np.fill_diagonal(d_nn, 1e9)
                    nn_med = np.median(d_nn.min(axis=1))
                    TAU_NODE_BASE = 0.30 * nn_med      # 可调：0.30~0.50
                    perm = np.arange(tip_num, dtype=np.int64)
                else:
                    # 与上一帧、与首帧都匹配，取更好者（以95分位成本为准）
                    if _HAS_SCIPY:
                        cost_prev   = cdist(prev_XY_canon,   XY_canon)  # (N,N)
                        _, col_prev = linear_sum_assignment(cost_prev)
                        p95_prev    = np.percentile(cost_prev[np.arange(tip_num), col_prev], 95)

                        cost_anchor   = cdist(anchor_XY_canon, XY_canon)
                        _, col_anchor = linear_sum_assignment(cost_anchor)
                        p95_anchor    = np.percentile(cost_anchor[np.arange(tip_num), col_anchor], 95)

                        perm = col_prev if p95_prev <= p95_anchor else col_anchor
                    else:
                        # 退化：最近邻近似
                        dist_prev   = ((prev_XY_canon[:,None,:]-XY_canon[None,:,:])**2).sum(-1)
                        col_prev    = dist_prev.argmin(axis=1)
                        p95_prev    = float(np.percentile(np.sqrt(dist_prev[np.arange(tip_num), col_prev]), 95))

                        dist_anchor = ((anchor_XY_canon[:,None,:]-XY_canon[None,:,:])**2).sum(-1)
                        col_anchor  = dist_anchor.argmin(axis=1)
                        p95_anchor  = float(np.percentile(np.sqrt(dist_anchor[np.arange(tip_num), col_anchor]), 95))

                        perm = col_prev if p95_prev <= p95_anchor else col_anchor

                    # 按 perm 重排（！！！三者一致重排）
                    XY_canon  = XY_canon[perm]
                    Axx_canon = Axx_canon[perm]
                    nodes_pos = nodes_pos[perm]

                # —— 节点特征 —— #
                node_feats = np.asarray([XY_canon[:, 0], XY_canon[:, 1], Axx_canon]).T
                x_list.append(torch.tensor(node_feats, dtype=torch.float32))

                # —— 帧内（空间）边 —— #
                spatial_edges = np.array(Delaunay_graph_generate(nodes_pos)).T  # [2, E_s]
                spatial_edges = spatial_edges + node_offset
                edge_list.append(torch.tensor(spatial_edges, dtype=torch.long))

                # —— 帧间（时序）边：逐边门控 —— #
                if prev_XY_canon is not None:
                    # 匹配距离
                    dmatch = ((prev_XY_canon - XY_canon)**2).sum(-1)**0.5  # [N]
                    # 动态阈值：不小于首帧尺度；也不低于当前 90 分位（可试 85/95）
                    tau_node = max(TAU_NODE_BASE, np.percentile(dmatch, 80))
                    good = np.where(dmatch <= tau_node)[0]                  # 对齐较好的节点

                    if good.size > 0:
                        te = np.vstack([good + node_offset - tip_num, good + node_offset])  # [2, |good|]
                        te_t = torch.tensor(te, dtype=torch.long)
                        edge_list.append(te_t)
                        temporal_edges_seq.append(te_t)
                        # 可选：置信度（0~1），供 GAT 使用
                        conf = 1.0 - (dmatch[good] / (tau_node + 1e-8))
                        temporal_conf_seq.append(torch.tensor(conf, dtype=torch.float32))

                # —— 更新参照与偏移 —— #
                prev_XY_canon = XY_canon.copy()
                node_offset  += tip_num

            # === 完整性检查 ===
            if len(x_list) != window_size:
                print(f"[Skipped Sequence] ❌ Incomplete seq obj_id={obj_id}, frames {start_idx}-{start_idx + window_size - 1}, valid frames={len(x_list)}")
                continue

            # === 拼接并打包 ===
            x = torch.cat(x_list, dim=0).float()                         # [T*N, 3]
            edge_index = torch.cat(edge_list, dim=1).long().contiguous() # [2, E_total]

            # 可选：单独的 temporal_edge_index / temporal_conf
            if len(temporal_edges_seq) > 0:
                temporal_edge_index = torch.cat(temporal_edges_seq, dim=1).long().contiguous()
                d = Data(
                    x=x,
                    edge_index=edge_index,
                    y=torch.tensor(label, dtype=torch.float32),
                    temporal_edge_index=temporal_edge_index
                )
                if len(temporal_conf_seq) > 0:
                    temporal_conf = torch.cat(temporal_conf_seq, dim=0).contiguous()  # [E_t]
                    d.temporal_conf = temporal_conf
                kept_cnt = int(temporal_edge_index.size(1))
                max_possible = (window_size - 1) * tip_num
            else:
                d = Data(x=x, edge_index=edge_index, y=torch.tensor(label, dtype=torch.float32))
                kept_cnt = 0
                max_possible = (window_size - 1) * tip_num

            total_cnt = int(edge_index.size(1))
            print(f"[Edges] temporal kept: {kept_cnt}/{max_possible}  "
                  f"= {kept_cnt/max_possible if max_possible>0 else 0:.2%} | "
                  f"temp/all = {kept_cnt/total_cnt if total_cnt>0 else 0:.2%}")

            data_list.append(d)
            print(f"Processed obj_id {obj_id}, frames {start_idx}-{start_idx + window_size - 1}")

    # Split train/val/test with safe fallback
    if len(data_list) < 4:
        print(f"[警告] 样本数仅 {len(data_list)} 个，将全部用于训练集。")
        train_data, test_data = data_list, []
    else:
        np.random.shuffle(data_list)
        split_idx = int(len(data_list) * 0.75)
        train_data = data_list[:split_idx]
        test_data  = data_list[split_idx:]

    torch.save(train_data, os.path.join(train_data_dir, 'Train_data_list.pt'))
    torch.save(test_data,  os.path.join(test_data_dir,  'Test_data_list.pt'))

    print(f"\n✅ Completed：{len(train_data)} train, {len(test_data)} test.  (tip_num={tip_num}, window={window_size}, stride={stride})")
    return train_data, test_data


## 2. Generate Voronoi Graph Dataset

In [3]:
# raw_data_dir = r'..\data\331\model_surface2d\frames_bw'  ## raw image dir
# train_data_dir = r'..\result\train'  ## train data saving dir
# test_data_dir = r'..\result\test'  ## test data saving dir


project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
raw_data_dir = os.path.join(project_root, "data", "331", "model_surface2d", "frames_bw")
train_data_dir = os.path.join(project_root, "result", "train")
test_data_dir= os.path.join(project_root, "result", "test")

## run the dataset generation...
voronoi_graph_data_list = voronoi_graph_dataset_generate_with_temporal_edges(raw_data_dir, train_data_dir, test_data_dir, tip_num = 331)

[Skipped Frame] ❌ Blob detection failed: frame_13_0.png, detected tips = 328
[Skipped Sequence] ❌ Incomplete seq obj_id=1, frames 0-19, valid frames=19
[Skipped Frame] ❌ Blob detection failed: frame_26_0.png, detected tips = 330
[Skipped Sequence] ❌ Incomplete seq obj_id=1, frames 20-39, valid frames=19
[Skipped Frame] ❌ Blob detection failed: frame_45_0.png, detected tips = 329
[Skipped Frame] ❌ Blob detection failed: frame_50_0.png, detected tips = 326
[Skipped Sequence] ❌ Incomplete seq obj_id=1, frames 40-59, valid frames=18
[Skipped Frame] ❌ Blob detection failed: frame_70_0.png, detected tips = 318
[Skipped Frame] ❌ Blob detection failed: frame_79_0.png, detected tips = 326
[Skipped Sequence] ❌ Incomplete seq obj_id=1, frames 60-79, valid frames=18
[Edges] temporal kept: 5673/6289  = 90.21% | temp/all = 13.23%
Processed obj_id 1, frames 80-99
[Skipped Frame] ❌ Blob detection failed: frame_107_0.png, detected tips = 326
[Skipped Frame] ❌ Blob detection failed: frame_112_0.png, det

In [4]:
train_val_data = voronoi_graph_data_list[0]
print(train_val_data[0])
print(train_val_data[0].x)  # tensor of shape [331, 3]
print(train_val_data[0].edge_index)  # tensor of shape [2, num_edges]
print(train_val_data[0].y)  # tensor of shape [2]



Data(x=[6620, 3], edge_index=[2, 42936], y=[2], temporal_edge_index=[2, 5736], temporal_conf=[5736])
tensor([[   2.7284,  113.2284,  157.2170],
        [  -9.5316,  113.0573,  154.6758],
        [  14.7904,  111.5826,  152.7911],
        ...,
        [   8.9214, -114.3650,  146.7135],
        [ -14.4119, -113.9119,  145.3327],
        [  -3.0624, -114.7237,  144.0359]])
tensor([[   0,    0,    0,  ..., 6286, 6287, 6288],
        [   1,    2,   10,  ..., 6617, 6618, 6619]])
tensor([-3.9756, -9.3399])
