
# Inspect Dataset — Notebook 自检

这个 Notebook 会：
- 读取 `configs/seqgat_k5_st.yaml`（或你可以改为 Tactile 的 YAML）；
- 用 `data/pt_dataset.py` 加载数据（会按 `k_last_frames` 自动裁到 5 帧）；
- 对前几个样本做一致性检查（帧索引、空间/时序边、时序边占比、属性尺寸等）；
- 输出数据集级别的统计。


In [1]:

import os, json, math, random
import numpy as np
import torch
import yaml

from data.pt_dataset import load_dataset

# 你也可以把它改成 'configs/tactile_gat_k5_s.yaml'
CONFIG = 'configs/seqgat_k5_st.yaml'

cfg = yaml.safe_load(open(CONFIG, 'r'))
tip_num = int(cfg.get('tip_num', 331))
k_last  = cfg.get('k_last_frames', 5)

ds = load_dataset(
    cfg['data_root'],
    glob_pattern=cfg.get('glob','Train_data_list.pt'),
    field_map=cfg.get('field_map',{}),
    tip_num=tip_num,
    k_last_frames=k_last,
)

print(f"Loaded {len(ds)} samples | tip_num={tip_num} | k_last_frames={k_last}")
if len(ds) > 0:
    d0 = ds[0]
    print('Sample[0] shapes:', 'x', tuple(d0.x.shape), '| y', tuple(d0.y.shape))
    print('t: min/max', int(d0.t.min()), int(d0.t.max()))
    print('edge_index_s:', getattr(d0,'edge_index_s', None).shape if hasattr(d0,'edge_index_s') else None)
    print('edge_index_t:', getattr(d0,'edge_index_t', None).shape if hasattr(d0,'edge_index_t') else None)
    print('edge_attr_t:', getattr(d0,'edge_attr_t', None).shape if hasattr(d0,'edge_attr_t') else None)


  from .autonotebook import tqdm as notebook_tqdm
