In [1]:
import os
import glob
import json
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
from collections import Counter, defaultdict

import torch
from torch.utils.data import Dataset

# ====== label space（沿用你之前那套 14 类；后续你想换可以改）======
ID2LABEL_14 = [
    "Title", "Author", "Mail", "Affiliation", "Section",
    "First-Line", "Para-Line", "Equation", "Table", "Figure",
    "Caption", "Page-Footer", "Page-Header", "Footnote"
]
LABEL2ID_14 = {k:i for i,k in enumerate(ID2LABEL_14)}

REL2ID = {"connect": 0, "contain": 1, "equality": 2, "meta": 3}
ID2REL = {v:k for k,v in REL2ID.items()}


def map_hrd_class_to_14(c: str) -> int:
    c = (c or "").lower().strip()
    if c == "title":
        return LABEL2ID_14["Title"]
    if c == "author":
        return LABEL2ID_14["Author"]
    if c in ("affili", "affiliation"):
        return LABEL2ID_14["Affiliation"]
    if c in ("header",):
        return LABEL2ID_14["Page-Header"]
    if c in ("footer",):
        return LABEL2ID_14["Page-Footer"]
    if c in ("fnote",):
        return LABEL2ID_14["Footnote"]
    if c.startswith("sec"):
        return LABEL2ID_14["Section"]
    if c in ("fstline",):
        return LABEL2ID_14["First-Line"]
    if c in ("para", "opara"):
        return LABEL2ID_14["Para-Line"]
    # 兜底
    return LABEL2ID_14["Para-Line"]



def get_image_path(image_root, doc_id, page_id):
    """
    Try multiple naming conventions:
    1) HRDH-style: <image_root>/<doc_id>/<page_id>.png
    2) HRDH-style with other ext: <image_root>/<doc_id>/<page_id>.(jpg/jpeg/png)
    3) HRDS-style: <image_root>/<doc_id>/<doc_id>_<page_id>.(jpg/jpeg/png)
    4) HRDS-style flat (just in image_root): <image_root>/<doc_id>_<page_id>.(jpg/jpeg/png)

    Returns existing path; otherwise raises FileNotFoundError with helpful info.
    """
    exts = ("png", "jpg", "jpeg", "webp")

    # --- A) in subfolder <doc_id>/ ---
    doc_dir = os.path.join(image_root, doc_id)

    # 1) <page_id>.<ext>
    for ext in exts:
        p = os.path.join(doc_dir, f"{page_id}.{ext}")
        if os.path.exists(p):
            return p

    # 2) <doc_id>_<page_id>.<ext>
    for ext in exts:
        p = os.path.join(doc_dir, f"{doc_id}_{page_id}.{ext}")
        if os.path.exists(p):
            return p

    # --- B) flat in image_root ---
    for ext in exts:
        p = os.path.join(image_root, f"{doc_id}_{page_id}.{ext}")
        if os.path.exists(p):
            return p

    # --- C) glob fallback (covers weird ext/case) ---
    # in doc folder
    pats = [
        os.path.join(doc_dir, f"{page_id}.*"),
        os.path.join(doc_dir, f"{doc_id}_{page_id}.*"),
        os.path.join(image_root, f"{doc_id}_{page_id}.*"),
    ]
    for pat in pats:
        hits = glob.glob(pat)
        if hits:
            # pick the first deterministically (sorted)
            hits = sorted(hits)
            return hits[0]

    raise FileNotFoundError(
        f"Missing page image for doc_id={doc_id}, page_id={page_id}. "
        f"Tried under: {doc_dir} and {image_root}. "
        f"Example expected: {os.path.join(doc_dir, str(page_id)+'.png')} or {os.path.join(doc_dir, f'{doc_id}_{page_id}.jpg')}"
    )

import heapq
from collections import defaultdict

def topo_sort_with_reading_priority(items):
    """
    items: HRDH json list，每个元素至少包含 box/page/parent_id
    返回：满足 parent 一定在 child 前的顺序（同时尽量贴近阅读序）
    """
    n = len(items)

    # reading priority rank
    ranks = []
    for i, it in enumerate(items):
        x0, y0, x1, y1 = it["box"]
        ranks.append((int(it["page"]), float(y0), float(x0), i))

    # parent -> child graph
    indeg = [0] * n
    g = defaultdict(list)
    for child in range(n):
        p = int(items[child].get("parent_id", -1))
        if p < 0:
            continue
        if 0 <= p < n:
            g[p].append(child)
            indeg[child] += 1

    # Kahn with heap prioritized by reading rank
    heap = []
    for i in range(n):
        if indeg[i] == 0:
            heapq.heappush(heap, (ranks[i], i))

    order = []
    while heap:
        _, u = heapq.heappop(heap)
        order.append(u)
        for v in g[u]:
            indeg[v] -= 1
            if indeg[v] == 0:
                heapq.heappush(heap, (ranks[v], v))

    # fallback if cycles/noise exist
    if len(order) < n:
        remaining = [i for i in range(n) if i not in set(order)]
        remaining.sort(key=lambda i: ranks[i])
        order.extend(remaining)

    return order


def load_hrdh_json(json_path: str) -> Dict[str, Any]:
    """
    读取单个 HRDH json，返回 doc dict（units + labels + parent + relation），并重映射 parent 到新排序 index。
    当前排序：按 (page, y0, x0) 近似阅读顺序（足够用于训练闭环；后续可替换为双栏阅读序）。
    """
    with open(json_path, "r", encoding="utf-8") as f:
        items = json.load(f)

    def sort_key(idx_item):
        idx, it = idx_item
        x0, y0, x1, y1 = it["box"]
        return (int(it["page"]), int(y0), int(x0), idx)

    order_old = topo_sort_with_reading_priority(items)
    indexed_sorted = [(i, items[i]) for i in order_old]


    old2new = {old_i: new_i for new_i, (old_i, _) in enumerate(indexed_sorted)}

    units = []
    y_parent, y_rel, y_cls, is_meta = [], [], [], []

    for new_i, (old_i, it) in enumerate(indexed_sorted):
        text = (it.get("text") or "").strip()
        x0, y0, x1, y1 = it["box"]
        page_id = int(it["page"])
        cls_raw = it.get("class", "para")
        rel_raw = it.get("relation", "connect")
        meta_flag = bool(it.get("is_meta", False))
        parent_old = int(it.get("parent_id", -1))
        parent_new = -1 if parent_old == -1 else old2new.get(parent_old, -1)

        units.append({
            "text": text,
            "bbox": (float(x0), float(y0), float(x1), float(y1)),  # pixel bbox
            "page_id": page_id,
            "order_id": new_i,
            "class_raw": cls_raw,
        })
        y_parent.append(parent_new)
        y_rel.append(REL2ID.get(rel_raw, 0))
        y_cls.append(map_hrd_class_to_14(cls_raw))
        is_meta.append(meta_flag)

    doc_id = os.path.basename(json_path).replace(".json", "")
    return {
        "doc_id": doc_id,
        "json_path": json_path,
        "units": units,
        "y_parent": y_parent,
        "y_rel": y_rel,
        "y_cls": y_cls,
        "is_meta": is_meta,
    }


class HRDHDataset(Dataset):
    def __init__(self, root_dir: str, split: str = "train", max_len: int = 512):
        """
        root_dir: .../HRDH
        split: train 或 test
        max_len: 截断长度（论文常用 512）
        """
        assert split in ("train", "test")
        self.root_dir = root_dir
        self.split = split
        self.max_len = max_len

        self.json_dir = os.path.join(root_dir, split)
        self.image_root = os.path.join(root_dir, "images")  

        self.json_paths = sorted(glob.glob(os.path.join(self.json_dir, "*.json")))
        if not self.json_paths:
            raise FileNotFoundError(f"No json files found in: {self.json_dir}")

    def __len__(self):
        return len(self.json_paths)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        json_path = self.json_paths[idx]
        doc = load_hrdh_json(json_path)

        # 截断（保持 parent 合法）
        if len(doc["units"]) > self.max_len:
            keep = self.max_len
            doc["units"] = doc["units"][:keep]
            doc["y_cls"] = doc["y_cls"][:keep]
            doc["y_rel"] = doc["y_rel"][:keep]
            doc["is_meta"] = doc["is_meta"][:keep]
            # parent: 超出范围的 parent 置为 0；指向被截断的也置 0（或 -1）
            y_parent = []
            for i, p in enumerate(doc["y_parent"][:keep]):
                if p == -1:
                    y_parent.append(-1)
                elif 0 <= p < keep:
                    y_parent.append(p)
                else:
                    y_parent.append(-1)
            doc["y_parent"] = y_parent

        # 为每个 unit 找到对应页图路径（按 page_id）
        # 注意：同一页只加载一次图片，训练时可以再做缓存/预处理
        page_ids = sorted(set(u["page_id"] for u in doc["units"]))
        page_images = {}
        for pid in page_ids:
            page_images[pid] = get_image_path(self.image_root, doc["doc_id"], pid)
        doc["page_images"] = page_images
        for i, p in enumerate(doc["y_parent"]):
            if p >= 0:
                assert p < i, f"parent not causal: i={i}, p={p}"
            if "is_meta" in doc:
                for i, flag in enumerate(doc["is_meta"]):
                    if flag:
                        doc["y_parent"][i] = -1  # ROOT
        
        return doc


In [2]:
REL3 = {"connect": 0, "contain": 1, "equality": 2}  # 论文三类
REL3_INV = {v:k for k,v in REL3.items()}

def filter_meta_and_remap(units_raw):
    """
    units_raw: list of dict from json, each has:
      text, box, class, page, is_meta, parent_id, relation
    return:
      units_kept: list of dict with keys {text, box, page_id, cls_name, parent, rel}
      old2new: dict old_index -> new_index  (only for kept)
    """
    keep_idx = [i for i,u in enumerate(units_raw) if not u.get("is_meta", False)]
    old2new = {old:new for new,old in enumerate(keep_idx)}

    units_kept = []
    for old_i in keep_idx:
        u = units_raw[old_i]
        p_old = int(u.get("parent_id", -1))
        # 如果 parent 被过滤掉，或本来就是 -1，则设为 ROOT(-1)
        p_new = old2new.get(p_old, -1) if p_old >= 0 else -1

        rel = u.get("relation", None)
        # 过滤后不应该再出现 meta；如果仍然出现，直接跳过该样本或置默认
        if rel == "meta":
            # 这里选择：直接将该单元丢弃（更干净）
            continue

        if rel not in REL3:
            raise ValueError(f"Unknown relation: {rel}")

        units_kept.append({
            "text": u.get("text",""),
            "box": u.get("box"),
            "page_id": int(u.get("page", 0)),
            "cls_name": u.get("class"),
            "parent": p_new,
            "rel": REL3[rel],
        })

    return units_kept, old2new


## 1. 环境与设备检查

- 检查 PyTorch / CUDA 可用性
- 设置 device


In [3]:
import torch

