In [2]:
import os
from utils.fns import load_graph
import torch
from torch_geometric.data import HeteroData

def inspect_heterodata(path):
    """
    加载单个 .dgl 文件（内部其实是 HeteroData），
    打印出其所有 node_types、edge_types 下各字段的 dtype & shape。
    """
    data = load_graph(path)
    # 如果 load_graph 返回的是一个 list（batch），取第一个
    if isinstance(data, list):
        data = data[0]

    if not isinstance(data, HeteroData):
        print(f"{os.path.basename(path)} 不是 HeteroData，而是 {type(data)}")
        return

    print(f"\n=== Inspect {os.path.basename(path)} ===")
    # 所有节点类型
    print("Node types:", data.node_types)
    for ntype in data.node_types:
        print(f"\n[NodeType: {ntype}]")
        for key, tensor in data[ntype].items():
            if isinstance(tensor, torch.Tensor):
                # 将 dtype 转为字符串再打印
                dtype_str = str(tensor.dtype)
                print(f"  • {key:<20s} dtype={dtype_str:<12s} shape={tuple(tensor.shape)}")
            else:
                print(f"  • {key:<20s} (非 Tensor 类型: {type(tensor)})")

    # 所有边类型
    print("\nEdge types:", data.edge_types)
    for etype in data.edge_types:
        print(f"\n[EdgeType: {etype}]")
        for key, tensor in data[etype].items():
            if isinstance(tensor, torch.Tensor):
                dtype_str = str(tensor.dtype)
                print(f"  • {key:<20s} dtype={dtype_str:<12s} shape={tuple(tensor.shape)}")
            else:
                print(f"  • {key:<20s} (非 Tensor 类型: {type(tensor)})")


if __name__ == "__main__":
    graph_dir = "datas/cas15/graph_train"   # 改成你的路径
    files = sorted(f for f in os.listdir(graph_dir) if f.endswith(".dgl"))
    print(f"Found {len(files)} .dgl files in {graph_dir}")

    # 只 inspect 第一个文件，改成 for 循环可以全部遍历
    first = files[0]
    inspect_heterodata(os.path.join(graph_dir, first))


Found 2000 .dgl files in datas/cas15/graph_train

=== Inspect 1A0D_A_336_346.dgl ===
Node types: ['protein', 'protein_atom', 'loop']

[NodeType: protein]
  • node_s               dtype=torch.float32 shape=(118, 95)
  • node_v               dtype=torch.float32 shape=(118, 3, 3)
  • xyz                  dtype=torch.float32 shape=(118, 3)
  • xyz_full             dtype=torch.float32 shape=(118, 24, 3)
  • seq                  dtype=torch.int32  shape=(118,)

[NodeType: protein_atom]
  • node_s               dtype=torch.float32 shape=(1058, 89)

[NodeType: loop]
  • node_s               dtype=torch.float32 shape=(44, 1)
  • xyz                  dtype=torch.float32 shape=(44, 3)
  • seq_parent_forward   dtype=torch.int64  shape=(44,)
  • seq_order_forward    dtype=torch.int64  shape=(44,)
  • seq_parent_reverse   dtype=torch.int64  shape=(44,)
  • seq_order_reverse    dtype=torch.int64  shape=(44,)

Edge types: [('protein', 'p2p', 'protein'), ('protein_atom', 'pa2pa', 'protein_atom'), ('loo