Processing...
  f_u = (u // tip_num); f_v = (v // tip_num)


Loaded 144 samples | tip_num=331 | k_last_frames=5
Sample[0] shapes: x (1655, 3) | y (2,)
t: min/max 15 19
edge_index_s: torch.Size([2, 9300])
edge_index_t: torch.Size([2, 1192])
edge_attr_t: torch.Size([1192, 1])


Done!


## 定义检查函数

In [2]:

def inspect_sample(d, tip_num: int):
    res = {}
    N = d.x.size(0)
    F = d.x.size(1)
    t = d.t
    t_min = int(t.min()); t_max = int(t.max())
    K = t_max - t_min + 1

    # 每帧节点数
    counts = torch.bincount(t - t_min, minlength=K).cpu().numpy()

    res['N_nodes'] = int(N)
    res['F'] = int(F)
    res['t_min'] = t_min
    res['t_max'] = t_max
    res['K_frames'] = int(K)
    res['counts_first5'] = counts[:min(5, K)].tolist()
    res['all_frames_have_tipnum_nodes'] = bool(np.all(counts == tip_num))

    # 边检查
    e_s = getattr(d, 'edge_index_s', None)
    e_t = getattr(d, 'edge_index_t', None)
    ea_t = getattr(d, 'edge_attr_t', None)

    res['E_s'] = 0 if (e_s is None) else int(e_s.size(1))
    res['E_t'] = 0 if (e_t is None) else int(e_t.size(1))

    ok_s = True
    if e_s is not None and e_s.numel() > 0:
        ok_s = torch.all(t[e_s[0]] == t[e_s[1]]).item()
    res['spatial_same_frame'] = bool(ok_s)

    ok_t_frame = True
    ok_t_jump  = True
    median_jump = None
    kept_frac = None
    if e_t is not None and e_t.numel() > 0:
        ok_t_frame = torch.all(t[e_t[1]] == t[e_t[0]] + 1).item()
        jumps = (e_t[1] - e_t[0]).detach().cpu()
        median_jump = int(torch.median(jumps).item())
        ok_t_jump  = bool(torch.all(jumps == tip_num).item())
        kept_frac = float(res['E_t'] / max((K-1)*tip_num, 1))
    res['temporal_next_frame'] = bool(ok_t_frame)
    res['temporal_v_minus_u_eq_tip'] = bool(ok_t_jump)
    res['temporal_jump_median'] = median_jump
    res['temporal_kept_fraction'] = kept_frac

    if ea_t is not None:
        res['edge_attr_t_shape'] = (int(ea_t.size(0)), int(ea_t.size(1)))
        res['edge_attr_t_align'] = (ea_t.size(0) == (e_t.size(1) if e_t is not None else 0))
    else:
        res['edge_attr_t_shape'] = None
        res['edge_attr_t_align'] = None

    return res


## 检查前 3 个样本

In [3]:

N_SHOW = min(3, len(ds))
for i in range(N_SHOW):
    print(f"--- Sample {i} ---")
    info = inspect_sample(ds[i], tip_num)
    for k, v in info.items():
        print(f"{k:>28}: {v}")


--- Sample 0 ---
                     N_nodes: 1655
                           F: 3
                       t_min: 15
                       t_max: 19
                    K_frames: 5
               counts_first5: [331, 331, 331, 331, 331]
all_frames_have_tipnum_nodes: True
                         E_s: 9300
                         E_t: 1192
          spatial_same_frame: True
         temporal_next_frame: True
   temporal_v_minus_u_eq_tip: True
        temporal_jump_median: 331
      temporal_kept_fraction: 0.9003021148036254
           edge_attr_t_shape: (1192, 1)
           edge_attr_t_align: True
--- Sample 1 ---
                     N_nodes: 1655
                           F: 3
                       t_min: 15
                       t_max: 19
                    K_frames: 5
               counts_first5: [331, 331, 331, 331, 331]
all_frames_have_tipnum_nodes: True
                         E_s: 9300
                         E_t: 1060
          spatial_same_frame: True
         tempora

## 数据集级别统计

In [4]:

tot_Es = tot_Et = 0
tot_nodes = 0
kept_fracs = []
bad_spatial = 0
bad_temporal = 0

for d in ds:
    info = inspect_sample(d, tip_num)
    tot_Es += info['E_s']
    tot_Et += info['E_t']
    tot_nodes += info['N_nodes']
    if info['temporal_kept_fraction'] is not None:
        kept_fracs.append(info['temporal_kept_fraction'])
    if not info['spatial_same_frame']:
        bad_spatial += 1
    if info['E_t'] > 0 and (not info['temporal_next_frame'] or not info['temporal_v_minus_u_eq_tip']):
        bad_temporal += 1

avg_Es = tot_Es / max(len(ds),1)
avg_Et = tot_Et / max(len(ds),1)
avg_nodes = tot_nodes / max(len(ds),1)
avg_kept = (sum(kept_fracs)/len(kept_fracs)) if kept_fracs else None

print(f"Avg nodes per sample: {avg_nodes:.1f}")
print(f"Avg spatial edges Es: {avg_Es:.1f}")
print(f"Avg temporal edges Et: {avg_Et:.1f}")
print(f"Avg temporal kept fraction (Et/((K-1)*N)): {avg_kept:.3f}" if avg_kept is not None else "No temporal edges found.")
print(f"Samples with bad spatial edges (not same-frame): {bad_spatial}")
print(f"Samples with bad temporal edges (not next-frame or jump!=tip_num): {bad_temporal}")


Avg nodes per sample: 1655.0
Avg spatial edges Es: 9300.1
Avg temporal edges Et: 1121.7
Avg temporal kept fraction (Et/((K-1)*N)): 0.847
Samples with bad spatial edges (not same-frame): 0
Samples with bad temporal edges (not next-frame or jump!=tip_num): 0