print("torch version:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))
    print("cuda capability:", torch.cuda.get_device_capability(0))


torch version: 2.5.1
cuda available: True
cuda device: NVIDIA GeForce RTX 4060 Laptop GPU
cuda capability: (8, 9)


## 2. 数据集路径与快速 sanity-check

- 指定 HRDH_ROOT
- 初始化数据集与样本查看

说明：原 Notebook 中此部分出现了两段几乎等价的 quick-check（为保持对齐，这里全部保留）。


In [4]:
HRDH_ROOT = r"C:\Users\tomra\Desktop\PAPER\final OBJ\HRDH"
ds = HRDHDataset(HRDH_ROOT, split="train", max_len=512)

doc = ds[0]
print("doc_id:", doc["doc_id"])
print("num_units:", len(doc["units"]))
print("page_images sample:", list(doc["page_images"].items())[:3])


doc_id: 1401.6399
num_units: 512
page_images sample: [(0, 'C:\\Users\\tomra\\Desktop\\PAPER\\final OBJ\\HRDH\\images\\1401.6399\\0.png'), (1, 'C:\\Users\\tomra\\Desktop\\PAPER\\final OBJ\\HRDH\\images\\1401.6399\\1.png'), (2, 'C:\\Users\\tomra\\Desktop\\PAPER\\final OBJ\\HRDH\\images\\1401.6399\\2.png')]


In [5]:
HRDH_ROOT = r"C:\Users\tomra\Desktop\PAPER\final OBJ\HRDH"

ds = HRDHDataset(HRDH_ROOT, split="train", max_len=512)
sample = ds[0]

print("doc_id:", sample["doc_id"])
print("num_units:", len(sample["units"]))
print("pages:", sorted(sample["page_images"].keys()))
print("missing_images:", [p for p,v in sample["page_images"].items() if v is None][:10])

# parent 合法性检查
L = len(sample["units"])
bad = [(i,p) for i,p in enumerate(sample["y_parent"]) if not (p==-1 or 0<=p<L)]
print("bad_parent_count:", len(bad), "example:", bad[:5])

# 简单预览
for i in range(min(10, L)):
    u = sample["units"][i]
    print(i, "p", u["page_id"], ID2LABEL_14[sample["y_cls"][i]], "par", sample["y_parent"][i], "|", u["text"][:80])


doc_id: 1401.6399
num_units: 512
pages: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
missing_images: []
bad_parent_count: 0 example: []
0 p 0 Title par -1 | SIMD Compression and the Intersection
1 p 0 Title par -1 | of Sorted Integers
2 p 0 Author par -1 | D. Lemire1 *, L. Boytsov2, N. Kurz3
3 p 0 Affiliation par -1 | 1LICEF Research Center, TELUQ, Montreal, QC, Canada
4 p 0 Affiliation par -1 | 2 Carnegie Mellon University, Pittsburgh, PA USA
5 p 0 Affiliation par -1 | 3 Verse Communications, Orinda, CA USA
6 p 0 First-Line par -1 | KEY WORDS: performance; measurement; index compression; vector processing
7 p 0 Section par -1 | 1. INTRODUCTION
8 p 0 First-Line par 7 | An inverted index maps terms to lists of document identifiers. A column index in
9 p 0 Para-Line par 8 | might, similarly, map attribute values to row identifiers. Storing all these lis


## 3. 先验统计：从数据集中估计 $M_{cp}$（Class Prior / Conditional Prior）

说明：原 Notebook 内存在两个版本的 `compute_M_cp_from_dataset`（一个较简版、一个带 `defaultdict`）。为保证“一个逻辑都不许少”，两段均保留，按原顺序排列。


In [6]:
import numpy as np

def compute_M_cp_from_dataset(dataset: HRDHDataset, num_classes: int, pseudo_count: float = 5.0):
    """
    返回 M_cp: shape (num_classes+1, num_classes)
    行：parent_class + ROOT
    列：child_class
    """
    ROOT = num_classes  # extra row index
    counts = np.zeros((num_classes + 1, num_classes), dtype=np.float64)

    for i in range(len(dataset)):
        doc = dataset[i]
        y_cls = doc["y_cls"]
        y_parent = doc["y_parent"]
        L = len(y_cls)

        for child in range(L):
            c = y_cls[child]
            p = y_parent[child]
            if p == -1:
                pc = ROOT
            else:
                pc = y_cls[p]
            counts[pc, c] += 1.0

    # additive smoothing per column
    counts += pseudo_count
    # normalize columns -> probability
    col_sum = counts.sum(axis=0, keepdims=True)
    M_cp = counts / np.clip(col_sum, 1e-12, None)
    return M_cp

train_ds = HRDHDataset(HRDH_ROOT, split="train", max_len=512)
M_cp = compute_M_cp_from_dataset(train_ds, num_classes=len(ID2LABEL_14), pseudo_count=5.0)
print("M_cp shape:", M_cp.shape, "colsum (should be 1):", M_cp.sum(axis=0)[:5])

np.save("M_cp_hrdh.npy", M_cp)
print("saved: M_cp_hrdh.npy")


M_cp shape: (15, 14) colsum (should be 1): [1. 1. 1. 1. 1.]
saved: M_cp_hrdh.npy


In [7]:
import numpy as np
from collections import defaultdict

def compute_M_cp_from_dataset(dataset, num_classes: int, pseudo_count: float = 5.0):
    """
    M_cp shape: (num_classes+1, num_classes)
      rows: parent_class + ROOT(row = num_classes)
      cols: child_class
    Column-normalized: sum over rows for each col = 1
    """
    ROOT = num_classes
    counts = np.zeros((num_classes + 1, num_classes), dtype=np.float64)

    for di in range(len(dataset)):
        doc = dataset[di]
        y_cls = doc["y_cls"]
        y_parent = doc["y_parent"]
        L = len(y_cls)

        for child in range(L):
            c = int(y_cls[child])
            p = int(y_parent[child])
            if p == -1:
                pc = ROOT
            else:
                pc = int(y_cls[p])
            counts[pc, c] += 1.0

    # Additive smoothing
    counts += float(pseudo_count)

    # Column normalize
    col_sum = counts.sum(axis=0, keepdims=True)
    M_cp = counts / np.clip(col_sum, 1e-12, None)
    return M_cp, counts

def top_parent_for_each_child(M_cp: np.ndarray, id2label: list, topk: int = 6):
    """
    For each child class, list top-k parent classes (including ROOT).
    """
    num_classes = len(id2label)
    ROOT = num_classes

    rows, cols = M_cp.shape
    assert rows == num_classes + 1 and cols == num_classes

    for child in range(num_classes):
        probs = M_cp[:, child]
        idxs = np.argsort(-probs)[:topk]
        child_name = id2label[child]
        print(f"\n[Child] {child} {child_name}")
        for r in idxs:
            parent_name = "ROOT" if r == ROOT else id2label[r]
            print(f"  parent={parent_name:<12s}  P={probs[r]:.4f}")

def top_edges_global(counts: np.ndarray, id2label: list, topn: int = 30):
    """
    Print global most frequent (parent, child) pairs from raw counts (before normalization).
    """
    num_classes = len(id2label)
    ROOT = num_classes

    flat = []
    for pc in range(num_classes + 1):
        for cc in range(num_classes):
            flat.append((counts[pc, cc], pc, cc))
    flat.sort(reverse=True, key=lambda x: x[0])

    print(f"\nTop {topn} (parent, child) by COUNT (after smoothing included):")
    for k in range(topn):
        cnt, pc, cc = flat[k]
        parent_name = "ROOT" if pc == ROOT else id2label[pc]
        child_name = id2label[cc]
        print(f"{k:02d}  {parent_name:<12s} -> {child_name:<12s}  count={cnt:.1f}")

# ==== run ====
HRDH_ROOT = r"C:\Users\tomra\Desktop\PAPER\final OBJ\HRDH"
train_ds = HRDHDataset(HRDH_ROOT, split="train", max_len=512)

M_cp, counts = compute_M_cp_from_dataset(train_ds, num_classes=len(ID2LABEL_14), pseudo_count=5.0)

print("M_cp shape:", M_cp.shape)
print("col sums (first 8):", np.round(M_cp.sum(axis=0)[:8], 6))

# sanity: should be ~1.0
assert np.allclose(M_cp.sum(axis=0), 1.0, atol=1e-6), "Column sums not 1.0"

top_edges_global(counts, ID2LABEL_14, topn=25)
top_parent_for_each_child(M_cp, ID2LABEL_14, topk=6)

np.save("M_cp_hrdh.npy", M_cp)
print("\nSaved: M_cp_hrdh.npy")


M_cp shape: (15, 14)
col sums (first 8): [1. 1. 1. 1. 1. 1. 1. 1.]

Top 25 (parent, child) by COUNT (after smoothing included):
00  Para-Line    -> Para-Line     count=288776.0
01  First-Line   -> Para-Line     count=57570.0
02  First-Line   -> First-Line    count=54224.0
03  ROOT         -> Para-Line     count=14250.0
04  Section      -> First-Line    count=10697.0
05  Section      -> Section       count=10588.0
06  ROOT         -> Page-Header   count=6025.0
07  ROOT         -> Affiliation   count=3185.0
08  ROOT         -> Footnote      count=2470.0
09  ROOT         -> Title         count=1762.0
10  ROOT         -> Author        count=1641.0
11  ROOT         -> Section       count=1005.0
12  ROOT         -> First-Line    count=566.0
13  Section      -> Para-Line     count=455.0
14  Title        -> Title         count=5.0
15  Title        -> Author        count=5.0
16  Title        -> Mail          count=5.0
17  Title        -> Affiliation   count=5.0
18  Title        -> Section      

## 4. 训练复现性与全局配置

- 随机种子
- CFG 超参集中管理
- bbox 归一化等通用函数


In [8]:
import random
import math
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from PIL import Image
import torchvision
from torchvision import transforms


In [9]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [10]:
@dataclass
class CFG:
    # data
    max_len: int = 512
    batch_size: int = 1          # 强烈建议 doc-level batch=1（结构任务依赖整篇文档序列）
    num_workers: int = 2

    # model dims
    d_model: int = 256
    nhead: int = 8
    num_layers: int = 4
    dropout: float = 0.1

    # embedding vocab sizes
    max_1d_pos: int = 512
    max_pages: int = 32          # 超过则截断或 clamp
    layout_bins: int = 1001      # bbox 归一化到 [0,1000]
    
    # visual
    vis_crop_size: int = 224
    vis_out_dim: int = 256       # 视觉向量投影到 d_model
    
    # losses
    alpha_parent: float = 1.0
    alpha_rel: float = 1.0
    focal_gamma: float = 2.0
    focal_alpha: float = 0.25

    # optim
    lr: float = 2e-4
    weight_decay: float = 1e-2
    epochs: int = 10
    
cfg = CFG()

In [11]:
def normalize_box_xyxy(box, page_w, page_h, bins=1000):
    """
    输入：像素坐标 [x0,y0,x1,y1]
    输出：整数归一化到 [0,bins]
    """
    x0, y0, x1, y1 = box
    x0 = int(np.clip(round(x0 / max(page_w, 1) * bins), 0, bins))
    x1 = int(np.clip(round(x1 / max(page_w, 1) * bins), 0, bins))
    y0 = int(np.clip(round(y0 / max(page_h, 1) * bins), 0, bins))
    y1 = int(np.clip(round(y1 / max(page_h, 1) * bins), 0, bins))
    return [x0, y0, x1, y1]


def collate_doc(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    assert len(batch) == 1, "当前实现是 doc-level batch=1"
    return batch[0]

## 5. 特征编码模块

- 文本：SBERT/Transformer 句向量 + 投影
- 布局：位置/页面 embedding
- 视觉：图像 crop embedding


In [12]:
class SBERTTextEmbedder(nn.Module):
    """
    用 sentence-transformers 得到每个 unit 的句向量，然后投影到 d_model。
    - 支持简单的 dict 缓存：key=(doc_id, unit_idx)
    """
    def __init__(self, d_model: int, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        super().__init__()
        self.d_model = d_model
        self.model_name = model_name
        
        self._sbert = None
        self.proj = nn.Linear(384, d_model)  # all-MiniLM-L6-v2 输出 384
        self.cache: Dict[Tuple[str, int], torch.Tensor] = {}

    def _lazy_load(self):
        if self._sbert is None:
            from sentence_transformers import SentenceTransformer
            self._sbert = SentenceTransformer(self.model_name)

    @torch.no_grad()
    def encode_texts(self, texts: List[str]) -> torch.Tensor:
        self._lazy_load()
        emb = self._sbert.encode(texts, convert_to_tensor=True, show_progress_bar=False)  # (L,384)
        return emb

    def forward(self, doc_id: str, texts: List[str]) -> torch.Tensor:
        """
        返回 (L, d_model)，保证在 self.proj.weight.device 上
        cache 存 CPU，取用时搬到目标 device
        """
        device = self.proj.weight.device
        L = len(texts)
        out = [None] * L
        missing_idx, missing_text = [], []

        # 1) cache hit：取 CPU -> to(device)
        for i, t in enumerate(texts):
            key = (doc_id, i)
            if key in self.cache:
                out[i] = self.cache[key].to(device)
            else:
                missing_idx.append(i)
                missing_text.append(t)

        # 2) cache miss：SBERT encode -> detach/clone -> proj -> 存 CPU
        if len(missing_idx) > 0:
            emb_384 = self.encode_texts(missing_text)          # 可能是 inference tensor
            emb_384 = emb_384.detach().clone().to(device)      # 关键：变普通 tensor + 上 GPU
            emb_d = self.proj(emb_384)                         # (M,d_model) on GPU

            for k, i in enumerate(missing_idx):
                key = (doc_id, i)
                self.cache[key] = emb_d[k].detach().cpu()      # cache 存 CPU，省显存
                out[i] = emb_d[k]                              # 当前 batch 直接用 GPU tensor

        return torch.stack(out, dim=0)
    

In [13]:
class LayoutPosPageEmbedder(nn.Module):
    def __init__(self, d_model: int, layout_bins: int = 1001, max_pos: int = 512, max_pages: int = 32):
        super().__init__()
        self.d_model = d_model
        self.layout_bins = layout_bins

        # LayoutLMv2 风格：x0,x1,w 与 y0,y1,h，各自 embedding 再 concat -> proj
        self.emb_x0 = nn.Embedding(layout_bins, d_model//4)
        self.emb_x1 = nn.Embedding(layout_bins, d_model//4)
        self.emb_w  = nn.Embedding(layout_bins, d_model//4)
        self.emb_y0 = nn.Embedding(layout_bins, d_model//4)
        self.emb_y1 = nn.Embedding(layout_bins, d_model//4)
        self.emb_h  = nn.Embedding(layout_bins, d_model//4)

        self.proj_layout = nn.Linear((d_model//4)*6, d_model)

        self.emb_pos = nn.Embedding(max_pos, d_model)
        self.emb_page = nn.Embedding(max_pages, d_model)

        self.ln = nn.LayerNorm(d_model)

    def forward(self, units: List[Dict[str, Any]], page_images: Dict[int, str]) -> torch.Tensor:
        """
        return: (L, d_model)
        """
        L = len(units)
        # 逐页拿宽高
        page_wh = {}
        for pid, pth in page_images.items():
            img = Image.open(pth)
            page_wh[int(pid)] = (img.width, img.height)

        layout_vecs = []
        pos_ids = []
        page_ids = []

        for i,u in enumerate(units):
            pid = int(u["page_id"])
            w, h = page_wh[pid]
            b = u.get("box", None)
            if b is None:
                b = u.get("bbox", None)
            if b is None:
                raise KeyError("Unit missing both 'box' and 'bbox'")
            x0, y0, x1, y1 = b

            nb = normalize_box_xyxy([x0,y0,x1,y1], w, h, bins=self.layout_bins-1)
            nx0, ny0, nx1, ny1 = nb
            nw = int(np.clip(nx1 - nx0, 0, self.layout_bins-1))
            nh = int(np.clip(ny1 - ny0, 0, self.layout_bins-1))

            layout_vecs.append([nx0, nx1, nw, ny0, ny1, nh])
            pos_ids.append(i)
            page_ids.append(min(pid, self.emb_page.num_embeddings-1))

        layout = torch.tensor(layout_vecs, dtype=torch.long, device=self.emb_pos.weight.device)  # (L,6)
        pos = torch.tensor(pos_ids, dtype=torch.long, device=self.emb_pos.weight.device)        # (L,)
        pages = torch.tensor(page_ids, dtype=torch.long, device=self.emb_pos.weight.device)     # (L,)

        x0 = self.emb_x0(layout[:,0])
        x1 = self.emb_x1(layout[:,1])
        ww = self.emb_w(layout[:,2])
        y0 = self.emb_y0(layout[:,3])
        y1 = self.emb_y1(layout[:,4])
        hh = self.emb_h(layout[:,5])
        layout_cat = torch.cat([x0,x1,ww,y0,y1,hh], dim=-1)
        layout_emb = self.proj_layout(layout_cat)

        pos_emb = self.emb_pos(pos)
        page_emb = self.emb_page(pages)

        out = self.ln(layout_emb + pos_emb + page_emb)
        return out

In [14]:
import torch
import torch.nn as nn
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision.ops import roi_align
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone


class VisualFPNRoIEmbedder(nn.Module):
    """
    Paper-style visual embedding:
    page image -> ResNet50+FPN -> RoIAlign by bbox -> pooled -> Linear -> (L, d_model)

    Interface kept identical to your existing VisualCropEmbedder:
      forward(units, page_images) -> Tensor[L, d_model]
    where:
      - units: list of dict, each must have:
          u["page_id"] : int
          u["box"]     : [x0, y0, x1, y1] in *pixel coords of the original page image*
      - page_images: dict or list-like indexed by page_id -> image path
    """
    def __init__(
        self,
        d_model: int,
        roi_out_size: int = 7,     # RoIAlign output spatial size (7x7 typical)
        roi_sampling_ratio: int = 2,
        fpn_level: str = "0",      # torchvision FPN returns keys like "0","1","2","3","pool"
    ):
        super().__init__()
        self.d_model = d_model
        self.roi_out_size = roi_out_size
        self.roi_sampling_ratio = roi_sampling_ratio
        self.fpn_level = fpn_level

        # ResNet50 + FPN backbone (detection-style)
        # returns dict of feature maps: {"0":P2,"1":P3,"2":P4,"3":P5,"pool":P6}
        self.backbone = resnet_fpn_backbone(
            backbone_name="resnet50",
            weights=torchvision.models.ResNet50_Weights.DEFAULT,
            trainable_layers=3,   # 可按需调；最小化侵入先这么设
        )

        # For resnet_fpn_backbone, each pyramid level channel is 256
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Linear(256, d_model)

        # ImageNet normalization
        w = torchvision.models.ResNet50_Weights.DEFAULT
        mean, std = w.transforms().mean, w.transforms().std
        self.tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])

    def _load_page(self, page_path: str):
        img = Image.open(page_path).convert("RGB")
        return img

    def _to_tensor(self, pil_img: Image.Image, device: torch.device):
        x = self.tf(pil_img).unsqueeze(0).to(device)  # (1,3,H,W)
        return x

    @staticmethod
    def _boxes_to_roi_format(boxes_xyxy, batch_idx: int = 0, device=None):
        """
        roi_align expects boxes as Tensor[K,5] = (batch_idx, x1, y1, x2, y2)
        in input image pixel coordinates when spatial_scale is set properly.
        """
        b = torch.tensor(boxes_xyxy, dtype=torch.float32, device=device)
        if b.numel() == 0:
            return b.new_zeros((0, 5))
        idx = torch.full((b.shape[0], 1), float(batch_idx), device=device)
        return torch.cat([idx, b], dim=1)

    def forward(self, units, page_images):
        device = self.proj.weight.device

        # group unit indices by page_id (so each page runs backbone once)
        page_to_indices = {}
        for i, u in enumerate(units):
            pid = int(u["page_id"])
            page_to_indices.setdefault(pid, []).append(i)

        out = torch.zeros((len(units), self.d_model), device=device)

        for pid, idxs in page_to_indices.items():
            page_path = page_images[pid]
            pil_img = self._load_page(page_path)
            W, H = pil_img.size

            x = self._to_tensor(pil_img, device=device)  # (1,3,H,W)

            feats = self.backbone(x)  # dict of feature maps
            if self.fpn_level not in feats:
                # fallback: use the highest-resolution level if key missing
                # typical keys: "0","1","2","3","pool"
                key = sorted([k for k in feats.keys() if k != "pool"])[0]
            else:
                key = self.fpn_level

            fmap = feats[key]  # (1,256,hf,wf)
            hf, wf = fmap.shape[-2], fmap.shape[-1]

            # spatial_scale maps original image coords -> feature map coords
            # roi_align uses a single scalar; assume isotropic scaling (works because fmap came from that image)
            spatial_scale = wf / float(W)

            boxes_xyxy = [units[i]["box"] for i in idxs]  # list of [x0,y0,x1,y1] in image pixels
            rois = self._boxes_to_roi_format(boxes_xyxy, batch_idx=0, device=device)  # (K,5)

            roi_feat = roi_align(
                input=fmap,
                boxes=rois,
                output_size=(self.roi_out_size, self.roi_out_size),
                spatial_scale=spatial_scale,
                sampling_ratio=self.roi_sampling_ratio,
                aligned=True,
            )  # (K,256,roi,roi)

            roi_feat = self.pool(roi_feat).flatten(1)  # (K,256)
            emb = self.proj(roi_feat)                  # (K,d_model)

            out[idxs] = emb

        return out


## 6. DSPS 模型主体

- 多模态编码
- 单元分类（node class）
- 关系分类（edge/rel）
- 论文中的 DSPS 思路在此实现


In [15]:
class DSPSModel(nn.Module):
    def __init__(
        self,
        num_classes: int,
        num_rel: int,
        M_cp: np.ndarray,
        cfg: CFG,
        use_text: bool = False,
        use_visual: bool = True,
        use_softmask: bool = True,
    ):
        super().__init__()
        self.use_softmask = use_softmask
        self.num_classes = num_classes
        self.num_rel = num_rel
        self.cfg = cfg
        self.use_text = use_text
        self.use_visual = use_visual

        # embeddings
        self.layout_pos_page = LayoutPosPageEmbedder(
            d_model=cfg.d_model,
            layout_bins=cfg.layout_bins,
            max_pos=cfg.max_1d_pos,
            max_pages=cfg.max_pages,
        )

        if use_text:
            self.text_emb = SBERTTextEmbedder(d_model=cfg.d_model)
        else:
            self.text_emb = None

        if use_visual:
            self.vis_emb = VisualFPNRoIEmbedder(d_model=cfg.d_model)
        else:
            self.vis_emb = None

        self.fuse_ln = nn.LayerNorm(cfg.d_model)

        # encoder (bidirectional)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=cfg.d_model,
            nhead=cfg.nhead,
            dim_feedforward=cfg.d_model * 4,
            dropout=cfg.dropout,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.num_layers)

        # subtask1: class
        self.cls_head = nn.Linear(cfg.d_model, num_classes)

        # decoder: structure-aware GRU
        self.gru = nn.GRU(input_size=cfg.d_model, hidden_size=cfg.d_model, batch_first=True)

        # attention projections for parent finding
        self.Wq = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.Wk = nn.Linear(cfg.d_model, cfg.d_model, bias=False)

        # relation head (concat)
        self.rel_head = nn.Sequential(
            nn.Linear(cfg.d_model * 2, cfg.d_model),
            nn.GELU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(cfg.d_model, num_rel),
        )

        # store M_cp (torch)
        # shape: (num_classes+1, num_classes)   rows=parent_class + ROOT(row=num_classes), cols=child_class
        M = torch.tensor(M_cp, dtype=torch.float32)
        self.register_buffer("M_cp", M)

    def forward(self, doc: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        """
        doc keys from your dataset:
          doc_id, units(list), y_cls(list), y_parent(list), y_rel(list), page_images(dict)
        returns logits:
          cls_logits: (L,C)
          par_logits: list of length L where par_logits[i] is (i,) logits for j in [0..i-1] + ROOT at index 0?
          rel_logits: (L, R) using GT parent during training convenience (可在 loss 里选)
        """
        units = doc["units"]
        L = len(units)
        doc_id = doc["doc_id"]
        page_images = doc["page_images"]

        # ---- embeddings sum: text + visual + layout/pos/page ----
        x = self.layout_pos_page(units, page_images)  # (L,d)

        if self.use_text:
            texts = [u.get("text","") for u in units]
            x = x + self.text_emb(doc_id, texts)

        if self.use_visual:
            x = x + self.vis_emb(units, page_images)

        x = self.fuse_ln(x)            # (L,d)
        x = x.unsqueeze(0)             # (1,L,d) for transformer batch_first

        # ---- encoder ----
        x_star = self.encoder(x)       # (1,L,d)
        x_star = x_star.squeeze(0)     # (L,d)

        # ---- class logits ----
        cls_logits = self.cls_head(x_star)  # (L,C)
        cls_prob = F.softmax(cls_logits, dim=-1)  # (L,C)

        # ---- ROOT representation ----
        root = x_star.mean(dim=0, keepdim=True)  # (1,d)

        # ---- GRU decoder (causal) ----
        # 论文写用 x*_{i-1} 驱动，这里用全序列输入 + 自己取 h_i（等价实现）
        h_seq, _ = self.gru(x_star.unsqueeze(0))  # (1,L,d)
        h_seq = h_seq.squeeze(0)                  # (L,d)

        # ---- parent logits with soft-mask ----
        # 我们把 candidate set 定义为 [ROOT] + [0..i-1]
        # 输出 par_logits[i] 形状 (i+1,) 其中 index0=ROOT, index k>0 对应 parent=j=k-1
        par_logits = []
        eps = 1e-8

        # 预先准备 ROOT 的 class distribution：用 uniform 或者用 mean prob；论文是扩展 P_cls(0) 为 (C+1)，ROOT=1
        # 我这里做：P_cls_root_over_parentclass = one-hot at ROOT row
        # 计算 P_dom 时使用 rows=parentclass+ROOT
        # 具体：P̃_cls(j) = [P_cls(j), 0] for real nodes；ROOT 用 [0..0,1]
        root_one = torch.zeros((1, self.num_classes + 1), device=x_star.device)
        root_one[0, self.num_classes] = 1.0

        # child prob 扩成 (C) 即原本；parent prob 扩成 (C+1)
        # 对每个 i:
        for i in range(L):
            q = self.Wq(h_seq[i:i+1])                # (1,d)
            # keys = [ROOT] + past h
            k_root = self.Wk(root)                   # (1,d)
            if i == 0:
                keys = k_root                        # (1,d)
            else:
                k_past = self.Wk(h_seq[:i])          # (i,d)
                keys = torch.cat([k_root, k_past], dim=0)  # (i+1,d)

            # dot-product scores
            score = (q @ keys.t()).squeeze(0)        # (i+1,)

            # ----- soft-mask prior: P_dom(i, j) -----
            # child distribution: (C)
            p_child = cls_prob[i:i+1]                # (1,C)
            # parent distributions:
            if i == 0:
                p_parent_ext = root_one              # (1,C+1)
            else:
                p_parent = cls_prob[:i]              # (i,C)
                zeros = torch.zeros((i,1), device=x_star.device)
                p_parent_ext = torch.cat([p_parent, zeros], dim=1)  # (i,C+1)
                p_parent_ext = torch.cat([root_one, p_parent_ext], dim=0)  # (i+1,C+1)

            # P_dom = p_parent_ext @ M_cp @ p_child^T
            # M_cp: (C+1,C)
            prior = (p_parent_ext @ self.M_cp @ p_child.t()).squeeze(-1)  # (i+1,)

            if self.use_softmask:
                score = score + torch.log(prior + eps)

            par_logits.append(score)

        # ---- relation logits (用 GT parent 保持训练稳定；推理时再用预测 parent) ----
        # 这里先输出一个 (L,R) 的 logits，其中第 i 个是与 GT parent 的关系
        # 对于 parent=-1 的（ROOT），relation 按你的数据里常见是 "contain"/"meta"，这里仍然算一个 rel loss（你也可 mask 掉）
        y_parent = doc.get("y_parent", [-1]*L)
        rel_logits = []
        for i in range(L):
            p = y_parent[i]
            if p is None or p < 0:
                # ROOT
                parent_vec = root.squeeze(0)
            else:
                parent_vec = h_seq[p]
            feat = torch.cat([h_seq[i], parent_vec], dim=-1)
            rel_logits.append(self.rel_head(feat))
        rel_logits = torch.stack(rel_logits, dim=0)  # (L,R)

        return {
            "cls_logits": cls_logits,
            "par_logits": par_logits,  # list of tensors
            "rel_logits": rel_logits,
        }

## 7. 损失函数与训练目标

- FocalLoss
- compute_losses：对分类/关系等分项计算并聚合


In [16]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        logits: (N,C)
        targets: (N,)
        """
        ce = F.cross_entropy(logits, targets, reduction="none")
        pt = torch.exp(-ce)
        focal = self.alpha * (1 - pt) ** self.gamma * ce
        if self.reduction == "mean":
            return focal.mean()
        if self.reduction == "sum":
            return focal.sum()
        return focal

In [17]:
def compute_losses(
    out: Dict[str, Any],
    doc: Dict[str, Any],
    num_classes: int,
    num_rel: int,
    cfg: CFG,
    focal_cls: nn.Module,
    focal_rel: nn.Module,
) -> Dict[str, torch.Tensor]:

    # ---- targets ----
    y_cls = torch.tensor(doc["y_cls"], dtype=torch.long, device=device)
    y_parent = doc["y_parent"]
    y_rel = torch.tensor(doc["y_rel"], dtype=torch.long, device=device)

    # ---- logits ----
    cls_logits = out["cls_logits"].to(device)   # (L,C)
    rel_logits = out["rel_logits"].to(device)   # (L,R)
    par_logits_list = out["par_logits"]         # list length L

    L = len(par_logits_list)

    # ========== sanity checks ==========
    # cls range
    cls_min = int(y_cls.min().item())
    cls_max = int(y_cls.max().item())
    if cls_min < 0 or cls_max >= num_classes:
        raise ValueError(f"y_cls out of range: min={cls_min}, max={cls_max}, num_classes={num_classes}")

    # rel range
    rel_min = int(y_rel.min().item())
    rel_max = int(y_rel.max().item())
    if rel_min < 0 or rel_max >= num_rel:
        raise ValueError(f"y_rel out of range: min={rel_min}, max={rel_max}, num_rel={num_rel}")

    # parent range per position
    for i in range(L):
        p = y_parent[i]
        if p is None:
            continue
        if p >= i:  # 注意：parent 必须 < i（只能指向过去）
            raise ValueError(f"y_parent invalid at i={i}: parent={p} but must be < {i}")
        if p < -1:
            raise ValueError(f"y_parent invalid at i={i}: parent={p} (should be -1 or >=0)")

    # ========== losses ==========
    # meta mask：True 表示参与结构监督的单元（非 meta）
    is_meta = torch.tensor(doc.get("is_meta", [False]*L), dtype=torch.bool, device=device)
    struct_mask = ~is_meta  # (L,)

    # Subtask1: class（是否 mask meta 看你后续要不要对齐论文；先不 mask，保证能跑通）
    loss_cls = focal_cls(cls_logits, y_cls)

    # Subtask2: parent（只对非 meta 计算）
    loss_par = 0.0
    denom_par = 0
    for i, logits_i in enumerate(par_logits_list):
        if not bool(struct_mask[i].item()):
            continue

        p = y_parent[i]
        tgt_i = 0 if (p is None or p < 0) else (p + 1)

        if tgt_i < 0 or tgt_i >= logits_i.numel():
            raise ValueError(
                f"parent target out of range at i={i}: tgt={tgt_i}, "
                f"logits_len={logits_i.numel()}, raw_parent={p}"
            )

        tgt = torch.tensor([tgt_i], dtype=torch.long, device=device)
        # logits_i 可能是 shape=(K,) 或 (1,K)。统一压成 (1,K)
        logits_ce = logits_i.to(device).reshape(1, -1)
        loss_par = loss_par + F.cross_entropy(logits_ce, tgt)

        denom_par += 1

    if denom_par == 0:
        loss_par = torch.zeros((), device=device)
    else:
        loss_par = loss_par / denom_par

    # Subtask3: relation（只对非 meta 计算）
    if struct_mask.any():
        loss_rel = focal_rel(rel_logits[struct_mask], y_rel[struct_mask])
    else:
        loss_rel = torch.zeros((), device=device)

    loss = loss_cls + cfg.alpha_parent * loss_par + cfg.alpha_rel * loss_rel
    return {
        "loss": loss,
        "loss_cls": loss_cls.detach(),
        "loss_par": loss_par.detach(),
        "loss_rel": loss_rel.detach(),
    }


## 8. 训练循环

- train_one_epoch：单 epoch 训练
- DataLoader 构建
- 优化器与训练入口


In [18]:
import time

def train_one_epoch(model, loader, optimizer, cfg, focal_cls, focal_rel):
    model.train()
    logs = {"loss": [], "loss_cls": [], "loss_par": [], "loss_rel": []}

    start_time = time.time()
    last_print = start_time

    print(f"\n[Train] start epoch, num_docs = {len(loader)}")

    for step, doc in enumerate(loader):
        t0 = time.time()

        optimizer.zero_grad(set_to_none=True)

        # ===== forward =====
        out = model(doc)

        # ===== loss =====
        losses = compute_losses(out, doc, model.num_classes, model.num_rel,
                                cfg, focal_cls, focal_rel)

        # ===== backward =====
        losses["loss"].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # ===== log =====
        for k in logs:
            logs[k].append(float(losses[k].cpu()))

        # ===== progress print =====
        if step == 0 or (step + 1) % 5 == 0:
            now = time.time()
            step_time = now - t0
            elapsed = now - start_time

            print(
                f"  step {step+1:4d}/{len(loader)} | "
                f"step_time={step_time:5.2f}s | "
                f"loss={logs['loss'][-1]:.4f} "
                f"(cls={logs['loss_cls'][-1]:.3f}, "
                f"par={logs['loss_par'][-1]:.3f}, "
                f"rel={logs['loss_rel'][-1]:.3f}) | "
                f"elapsed={elapsed/60:.1f}min"
            )

        # ===== quick sanity check (只在最前面几步) =====
        if step < 2:
            L = len(doc["units"])
            print(f"    sanity: seq_len={L}, "
                  f"avg_parent_candidates={sum(len(x) for x in out['par_logits'])/L:.1f}")

    epoch_time = time.time() - start_time
    print(f"[Train] epoch done in {epoch_time/60:.2f} min")

    return {k: sum(v)/max(len(v),1) for k,v in logs.items()}

def eval_one_epoch(model, loader, cfg, focal_cls, focal_rel):
    model.eval()
    logs = {"loss": [], "loss_cls": [], "loss_par": [], "loss_rel": []}

    for doc in loader:
        out = model(doc)
        losses = compute_losses(out, doc, model.num_classes, model.num_rel, cfg, focal_cls, focal_rel)
        for k in logs:
            logs[k].append(float(losses[k].cpu()))
    return {k: sum(v)/max(len(v),1) for k,v in logs.items()}

In [19]:
# ===== dataset =====
train_ds = HRDHDataset(HRDH_ROOT, split="train", max_len=cfg.max_len)
test_ds  = HRDHDataset(HRDH_ROOT, split="test",  max_len=cfg.max_len)

train_loader = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)

test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)


# ===== M_cp =====
M_cp = np.load("M_cp_hrdh.npy")  # shape (C+1,C)

# ===== model =====
num_classes = len(ID2LABEL_14)
num_rel = len(REL2ID)  # 你的 REL2ID 包含 meta；如果你想只做三类，把 meta 从标签里移除并重映射

model = DSPSModel(
    num_classes=num_classes,
    num_rel=num_rel,
    M_cp=M_cp,
    cfg=cfg,
    use_text=False,      # 如果本地没 sentence-transformers 可先设 False
    use_visual=False,    # 如果显存吃紧可先 False
).to(device)

# ===== optim =====
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
focal_cls = FocalLoss(gamma=cfg.focal_gamma, alpha=cfg.focal_alpha).to(device)
focal_rel = FocalLoss(gamma=cfg.focal_gamma, alpha=cfg.focal_alpha).to(device)

print("ready")

ready


In [20]:
cfg.epochs = 2
cfg.num_workers = 0

model = DSPSModel(
    num_classes=num_classes,
    num_rel=num_rel,
    M_cp=M_cp,
    cfg=cfg,
    use_text=False,
    use_visual=False,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
focal_cls = FocalLoss(gamma=2.0, alpha=0.25).to(device)
focal_rel = FocalLoss(gamma=2.0, alpha=0.25).to(device)

for ep in range(1, cfg.epochs+1):
    tr = train_one_epoch(model, train_loader, optimizer, cfg, focal_cls, focal_rel)
    print(ep, tr)



[Train] start epoch, num_docs = 1000
  step    1/1000 | step_time= 1.23s | loss=6.2914 (cls=0.612, par=5.467, rel=0.212) | elapsed=0.0min
    sanity: seq_len=512, avg_parent_candidates=256.5
    sanity: seq_len=512, avg_parent_candidates=256.5
  step    5/1000 | step_time= 0.86s | loss=3.9795 (cls=0.089, par=3.761, rel=0.129) | elapsed=0.1min
  step   10/1000 | step_time= 0.83s | loss=4.2965 (cls=0.143, par=3.999, rel=0.155) | elapsed=0.1min
  step   15/1000 | step_time= 0.76s | loss=3.5786 (cls=0.100, par=3.391, rel=0.088) | elapsed=0.2min
  step   20/1000 | step_time= 0.79s | loss=3.4521 (cls=0.159, par=3.202, rel=0.092) | elapsed=0.2min
  step   25/1000 | step_time= 0.80s | loss=3.0337 (cls=0.165, par=2.778, rel=0.091) | elapsed=0.3min
  step   30/1000 | step_time= 0.52s | loss=3.0748 (cls=0.153, par=2.833, rel=0.089) | elapsed=0.4min
  step   35/1000 | step_time= 0.30s | loss=2.5898 (cls=0.197, par=2.304, rel=0.088) | elapsed=0.4min
  step   40/1000 | step_time= 0.80s | loss=2.596

KeyboardInterrupt: 

In [None]:
best = 1e9
for ep in range(1, cfg.epochs+1):
    tr = train_one_epoch(model, train_loader, optimizer, cfg, focal_cls, focal_rel)
    te = eval_one_epoch(model, test_loader, cfg, focal_cls, focal_rel)
    print(f"[ep {ep}] train:", tr, " | test:", te)

    if te["loss"] < best:
        best = te["loss"]
        torch.save(model.state_dict(), "dsps_hrdh_best.pt")
        print("saved: dsps_hrdh_best.pt")

## 9. 推理与预测导出

- 单文档推理
- 带关系重组/后处理的推理
- 批量导出 split 预测结果到目录


In [None]:
@torch.no_grad()
def predict_doc(model: DSPSModel, doc: Dict[str, Any]) -> Dict[str, Any]:
    model.eval()
    out = model(doc)

    cls = out["cls_logits"].argmax(dim=-1).cpu().tolist()

    # parent: each i logits over [ROOT]+past
    par = []
    for i, logits_i in enumerate(out["par_logits"]):
        idx = int(torch.argmax(logits_i).cpu())
        p = -1 if idx == 0 else (idx - 1)
        par.append(p)

    # relation: 这里的 rel_logits 是用 GT parent 生成的；
    # 推理时更严格做法：用预测 parent 重算 rel（下面给你重算版本）
    # 先给简版：
    rel = out["rel_logits"].argmax(dim=-1).cpu().tolist()

    return {"pred_cls": cls, "pred_parent": par, "pred_rel": rel}


In [None]:
import json
from collections import defaultdict

@torch.no_grad()
def predict_doc_with_rel_recompute(model, doc, device):
    """
    返回：
      pred_cls: (L,)
      pred_parent: (L,)  -1 表示 ROOT
      pred_rel: (L,)
      also return raw probs/logits if needed
    """
    model.eval()
    out = model(doc)

    cls_logits = out["cls_logits"].to(device)          # (L,C)
    par_logits_list = out["par_logits"]                # list of (i+1,)
    L = cls_logits.shape[0]

    pred_cls = cls_logits.argmax(dim=-1).detach().cpu().tolist()

    # parent decode: candidates = [ROOT] + [0..i-1]
    pred_parent = []
    for i, logits_i in enumerate(par_logits_list):
        idx = int(torch.argmax(logits_i).item())
        p = -1 if idx == 0 else (idx - 1)
        pred_parent.append(p)

    # recompute relation logits using predicted parent
    # need access to h_seq and root inside the model; easiest: re-run minimal pieces here
    # We can reconstruct h_seq by running embeddings+encoder+gru again using model modules
    units = doc["units"]
    doc_id = doc["doc_id"]
    page_images = doc["page_images"]

    # embeddings sum (same as forward)
    x = model.layout_pos_page(units, page_images)
    if model.use_text:
        texts = [u.get("text","") for u in units]
        x = x + model.text_emb(doc_id, texts)
    if model.use_visual:
        x = x + model.vis_emb(units, page_images)
    x = model.fuse_ln(x)                       # (L,d)
    x_star = model.encoder(x.unsqueeze(0)).squeeze(0)  # (L,d)
    root = x_star.mean(dim=0, keepdim=True)    # (1,d)
    h_seq = model.gru(x_star.unsqueeze(0))[0].squeeze(0)  # (L,d)

    rel_logits = []
    for i in range(L):
        p = pred_parent[i]
        # --- safe parent index for relation recompute ---
        # p could be invalid (>=len) during early training; fallback to root for safety
        L = h_seq.size(0)
        if (p is None) or (p < 0) or (p >= L):
            parent_vec = root.squeeze(0)
            # 可选：记录一次非法 parent（不影响逻辑）
            # doc.setdefault("_bad_parent_count", 0)
                    # doc["_bad_parent_count"] += 1
        else:
            parent_vec = h_seq[p]

        feat = torch.cat([h_seq[i], parent_vec], dim=-1)
        rel_logits.append(model.rel_head(feat))

    rel_logits = torch.stack(rel_logits, dim=0)  # (L,R)

    pred_rel = rel_logits.argmax(dim=-1).detach().cpu().tolist()

    return {
        "pred_cls": pred_cls,
        "pred_parent": pred_parent,
        "pred_rel": pred_rel,
        "cls_logits": cls_logits.detach().cpu(),
        "rel_logits": rel_logits.detach().cpu(),
    }


In [None]:
def _id2name(mapping, idx: int) -> str:
    """
    mapping 可以是：
      - list: mapping[idx]
      - dict: mapping.get(idx)
      - None: 返回 idx 字符串
    """
    if mapping is None:
        return str(idx)
    if isinstance(mapping, dict):
        return mapping.get(idx, str(idx))
    if isinstance(mapping, (list, tuple)):
        if 0 <= idx < len(mapping):
            return str(mapping[idx])
        return str(idx)
    return str(idx)


def export_tree_json(doc, pred=None):
    """
    doc: dataset 返回的 doc dict
    pred: None => 导出 GT
          dict => 导出 pred（需要含 pred_cls/pred_parent/pred_rel）
    输出格式：
      {
        doc_id,
        nodes: [
          {id, text, label_id, label_name, is_meta, parent, rel_id, rel_name, page_id, box},
          ...
        ]
      }
    """
    doc_id = doc["doc_id"]
    units = doc["units"]
    L = len(units)

    if pred is None:
        cls_ids = doc["y_cls"]
        parent = doc["y_parent"]
        rel_ids = doc["y_rel"]
    else:
        cls_ids = pred["pred_cls"]
        parent = pred["pred_parent"]
        rel_ids = pred["pred_rel"]

    # 兼容你 notebook 里 ID2LABEL_14 / ID2REL 的类型（list 或 dict）
    label_map = globals().get("ID2LABEL_14", None)
    rel_map = globals().get("ID2REL", None)

    out = {"doc_id": doc_id, "nodes": []}
    is_meta_list = doc.get("is_meta", [False]*L)

    for i in range(L):
        u = units[i]
        is_meta = bool(is_meta_list[i])
        cls_id = int(cls_ids[i])
        rel_id = int(rel_ids[i]) if rel_ids is not None else -1

        out["nodes"].append({
            "id": i,
            "text": u.get("text", ""),
            "label_id": cls_id,
            "label_name": _id2name(label_map, cls_id),
            "is_meta": is_meta,
            "parent": int(parent[i]) if parent[i] is not None else -1,
            "rel_id": rel_id,
            "rel_name": _id2name(rel_map, rel_id),
            "page_id": int(u.get("page_id", 0)),
            "box": [int(x) for x in u.get("box", [0,0,0,0])]
        })
    return out


In [None]:
from tqdm import tqdm
import os

@torch.no_grad()
def export_split_predictions(model, loader, save_dir, device):
    os.makedirs(save_dir, exist_ok=True)

    for doc in tqdm(loader, desc=f"export -> {save_dir}"):
        pred = predict_doc_with_rel_recompute(model, doc, device)

        gt_json = export_tree_json(doc, pred=None)
        pr_json = export_tree_json(doc, pred=pred)

        doc_id = doc["doc_id"].replace("/", "_")
        with open(os.path.join(save_dir, f"{doc_id}.gt.json"), "w", encoding="utf-8") as f:
            json.dump(gt_json, f, ensure_ascii=False, indent=2)
        with open(os.path.join(save_dir, f"{doc_id}.pred.json"), "w", encoding="utf-8") as f:
            json.dump(pr_json, f, ensure_ascii=False, indent=2)

    print("done.")


In [None]:
# 加载最优模型
model.load_state_dict(torch.load("dsps_hrdh_best.pt", map_location=device))

export_split_predictions(model, test_loader, save_dir="exports_hrdh_test", device=device)



## 10. STEDS / TED 评估实现

- TNode 数据结构
- Zhang-Shasha Tree Edit Distance
- compute_steds 与导出目录评估


In [None]:
import os, json, glob
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

def load_pair_files(export_dir: str) -> List[Tuple[str, str]]:
    gt_files = sorted(glob.glob(os.path.join(export_dir, "*.gt.json")))
    pairs = []
    for gt in gt_files:
        pred = gt.replace(".gt.json", ".pred.json")
        if os.path.exists(pred):
            pairs.append((gt, pred))
    return pairs

pairs = load_pair_files("exports_hrdh_test")
len(pairs), pairs[0] if pairs else None


In [None]:
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

@dataclass
class TNode:
    label: int
    text: str
    oid: int                 # original id，用来排序
    children: List["TNode"]

    def __init__(self, label: int, text: str, oid: int):
        self.label = label
        self.text = text
        self.oid = oid
        self.children = []

def find_effective_parent(i: int) -> int:
    """
    返回 i 的有效父节点：
    - 若 parent 指向越界/非法：视为无父节点（-1），挂到 ROOT
    - 若 parent 指向被排除节点（meta 或不在 kept 集）：沿 parent 链向上跳
    - 增加循环保护，避免死循环
    """
    n = len(parent)

    def _is_valid_idx(x: int) -> bool:
        return isinstance(x, int) and (0 <= x < n)

    p = parent[i] if 0 <= i < n else -1

    # 1) 直接非法或 root
    if (p is None) or (not isinstance(p, int)) or (p < 0):
        return -1
    if not _is_valid_idx(p):
        return -1

    # 2) 向上跳，直到遇到 kept 的父节点或 root/非法
    seen = set()
    while True:
        if p in kept_set:
            return p

        if p in seen:
            # 出现环，按无父节点处理
            return -1
        seen.add(p)

        # 向上跳
        pp = parent[p]
        if (pp is None) or (not isinstance(pp, int)) or (pp < 0):
            return -1
        if not _is_valid_idx(pp):
            return -1
        p = pp

    # 找“跳过 meta 后的最近祖先”
    def find_effective_parent(i: int) -> int:
        p = parent[i]
        seen = set()
        while True:
            if p < 0:
                return -1
            if p in seen:         # 防环保护
                return -1
            seen.add(p)
            if kept(p):
                return p
            # p 是 meta 或被排除节点，继续向上跳
            p = parent[p]

    # 连接
    for i in kept_ids:
        ep = find_effective_parent(i)
        if ep < 0:
            root.children.append(obj[i])
        else:
            obj[ep].children.append(obj[i])

    # children 按原始 reading-order id 排序，确保“有序树”一致
    def sort_rec(x: TNode):
        x.children.sort(key=lambda c: c.oid)
        for ch in x.children:
            sort_rec(ch)

    sort_rec(root)
    return root, len(kept_ids)



In [None]:
def ted_zhang_shasha(t1: TNode, t2: TNode, match_mode: str = "strict") -> int:
    """
    Ordered Tree Edit Distance (Zhang-Shasha)
    Costs: ins=1, del=1
    match_mode:
      - "strict": label + text 完全一致 cost=0，否则 1
      - "label": 只要 label 一致 cost=0，否则 1
    """
    def postorder(root: TNode):
        nodes = []
        def dfs(x: TNode):
            for ch in x.children:
                dfs(ch)
            nodes.append(x)
        dfs(root)
        return nodes

    A = postorder(t1)
    B = postorder(t2)

    idxA = {id(node): i+1 for i, node in enumerate(A)}
    idxB = {id(node): i+1 for i, node in enumerate(B)}

    n, m = len(A), len(B)

    def leftmost_indices(nodes, idx_map):
        l = [0] * (len(nodes) + 1)  # 1-based
        for i, node in enumerate(nodes, start=1):
            cur = node
            while cur.children:
                cur = cur.children[0]
            l[i] = idx_map[id(cur)]
        return l

    l1 = leftmost_indices(A, idxA)
    l2 = leftmost_indices(B, idxB)

    def keyroots(leftmost):
        seen = {}
        for i in range(1, len(leftmost)):
            seen[leftmost[i]] = i
        return sorted(seen.values())

    kr1 = keyroots(l1)
    kr2 = keyroots(l2)

    treedist = [[0] * (m + 1) for _ in range(n + 1)]

    def relabel_cost(i: int, j: int) -> int:
        a = A[i-1]
        b = B[j-1]
        if match_mode == "label":
            return 0 if (a.label == b.label) else 1
        return 0 if (a.label == b.label and a.text == b.text) else 1

    for i in kr1:
        for j in kr2:
            i0 = l1[i]
            j0 = l2[j]
            fd = [[0] * (j - j0 + 2) for _ in range(i - i0 + 2)]

            for di in range(1, i - i0 + 2):
                fd[di][0] = fd[di-1][0] + 1  # delete
            for dj in range(1, j - j0 + 2):
                fd[0][dj] = fd[0][dj-1] + 1  # insert

            for di in range(1, i - i0 + 2):
                for dj in range(1, j - j0 + 2):
                    ii = i0 + di - 1
                    jj = j0 + dj - 1
                    if l1[ii] == i0 and l2[jj] == j0:
                        fd[di][dj] = min(
                            fd[di-1][dj] + 1,
                            fd[di][dj-1] + 1,
                            fd[di-1][dj-1] + relabel_cost(ii, jj)
                        )
                        treedist[ii][jj] = fd[di][dj]
                    else:
                        fd[di][dj] = min(
                            fd[di-1][dj] + 1,
                            fd[di][dj-1] + 1,
                            fd[di-1][dj-1] + treedist[ii][jj]
                        )

    return treedist[n][m]

In [None]:
def compute_steds(gt_js: Dict[str,Any], pr_js: Dict[str,Any], exclude_meta: bool = True, match_mode: str = "strict") -> float:
    gt_root, gt_n = build_tree_from_export(gt_js, exclude_meta=exclude_meta)
    pr_root, pr_n = build_tree_from_export(pr_js, exclude_meta=exclude_meta)
    dist = ted_zhang_shasha(gt_root, pr_root, match_mode=match_mode)
    denom = max(gt_n, pr_n, 1)
    return 1.0 - dist / denom

def compute_aux_metrics(gt_js: Dict[str,Any], pr_js: Dict[str,Any], exclude_meta: bool = True) -> Dict[str,float]:
    gt_nodes = gt_js["nodes"]
    pr_nodes = pr_js["nodes"]
    L = min(len(gt_nodes), len(pr_nodes))

    mask = []
    for i in range(L):
        is_meta = bool(gt_nodes[i].get("is_meta", False))
        mask.append((not is_meta) if exclude_meta else True)

    def acc(key: str) -> float:
        tot = 0
        hit = 0
        for i in range(L):
            if not mask[i]:
                continue
            tot += 1
            hit += int(gt_nodes[i][key] == pr_nodes[i][key])
        return hit / tot if tot else 0.0

    # label_id / parent / rel_id
    return {
        "cls_acc": acc("label_id"),
        "parent_acc": acc("parent"),
        "rel_acc": acc("rel_id"),
        "n_eval": float(sum(mask)),
    }

In [None]:
import numpy as np
from tqdm import tqdm

def eval_export_dir(export_dir: str, exclude_meta: bool = True):
    pairs = load_pair_files(export_dir)
    steds_list = []
    aux_list = []

    for gt_path, pr_path in tqdm(pairs, desc=f"eval {export_dir}"):
        with open(gt_path, "r", encoding="utf-8") as f:
            gt_js = json.load(f)
        with open(pr_path, "r", encoding="utf-8") as f:
            pr_js = json.load(f)

        st = compute_steds(gt_js, pr_js, exclude_meta=exclude_meta)
        aux = compute_aux_metrics(gt_js, pr_js, exclude_meta=exclude_meta)

        steds_list.append(st)
        aux_list.append(aux)

    steds_arr = np.array(steds_list, dtype=float)
    cls_acc = np.array([a["cls_acc"] for a in aux_list], dtype=float)
    par_acc = np.array([a["parent_acc"] for a in aux_list], dtype=float)
    rel_acc = np.array([a["rel_acc"] for a in aux_list], dtype=float)

    report = {
        "num_docs": len(pairs),
        "exclude_meta": exclude_meta,
        "STEDS_mean": float(steds_arr.mean()) if len(steds_arr) else 0.0,
        "STEDS_std": float(steds_arr.std()) if len(steds_arr) else 0.0,
        "STEDS_median": float(np.median(steds_arr)) if len(steds_arr) else 0.0,
        "cls_acc_mean": float(cls_acc.mean()) if len(cls_acc) else 0.0,
        "parent_acc_mean": float(par_acc.mean()) if len(par_acc) else 0.0,
        "rel_acc_mean": float(rel_acc.mean()) if len(rel_acc) else 0.0,
    }
    return report, steds_list, aux_list

report, steds_list, aux_list = eval_export_dir("exports_hrdh_test", exclude_meta=True)
report

## 11. 文本规范化与阅读顺序（Reading Order）辅助

- 文本归一化/匹配
- 多列布局排序与文档级排序

说明：原 Notebook 中此部分包含多段独立工具函数，均按原顺序保留。


In [None]:
import json, os, re, numpy as np
from tqdm import tqdm

# ---------- text normalization ----------
def _norm_text(s: str) -> str:
    s = s.strip()
    s = s.replace("\n", " ")
    s = re.sub(r"\s+", " ", s)
    s = s.replace("- ", "-")
    return s


# ---------- STEDS variants ----------
def steds_original(gt_js, pr_js):
    return compute_steds(gt_js, pr_js, exclude_meta=True)


def steds_label_only(gt_js, pr_js):
    gt_root, gt_n = build_tree_from_export(gt_js, exclude_meta=True)
    pr_root, pr_n = build_tree_from_export(pr_js, exclude_meta=True)

    def wipe_text(x):
        x.text = ""
        for c in x.children:
            wipe_text(c)

    wipe_text(gt_root)
    wipe_text(pr_root)

    dist = ted_zhang_shasha(gt_root, pr_root)
    denom = max(gt_n, pr_n, 1)
    return 1.0 - dist / denom


def steds_text_norm(gt_js, pr_js):
    gt_root, gt_n = build_tree_from_export(gt_js, exclude_meta=True)
    pr_root, pr_n = build_tree_from_export(pr_js, exclude_meta=True)

    def apply_norm(x):
        x.text = _norm_text(x.text)
        for c in x.children:
            apply_norm(c)

    apply_norm(gt_root)
    apply_norm(pr_root)

    dist = ted_zhang_shasha(gt_root, pr_root)
    denom = max(gt_n, pr_n, 1)
    return 1.0 - dist / denom


# ---------- batch evaluation ----------
def eval_steds_variants(export_dir: str):
    pairs = load_pair_files(export_dir)

    res = {
        "original": [],
        "label_only": [],
        "text_norm": [],
    }

    for gt_path, pr_path in tqdm(pairs, desc="STEDS variants"):
        with open(gt_path, "r", encoding="utf-8") as f:
            gt = json.load(f)
        with open(pr_path, "r", encoding="utf-8") as f:
            pr = json.load(f)

        res["original"].append(steds_original(gt, pr))
        res["label_only"].append(steds_label_only(gt, pr))
        res["text_norm"].append(steds_text_norm(gt, pr))

    def summary(arr):
        arr = np.asarray(arr)
        return {
            "mean": float(arr.mean()),
            "median": float(np.median(arr)),
            "std": float(arr.std()),
            "min": float(arr.min()),
            "max": float(arr.max()),
        }

    report = {k: summary(v) for k, v in res.items()}
    return report


# ---------- run ----------
report = eval_steds_variants("exports_hrdh_test")
report



In [None]:
import numpy as np

def kmeans_1d(x, k, iters=30):
    """简单 1D kmeans。x: (n,) float"""
    x = x.astype(np.float32)
    # init: 均匀取 k 个分位点
    qs = np.linspace(0.1, 0.9, k)
    centers = np.quantile(x, qs)
    for _ in range(iters):
        # assign
        d = np.abs(x[:, None] - centers[None, :])
        labels = d.argmin(axis=1)
        new_centers = centers.copy()
        for j in range(k):
            mask = labels == j
            if mask.any():
                new_centers[j] = x[mask].mean()
        if np.allclose(new_centers, centers, atol=1e-4):
            break
        centers = new_centers
    # inertia
    inertia = ((x - centers[labels]) ** 2).sum()
    # sort clusters by center (left->right)
    order = np.argsort(centers)
    remap = np.zeros_like(order)
    remap[order] = np.arange(k)
    labels = remap[labels]
    centers = centers[order]
    return labels, centers, inertia

def choose_k_1d(x, k_min=1, k_max=3, min_cluster_size=5):
    """
    自动选择列数 k：
    - 基于“惯性下降幅度”（elbow-ish）
    - 同时避免产生很小的列
    """
    best = None
    prev_inertia = None
    for k in range(k_min, k_max + 1):
        labels, centers, inertia = kmeans_1d(x, k)
        # 小簇过滤
        ok = True
        for j in range(k):
            if (labels == j).sum() < min_cluster_size and len(x) >= min_cluster_size * k:
                ok = False
                break
        if not ok:
            continue

        if prev_inertia is None:
            score = 0.0
        else:
            # inertia 下降比例越大越好
            score = (prev_inertia - inertia) / max(prev_inertia, 1e-6)

        # 记录：优先更大“下降”，其次更小 inertia
        cand = (score, -inertia, k, labels, centers, inertia)
        if best is None or cand > best:
            best = cand
        prev_inertia = inertia

    if best is None:
        labels, centers, inertia = kmeans_1d(x, 1)
        return 1, labels, centers
    _, _, k, labels, centers, _ = best
    return k, labels, centers


In [None]:
def order_units_multi_column(units, page_w, page_h, max_cols=3):
    """
    units: list of dict with 'box'=[x0,y0,x1,y1]
    返回：按阅读顺序排序后的 index 列表
    """
    n = len(units)
    if n <= 1:
        return list(range(n))

    boxes = np.array([u["box"] for u in units], dtype=np.float32)
    x0, y0, x1, y1 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]
    xc = 0.5 * (x0 + x1)
    yt = y0
    bw = np.clip(x1 - x0, 1.0, None)

    # 识别“跨栏大框”：宽度超过页面的一定比例（标题/大图/大表等），单独处理
    wide = bw > (0.70 * page_w)

    # 对非 wide 的做列聚类
    idx_normal = np.where(~wide)[0]
    idx_wide = np.where(wide)[0]

    ordered = []

    # 先把“宽大框”按 y 排到合适位置：通常它们是跨栏标题/大图，阅读上仍按 y 走
    # 策略：把 wide 与 normal 一起排序时，wide 作为“单独列”插入：这里用更稳的两阶段合并
    if len(idx_normal) > 0:
        x_norm = xc[idx_normal]
        k, labels, centers = choose_k_1d(x_norm, 1, max_cols, min_cluster_size=4)

        # 每一列内部按 y_top 排序
        cols = []
        for c in range(k):
            members = idx_normal[labels == c]
            members = members[np.argsort(yt[members])]
            cols.append(list(members))

        # 列从左到右拼接
        seq_normal = [i for col in cols for i in col]
    else:
        seq_normal = []

    # wide 按 y_top 排序
    seq_wide = list(idx_wide[np.argsort(yt[idx_wide])]) if len(idx_wide) > 0 else []

    # 合并 wide 与 normal：按 y_top 进行稳定插入（wide 的 y 决定位置）
    # 做法：对两序列按 y_top 归并，但当 y 接近时优先 wide（跨栏标题通常应先读）
    def merge_by_y(a, b):
        ia = ib = 0
        out = []
        while ia < len(a) and ib < len(b):
            ya = yt[a[ia]]
            yb = yt[b[ib]]
            if yb <= ya + 2:  # 容忍 2px 抖动，wide 略优先
                out.append(b[ib]); ib += 1
            else:
                out.append(a[ia]); ia += 1
        out.extend(a[ia:])
        out.extend(b[ib:])
        return out

    ordered = merge_by_y(seq_normal, seq_wide)
    return ordered

In [None]:
def sort_doc_units_by_reading_order(doc_units, page_images):
    """
    doc_units: list[dict], each has page_id, box, text, ...
    page_images: dict {page_id: image_path}
    return sorted_units
    """
    # 1) 按页收集
    by_page = {}
    for u in doc_units:
        pid = int(u["page_id"])
        by_page.setdefault(pid, []).append(u)

    # 2) 页按 pid 排序
    sorted_units = []
    for pid in sorted(by_page.keys()):
        # 读 page 尺寸
        from PIL import Image
        img = Image.open(page_images[pid])
        W, H = img.width, img.height

        units_p = by_page[pid]
        order_idx = order_units_multi_column(units_p, W, H, max_cols=3)
        sorted_units.extend([units_p[i] for i in order_idx])

    return sorted_units

In [None]:
import numpy as np
from PIL import Image

def _kmeans_1d(x, k, iters=30):
    x = x.astype(np.float32)
    qs = np.linspace(0.1, 0.9, k)
    centers = np.quantile(x, qs)

    for _ in range(iters):
        d = np.abs(x[:, None] - centers[None, :])
        labels = d.argmin(axis=1)
        new_centers = centers.copy()
        for j in range(k):
            m = labels == j
            if m.any():
                new_centers[j] = x[m].mean()
        if np.allclose(new_centers, centers, atol=1e-4):
            break
        centers = new_centers

    inertia = ((x - centers[labels]) ** 2).sum()

    order = np.argsort(centers)
    remap = np.zeros_like(order)
    remap[order] = np.arange(k)
    labels = remap[labels]
    centers = centers[order]
    return labels, centers, inertia


def _choose_k_1d(x, k_min=1, k_max=3, min_cluster_size=4):
    best = None
    prev_inertia = None

    for k in range(k_min, k_max + 1):
        labels, centers, inertia = _kmeans_1d(x, k)
        ok = True
        if len(x) >= min_cluster_size * k:
            for j in range(k):
                if (labels == j).sum() < min_cluster_size:
                    ok = False
                    break
        if not ok:
            continue

        score = 0.0 if prev_inertia is None else (prev_inertia - inertia) / max(prev_inertia, 1e-6)
        cand = (score, -inertia, k, labels, centers)
        if best is None or cand > best:
            best = cand
        prev_inertia = inertia

    if best is None:
        labels, centers, _ = _kmeans_1d(x, 1)
        return 1, labels, centers
    _, _, k, labels, centers = best
    return k, labels, centers


def order_units_multi_column(units, page_w, page_h, max_cols=3):
    """
    units: list of dict with 'box'=[x0,y0,x1,y1]
    return: list of indices in reading order
    """
    n = len(units)
    if n <= 1:
        return list(range(n))

    boxes = np.array([u["box"] for u in units], dtype=np.float32)
    x0, y0, x1, y1 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    xc = 0.5 * (x0 + x1)
    yt = y0
    bw = np.clip(x1 - x0, 1.0, None)

    # “跨栏大框”识别：标题/大图/大表常见
    wide = bw > (0.70 * page_w)

    idx_normal = np.where(~wide)[0]
    idx_wide = np.where(wide)[0]

    seq_normal = []
    if len(idx_normal) > 0:
        x_norm = xc[idx_normal]
        k, labels, _ = _choose_k_1d(x_norm, 1, max_cols, min_cluster_size=4)

        cols = []
        for c in range(k):
            members = idx_normal[labels == c]
            members = members[np.argsort(yt[members])]
            cols.append(list(members))

        seq_normal = [i for col in cols for i in col]

    seq_wide = list(idx_wide[np.argsort(yt[idx_wide])]) if len(idx_wide) > 0 else []

    # 按 y 归并：wide 在接近同一 y 时略优先（跨栏标题通常应先读）
    def merge_by_y(a, b):
        ia = ib = 0
        out = []
        while ia < len(a) and ib < len(b):
            ya = yt[a[ia]]
            yb = yt[b[ib]]
            if yb <= ya + 2:
                out.append(b[ib]); ib += 1
            else:
                out.append(a[ia]); ia += 1
        out.extend(a[ia:])
        out.extend(b[ib:])
        return out

    return merge_by_y(seq_normal, seq_wide)


def sort_doc_units_by_reading_order(doc_units, page_images, max_cols=3):
    """
    doc_units: list[dict] each has 'page_id','box',...
    page_images: dict[int -> image_path]
    return: new_units_sorted
    """
    by_page = {}
    for u in doc_units:
        pid = int(u["page_id"])
        by_page.setdefault(pid, []).append(u)

    sorted_units = []
    for pid in sorted(by_page.keys()):
        img = Image.open(page_images[pid])
        W, H = img.width, img.height

        units_p = by_page[pid]
        order_idx = order_units_multi_column(units_p, W, H, max_cols=max_cols)
        sorted_units.extend([units_p[i] for i in order_idx])

    return sorted_units


## 12. 诊断、对比评估与实验封装

- check_causal_violation
- eval_export_dir_dual
- run_experiment：实验跑批封装


In [None]:
def check_causal_violation(dataset, n=100):
    cnt = 0
    viol = 0
    for i in range(min(n, len(dataset))):
        d = dataset[i]
        yp = d["y_parent"]
        for j,p in enumerate(yp):
            if p is not None and p >= j:
                viol += 1
            cnt += 1
    return viol / max(cnt,1)

print("causal violation rate:", check_causal_violation(train_ds, n=100))


In [None]:
import numpy as np
import json, os, glob

@torch.no_grad()
def eval_export_dir_dual(export_dir: str, exclude_meta: bool = True,show_progress: bool = True):
    pairs = load_pair_files(export_dir)

    st_strict, st_label = [], []
    cls_acc, par_acc, rel_acc = [], [], []

    for gt_path, pr_path in pairs:
        with open(gt_path, "r", encoding="utf-8") as f:
            gt_js = json.load(f)
        with open(pr_path, "r", encoding="utf-8") as f:
            pr_js = json.load(f)

        st_strict.append(compute_steds(gt_js, pr_js, exclude_meta=exclude_meta, match_mode="strict"))
        st_label.append(compute_steds(gt_js, pr_js, exclude_meta=exclude_meta, match_mode="label"))

        aux = compute_aux_metrics(gt_js, pr_js, exclude_meta=exclude_meta)
        cls_acc.append(aux["cls_acc"])
        par_acc.append(aux["parent_acc"])
        rel_acc.append(aux["rel_acc"])

    def agg(x):
        x = np.array(x, dtype=float)
        return float(x.mean()), float(x.std()), float(np.median(x))

    s_mean, s_std, s_med = agg(st_strict)
    l_mean, l_std, l_med = agg(st_label)

    return {
        "num_docs": len(pairs),
        "STEDS_strict_mean": s_mean,
        "STEDS_strict_std": s_std,
        "STEDS_strict_median": s_med,
        "STEDS_label_mean": l_mean,
        "STEDS_label_std": l_std,
        "STEDS_label_median": l_med,
        "cls_acc_mean": float(np.mean(cls_acc)) if cls_acc else 0.0,
        "parent_acc_mean": float(np.mean(par_acc)) if par_acc else 0.0,
        "rel_acc_mean": float(np.mean(rel_acc)) if rel_acc else 0.0,
    }


In [None]:
import time


def run_experiment(exp_name: str, use_softmask: bool, train_loader, test_loader, M_cp, cfg, seed=42, save_root="ablation_runs"):
    # 固定随机种子
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    exp_dir = os.path.join(save_root, exp_name)
    os.makedirs(exp_dir, exist_ok=True)

    # 建模：只改变 use_softmask
    model = DSPSModel(
        num_classes=len(ID2LABEL_14),
        num_rel=len(REL2ID),
        M_cp=M_cp,
        cfg=cfg,
        use_text=True,
        use_visual= False,
        use_softmask=use_softmask,
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    focal_cls = FocalLoss(gamma=cfg.focal_gamma, alpha=cfg.focal_alpha).to(device)
    focal_rel = FocalLoss(gamma=cfg.focal_gamma, alpha=cfg.focal_alpha).to(device)

    best = 1e18
    best_path = os.path.join(exp_dir, "best.pt")

    t0 = time.time()
    for ep in range(1, cfg.epochs + 1):
        tr = train_one_epoch(model, train_loader, optimizer, cfg, focal_cls, focal_rel)
        te = eval_one_epoch(model, test_loader, cfg, focal_cls, focal_rel)
        print(f"[{exp_name}] ep={ep} train={tr} test={te}")

        if te["loss"] < best:
            best = te["loss"]
            torch.save(model.state_dict(), best_path)
            print(f"[{exp_name}] saved -> {best_path}")

    # 导出与评测
    state = torch.load(best_path, map_location=device, weights_only=True)
    model.load_state_dict(state)

    export_dir = os.path.join(exp_dir, "exports_test")
    export_split_predictions(model, test_loader, save_dir=export_dir, device=device)

    print(f"[{exp_name}] export done, start eval...")
    metrics = eval_export_dir_dual(export_dir, exclude_meta=True, show_progress=True)
    metrics.update({
        "exp_name": exp_name,
        "use_softmask": use_softmask,
        "seed": seed,
        "epochs": cfg.epochs,
        "elapsed_sec": round(time.time() - t0, 2),
        "export_dir": export_dir,
    })
    print(f"[{exp_name}] eval done.")


    return metrics


# ======= 跑两组：soft-mask on/off =======
os.makedirs("ablation_runs", exist_ok=True)

res_on = run_experiment(
    exp_name="baseline_softmask_on",
    use_softmask=True,
    train_loader=train_loader,
    test_loader=test_loader,
    M_cp=M_cp,
    cfg=cfg,
    seed=42,
)

res_off = run_experiment(
    exp_name="ablation_softmask_off",
    use_softmask=False,
    train_loader=train_loader,
    test_loader=test_loader,
    M_cp=M_cp,
    cfg=cfg,
    seed=42,
)

for r in [res_on, res_off]:
    print("=" * 60)
    for k, v in r.items():
        print(f"{k:25s}: {v}")



In [None]:
import inspect
cfg.batch_size = 1

def _call_with_compatible_signature(fn, *args, **kwargs):
    """
    Call fn with kwargs filtered to those supported by fn's signature.
    This prevents errors like: got an unexpected keyword argument 'device'.
    """
    sig = inspect.signature(fn)
    accepted = set(sig.parameters.keys())
    filtered = {k: v for k, v in kwargs.items() if k in accepted}
    return fn(*args, **filtered)

def train_one_epoch_compat(model, loader, optimizer, *, device=None, cfg=None, scaler=None):
        if focal_cls is None:
            focal_cls = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0),
                                  alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)
        if focal_rel is None:
            focal_rel = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0),
                                  alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)
    # Try calling with keywords first; incompatible keywords are filtered out.
        return _call_with_compatible_signature(
        train_one_epoch,
        model, loader, optimizer,
        device=device, cfg=cfg, scaler=scaler
    )

def eval_one_epoch_compat(model, loader, *, device=None, cfg=None):
    return _call_with_compatible_signature(
        eval_one_epoch,
        model, loader,
        device=device, cfg=cfg
    )


In [None]:
# =========================
# HRDS 全流程：train -> best ckpt -> export -> eval
# 依赖你 notebook 已有定义（Run All 到此处后再运行本 cell）：
#   HRDHDataset, collate_doc
#   compute_M_cp_from_dataset
#   DSPSModel
#   FocalLoss
#   train_one_epoch
#   export_split_predictions
#   eval_export_dir_dual
# =========================

import os, json, time, shutil
import numpy as np
import torch
from torch.utils.data import DataLoader

# --------- A) 你必须改的路径：HRDS_ROOT ----------
HRDS_ROOT = r"C:\Users\tomra\Desktop\PAPER\final OBJ\HRDS"   # 改成你 HRDoc-Simple 根目录

# --------- B) 实验输出目录 ----------
EXP_NAME = "hrds_dsps_fullrun"
RUN_ROOT = "hrds_runs"
EXP_DIR = os.path.join(RUN_ROOT, EXP_NAME)
os.makedirs(EXP_DIR, exist_ok=True)

# --------- C) 基础对象检查：确保你前面 cell 已经定义完 ----------
required = [
    "HRDHDataset", "collate_doc",
    "compute_M_cp_from_dataset",
    "DSPSModel",
    "FocalLoss",
    "train_one_epoch",
    "export_split_predictions",
    "eval_export_dir_dual",
    "cfg",
]
missing = [n for n in required if n not in globals()]
if missing:
    raise RuntimeError(
        "你还没有运行到包含以下定义的 cell（或变量名不同）：\n"
        + "\n".join(missing)
        + "\n\n请先 Run All 直到模型/数据集/评估函数都定义完，再运行本 cell。"
    )

if not os.path.isdir(HRDS_ROOT):
    raise FileNotFoundError(f"HRDS_ROOT 不存在：{HRDS_ROOT}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[Device]", device)
print("[EXP_DIR]", EXP_DIR)

# --------- D) 固定随机种子（保持复现稳定） ----------
seed = 42
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# --------- E) 构建 HRDS 数据集与 dataloader ----------
train_ds = HRDHDataset(HRDS_ROOT, split="train", max_len=cfg.max_len)
test_ds  = HRDHDataset(HRDS_ROOT, split="test",  max_len=cfg.max_len)

train_loader = DataLoader(
    train_ds,
    batch_size=1,         
    shuffle=True,
    num_workers=getattr(cfg, "num_workers", 0),
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)

test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=getattr(cfg, "num_workers", 0),
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)

print(f"[HRDS] train docs = {len(train_loader)}, test docs = {len(test_loader)}")

# --------- F) 类别/关系维度（来自你 dataset 定义） ----------
# 你的 dataset cell 里固定 14 类，关系含 meta（4 类）
num_classes = len(ID2LABEL_14)
num_rel = len(REL2ID)
print("[Dims] num_classes =", num_classes, "num_rel =", num_rel)

# --------- G) 计算/加载 M_cp（soft-mask 先验） ----------
MCP_PATH = os.path.join(EXP_DIR, "M_cp_hrds.npy")
if os.path.exists(MCP_PATH):
    M_cp = np.load(MCP_PATH)
    print("[M_cp] loaded:", MCP_PATH, "shape=", M_cp.shape)
else:
    print("[M_cp] computing from HRDS train ...")
    ret = compute_M_cp_from_dataset(train_ds, num_classes=num_classes, pseudo_count=5.0)

    # 兼容：函数可能返回 (M_cp, stats) 或 {"M_cp":..., ...} 或直接 ndarray
    if isinstance(ret, tuple):
        M_cp = ret[0]
        mcp_extra = ret[1:]
    elif isinstance(ret, dict):
        M_cp = ret.get("M_cp", None)
        mcp_extra = {k:v for k,v in ret.items() if k != "M_cp"}
        if M_cp is None:
            raise RuntimeError("compute_M_cp_from_dataset 返回 dict，但不包含键 'M_cp'")
    else:
        M_cp = ret
        mcp_extra = None

    M_cp = np.asarray(M_cp)
    np.save(MCP_PATH, M_cp)
    print("[M_cp] saved:", MCP_PATH, "shape=", M_cp.shape)

    # 可选：把额外统计信息也存下来（不影响训练）
    if mcp_extra is not None:
        extra_path = os.path.join(EXP_DIR, "M_cp_hrds_extra.json")
        try:
            with open(extra_path, "w", encoding="utf-8") as f:
                json.dump(mcp_extra, f, ensure_ascii=False, indent=2, default=str)
            print("[M_cp] extra saved:", extra_path)
        except Exception as e:
            print("[M_cp] extra save skipped:", repr(e))

# --------- H) 建模（按你现有 DSPSModel 接口） ----------
# 这里默认用你当前复现的设置：use_text=False；视觉是否启用取决于你 notebook 中 DSPSModel 的实现
# 若你后续要开视觉或文本，改下面两个开关即可
USE_TEXT = False
USE_VISUAL = False
USE_SOFTMASK = True

model = DSPSModel(
    num_classes=num_classes,
    num_rel=num_rel,
    M_cp=M_cp,
    cfg=cfg,
    use_text=USE_TEXT,
    use_visual=USE_VISUAL,
    use_softmask=USE_SOFTMASK,
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=getattr(cfg, "lr", 2e-4),
    weight_decay=getattr(cfg, "weight_decay", 1e-2),
)

focal_cls = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0), alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)
focal_rel = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0), alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)

# --------- I) 训练 + 每 epoch 验证 + 保存 best ----------
CKPT_BEST = os.path.join(EXP_DIR, "best_hrds.pt")
REPORT_PATH = os.path.join(EXP_DIR, "train_eval_log.json")

history = []
best_score = -1.0
best_epoch = -1

def _extract_macro_strict(metrics_dict):
    """
    兼容 eval_export_dir_dual 的返回格式：
    - 若存在宏严格 STEDS 的 key，就取它
    - 否则退化为 strict_mean 或第一项可用 strict 指标
    """
    if not isinstance(metrics_dict, dict):
        return None

    # 常见候选 key（根据你实现的命名习惯做容错）
    candidates = [
        "steds_strict_macro", "macro_steds_strict", "steds_macro_strict",
        "steds_strict_mean",  "strict_mean",        "steds_strict",
    ]
    for k in candidates:
        if k in metrics_dict and isinstance(metrics_dict[k], (int, float)):
            return float(metrics_dict[k])

    # 兜底：如果 metrics_dict 里有任何包含 "strict" 的数值字段，取第一个
    for k, v in metrics_dict.items():
        if "strict" in k.lower() and isinstance(v, (int, float)):
            return float(v)

    return None

for ep in range(1, cfg.epochs + 1):
    print(f"\n========== [HRDS] Epoch {ep}/{cfg.epochs} ==========")

    # 1) train
    tr_logs = train_one_epoch(model, train_loader, optimizer, cfg, focal_cls, focal_rel)

    # 2) export + eval on test
    export_dir = os.path.join(EXP_DIR, f"export_ep{ep:02d}")
    # 为避免旧结果干扰，每次先清目录
    if os.path.isdir(export_dir):
        shutil.rmtree(export_dir, ignore_errors=True)
    os.makedirs(export_dir, exist_ok=True)

    t0 = time.time()
    export_split_predictions(model, test_loader, save_dir=export_dir, device=device)
    metrics = eval_export_dir_dual(export_dir, exclude_meta=True)
    eval_time = time.time() - t0

    # 3) best selection（以 Macro Strict STEDS 为主）
    score = _extract_macro_strict(metrics)
    if score is None:
        print("[WARN] 无法从 metrics 中提取 strict 指标字段，将不保存 best。metrics=", metrics)
        score = -1.0

    is_best = score > best_score
    if is_best:
        best_score = score
        best_epoch = ep
        torch.save({"model": model.state_dict(), "epoch": ep, "metrics": metrics}, CKPT_BEST)
        print(f"[BEST] updated: epoch={ep}, score={best_score:.6f} -> {CKPT_BEST}")

    row = {
        "epoch": ep,
        "train": tr_logs,
        "metrics": metrics,
        "macro_strict_for_select": score,
        "eval_time_sec": eval_time,
        "is_best": is_best,
    }
    history.append(row)

    with open(REPORT_PATH, "w", encoding="utf-8") as f:
        json.dump(
            {
                "exp_name": EXP_NAME,
                "hrds_root": HRDS_ROOT,
                "cfg": cfg.__dict__ if hasattr(cfg, "__dict__") else str(cfg),
                "use_text": USE_TEXT,
                "use_visual": USE_VISUAL,
                "use_softmask": USE_SOFTMASK,
                "best_epoch": best_epoch,
                "best_score_macro_strict": best_score,
                "best_ckpt": CKPT_BEST,
                "history": history,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    print("[Epoch Summary]")
    print("  train loss(last):", tr_logs["loss"][-1] if isinstance(tr_logs, dict) and "loss" in tr_logs else tr_logs)
    print("  metrics:", metrics)
    print("  eval_time_sec:", eval_time)

print("\n========== HRDS Full Flow DONE ==========")
print("[BEST] epoch =", best_epoch, "macro_strict =", best_score)
print("[BEST CKPT]", CKPT_BEST)
print("[LOG JSON ]", REPORT_PATH)

print("\nPaper reference (Table 2 HRDS best, Document+Semantic+Vision+Soft-mask):")
print("Micro-STEDS ≈ 0.8143")
print("Macro-STEDS ≈ 0.8174")

# --------- J) 用 best ckpt 再导出一次（最终留档） ----------
FINAL_EXPORT = os.path.join(EXP_DIR, "export_best_final")
if os.path.isdir(FINAL_EXPORT):
    shutil.rmtree(FINAL_EXPORT, ignore_errors=True)
os.makedirs(FINAL_EXPORT, exist_ok=True)

if os.path.exists(CKPT_BEST):
    ckpt = torch.load(CKPT_BEST, map_location="cpu")
    model.load_state_dict(ckpt["model"], strict=False)
    model.eval()

export_split_predictions(model, test_loader, save_dir=FINAL_EXPORT, device=device)
final_metrics = eval_export_dir_dual(FINAL_EXPORT, exclude_meta=True)

FINAL_REPORT = os.path.join("final_reports", f"{EXP_NAME}_HRDS_report.json")
os.makedirs("final_reports", exist_ok=True)
with open(FINAL_REPORT, "w", encoding="utf-8") as f:
    json.dump(
        {
            "dataset": "HRDoc-Simple (HRDS)",
            "hrds_root": HRDS_ROOT,
            "best_ckpt": CKPT_BEST,
            "best_epoch": best_epoch,
            "best_score_macro_strict": best_score,
            "final_export_dir": FINAL_EXPORT,
            "final_metrics": final_metrics,
            "log_path": REPORT_PATH,
        },
        f,
        ensure_ascii=False,
        indent=2,
    )

print("\n===== HRDS Final Validation =====")
print("Final export :", FINAL_EXPORT)
print("Final report :", FINAL_REPORT)
print("Final metrics:", final_metrics)


In [None]:
# =========================
# HRDS 全流程：train -> best ckpt -> export -> eval
# 依赖你 notebook 已有定义（Run All 到此处后再运行本 cell）：
#   HRDHDataset, collate_doc
#   compute_M_cp_from_dataset
#   DSPSModel
#   FocalLoss
#   train_one_epoch
#   export_split_predictions
#   eval_export_dir_dual
# =========================

import os, json, time, shutil, random
import numpy as np
import torch
from torch.utils.data import DataLoader

# --------- A) 你必须改的路径：HRDS_ROOT ----------
HRDS_ROOT = r"C:\Users\tomra\Desktop\PAPER\final OBJ\HRDS"   # 改成你 HRDoc-Simple 根目录

# --------- B) 实验输出目录 ----------
EXP_NAME = "hrds_dsps_fullrun"
RUN_ROOT = "hrds_runs"
EXP_DIR = os.path.join(RUN_ROOT, EXP_NAME)
os.makedirs(EXP_DIR, exist_ok=True)

# --------- C) 基础对象检查：确保你前面 cell 已经定义完 ----------
required = [
    "HRDHDataset", "collate_doc",
    "compute_M_cp_from_dataset",
    "DSPSModel",
    "FocalLoss",
    "train_one_epoch",
    "export_split_predictions",
    "eval_export_dir_dual",
    "cfg",
]
missing = [n for n in required if n not in globals()]
if missing:
    raise RuntimeError(
        "你还没有运行到包含以下定义的 cell（或变量名不同）：\n"
        + "\n".join(missing)
        + "\n\n请先 Run All 直到模型/数据集/评估函数都定义完，再运行本 cell。"
    )

if not os.path.isdir(HRDS_ROOT):
    raise FileNotFoundError(f"HRDS_ROOT 不存在：{HRDS_ROOT}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[Device]", device)
print("[EXP_DIR]", EXP_DIR)

# --------- D) 固定随机种子（保持复现稳定） ----------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# --------- E) 构建 HRDS 数据集与 dataloader ----------
train_ds = HRDHDataset(HRDS_ROOT, split="train", max_len=cfg.max_len)
test_ds  = HRDHDataset(HRDS_ROOT, split="test",  max_len=cfg.max_len)

train_loader = DataLoader(
    train_ds,
    batch_size=1,         
    shuffle=True,
    num_workers=getattr(cfg, "num_workers", 0),
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)

test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=getattr(cfg, "num_workers", 0),
    collate_fn=collate_doc,
    pin_memory=False,
    persistent_workers=False,
)

print(f"[HRDS] train docs = {len(train_loader)}, test docs = {len(test_loader)}")

# --------- F) 类别/关系维度（来自你 dataset 定义） ----------
num_classes = len(ID2LABEL_14)
num_rel = len(REL2ID)
print("[Dims] num_classes =", num_classes, "num_rel =", num_rel)

# --------- G) 计算/加载 M_cp（soft-mask 先验） ----------
MCP_PATH = os.path.join(EXP_DIR, "M_cp_hrds.npy")
if os.path.exists(MCP_PATH):
    M_cp = np.asarray(np.load(MCP_PATH))
    print("[M_cp] loaded:", MCP_PATH, "shape=", M_cp.shape)
else:
    print("[M_cp] computing from HRDS train ...")
    ret = compute_M_cp_from_dataset(train_ds, num_classes=num_classes, pseudo_count=5.0)

    # 兼容：函数可能返回 (M_cp, stats) 或 {"M_cp":..., ...} 或直接 ndarray
    if isinstance(ret, tuple):
        M_cp = ret[0]
        mcp_extra = ret[1:]
    elif isinstance(ret, dict):
        M_cp = ret.get("M_cp", None)
        mcp_extra = {k:v for k,v in ret.items() if k != "M_cp"}
        if M_cp is None:
            raise RuntimeError("compute_M_cp_from_dataset 返回 dict，但不包含键 'M_cp'")
    else:
        M_cp = ret
        mcp_extra = None

    M_cp = np.asarray(M_cp)
    np.save(MCP_PATH, M_cp)
    print("[M_cp] saved:", MCP_PATH, "shape=", M_cp.shape)

    # 可选：把额外统计信息也存下来（不影响训练）
    if mcp_extra is not None:
        extra_path = os.path.join(EXP_DIR, "M_cp_hrds_extra.json")
        try:
            with open(extra_path, "w", encoding="utf-8") as f:
                json.dump(mcp_extra, f, ensure_ascii=False, indent=2, default=str)
            print("[M_cp] extra saved:", extra_path)
        except Exception as e:
            print("[M_cp] extra save skipped:", repr(e))

# --------- H) 建模 ----------
USE_TEXT = False
USE_VISUAL = False
USE_SOFTMASK = True

model = DSPSModel(
    num_classes=num_classes,
    num_rel=num_rel,
    M_cp=M_cp,
    cfg=cfg,
    use_text=USE_TEXT,
    use_visual=USE_VISUAL,
    use_softmask=USE_SOFTMASK,
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=getattr(cfg, "lr", 2e-4),
    weight_decay=getattr(cfg, "weight_decay", 1e-2),
)

focal_cls = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0), alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)
focal_rel = FocalLoss(gamma=getattr(cfg, "focal_gamma", 2.0), alpha=getattr(cfg, "focal_alpha", 0.25)).to(device)

# ---------------------------------------------------------
# [关键整合修正] 安全推理函数：保证 export_tree_json 不再炸
# ---------------------------------------------------------
def _install_safe_predict_doc_with_rel_recompute():
    import torch

    @torch.no_grad()
    def _safe(model, doc, device):
        out = model(doc)

        # 必需：cls_logits / par_logits
        cls_logits = out["cls_logits"]   # (L,C)
        par_logits = out["par_logits"]   # list-like
        L = int(cls_logits.size(0))

        pred_cls = torch.argmax(cls_logits, dim=-1).detach().cpu().tolist()

        # parents: argmax + clamp
        pred_parent = [-1] * L
        for i in range(L):
            if i == 0:
                pred_parent[i] = -1
                continue
            logits_i = par_logits[i].view(-1)
            p = int(torch.argmax(logits_i).item())
            pred_parent[i] = p

        # clamp parents 防越界/自指/指向未来
        for i in range(L):
            p = pred_parent[i]
            if (p is None) or (not isinstance(p, int)) or (p < 0) or (p >= L) or (p == i) or (p >= i):
                pred_parent[i] = -1

        # rel logits：优先重算；若找不到 hidden seq 则回退 out["rel_logits"]；否则全零
        rel_logits = out.get("rel_logits", None)

        # 尝试找 hidden sequence (L,D) —— 不硬编码键名，按形状挑一个
        h_seq = None
        for k, v in out.items():
            if torch.is_tensor(v) and v.dim() == 2 and v.size(0) == L:
                if k in ("cls_logits", "rel_logits"):
                    continue
                h_seq = v
                break

        root = out.get("root", None)
        if root is None and h_seq is not None:
            root = h_seq[0:1]  # (1,D)

        if (h_seq is not None) and hasattr(model, "rel_head"):
            root_vec = root.squeeze(0) if (torch.is_tensor(root) and root.dim() == 2) else root
            rel_list = []
            for i in range(L):
                p = pred_parent[i]
                parent_vec = root_vec if (p < 0 or p >= L) else h_seq[p]
                feat = torch.cat([h_seq[i], parent_vec], dim=-1)
                rel_list.append(model.rel_head(feat))
            rel_logits = torch.stack(rel_list, dim=0)

        if rel_logits is None:
            R = getattr(model, "num_rel", num_rel)
            rel_logits = torch.zeros((L, R), device=cls_logits.device)

        pred_rel = torch.argmax(rel_logits, dim=-1).detach().cpu().tolist()

        # 关键：键名必须匹配 export_tree_json
        return {"pred_cls": pred_cls, "pred_parent": pred_parent, "pred_rel": pred_rel}

    # 覆写全局函数名，供 export_split_predictions 调用
    globals()["predict_doc_with_rel_recompute"] = _safe


# --------- I) 训练 + 每 epoch 验证 + 保存 best ----------
CKPT_BEST = os.path.join(EXP_DIR, "best_hrds.pt")
REPORT_PATH = os.path.join(EXP_DIR, "train_eval_log.json")

history = []
best_score = -1.0
best_epoch = -1

def _extract_macro_strict(metrics_dict):
    if not isinstance(metrics_dict, dict):
        return None
    candidates = [
        "steds_strict_macro", "macro_steds_strict", "steds_macro_strict",
        "steds_strict_mean",  "strict_mean",        "steds_strict",
    ]
    for k in candidates:
        if k in metrics_dict and isinstance(metrics_dict[k], (int, float)):
            return float(metrics_dict[k])
    for k, v in metrics_dict.items():
        if "strict" in k.lower() and isinstance(v, (int, float)):
            return float(v)
    return None

for ep in range(1, cfg.epochs + 1):
    print(f"\n========== [HRDS] Epoch {ep}/{cfg.epochs} ==========")

    # 1) train
    tr_logs = train_one_epoch(model, train_loader, optimizer, cfg, focal_cls, focal_rel)

    # 2) export + eval on test
    export_dir = os.path.join(EXP_DIR, f"export_ep{ep:02d}")
    if os.path.isdir(export_dir):
        shutil.rmtree(export_dir, ignore_errors=True)
    os.makedirs(export_dir, exist_ok=True)

    # [关键] 每次导出前安装安全推理函数，避免旧定义干扰
    _install_safe_predict_doc_with_rel_recompute()

    t0 = time.time()
    metrics = None
    try:
        export_split_predictions(model, test_loader, save_dir=export_dir, device=device)
        metrics = eval_export_dir_dual(export_dir, exclude_meta=True)
    except Exception as e:
        # 不中断训练：记录错误，继续下一个 epoch
        print("[WARN] export/eval failed at epoch", ep, "error:", repr(e))
        metrics = {"_error": repr(e)}
    eval_time = time.time() - t0

    # 3) best selection（以 Macro Strict STEDS 为主）
    score = _extract_macro_strict(metrics) if isinstance(metrics, dict) else None
    if score is None:
        print("[WARN] 无法从 metrics 中提取 strict 指标字段，本 epoch 不更新 best。metrics=", metrics)
        score = -1.0

    is_best = score > best_score
    if is_best:
        best_score = score
        best_epoch = ep
        torch.save({"model": model.state_dict(), "epoch": ep, "metrics": metrics}, CKPT_BEST)
        print(f"[BEST] updated: epoch={ep}, score={best_score:.6f} -> {CKPT_BEST}")

    row = {
        "epoch": ep,
        "train": tr_logs,
        "metrics": metrics,
        "macro_strict_for_select": score,
        "eval_time_sec": eval_time,
        "is_best": is_best,
    }
    history.append(row)

    with open(REPORT_PATH, "w", encoding="utf-8") as f:
        json.dump(
            {
                "exp_name": EXP_NAME,
                "hrds_root": HRDS_ROOT,
                "cfg": cfg.__dict__ if hasattr(cfg, "__dict__") else str(cfg),
                "use_text": USE_TEXT,
                "use_visual": USE_VISUAL,
                "use_softmask": USE_SOFTMASK,
                "best_epoch": best_epoch,
                "best_score_macro_strict": best_score,
                "best_ckpt": CKPT_BEST,
                "history": history,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    print("[Epoch Summary]")
    # 兼容：loss 可能是 float / list / numpy / tensor
    loss_last = None
    if isinstance(tr_logs, dict) and "loss" in tr_logs:
        v = tr_logs["loss"]
        if isinstance(v, (list, tuple)):
            loss_last = v[-1] if len(v) else None
        elif torch.is_tensor(v):
            loss_last = float(v.detach().cpu().item()) if v.numel() == 1 else float(v.detach().cpu().mean().item())
        else:
            # float / int / numpy scalar
            try:
                loss_last = float(v)
            except Exception:
                loss_last = v
    else:
            loss_last = tr_logs

    print("  train loss(last):", loss_last)

    print("  metrics:", metrics)
    print("  eval_time_sec:", eval_time)

print("\n========== HRDS Full Flow DONE ==========")
print("[BEST] epoch =", best_epoch, "macro_strict =", best_score)
print("[BEST CKPT]", CKPT_BEST)
print("[LOG JSON ]", REPORT_PATH)

print("\nPaper reference (Table 2 HRDS best, Document+Semantic+Vision+Soft-mask):")
print("Micro-STEDS ≈ 0.8143")
print("Macro-STEDS ≈ 0.8174")

# --------- J) 用 best ckpt 再导出一次（最终留档） ----------
FINAL_EXPORT = os.path.join(EXP_DIR, "export_best_final")
if os.path.isdir(FINAL_EXPORT):
    shutil.rmtree(FINAL_EXPORT, ignore_errors=True)
os.makedirs(FINAL_EXPORT, exist_ok=True)

_install_safe_predict_doc_with_rel_recompute()

if os.path.exists(CKPT_BEST):
    ckpt = torch.load(CKPT_BEST, map_location="cpu")
    state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    model.load_state_dict(state, strict=False)
    model.eval()

export_split_predictions(model, test_loader, save_dir=FINAL_EXPORT, device=device)
final_metrics = eval_export_dir_dual(FINAL_EXPORT, exclude_meta=True)

FINAL_REPORT = os.path.join("final_reports", f"{EXP_NAME}_HRDS_report.json")
os.makedirs("final_reports", exist_ok=True)
with open(FINAL_REPORT, "w", encoding="utf-8") as f:
    json.dump(
        {
            "dataset": "HRDoc-Simple (HRDS)",
            "hrds_root": HRDS_ROOT,
            "best_ckpt": CKPT_BEST,
            "best_epoch": best_epoch,
            "best_score_macro_strict": best_score,
            "final_export_dir": FINAL_EXPORT,
            "final_metrics": final_metrics,
            "log_path": REPORT_PATH,
        },
        f,
        ensure_ascii=False,
        indent=2,
    )

print("\n===== HRDS Final Validation =====")
print("Final export :", FINAL_EXPORT)
print("Final report :", FINAL_REPORT)
print("Final metrics:", final_metrics)
