In [1]:
from pathlib import Path
import json
import pandas as pd

ROOT = Path(".").resolve()
DATA_DIR = ROOT / "dochienet_dataset"
labels_dir = DATA_DIR / "labels"
images_dir = DATA_DIR / "images"
hres_dir = DATA_DIR / "hres_images"

# splits
en_zh = json.loads((DATA_DIR / "en_zh_split.json").read_text(encoding="utf-8"))
tt = json.loads((DATA_DIR / "train_test_split.json").read_text(encoding="utf-8"))
en_ids, zh_ids = set(en_zh["en"]), set(en_zh["zh"])
train_ids, test_ids = set(tt["train"]), set(tt["test"])

def get_lang(doc_id):
    return "en" if doc_id in en_ids else ("zh" if doc_id in zh_ids else "unknown")

def get_split(doc_id):
    return "train" if doc_id in train_ids else ("test" if doc_id in test_ids else "unknown")

def list_pages(doc_id, prefer_hres=True):
    base = hres_dir if prefer_hres else images_dir
    folder = base / doc_id
    # page1.jpg, page2.jpg...
    pages = sorted(folder.glob("page*.jpg"), key=lambda p: int(p.stem.replace("page","")))
    return pages

doc_ids = sorted(list(train_ids | test_ids))
rows = []
for did in doc_ids:
    lp = labels_dir / f"{did}.json"
    pages = list_pages(did, prefer_hres=True)
    rows.append({
        "doc_id": did,
        "split": get_split(did),
        "lang": get_lang(did),
        "label_path": str(lp),
        "n_pages": len(pages),
        "page_paths": [str(p) for p in pages],
    })

index_df = pd.DataFrame(rows)
index_df.head(), index_df["split"].value_counts(), index_df["lang"].value_counts()


(       doc_id  split lang                                         label_path  \
 0  024v9632db  train   en  C:\Users\tomra\Desktop\PAPER\part2\dochienet_d...   
 1  02nk0izsev   test   zh  C:\Users\tomra\Desktop\PAPER\part2\dochienet_d...   
 2  02uch9a7af   test   en  C:\Users\tomra\Desktop\PAPER\part2\dochienet_d...   
 3  03uc9l6o7j  train   en  C:\Users\tomra\Desktop\PAPER\part2\dochienet_d...   
 4  04xbptr4ih  train   en  C:\Users\tomra\Desktop\PAPER\part2\dochienet_d...   
 
    n_pages                                         page_paths  
 0        5  [C:\Users\tomra\Desktop\PAPER\part2\dochienet_...  
 1        6  [C:\Users\tomra\Desktop\PAPER\part2\dochienet_...  
 2       19  [C:\Users\tomra\Desktop\PAPER\part2\dochienet_...  
 3       11  [C:\Users\tomra\Desktop\PAPER\part2\dochienet_...  
 4        8  [C:\Users\tomra\Desktop\PAPER\part2\dochienet_...  ,
 split
 train    1512
 test      161
 Name: count, dtype: int64,
 lang
 en    1110
 zh     563
 Name: count, dtype: int64

In [2]:
def load_label(label_path):
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    pages_meta = obj["pages"]          # dict: page1..pageN -> {width,height}
    contents = obj["contents"]         # list of elements
    
    # 1) sort by reading order
    contents_sorted = sorted(contents, key=lambda x: x.get("order", 10**9))
    
    # 2) collect edges from linking
    edges = set()
    for c in contents_sorted:
        for pair in c.get("linking", []):
            if isinstance(pair, list) and len(pair) == 2:
                edges.add((pair[0], pair[1]))
    return pages_meta, contents_sorted, sorted(list(edges))

# quick test
row = index_df.iloc[0]
pages_meta, elements, edges = load_label(row["label_path"])
print("pages:", len(pages_meta), "elements:", len(elements), "edges:", len(edges))
print("first element keys:", elements[0].keys())
print("first 3 edges:", edges[:3])


pages: 5 elements: 93 edges: 93
first element keys: dict_keys(['box', 'text', 'page', 'label', 'linking', 'id', 'order'])
first 3 edges: [(-1, 1), (-1, 2), (-1, 7)]


In [3]:
def build_parent_map(elements, edges, root_id=0):
    # 先建 child->parent（若数据保证单父，就应唯一）
    parent = {}
    for p, c in edges:
        # 如果出现多父，先保留第一个，同时统计冲突
        if c in parent and parent[c] != p:
            pass
        else:
            parent[c] = p
    
    # 确保每个 element 都有 parent（没有的挂到 root）
    ids = [e["id"] for e in elements]
    for cid in ids:
        if cid == root_id:
            continue
        if cid not in parent:
            parent[cid] = root_id
    return parent

parent_map = build_parent_map(elements, edges, root_id=0)
# 打印几个例子
for cid in list(parent_map.keys())[:10]:
    print(cid, "->", parent_map[cid])


1 -> -1
2 -> -1
7 -> -1
8 -> -1
16 -> -1
53 -> -1
66 -> -1
64 -> 0
65 -> 0
92 -> 0


In [4]:
def norm_box_0_1000(box, page_w, page_h):
    x0, y0, x1, y1 = box
    # 注意：y轴方向是否与图像一致，需要用可视化确认（下一步会做）
    nx0 = int(1000 * x0 / page_w)
    nx1 = int(1000 * x1 / page_w)
    ny0 = int(1000 * y0 / page_h)
    ny1 = int(1000 * y1 / page_h)
    # clip
    nx0, ny0 = max(0,nx0), max(0,ny0)
    nx1, ny1 = min(1000,nx1), min(1000,ny1)
    return [nx0, ny0, nx1, ny1]

# test first 5 elements on their pages
def page_size(pages_meta, page_num):
    p = pages_meta[f"page{page_num}"]
    return p["width"], p["height"]

for e in elements[:5]:
    w,h = page_size(pages_meta, e["page"])
    print(e["id"], e["page"], e["label"], norm_box_0_1000(e["box"], w, h))


64 5 footer [508, 950, 647, 968]
65 5 page-number [666, 950, 680, 966]
92 6 page-number [318, 951, 336, 965]
93 6 footer [348, 950, 486, 966]
1 1 figure [322, 227, 623, 284]


In [5]:
def parse_document(label_path):
    """
    返回：
      pages_meta: dict(pageX -> {width,height})
      elements: list of dict, 按 reading order 排好
    """
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    pages_meta = obj["pages"]
    contents = obj["contents"]

    # 按阅读顺序排序（论文强调 traversal order）
    contents = sorted(contents, key=lambda x: x.get("order", 10**9))

    elements = []
    for c in contents:
        page = c["page"]
        page_info = pages_meta[f"page{page}"]
        pw, ph = page_info["width"], page_info["height"]

        # box → 0–1000（不做 y 翻转）
        x0, y0, x1, y1 = c["box"]
        bbox_1000 = [
            int(1000 * x0 / pw),
            int(1000 * y0 / ph),
            int(1000 * x1 / pw),
            int(1000 * y1 / ph),
        ]

        elements.append({
            "elem_id": c["id"],
            "page_id": page,
            "bbox": bbox_1000,
            "text": c.get("text", ""),
            "label": c.get("label", ""),
            "order": c.get("order", -1),
            "linking": c.get("linking", []),
        })

    return pages_meta, elements


# sanity check
row = index_df.iloc[0]
pages_meta, elements = parse_document(row["label_path"])
print("elements:", len(elements))
for e in elements[:5]:
    print(e)


elements: 93
{'elem_id': 64, 'page_id': 5, 'bbox': [508, 950, 647, 968], 'text': 'Living Without Food', 'label': 'footer', 'order': 0, 'linking': [[0, 64]]}
{'elem_id': 65, 'page_id': 5, 'bbox': [666, 950, 680, 966], 'text': '3', 'label': 'page-number', 'order': 0, 'linking': [[0, 65]]}
{'elem_id': 92, 'page_id': 6, 'bbox': [318, 951, 336, 965], 'text': '4', 'label': 'page-number', 'order': 0, 'linking': [[0, 92]]}
{'elem_id': 93, 'page_id': 6, 'bbox': [348, 950, 486, 966], 'text': 'Living Without Food', 'label': 'footer', 'order': 0, 'linking': [[0, 93]]}
{'elem_id': 1, 'page_id': 1, 'bbox': [322, 227, 623, 284], 'text': '', 'label': 'figure', 'order': 1, 'linking': [[-1, 1]]}


In [6]:
# Cell G (fixed): build parent map and normalize parent ids

def build_parent_map(elements, root_id=0):
    """
    elements: list of dict, each must contain:
        - 'id'
        - 'linking'
    return:
        parent_map: dict(child_id -> parent_id), parent_id >= 0
    """
    parent_map = {}

    # 1. 从 linking 中提取 parent-child
    for e in elements:
        for pair in e.get("linking", []):
            if isinstance(pair, list) and len(pair) == 2:
                p, c = pair
                if c not in parent_map:
                    parent_map[c] = p

    # 2. 处理所有 element，补齐 parent
    for e in elements:
        cid = e["elem_id"]

        # 情况 A：完全没有 parent → 挂到 root
        if cid not in parent_map:
            parent_map[cid] = root_id

        # 情况 B：标成 -1 → 映射到 root
        if parent_map[cid] == -1:
            parent_map[cid] = root_id

    return parent_map


# ====== sanity check ======
parent_map = build_parent_map(elements)

bad = [cid for cid, p in parent_map.items() if p < 0]
print("invalid parents (<0):", bad[:10])
print("total elements:", len(elements))
print("total parents:", len(parent_map))

# 打印前几个看看
for e in elements[:10]:
    print(f"{e['elem_id']} -> parent {parent_map[e['elem_id']]}")


invalid parents (<0): []
total elements: 93
total parents: 93
64 -> parent 0
65 -> parent 0
92 -> parent 0
93 -> parent 0
1 -> parent 0
2 -> parent 0
3 -> parent 2
4 -> parent 2
5 -> parent 2
6 -> parent 2


In [7]:
# CELL H: build parent classification targets (indices)

ROOT_ID = 0

# 1) 候选 parent 列表（ROOT + 全部元素 id，保持 elements 的 reading order）
elem_ids = [e["elem_id"] for e in elements]
candidate_parents = [ROOT_ID] + elem_ids

# 2) 建立 id -> index
id2idx = {pid: i for i, pid in enumerate(candidate_parents)}

# 3) 生成每个元素的监督标签：parent 的 index
parent_target_idx = []
bad_refs = []

for cid in elem_ids:
    pid = parent_map[cid]          # parent id
    if pid not in id2idx:
        bad_refs.append((cid, pid))
        parent_target_idx.append(None)
    else:
        parent_target_idx.append(id2idx[pid])

print("num elements:", len(elem_ids))
print("num candidates (ROOT + elems):", len(candidate_parents))
print("bad parent refs (should be empty):", bad_refs[:10])

# 4) 打印前 10 个样本核对
for i in range(min(10, len(elem_ids))):
    cid = elem_ids[i]
    pid = parent_map[cid]
    print(f"child {cid:>3} -> parent {pid:>3} | target_idx={parent_target_idx[i]}")


num elements: 93
num candidates (ROOT + elems): 94
bad parent refs (should be empty): []
child  64 -> parent   0 | target_idx=0
child  65 -> parent   0 | target_idx=0
child  92 -> parent   0 | target_idx=0
child  93 -> parent   0 | target_idx=0
child   1 -> parent   0 | target_idx=0
child   2 -> parent   0 | target_idx=0
child   3 -> parent   2 | target_idx=6
child   4 -> parent   2 | target_idx=6
child   5 -> parent   2 | target_idx=6
child   6 -> parent   2 | target_idx=6


In [8]:
# CELL J: tokenizer + element-wise packing into chunks (max_tokens=512)

from transformers import AutoTokenizer

# ====== 你需要改的参数 ======
MODEL_NAME = "bert-base-uncased"   # 先用它跑通；之后换成 GeoLayoutLM 对应 tokenizer
MAX_TOKENS = 512
DOC_ID = "0a51k4mj12"
# ===========================

# 1) load doc sample (依赖你已有 parse_document / build_parent_map / build_parent_labels 那些 cell 的结果)
label_path = DATA_DIR / "labels" / f"{DOC_ID}.json"
obj = json.loads(label_path.read_text(encoding="utf-8"))
pages_meta = obj["pages"]
contents = sorted(obj["contents"], key=lambda x: x.get("order", 10**9))

# 2) normalize boxes to 0-1000 per page (no y-flip)
def box_to_1000(box, page_w, page_h):
    x0, y0, x1, y1 = box
    return [
        int(1000 * x0 / page_w),
        int(1000 * y0 / page_h),
        int(1000 * x1 / page_w),
        int(1000 * y1 / page_h),
    ]

elements = []
for c in contents:
    p = c["page"]
    pw, ph = pages_meta[f"page{p}"]["width"], pages_meta[f"page{p}"]["height"]
    elements.append({
        "id": c["id"],
        "page": p,
        "label": c.get("label", ""),
        "text": c.get("text", ""),
        "bbox": box_to_1000(c["box"], pw, ph),
        "linking": c.get("linking", []),
        "order": c.get("order", -1),
    })

# 3) build parent targets (你之前已验证过，这里简化直接复用逻辑)
def build_parent_map(elements, root_id=0):
    parent_map = {}
    for e in elements:
        for pair in e.get("linking", []):
            if isinstance(pair, list) and len(pair) == 2:
                p, c = pair
                if c not in parent_map:
                    parent_map[c] = p
    for e in elements:
        cid = e["id"]
        if cid not in parent_map:
            parent_map[cid] = root_id
        if parent_map[cid] == -1:
            parent_map[cid] = root_id
    return parent_map

parent_map = build_parent_map(elements, root_id=0)

# 4) tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# 5) pack elements into chunks
chunks = []
cur = {"input_ids": [], "attention_mask": [], "bbox": [], "page_id": [], "elem_id": [], "inner_pos": []}
cur_len = 0

def flush():
    nonlocal_chunks.append(cur.copy())

nonlocal_chunks = []
for e in elements:
    text = e["text"] if e["text"] is not None else ""
    if text.strip() == "":
        # 空文本也保留一个占位 token（避免丢元素）；用 tokenizer 的 unk/cls 可能不一致，这里用 [UNK] 文本占位
        text = "[UNK]"

    enc = tok(
        text,
        add_special_tokens=False,
        return_attention_mask=True,
        return_offsets_mapping=False,
        truncation=False
    )
    ids = enc["input_ids"]
    am  = enc["attention_mask"]

    # inner-layout position：该元素内部 token 的相对位置（0..len-1）
    inner = list(range(len(ids)))

    # token-level boxes/page_id/elem_id：每个 token 复制该元素的 bbox/page/id
    bxs = [e["bbox"]] * len(ids)
    pids = [e["page"]] * len(ids)
    eids = [e["id"]]  * len(ids)

    # 如果这个元素加进去会超 MAX_TOKENS，则先 flush 当前 chunk
    if cur_len > 0 and cur_len + len(ids) > MAX_TOKENS:
        nonlocal_chunks.append(cur)
        cur = {"input_ids": [], "attention_mask": [], "bbox": [], "page_id": [], "elem_id": [], "inner_pos": []}
        cur_len = 0

    # 若单个元素本身就超过 MAX_TOKENS：硬截断（真实复现应更谨慎，但先跑通）
    if len(ids) > MAX_TOKENS:
        ids = ids[:MAX_TOKENS]
        am  = am[:MAX_TOKENS]
        inner = inner[:MAX_TOKENS]
        bxs = bxs[:MAX_TOKENS]
        pids = pids[:MAX_TOKENS]
        eids = eids[:MAX_TOKENS]

    cur["input_ids"].extend(ids)
    cur["attention_mask"].extend(am)
    cur["bbox"].extend(bxs)
    cur["page_id"].extend(pids)
    cur["elem_id"].extend(eids)
    cur["inner_pos"].extend(inner)
    cur_len += len(ids)

# flush last
if cur_len > 0:
    nonlocal_chunks.append(cur)

chunks = nonlocal_chunks

print("DOC_ID:", DOC_ID)
print("num elements:", len(elements))
print("num chunks:", len(chunks))
print("chunk lens (first 10):", [len(c["input_ids"]) for c in chunks[:10]])
print("max chunk len:", max(len(c["input_ids"]) for c in chunks))


  from .autonotebook import tqdm as notebook_tqdm


DOC_ID: 0a51k4mj12
num elements: 192
num chunks: 18
chunk lens (first 10): [476, 408, 243, 276, 279, 506, 449, 475, 504, 494]
max chunk len: 509


In [9]:
# CELL K: build element-level pooling indices per chunk

# 输入假设：
# chunks: list of dicts
#   每个 chunk 包含：
#     - input_ids
#     - elem_id  (token-level element id)
# elements: list of element dicts（按 reading order）

from collections import OrderedDict

def build_pooling_indices(chunks):
    """
    返回：
      pooling = list of dicts, 每个 dict 对应一个 chunk：
        {
          "elem_ids": [e1, e2, ...],              # 该 chunk 中出现的 element（按首次出现顺序）
          "first_token_idx": [i1, i2, ...]        # 每个 element 在该 chunk 中的首 token index
        }
    """
    pooling = []

    for chunk in chunks:
        seen = OrderedDict()  # elem_id -> first token index
        for idx, eid in enumerate(chunk["elem_id"]):
            if eid not in seen:
                seen[eid] = idx

        pooling.append({
            "elem_ids": list(seen.keys()),
            "first_token_idx": list(seen.values())
        })

    return pooling


pooling = build_pooling_indices(chunks)

# ====== sanity check ======
print("num chunks:", len(pooling))

for i in range(min(3, len(pooling))):
    p = pooling[i]
    print(f"\nChunk {i}:")
    print("  num elements in chunk:", len(p["elem_ids"]))
    print("  first 5 elem_ids:", p["elem_ids"][:5])
    print("  first 5 token idx:", p["first_token_idx"][:5])

# 检查：token index 一定 < chunk length
for i, (p, c) in enumerate(zip(pooling, chunks)):
    assert max(p["first_token_idx"]) < len(c["input_ids"]), f"Index overflow in chunk {i}"

print("\nPooling indices sanity check passed.")


num chunks: 18

Chunk 0:
  num elements in chunk: 40
  first 5 elem_ids: [1, 2, 13, 14, 24]
  first 5 token idx: [0, 14, 16, 30, 32]

Chunk 1:
  num elements in chunk: 8
  first 5 elem_ids: [9, 10, 11, 12, 15]
  first 5 token idx: [0, 93, 123, 234, 275]

Chunk 2:
  num elements in chunk: 4
  first 5 elem_ids: [19, 20, 21, 22]
  first 5 token idx: [0, 153, 161, 192]

Pooling indices sanity check passed.


In [10]:
# CELL L: build document-level element sequence for decoder (with ROOT)

from collections import OrderedDict

def build_document_element_sequence(chunks, pooling):
    """
    返回：
      doc_elem_ids:        [e1, e2, ..., eM]         # 文档级 element 顺序（去重，按首次出现）
      doc_elem_positions: dict elem_id -> (chunk_idx, token_idx)
                           表示该 element 的 pooling token 位置
    """
    doc_elem_positions = OrderedDict()

    for chunk_idx, (chunk, pool) in enumerate(zip(chunks, pooling)):
        for eid, tok_idx in zip(pool["elem_ids"], pool["first_token_idx"]):
            # 只记录第一次出现的位置（reading order 保证正确）
            if eid not in doc_elem_positions:
                doc_elem_positions[eid] = (chunk_idx, tok_idx)

    doc_elem_ids = list(doc_elem_positions.keys())
    return doc_elem_ids, doc_elem_positions


doc_elem_ids, doc_elem_positions = build_document_element_sequence(chunks, pooling)

# ====== 加 ROOT ======
ROOT_ID = 0
decoder_elem_ids = [ROOT_ID] + doc_elem_ids   # decoder 输入顺序
# ROOT 的 embedding 位置由模型内部 learnable embedding 提供，这里只占位

# ====== sanity checks ======
print("num document elements (no root):", len(doc_elem_ids))
print("num decoder elements (with root):", len(decoder_elem_ids))

# 检查：是否与原始 elements 数量一致
orig_elem_ids = [e["id"] for e in elements]
assert set(doc_elem_ids) == set(orig_elem_ids), "Element ID mismatch after merging chunks"

# 打印前几个核对顺序
print("\nFirst 10 decoder element ids:")
print(decoder_elem_ids[:10])

# 打印一个 element 的 pooling 来源
sample_eid = doc_elem_ids[0]
print(f"\nElement {sample_eid} comes from (chunk_idx, token_idx):",
      doc_elem_positions[sample_eid])

print("\nCELL L sanity check passed.")


num document elements (no root): 192
num decoder elements (with root): 193

First 10 decoder element ids:
[0, 1, 2, 13, 14, 24, 25, 35, 36, 48]

Element 1 comes from (chunk_idx, token_idx): (0, 0)

CELL L sanity check passed.


In [11]:
# CELL O: wrap into an nn.Module (single-document), reuse existing prepared inputs

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

class DHFormerMini(nn.Module):
    """
    Minimal DHFormer-like wrapper:
      - real text-layout encoder (LayoutLMv3/GeoLayoutLM/etc.)
      - element pooling by first-token
      - root learnable embedding
      - bilinear parent scorer
    Input format:
      chunks: list of dict with token-level fields
      doc_elem_ids: list of element ids (no root)
      doc_elem_positions: dict elem_id -> (chunk_idx, token_idx)
      parent_map: dict child_elem_id -> parent_elem_id (root=0)
    """
    def __init__(self, encoder: nn.Module, hidden_size: int = 768, root_id: int = 0):
        super().__init__()
        self.encoder = encoder
        self.hidden_size = hidden_size
        self.root_id = root_id

        self.root_emb = nn.Parameter(torch.zeros(hidden_size))
        nn.init.normal_(self.root_emb, mean=0.0, std=0.02)

        self.bilinear = nn.Bilinear(hidden_size, hidden_size, 1, bias=False)

    def encode_chunks(self, chunks):
        # returns list of [num_tokens, hidden]
        outs = []
        for c in chunks:
            input_ids = torch.tensor(c["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
            attention_mask = torch.tensor(c["attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
            bbox = torch.tensor(c["bbox"], dtype=torch.long, device=device).unsqueeze(0)

            o = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                bbox=bbox,
                return_dict=True
            )
            outs.append(o.last_hidden_state.squeeze(0))
        return outs

    def forward(self, chunks, doc_elem_ids, doc_elem_positions, parent_map):
        # 1) encoder forward
        self.encoder.eval()  # keep eval by default for reproducibility; switch to train() in training loop
        with torch.set_grad_enabled(self.training):
            enc_outs = self.encode_chunks(chunks)

        # 2) element pooling
        elem_emb = {}
        for eid, (chunk_idx, tok_idx) in doc_elem_positions.items():
            elem_emb[eid] = enc_outs[chunk_idx][tok_idx]  # [H]

        # 3) decoder input (ROOT + elements)
        decoder_inputs = torch.stack(
            [self.root_emb.to(device)] + [elem_emb[eid] for eid in doc_elem_ids],
            dim=0
        )  # [M+1, H]

        # 4) build parent targets (indices)
        decoder_elem_ids = [self.root_id] + doc_elem_ids
        id2idx = {pid: i for i, pid in enumerate(decoder_elem_ids)}
        parent_target = torch.tensor(
            [id2idx[parent_map[eid]] for eid in doc_elem_ids],
            dtype=torch.long,
            device=device
        )  # [M]

        # 5) bilinear parent scoring: logits [M, M+1]
        child_emb = decoder_inputs[1:]   # [M, H]
        parent_emb = decoder_inputs      # [M+1, H]

        logits = []
        for i in range(child_emb.size(0)):
            c = child_emb[i].unsqueeze(0).repeat(parent_emb.size(0), 1)
            score = self.bilinear(c, parent_emb).squeeze(-1)
            logits.append(score)
        logits = torch.stack(logits, dim=0)

        loss = F.cross_entropy(logits, parent_target)
        return loss, logits


# ====== instantiate encoder (use the same one you tested) ======
MODEL_NAME = "microsoft/layoutlmv3-base"
encoder = AutoModel.from_pretrained(MODEL_NAME).to(device)

model = DHFormerMini(encoder=encoder, hidden_size=768, root_id=0).to(device)
model.eval()

# ====== quick forward test ======
with torch.no_grad():
    loss, logits = model(chunks, doc_elem_ids, doc_elem_positions, parent_map)

print("loss:", float(loss))
print("logits shape:", tuple(logits.shape))




loss: 10.119296073913574
logits shape: (192, 193)


In [12]:
# CELL Q: Dataset + DataLoader (batch_size=1)

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from pathlib import Path
import json

# ====== settings ======
MODEL_NAME = "microsoft/layoutlmv3-base"   # tokenizer 与 encoder 对齐
MAX_TOKENS = 512
USE_HRES = True
# ======================

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

class DocHieNetDocDataset(Dataset):
    def __init__(self, index_df, split="train"):
        self.df = index_df[index_df["split"] == split].reset_index(drop=True)
        self.root_id = 0

    def __len__(self):
        return len(self.df)

    @staticmethod
    def _box_to_1000(box, pw, ph):
        x0, y0, x1, y1 = box
        return [
            int(1000 * x0 / pw),
            int(1000 * y0 / ph),
            int(1000 * x1 / pw),
            int(1000 * y1 / ph),
        ]

    def _load_elements(self, label_path):
        obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
        pages_meta = obj["pages"]
        contents = sorted(obj["contents"], key=lambda x: x.get("order", 10**9))

        elements = []
        for c in contents:
            p = c["page"]
            pw, ph = pages_meta[f"page{p}"]["width"], pages_meta[f"page{p}"]["height"]
            elements.append({
                "id": c["id"],
                "page": p,
                "text": c.get("text", "") or "",
                "bbox": self._box_to_1000(c["box"], pw, ph),
                "linking": c.get("linking", []),
                "order": c.get("order", -1),
            })
        return elements

    def _build_parent_map(self, elements):
        parent_map = {}
        for e in elements:
            for pair in e.get("linking", []):
                if isinstance(pair, list) and len(pair) == 2:
                    p, c = pair
                    if c not in parent_map:
                        parent_map[c] = p
        for e in elements:
            cid = e["id"]
            if cid not in parent_map:
                parent_map[cid] = self.root_id
            if parent_map[cid] == -1:
                parent_map[cid] = self.root_id
        return parent_map

    def _tokenize_and_chunk(self, elements):
        chunks = []
        cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
        cur_len = 0

        def flush():
            nonlocal cur, cur_len
            if cur_len > 0:
                chunks.append(cur)
            cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
            cur_len = 0

        for e in elements:
            text = (e["text"] or "").strip()
            if text == "":
                text = "[UNK]"

            # LayoutLMv3TokenizerFast expects pretokenized words when text-only is provided
            words = text.split()
            if len(words) == 0:
                words = ["[UNK]"]

        # We only have element-level bbox, so we replicate it for each word
            word_boxes = [e["bbox"]] * len(words)

            enc = tokenizer(
                words,
                boxes=word_boxes,
                add_special_tokens=False,
                return_attention_mask=True,
                truncation=False
            )

            ids = enc["input_ids"]
            am  = enc["attention_mask"]

            # inner-layout positions: token position inside this element
            inner = list(range(len(ids)))

        # token-level metadata replicate
            bxs = [e["bbox"]] * len(ids)
            eids = [e["id"]]  * len(ids)
            pids = [e["page"]] * len(ids)

            if cur_len > 0 and cur_len + len(ids) > MAX_TOKENS:
                flush()

            if len(ids) > MAX_TOKENS:
                ids = ids[:MAX_TOKENS]; am = am[:MAX_TOKENS]
                inner = inner[:MAX_TOKENS]
                bxs = bxs[:MAX_TOKENS]; eids = eids[:MAX_TOKENS]; pids = pids[:MAX_TOKENS]

            cur["input_ids"].extend(ids)
            cur["attention_mask"].extend(am)
            cur["bbox"].extend(bxs)
            cur["elem_id"].extend(eids)
            cur["page_id"].extend(pids)
            cur["inner_pos"].extend(inner)
            cur_len += len(ids)

        flush()
        return chunks


    @staticmethod
    def _build_pooling(chunks):
        from collections import OrderedDict
        pooling = []
        for c in chunks:
            seen = OrderedDict()
            for i, eid in enumerate(c["elem_id"]):
                if eid not in seen:
                    seen[eid] = i
            pooling.append({"elem_ids": list(seen.keys()), "first_token_idx": list(seen.values())})
        return pooling

    @staticmethod
    def _merge_doc_positions(chunks, pooling):
        from collections import OrderedDict
        pos = OrderedDict()
        for chunk_idx, pool in enumerate(pooling):
            for eid, tok_idx in zip(pool["elem_ids"], pool["first_token_idx"]):
                if eid not in pos:
                    pos[eid] = (chunk_idx, tok_idx)
        doc_elem_ids = list(pos.keys())
        return doc_elem_ids, pos

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        doc_id = row["doc_id"]
        label_path = row["label_path"]

        elements = self._load_elements(label_path)
        parent_map = self._build_parent_map(elements)

        chunks = self._tokenize_and_chunk(elements)
        pooling = self._build_pooling(chunks)
        doc_elem_ids, doc_elem_positions = self._merge_doc_positions(chunks, pooling)

        return {
            "doc_id": doc_id,
            "chunks": chunks,
            "doc_elem_ids": doc_elem_ids,
            "doc_elem_positions": doc_elem_positions,
            "parent_map": parent_map,
        }

def collate_fn(batch):
    # batch_size=1: 直接返回 dict
    return batch[0]

train_ds = DocHieNetDocDataset(index_df, split="train")
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)

# quick iterate one batch
b = next(iter(train_loader))
print("doc_id:", b["doc_id"])
print("num chunks:", len(b["chunks"]))
print("num elements:", len(b["doc_elem_ids"]))


doc_id: u2v6z8fqbg
num chunks: 13
num elements: 105


In [13]:
# CELL U (REPLACE): SSA decoder with cached block-diag attn mask + model_ssa

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

def sinusoidal_position_encoding(pos, dim):
    pos = pos.float().unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2, device=pos.device).float() * (-math.log(10000.0) / dim))
    pe = torch.zeros(pos.size(0), dim, device=pos.device)
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe

# ====== KEY SPEED FIX: cache attention masks ======
_ATTN_MASK_CACHE = {}

def build_block_diag_attn_mask(seq_len: int, window: int, device):
    """
    bool mask [L, L], True means masked (NOT allowed to attend)
    cached by (L, window, device_type)
    """
    key = (seq_len, window, device.type)
    m = _ATTN_MASK_CACHE.get(key, None)
    if m is not None:
        return m
    win_id = torch.arange(seq_len, device=device) // window
    mask = ~(win_id.unsqueeze(0) == win_id.unsqueeze(1))  # True if different window => masked
    _ATTN_MASK_CACHE[key] = mask
    return mask

class SSALayer(nn.Module):
    def __init__(self, hidden_size=768, heads=8, ff=2048, dropout=0.1, window=48, shift=0):
        super().__init__()
        self.window = window
        self.shift = shift

        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn = nn.MultiheadAttention(hidden_size, heads, dropout=dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff, hidden_size),
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, L, H]
        B, L, H = x.shape

        if self.shift != 0:
            x = torch.roll(x, shifts=-self.shift, dims=1)

        attn_mask = build_block_diag_attn_mask(L, self.window, x.device)  # cached

        q = self.norm1(x)
        y, _ = self.attn(q, q, q, attn_mask=attn_mask, need_weights=False)
        x = x + self.drop1(y)

        if self.shift != 0:
            x = torch.roll(x, shifts=self.shift, dims=1)

        z = self.norm2(x)
        z = self.ffn(z)
        x = x + self.drop2(z)
        return x

class SSADecoder(nn.Module):
    def __init__(self, hidden_size=768, heads=8, ff=2048, dropout=0.1, window=48, layers=2):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(layers):
            shift = 0 if (i % 2 == 0) else (window // 2)
            self.layers.append(SSALayer(hidden_size, heads, ff, dropout, window, shift))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class DHFormerWithSSA(nn.Module):
    def __init__(
        self,
        encoder: nn.Module,
        hidden_size: int = 768,
        root_id: int = 0,
        max_inner_pos: int = 2048,
        page_pe_dim: int = 128,
        ssa_layers: int = 2,
        ssa_heads: int = 8,
        ssa_ff: int = 2048,
        ssa_dropout: float = 0.1,
        window: int = 48,
    ):
        super().__init__()
        self.encoder = encoder
        self.hidden_size = hidden_size
        self.root_id = root_id

        self.root_emb = nn.Parameter(torch.zeros(hidden_size))
        nn.init.normal_(self.root_emb, mean=0.0, std=0.02)

        self.inner_emb = nn.Embedding(max_inner_pos, hidden_size)
        self.page_proj = nn.Linear(page_pe_dim, hidden_size)
        self.page_pe_dim = page_pe_dim

        self.elem_decoder = SSADecoder(
            hidden_size=hidden_size,
            heads=ssa_heads,
            ff=ssa_ff,
            dropout=ssa_dropout,
            window=window,
            layers=ssa_layers
        )

        self.bilinear = nn.Bilinear(hidden_size, hidden_size, 1, bias=False)

    def _encode_one_chunk(self, c):
        input_ids = torch.tensor(c["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
        attention_mask = torch.tensor(c["attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
        bbox = torch.tensor(c["bbox"], dtype=torch.long, device=device).unsqueeze(0)

        o = self.encoder(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, return_dict=True)
        hidden = o.last_hidden_state.squeeze(0)  # [T,H]

        page_ids = torch.tensor(c["page_id"], dtype=torch.long, device=device)
        inner_pos = torch.tensor(c["inner_pos"], dtype=torch.long, device=device)

        pe = sinusoidal_position_encoding(page_ids, self.page_pe_dim)
        page_e = self.page_proj(pe)

        inner_pos = torch.clamp(inner_pos, 0, self.inner_emb.num_embeddings - 1)
        inner_e = self.inner_emb(inner_pos)

        return hidden + page_e + inner_e

    def forward(self, chunks, doc_elem_ids, doc_elem_positions, parent_map):
        enc_outs = [self._encode_one_chunk(c) for c in chunks]  # list [T,H]

        elem_emb = {}
        for eid, (chunk_idx, tok_idx) in doc_elem_positions.items():
            elem_emb[eid] = enc_outs[chunk_idx][tok_idx]

        elem_seq = torch.stack([self.root_emb] + [elem_emb[eid] for eid in doc_elem_ids], dim=0)  # [L,H]
        elem_seq = elem_seq.unsqueeze(0)  # [1,L,H]

        elem_ctx = self.elem_decoder(elem_seq).squeeze(0)  # [L,H]

        decoder_elem_ids = [self.root_id] + doc_elem_ids
        id2idx = {pid: i for i, pid in enumerate(decoder_elem_ids)}
        parent_target = torch.tensor([id2idx[parent_map[eid]] for eid in doc_elem_ids],
                                     dtype=torch.long, device=device)

        child_emb = elem_ctx[1:]
        parent_emb = elem_ctx

        logits = []
        for i in range(child_emb.size(0)):
            c = child_emb[i].unsqueeze(0).repeat(parent_emb.size(0), 1)
            logits.append(self.bilinear(c, parent_emb).squeeze(-1))
        logits = torch.stack(logits, dim=0)

        loss = F.cross_entropy(logits, parent_target)
        return loss, logits

# instantiate SSA model
MODEL_NAME = "microsoft/layoutlmv3-base"
encoder = AutoModel.from_pretrained(MODEL_NAME).to(device)

model_ssa = DHFormerWithSSA(
    encoder=encoder,
    hidden_size=768,
    root_id=0,
    window=48,
    ssa_layers=2,
    ssa_heads=8,
    ssa_ff=2048
).to(device)

print("model_ssa ready on", device)


model_ssa ready on cuda


In [14]:
# CELL F1: evaluate relation-level Precision/Recall/F1 on test split for model_ssa
import torch
from torch.utils.data import DataLoader

def edges_from_parent_map(doc_elem_ids, parent_map, root_id=0):
    # include root edges too (as paper's relation set); if you want "exclude root edges", you can filter later
    return set((parent_map[eid], eid) for eid in doc_elem_ids)

def edges_from_logits(doc_elem_ids, logits, root_id=0):
    """
    logits: [M, M+1] where parents are [ROOT] + doc_elem_ids
    returns set of (parent_id, child_id)
    """
    parent_candidates = [root_id] + doc_elem_ids
    pred_parent_idx = torch.argmax(logits, dim=1).tolist()  # len M
    pred_edges = set()
    for child_eid, pidx in zip(doc_elem_ids, pred_parent_idx):
        pred_parent_id = parent_candidates[pidx]
        pred_edges.add((pred_parent_id, child_eid))
    return pred_edges

def prf1(gt_edges, pred_edges):
    inter = gt_edges & pred_edges
    p = len(inter) / len(pred_edges) if pred_edges else 0.0
    r = len(inter) / len(gt_edges) if gt_edges else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    return p, r, f1, len(inter)

# build test loader
test_ds = DocHieNetDocDataset(index_df, split="test")
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

model_ssa.eval()
root_id = 0

# accumulators
tp_sum = 0
pred_sum = 0
gt_sum = 0

# optional: limit for quick sanity; set to None to run full test
LIMIT_DOCS = None  # e.g., 50 for quick run

with torch.no_grad():
    for i, batch in enumerate(test_loader, start=1):
        loss, logits = model_ssa(
            batch["chunks"],
            batch["doc_elem_ids"],
            batch["doc_elem_positions"],
            batch["parent_map"]
        )

        gt = edges_from_parent_map(batch["doc_elem_ids"], batch["parent_map"], root_id=root_id)
        pred = edges_from_logits(batch["doc_elem_ids"], logits, root_id=root_id)

        inter = gt & pred
        tp_sum += len(inter)
        pred_sum += len(pred)
        gt_sum += len(gt)

        if LIMIT_DOCS is not None and i >= LIMIT_DOCS:
            break

# micro P/R/F1
precision = tp_sum / pred_sum if pred_sum else 0.0
recall = tp_sum / gt_sum if gt_sum else 0.0
f1 = (2*precision*recall/(precision+recall)) if (precision+recall) else 0.0

print(f"Evaluated docs: {i if LIMIT_DOCS is None else LIMIT_DOCS}")
print(f"TP={tp_sum}, Pred={pred_sum}, GT={gt_sum}")
print(f"Micro Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")

Token indices sequence length is longer than the specified maximum sequence length for this model (549 > 512). Running this sequence through the model will result in indexing errors


Evaluated docs: 161
TP=611, Pred=29401, GT=29401
Micro Precision=0.0208, Recall=0.0208, F1=0.0208


In [15]:
# CELL FIX_BBOX: clamp bbox to [0,1000] before feeding into LayoutLMv3 encoder

import types
import torch
import math

device = "cuda" if torch.cuda.is_available() else "cpu"

def sinusoidal_position_encoding(pos, dim):
    pos = pos.float().unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2, device=pos.device).float() * (-math.log(10000.0) / dim))
    pe = torch.zeros(pos.size(0), dim, device=pos.device)
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe

def _encode_one_chunk_clamped(self, c):
    input_ids = torch.tensor(c["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
    attention_mask = torch.tensor(c["attention_mask"], dtype=torch.long, device=device).unsqueeze(0)

    bbox = torch.tensor(c["bbox"], dtype=torch.long, device=device).unsqueeze(0)
    bbox = torch.clamp(bbox, 0, 1000)  # <<<关键修复：保证在 [0,1000]

    o = self.encoder(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, return_dict=True)
    hidden = o.last_hidden_state.squeeze(0)  # [T,H]

    # 下面保持与你原来的 DHFormerWithSSA 一致：page + inner embedding injection
    page_ids = torch.tensor(c["page_id"], dtype=torch.long, device=device)
    inner_pos = torch.tensor(c["inner_pos"], dtype=torch.long, device=device)

    pe = sinusoidal_position_encoding(page_ids, self.page_pe_dim)
    page_e = self.page_proj(pe)

    inner_pos = torch.clamp(inner_pos, 0, self.inner_emb.num_embeddings - 1)
    inner_e = self.inner_emb(inner_pos)

    return hidden + page_e + inner_e

# 把补丁方法绑定到当前 model_ssa 实例
model_ssa._encode_one_chunk = types.MethodType(_encode_one_chunk_clamped, model_ssa)

print("Patched model_ssa: bbox will be clamped to [0,1000] before encoder.")


Patched model_ssa: bbox will be clamped to [0,1000] before encoder.


In [16]:
# CELL CHUNKCAP_FIX: fix LayoutLMv3TokenizerFast usage (needs words list + boxes)

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from pathlib import Path
import json
from collections import OrderedDict

MODEL_NAME = "microsoft/layoutlmv3-base"
MAX_TOKENS = 512
MAX_CHUNKS_TRAIN = 32
MAX_CHUNKS_TEST  = 128

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

class DocHieNetDocDataset(Dataset):
    def __init__(self, index_df, split="train"):
        self.df = index_df[index_df["split"] == split].reset_index(drop=True)
        self.root_id = 0
        self.split = split
        self.max_chunks = MAX_CHUNKS_TRAIN if split == "train" else (MAX_CHUNKS_TEST if split == "test" else None)

    def __len__(self):
        return len(self.df)

    @staticmethod
    def _box_to_1000(box, pw, ph):
        x0, y0, x1, y1 = box
        return [
            int(1000 * x0 / pw),
            int(1000 * y0 / ph),
            int(1000 * x1 / pw),
            int(1000 * y1 / ph),
        ]

    def _load_elements(self, label_path):
        obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
        pages_meta = obj["pages"]
        contents = sorted(obj["contents"], key=lambda x: x.get("order", 10**9))

        elements = []
        for c in contents:
            p = c["page"]
            pw, ph = pages_meta[f"page{p}"]["width"], pages_meta[f"page{p}"]["height"]
            elements.append({
                "id": c["id"],
                "page": p,
                "text": (c.get("text", "") or ""),
                "bbox": self._box_to_1000(c["box"], pw, ph),
                "linking": c.get("linking", []),
                "order": c.get("order", -1),
            })
        return elements

    def _build_parent_map(self, elements):
        parent_map = {}
        for e in elements:
            for pair in e.get("linking", []):
                if isinstance(pair, list) and len(pair) == 2:
                    p, c = pair
                    if c not in parent_map:
                        parent_map[c] = p
        for e in elements:
            cid = e["id"]
            if cid not in parent_map:
                parent_map[cid] = self.root_id
            if parent_map[cid] == -1:
                parent_map[cid] = self.root_id
        return parent_map

    def _tokenize_and_chunk(self, elements):
        """
        LayoutLMv3 tokenizer requires:
          - words: List[str]
          - boxes: List[List[int]] (one box per word)
        We'll assign each word the element bbox.
        """
        chunks = []
        cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
        cur_len = 0

        def flush():
            nonlocal cur, cur_len
            if cur_len > 0:
                chunks.append(cur)
            cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
            cur_len = 0

        for e in elements:
            text = e["text"].strip()
            words = text.split() if text else []
            if len(words) == 0:
                words = ["[UNK]"]

            boxes = [e["bbox"]] * len(words)

            enc = tokenizer(
                words,
                boxes=boxes,
                add_special_tokens=False,
                return_attention_mask=True,
                truncation=False
            )

            ids = enc["input_ids"]
            am  = enc["attention_mask"]
            bxs = enc["bbox"]  # token-level bboxes aligned by tokenizer

            # inner position within this element (token-level)
            inner = list(range(len(ids)))
            eids = [e["id"]] * len(ids)
            pids = [e["page"]] * len(ids)

            if cur_len > 0 and cur_len + len(ids) > MAX_TOKENS:
                flush()

            # safeguard: element too long
            if len(ids) > MAX_TOKENS:
                ids = ids[:MAX_TOKENS]; am = am[:MAX_TOKENS]; bxs = bxs[:MAX_TOKENS]
                inner = inner[:MAX_TOKENS]; eids = eids[:MAX_TOKENS]; pids = pids[:MAX_TOKENS]

            cur["input_ids"].extend(ids)
            cur["attention_mask"].extend(am)
            cur["bbox"].extend(bxs)
            cur["elem_id"].extend(eids)
            cur["page_id"].extend(pids)
            cur["inner_pos"].extend(inner)
            cur_len += len(ids)

        flush()
        return chunks

    @staticmethod
    def _build_pooling(chunks):
        pooling = []
        for c in chunks:
            seen = OrderedDict()
            for i, eid in enumerate(c["elem_id"]):
                if eid not in seen:
                    seen[eid] = i
            pooling.append({"elem_ids": list(seen.keys()), "first_token_idx": list(seen.values())})
        return pooling

    @staticmethod
    def _merge_doc_positions(chunks, pooling):
        pos = OrderedDict()
        for chunk_idx, pool in enumerate(pooling):
            for eid, tok_idx in zip(pool["elem_ids"], pool["first_token_idx"]):
                if eid not in pos:
                    pos[eid] = (chunk_idx, tok_idx)
        doc_elem_ids = list(pos.keys())
        return doc_elem_ids, pos

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        doc_id = row["doc_id"]
        label_path = row["label_path"]

        elements = self._load_elements(label_path)
        parent_map_full = self._build_parent_map(elements)

        chunks = self._tokenize_and_chunk(elements)

        # cap chunks per paper setting
        if self.max_chunks is not None and len(chunks) > self.max_chunks:
            chunks = chunks[:self.max_chunks]

        pooling = self._build_pooling(chunks)
        doc_elem_ids, doc_elem_positions = self._merge_doc_positions(chunks, pooling)

        parent_map = {eid: parent_map_full[eid] for eid in doc_elem_ids}

        return {
            "doc_id": doc_id,
            "chunks": chunks,
            "doc_elem_ids": doc_elem_ids,
            "doc_elem_positions": doc_elem_positions,
            "parent_map": parent_map,
            "n_chunks": len(chunks),
            "n_elems": len(doc_elem_ids),
        }

def collate_fn(batch):
    return batch[0]

train_ds = DocHieNetDocDataset(index_df, split="train")
test_ds  = DocHieNetDocDataset(index_df, split="test")

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

b = next(iter(train_loader))
print("train sample:", b["doc_id"], "| chunks:", b["n_chunks"], "| elems:", b["n_elems"])
b2 = next(iter(test_loader))
print("test sample: ", b2["doc_id"], "| chunks:", b2["n_chunks"], "| elems:", b2["n_elems"])


Token indices sequence length is longer than the specified maximum sequence length for this model (549 > 512). Running this sequence through the model will result in indexing errors


train sample: tdvwfwxc65 | chunks: 10 | elems: 107
test sample:  02nk0izsev | chunks: 34 | elems: 55


In [17]:
# CLEANUP CELL: run once before training to avoid GPU memory fragmentation

import gc
import torch

for name in ["model", "model_pe", "model_dec", "encoder", "model_ssa_old", "model_dec_old"]:
    if name in globals():
        del globals()[name]

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("cleanup done")


cleanup done


In [18]:
# TRAIN CELL (REPLACE): AMP + step timing + epoch timing + ETA
'''
import time
import math
import random
import numpy as np
import torch
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"
assert device == "cuda", "你现在不是 CUDA，先确认 GPU 可用并让 model_ssa 在 cuda 上。"

# reproducibility (optional)
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

EPOCHS = 5               # 先用 5 确认速度/流程，没问题再改 100
LR_START = 4e-5
LR_END = 1e-6
WEIGHT_DECAY = 0.01
MAX_GRAD_NORM = 1.0
PRINT_EVERY_STEP = 20
SAVE_PATH = "best_model_ssa.pt"
best_f1 = -1.0

model_ssa.to(device)
model_ssa.train()

optimizer = optim.AdamW(model_ssa.parameters(), lr=LR_START, weight_decay=WEIGHT_DECAY)
gamma = (LR_END / LR_START) ** (1.0 / max(EPOCHS - 1, 1))
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

scaler = torch.cuda.amp.GradScaler(enabled=True)

def micro_f1_on_loader(model, loader, root_id=0):
    model.eval()
    tp_sum = pred_sum = gt_sum = 0
    with torch.no_grad():
        for batch in loader:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                loss, logits = model(
                    batch["chunks"],
                    batch["doc_elem_ids"],
                    batch["doc_elem_positions"],
                    batch["parent_map"]
                )
            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()
            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])
            inter = gt_edges & pred_edges
            tp_sum += len(inter); pred_sum += len(pred_edges); gt_sum += len(gt_edges)
    p = tp_sum / pred_sum if pred_sum else 0.0
    r = tp_sum / gt_sum if gt_sum else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    model.train()
    return p, r, f1

global_start = time.time()

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    step_times = []

    total_loss = 0.0
    steps = 0

    for step, batch in enumerate(train_loader, start=1):
        t0 = time.time()

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(dtype=torch.float16):
            loss, _ = model_ssa(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model_ssa.parameters(), MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.detach().cpu())
        steps += 1

        dt = time.time() - t0
        step_times.append(dt)

        if step % PRINT_EVERY_STEP == 0:
            avg_dt = sum(step_times[-PRINT_EVERY_STEP:]) / PRINT_EVERY_STEP
            print(f"[epoch {epoch:03d} | step {step:04d}] loss={loss.item():.4f} | {avg_dt:.2f}s/step")

    avg_train_loss = total_loss / max(steps, 1)

    eval_start = time.time()
    p, r, f1 = micro_f1_on_loader(model_ssa, test_loader, root_id=0)
    eval_time = time.time() - eval_start

    # save best
    if f1 > best_f1:
        best_f1 = f1
        torch.save({"epoch": epoch, "model_state_dict": model_ssa.state_dict(), "best_f1": best_f1}, SAVE_PATH)
        tag = " (best)"
    else:
        tag = ""

    epoch_time = time.time() - epoch_start
    elapsed = time.time() - global_start
    avg_epoch = elapsed / epoch
    eta = avg_epoch * (EPOCHS - epoch)
    lr = optimizer.param_groups[0]["lr"]

    print(
        f"\n[epoch {epoch:03d}/{EPOCHS}] lr={lr:.2e} | train_loss={avg_train_loss:.4f} | test_F1={f1:.4f}{tag}\n"
        f"  epoch_time={epoch_time/60:.1f} min | eval_time={eval_time:.1f}s | ETA={eta/60:.1f} min\n"
    )

    scheduler.step()

print("done. best_f1=", best_f1, "saved:", SAVE_PATH)
'''

'\nimport time\nimport math\nimport random\nimport numpy as np\nimport torch\nimport torch.optim as optim\n\ndevice = "cuda" if torch.cuda.is_available() else "cpu"\nassert device == "cuda", "你现在不是 CUDA，先确认 GPU 可用并让 model_ssa 在 cuda 上。"\n\n# reproducibility (optional)\nSEED = 42\nrandom.seed(SEED); np.random.seed(SEED)\ntorch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)\n\nEPOCHS = 5               # 先用 5 确认速度/流程，没问题再改 100\nLR_START = 4e-5\nLR_END = 1e-6\nWEIGHT_DECAY = 0.01\nMAX_GRAD_NORM = 1.0\nPRINT_EVERY_STEP = 20\nSAVE_PATH = "best_model_ssa.pt"\nbest_f1 = -1.0\n\nmodel_ssa.to(device)\nmodel_ssa.train()\n\noptimizer = optim.AdamW(model_ssa.parameters(), lr=LR_START, weight_decay=WEIGHT_DECAY)\ngamma = (LR_END / LR_START) ** (1.0 / max(EPOCHS - 1, 1))\nscheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)\n\nscaler = torch.cuda.amp.GradScaler(enabled=True)\n\ndef micro_f1_on_loader(model, loader, root_id=0):\n    model.eval()\n    tp_sum = pred_sum = gt_sum = 0

In [19]:
# CELL EVALBEST: load best checkpoint and evaluate test micro-F1 (all edges / non-root edges)

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
CKPT_PATH = "best_model_ssa.pt"
ROOT_ID = 0

# 1) load checkpoint
ckpt = torch.load(CKPT_PATH, map_location=device)
model_ssa.load_state_dict(ckpt["model_state_dict"])
model_ssa.to(device)
model_ssa.eval()

print("Loaded:", CKPT_PATH, "| epoch:", ckpt.get("epoch"), "| best_f1:", ckpt.get("best_f1"))

def micro_prf(model, loader, exclude_root=False, root_id=0):
    tp = pred_n = gt_n = 0
    with torch.no_grad():
        for batch in loader:
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )

            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()

            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])

            if exclude_root:
                pred_edges = set(e for e in pred_edges if e[0] != root_id)
                gt_edges = set(e for e in gt_edges if e[0] != root_id)

            inter = pred_edges & gt_edges
            tp += len(inter)
            pred_n += len(pred_edges)
            gt_n += len(gt_edges)

    p = tp / pred_n if pred_n else 0.0
    r = tp / gt_n if gt_n else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    return p, r, f1, tp, pred_n, gt_n

# 2) evaluate
p, r, f1, tp, pred_n, gt_n = micro_prf(model_ssa, test_loader, exclude_root=False, root_id=ROOT_ID)
print(f"[ALL]    TP={tp} Pred={pred_n} GT={gt_n} | P={p:.4f} R={r:.4f} F1={f1:.4f}")

p2, r2, f12, tp2, pred_n2, gt_n2 = micro_prf(model_ssa, test_loader, exclude_root=True, root_id=ROOT_ID)
print(f"[NONROOT] TP={tp2} Pred={pred_n2} GT={gt_n2} | P={p2:.4f} R={r2:.4f} F1={f12:.4f}")


  ckpt = torch.load(CKPT_PATH, map_location=device)


Loaded: best_model_ssa.pt | epoch: 2 | best_f1: 0.18653653826257888
[ALL]    TP=5459 Pred=28997 GT=28997 | P=0.1883 R=0.1883 F1=0.1883
[NONROOT] TP=767 Pred=16604 GT=23646 | P=0.0462 R=0.0324 F1=0.0381


In [20]:
# CELL TEDS_LABEL: semantic-label TEDS via zss (closer to paper than id-sensitive)

import torch
import sys
import subprocess

try:
    import zss
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "zss"])
    import zss

ROOT_ID = 0

def build_children_ordered(doc_elem_ids, parent_map, root_id=0):
    order = {eid: i for i, eid in enumerate([root_id] + doc_elem_ids)}
    children = {root_id: []}
    for eid in doc_elem_ids:
        children.setdefault(eid, [])
    for child in doc_elem_ids:
        par = parent_map.get(child, root_id)
        children.setdefault(par, [])
        children[par].append(child)
    for k in children:
        children[k].sort(key=lambda x: order.get(x, 10**9))
    return children

def to_zss_tree(children, node_id, label_map):
    node = zss.Node(label_map.get(node_id, "UNK"))
    for ch in children.get(node_id, []):
        node.addkid(to_zss_tree(children, ch, label_map))
    return node

def predict_parent_map_from_logits(doc_elem_ids, logits, root_id=0):
    parent_candidates = [root_id] + doc_elem_ids
    pred_parent_idx = torch.argmax(logits, dim=1).tolist()
    return {child: parent_candidates[pidx] for child, pidx in zip(doc_elem_ids, pred_parent_idx)}

def compute_semantic_teds_one(doc_elem_ids, gt_parent_map, pred_parent_map, elem_label_map, root_id=0):
    gt_children = build_children_ordered(doc_elem_ids, gt_parent_map, root_id=root_id)
    pr_children = build_children_ordered(doc_elem_ids, pred_parent_map, root_id=root_id)

    # label map: ROOT + element semantic class
    label_map = {root_id: "ROOT"}
    for eid in doc_elem_ids:
        label_map[eid] = elem_label_map.get(eid, "UNK")

    gt_root = to_zss_tree(gt_children, root_id, label_map)
    pr_root = to_zss_tree(pr_children, root_id, label_map)

    def insert_cost(n): return 1
    def remove_cost(n): return 1
    def update_cost(a, b): return 0 if a.label == b.label else 1

    ed = zss.distance(
        gt_root, pr_root,
        get_children=zss.Node.get_children,
        insert_cost=insert_cost,
        remove_cost=remove_cost,
        update_cost=update_cost
    )

    # size: include root + nodes
    denom = 1 + len(doc_elem_ids)
    return 1.0 - (ed / denom)

# ---- Evaluate on test ----
model_ssa.eval()
scores = []

with torch.no_grad():
    for batch in test_loader:
        # IMPORTANT: we need semantic labels per element id.
        # batch currently doesn't include them, so we approximate by re-reading label file via dataset logic:
        # simplest: use the fact that your dataset was built from label json; here we re-open it from index_df.
        # We'll build elem_label_map by scanning the original label json once per doc.

        doc_id = batch["doc_id"]
        # locate label path from index_df
        row = index_df[index_df["doc_id"] == doc_id].iloc[0]
        label_path = row["label_path"]

        obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
        elem_label_map = {c["id"]: c.get("label", "UNK") for c in obj["contents"]}

        _, logits = model_ssa(
            batch["chunks"],
            batch["doc_elem_ids"],
            batch["doc_elem_positions"],
            batch["parent_map"]
        )

        pred_parent_map = predict_parent_map_from_logits(batch["doc_elem_ids"], logits, root_id=ROOT_ID)

        scores.append(
            compute_semantic_teds_one(
                doc_elem_ids=batch["doc_elem_ids"],
                gt_parent_map=batch["parent_map"],
                pred_parent_map=pred_parent_map,
                elem_label_map=elem_label_map,
                root_id=ROOT_ID
            )
        )

avg_teds = sum(scores) / max(len(scores), 1)
print(f"Test docs: {len(scores)}")
print(f"Avg TEDS (semantic-label): {avg_teds:.4f}")


Test docs: 161
Avg TEDS (semantic-label): 0.4589


In [21]:
# CELL PARENTMAP_FIX: after chunk truncation, remap missing parents to ROOT to avoid KeyError

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from pathlib import Path
import json
from collections import OrderedDict

MODEL_NAME = "microsoft/layoutlmv3-base"
MAX_TOKENS = 512

# you can set these to your current speed config
MAX_CHUNKS_TRAIN = 16   # e.g., 16 for RTX 4060 8GB
MAX_CHUNKS_TEST  = 64   # e.g., 64 (can raise later)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

class DocHieNetDocDataset(Dataset):
    def __init__(self, index_df, split="train"):
        self.df = index_df[index_df["split"] == split].reset_index(drop=True)
        self.root_id = 0
        self.split = split
        if split == "train":
            self.max_chunks = MAX_CHUNKS_TRAIN
        elif split == "test":
            self.max_chunks = MAX_CHUNKS_TEST
        else:
            self.max_chunks = None

    def __len__(self):
        return len(self.df)

    @staticmethod
    def _box_to_1000(box, pw, ph):
        x0, y0, x1, y1 = box
        return [
            int(1000 * x0 / pw),
            int(1000 * y0 / ph),
            int(1000 * x1 / pw),
            int(1000 * y1 / ph),
        ]

    def _load_elements(self, label_path):
        obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
        pages_meta = obj["pages"]
        contents = sorted(obj["contents"], key=lambda x: x.get("order", 10**9))
        elements = []
        for c in contents:
            p = c["page"]
            pw, ph = pages_meta[f"page{p}"]["width"], pages_meta[f"page{p}"]["height"]
            elements.append({
                "id": c["id"],
                "page": p,
                "text": (c.get("text", "") or ""),
                "bbox": self._box_to_1000(c["box"], pw, ph),
                "linking": c.get("linking", []),
                "order": c.get("order", -1),
            })
        return elements

    def _build_parent_map(self, elements):
        parent_map = {}
        for e in elements:
            for pair in e.get("linking", []):
                if isinstance(pair, list) and len(pair) == 2:
                    p, c = pair
                    if c not in parent_map:
                        parent_map[c] = p
        for e in elements:
            cid = e["id"]
            if cid not in parent_map:
                parent_map[cid] = self.root_id
            if parent_map[cid] == -1:
                parent_map[cid] = self.root_id
        return parent_map

    def _tokenize_and_chunk(self, elements):
        chunks = []
        cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
        cur_len = 0

        def flush():
            nonlocal cur, cur_len
            if cur_len > 0:
                chunks.append(cur)
            cur = {"input_ids": [], "attention_mask": [], "bbox": [], "elem_id": [], "page_id": [], "inner_pos": []}
            cur_len = 0

        for e in elements:
            text = e["text"].strip()
            words = text.split() if text else []
            if len(words) == 0:
                words = ["[UNK]"]
            boxes = [e["bbox"]] * len(words)

            enc = tokenizer(
                words,
                boxes=boxes,
                add_special_tokens=False,
                return_attention_mask=True,
                truncation=False
            )

            ids = enc["input_ids"]
            am  = enc["attention_mask"]
            bxs = enc["bbox"]

            inner = list(range(len(ids)))
            eids = [e["id"]] * len(ids)
            pids = [e["page"]] * len(ids)

            if cur_len > 0 and cur_len + len(ids) > MAX_TOKENS:
                flush()

            if len(ids) > MAX_TOKENS:
                ids = ids[:MAX_TOKENS]; am = am[:MAX_TOKENS]; bxs = bxs[:MAX_TOKENS]
                inner = inner[:MAX_TOKENS]; eids = eids[:MAX_TOKENS]; pids = pids[:MAX_TOKENS]

            cur["input_ids"].extend(ids)
            cur["attention_mask"].extend(am)
            cur["bbox"].extend(bxs)
            cur["elem_id"].extend(eids)
            cur["page_id"].extend(pids)
            cur["inner_pos"].extend(inner)
            cur_len += len(ids)

        flush()
        return chunks

    @staticmethod
    def _build_pooling(chunks):
        pooling = []
        for c in chunks:
            seen = OrderedDict()
            for i, eid in enumerate(c["elem_id"]):
                if eid not in seen:
                    seen[eid] = i
            pooling.append({"elem_ids": list(seen.keys()), "first_token_idx": list(seen.values())})
        return pooling

    @staticmethod
    def _merge_doc_positions(chunks, pooling):
        pos = OrderedDict()
        for chunk_idx, pool in enumerate(pooling):
            for eid, tok_idx in zip(pool["elem_ids"], pool["first_token_idx"]):
                if eid not in pos:
                    pos[eid] = (chunk_idx, tok_idx)
        doc_elem_ids = list(pos.keys())
        return doc_elem_ids, pos

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        doc_id = row["doc_id"]
        label_path = row["label_path"]

        elements = self._load_elements(label_path)
        parent_map_full = self._build_parent_map(elements)

        chunks = self._tokenize_and_chunk(elements)

        # cap chunks
        if self.max_chunks is not None and len(chunks) > self.max_chunks:
            chunks = chunks[:self.max_chunks]

        pooling = self._build_pooling(chunks)
        doc_elem_ids, doc_elem_positions = self._merge_doc_positions(chunks, pooling)

        kept = set(doc_elem_ids)
        root = self.root_id

        # --- FIX: if parent not in kept, remap to ROOT ---
        parent_map = {}
        for eid in doc_elem_ids:
            p = parent_map_full.get(eid, root)
            if p == -1:
                p = root
            if p != root and p not in kept:
                p = root
            if p == eid:  # guard self-loop
                p = root
            parent_map[eid] = p

        return {
            "doc_id": doc_id,
            "chunks": chunks,
            "doc_elem_ids": doc_elem_ids,
            "doc_elem_positions": doc_elem_positions,
            "parent_map": parent_map,
            "n_chunks": len(chunks),
            "n_elems": len(doc_elem_ids),
        }

def collate_fn(batch):
    return batch[0]

train_ds = DocHieNetDocDataset(index_df, split="train")
test_ds  = DocHieNetDocDataset(index_df, split="test")
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

# quick sanity: try one forward to ensure no KeyError
b = next(iter(train_loader))
print("sample:", b["doc_id"], "| chunks:", b["n_chunks"], "| elems:", b["n_elems"])
print("parents missing after fix:",
      sum(1 for eid in b["doc_elem_ids"] if b["parent_map"][eid] not in set([0]+b["doc_elem_ids"])))


sample: w6praprwst | chunks: 14 | elems: 75
parents missing after fix: 0


In [None]:
# TRAIN CELL (FINAL, REPLACE): AMP + grad accumulation + warmup/decay + step/epoch timing + ETA
# + train-step cap per epoch + sparse eval (TEDS less frequent) + optional chunk-cap patch
#
# Assumes you already have: model_ssa, DocHieNetDocDataset, index_df, collate_fn
# This cell rebuilds train_loader/test_loader with faster caps, then trains.
'''
import time
import math
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
import json
import sys
import subprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    import zss
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "zss"])
    import zss

# ------------------ reproducibility ------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# ------------------ speed knobs (RTX 4060 8GB friendly) ------------------
EPOCHS = 10                       # 先 20 看趋势；稳定后可加到 30~50
MAX_TRAIN_STEPS_PER_EPOCH = 200   # 每个 epoch 只跑这么多 doc（关键加速）
PRINT_EVERY_STEP = 20

TRAIN_MAX_CHUNKS = 16             # 训练 cap（论文 32；你这张卡建议 16）
TEST_MAX_CHUNKS  = 64             # 测试 cap（可先 64；最终可回 128）

EVAL_EVERY_EPOCH = 1              # 每个 epoch 都算 F1
EVAL_TEDS_EVERY  = 5              # 每 5 个 epoch 才算一次 TEDS（关键加速）

# ------------------ optimization ------------------
LR_START = 4e-5
LR_END = 1e-6
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.05
GRAD_ACCUM_STEPS = 2              # 8GB 上建议 2；如果显存更紧就 1
MAX_GRAD_NORM = 1.0
ROOT_ID = 0

SAVE_PATH = "best_model_ssa_fast.pt"
# choose best by: "F1_nonroot" or "TEDS_sem"
BEST_BY = "F1_nonroot"

# ------------------ rebuild loaders with new caps ------------------
train_ds = DocHieNetDocDataset(index_df, split="train")
test_ds  = DocHieNetDocDataset(index_df, split="test")
train_ds.max_chunks = TRAIN_MAX_CHUNKS
test_ds.max_chunks  = TEST_MAX_CHUNKS

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

print("Loader caps:", f"train_max_chunks={TRAIN_MAX_CHUNKS}", f"test_max_chunks={TEST_MAX_CHUNKS}")
print("Train steps/epoch cap:", MAX_TRAIN_STEPS_PER_EPOCH)

# ------------------ model/optimizer ------------------
model_ssa.to(device)
model_ssa.train()

optimizer = optim.AdamW(model_ssa.parameters(), lr=LR_START, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=True)

# total updates = epochs * ceil(train_steps/accum)
steps_per_epoch = math.ceil(MAX_TRAIN_STEPS_PER_EPOCH / GRAD_ACCUM_STEPS)
total_updates = steps_per_epoch * EPOCHS
warmup_updates = max(1, int(total_updates * WARMUP_RATIO))

def set_lr(lr):
    for pg in optimizer.param_groups:
        pg["lr"] = lr

def lr_schedule(update_idx):
    if update_idx <= warmup_updates:
        return LR_START * (update_idx / warmup_updates)
    t = update_idx - warmup_updates
    T = max(1, total_updates - warmup_updates)
    gamma = (LR_END / LR_START) ** (1.0 / T)
    return LR_START * (gamma ** t)

# ------------------ eval helpers ------------------
def micro_f1(model, loader, exclude_root=False, root_id=0):
    model.eval()
    tp = pred_n = gt_n = 0
    with torch.no_grad():
        for batch in loader:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                _, logits = model(
                    batch["chunks"],
                    batch["doc_elem_ids"],
                    batch["doc_elem_positions"],
                    batch["parent_map"]
                )
            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()

            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])

            if exclude_root:
                pred_edges = set(e for e in pred_edges if e[0] != root_id)
                gt_edges = set(e for e in gt_edges if e[0] != root_id)

            inter = pred_edges & gt_edges
            tp += len(inter); pred_n += len(pred_edges); gt_n += len(gt_edges)

    p = tp / pred_n if pred_n else 0.0
    r = tp / gt_n if gt_n else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    model.train()
    return p, r, f1

def build_children_ordered(doc_elem_ids, parent_map, root_id=0):
    order = {eid: i for i, eid in enumerate([root_id] + doc_elem_ids)}
    children = {root_id: []}
    for eid in doc_elem_ids:
        children.setdefault(eid, [])
    for child in doc_elem_ids:
        par = parent_map.get(child, root_id)
        children.setdefault(par, [])
        children[par].append(child)
    for k in children:
        children[k].sort(key=lambda x: order.get(x, 10**9))
    return children

def to_zss_tree(children, node_id, label_map):
    node = zss.Node(label_map.get(node_id, "UNK"))
    for ch in children.get(node_id, []):
        node.addkid(to_zss_tree(children, ch, label_map))
    return node

def semantic_teds_one(doc_id, doc_elem_ids, gt_parent_map, pred_parent_map):
    row = index_df[index_df["doc_id"] == doc_id].iloc[0]
    label_path = row["label_path"]
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    elem_label_map = {c["id"]: c.get("label", "UNK") for c in obj["contents"]}

    gt_children = build_children_ordered(doc_elem_ids, gt_parent_map, root_id=ROOT_ID)
    pr_children = build_children_ordered(doc_elem_ids, pred_parent_map, root_id=ROOT_ID)

    label_map = {ROOT_ID: "ROOT"}
    for eid in doc_elem_ids:
        label_map[eid] = elem_label_map.get(eid, "UNK")

    gt_root = to_zss_tree(gt_children, ROOT_ID, label_map)
    pr_root = to_zss_tree(pr_children, ROOT_ID, label_map)

    def insert_cost(n): return 1
    def remove_cost(n): return 1
    def update_cost(a, b): return 0 if a.label == b.label else 1

    ed = zss.distance(
        gt_root, pr_root,
        get_children=zss.Node.get_children,
        insert_cost=insert_cost,
        remove_cost=remove_cost,
        update_cost=update_cost
    )
    denom = 1 + len(doc_elem_ids)
    return 1.0 - (ed / denom)

def eval_semantic_teds(model, loader):
    model.eval()
    scores = []
    with torch.no_grad():
        for batch in loader:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                _, logits = model(
                    batch["chunks"],
                    batch["doc_elem_ids"],
                    batch["doc_elem_positions"],
                    batch["parent_map"]
                )
            parent_candidates = [ROOT_ID] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()
            pred_parent_map = {child: parent_candidates[pidx]
                               for child, pidx in zip(batch["doc_elem_ids"], pred_parent_idx)}
            scores.append(semantic_teds_one(batch["doc_id"], batch["doc_elem_ids"], batch["parent_map"], pred_parent_map))
    model.train()
    return sum(scores) / max(len(scores), 1)

# ------------------ train loop ------------------
global_start = time.time()
global_update = 0

best_key = -1.0
best_snapshot = None

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    step_times = []

    model_ssa.train()
    optimizer.zero_grad(set_to_none=True)

    total_loss = 0.0
    steps = 0
    updates = 0

    for step, batch in enumerate(train_loader, start=1):
        t0 = time.time()

        with torch.cuda.amp.autocast(dtype=torch.float16):
            loss, _ = model_ssa(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )

        scaler.scale(loss / GRAD_ACCUM_STEPS).backward()

        total_loss += float(loss.detach().cpu())
        steps += 1

        dt = time.time() - t0
        step_times.append(dt)

        if step % PRINT_EVERY_STEP == 0:
            avg_dt = sum(step_times[-PRINT_EVERY_STEP:]) / PRINT_EVERY_STEP
            print(f"[epoch {epoch:03d} | step {step:04d}] loss={loss.item():.4f} | {avg_dt:.2f}s/step")

        # optimizer update
        if step % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_ssa.parameters(), MAX_GRAD_NORM)

            global_update += 1
            lr = lr_schedule(global_update)
            set_lr(lr)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            updates += 1

        if step >= MAX_TRAIN_STEPS_PER_EPOCH:
            break

    avg_train_loss = total_loss / max(steps, 1)

    # ------------------ eval ------------------
    eval_start = time.time()

    if (epoch % EVAL_EVERY_EPOCH) == 0:
        _, _, f1_all = micro_f1(model_ssa, test_loader, exclude_root=False, root_id=ROOT_ID)
        _, _, f1_nr  = micro_f1(model_ssa, test_loader, exclude_root=True,  root_id=ROOT_ID)
    else:
        f1_all, f1_nr = float("nan"), float("nan")

    teds_sem = float("nan")
    if (epoch % EVAL_TEDS_EVERY) == 0:
        teds_sem = eval_semantic_teds(model_ssa, test_loader)

    eval_time = time.time() - eval_start

    # ------------------ choose best & save ------------------
    if BEST_BY == "F1_nonroot":
        key = f1_nr
    else:
        key = teds_sem

    tag = ""
    if not (isinstance(key, float) and (math.isnan(key))):
        if key > best_key:
            best_key = key
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model_ssa.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_key": best_key,
                    "best_by": BEST_BY,
                    "f1_all": f1_all,
                    "f1_nonroot": f1_nr,
                    "teds_sem": teds_sem,
                    "seed": SEED,
                    "global_update": global_update,
                    "caps": {"train_max_chunks": TRAIN_MAX_CHUNKS, "test_max_chunks": TEST_MAX_CHUNKS},
                    "train_steps_per_epoch_cap": MAX_TRAIN_STEPS_PER_EPOCH,
                    "grad_accum_steps": GRAD_ACCUM_STEPS,
                    "note": "fast timed train for RTX 4060 8GB"
                },
                SAVE_PATH
            )
            tag = " (best)"

    # ------------------ timing / ETA ------------------
    epoch_time = time.time() - epoch_start
    elapsed = time.time() - global_start
    avg_epoch = elapsed / epoch
    eta = avg_epoch * (EPOCHS - epoch)
    lr_now = optimizer.param_groups[0]["lr"]

    # pretty print
    f1_all_str  = f"{f1_all:.4f}" if not (isinstance(f1_all, float) and math.isnan(f1_all)) else "skip"
    f1_nr_str   = f"{f1_nr:.4f}"  if not (isinstance(f1_nr, float) and math.isnan(f1_nr))  else "skip"
    teds_str    = f"{teds_sem:.4f}" if not (isinstance(teds_sem, float) and math.isnan(teds_sem)) else "skip"

    print(
        f"\n[epoch {epoch:03d}/{EPOCHS}] lr={lr_now:.2e} | updates={updates} | train_loss={avg_train_loss:.4f}\n"
        f"  test_F1_all={f1_all_str} | test_F1_nonroot={f1_nr_str} | test_TEDS_sem={teds_str}{tag}\n"
        f"  epoch_time={epoch_time/60:.1f} min | eval_time={eval_time:.1f}s | ETA={eta/60:.1f} min\n"
    )

print("done. best_key=", best_key, "best_by=", BEST_BY, "saved:", SAVE_PATH)


try:
    import zss
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "zss"])
    import zss

CKPT_PATH = "best_model_ssa_fast.pt"
ROOT_ID = 0

ckpt = torch.load(CKPT_PATH, map_location=device)
model_ssa.load_state_dict(ckpt["model_state_dict"])
model_ssa.to(device)
model_ssa.eval()

print("Loaded:", CKPT_PATH, "| epoch:", ckpt.get("epoch"), "| best_key:", ckpt.get("best_key"))

def micro_prf(model, loader, exclude_root=False, root_id=0):
    model.eval()
    tp = pred_n = gt_n = 0
    with torch.no_grad():
        for batch in loader:
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()

            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])

            if exclude_root:
                pred_edges = set(e for e in pred_edges if e[0] != root_id)
                gt_edges = set(e for e in gt_edges if e[0] != root_id)

            inter = pred_edges & gt_edges
            tp += len(inter)
            pred_n += len(pred_edges)
            gt_n += len(gt_edges)

    p = tp / pred_n if pred_n else 0.0
    r = tp / gt_n if gt_n else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    return p, r, f1

def build_children_ordered(doc_elem_ids, parent_map, root_id=0):
    order = {eid: i for i, eid in enumerate([root_id] + doc_elem_ids)}
    children = {root_id: []}
    for eid in doc_elem_ids:
        children.setdefault(eid, [])
    for child in doc_elem_ids:
        par = parent_map.get(child, root_id)
        children.setdefault(par, [])
        children[par].append(child)
    for k in children:
        children[k].sort(key=lambda x: order.get(x, 10**9))
    return children

def to_zss_tree(children, node_id, label_map):
    node = zss.Node(label_map.get(node_id, "UNK"))
    for ch in children.get(node_id, []):
        node.addkid(to_zss_tree(children, ch, label_map))
    return node

def semantic_teds_one(doc_id, doc_elem_ids, gt_parent_map, pred_parent_map):
    row = index_df[index_df["doc_id"] == doc_id].iloc[0]
    label_path = row["label_path"]
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    elem_label_map = {c["id"]: c.get("label", "UNK") for c in obj["contents"]}

    gt_children = build_children_ordered(doc_elem_ids, gt_parent_map, root_id=ROOT_ID)
    pr_children = build_children_ordered(doc_elem_ids, pred_parent_map, root_id=ROOT_ID)

    label_map = {ROOT_ID: "ROOT"}
    for eid in doc_elem_ids:
        label_map[eid] = elem_label_map.get(eid, "UNK")

    gt_root = to_zss_tree(gt_children, ROOT_ID, label_map)
    pr_root = to_zss_tree(pr_children, ROOT_ID, label_map)

    def insert_cost(n): return 1
    def remove_cost(n): return 1
    def update_cost(a, b): return 0 if a.label == b.label else 1

    ed = zss.distance(
        gt_root, pr_root,
        get_children=zss.Node.get_children,
        insert_cost=insert_cost,
        remove_cost=remove_cost,
        update_cost=update_cost
    )

    denom = 1 + len(doc_elem_ids)
    return 1.0 - (ed / denom)

def eval_semantic_teds(model, loader):
    model.eval()
    scores = []
    with torch.no_grad():
        for batch in loader:
            doc_id = batch["doc_id"]
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [ROOT_ID] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()
            pred_parent_map = {child: parent_candidates[pidx]
                               for child, pidx in zip(batch["doc_elem_ids"], pred_parent_idx)}

            scores.append(
                semantic_teds_one(doc_id, batch["doc_elem_ids"], batch["parent_map"], pred_parent_map)
            )
    return sum(scores)/max(len(scores), 1)

# Final metrics on test_loader
p_all, r_all, f1_all = micro_prf(model_ssa, test_loader, exclude_root=False, root_id=ROOT_ID)
p_nr,  r_nr,  f1_nr  = micro_prf(model_ssa, test_loader, exclude_root=True,  root_id=ROOT_ID)
teds_sem = eval_semantic_teds(model_ssa, test_loader)

print("\n===== FINAL TEST REPORT =====")
print(f"F1_all     : {f1_all:.4f} (P={p_all:.4f}, R={r_all:.4f})")
print(f"F1_nonroot : {f1_nr:.4f} (P={p_nr:.4f}, R={r_nr:.4f})")
print(f"TEDS_sem   : {teds_sem:.4f}")

In [27]:
try:
    import zss
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "zss"])
    import zss

CKPT_PATH = "best_model_ssa_fast.pt"
ROOT_ID = 0

ckpt = torch.load(CKPT_PATH, map_location=device)
model_ssa.load_state_dict(ckpt["model_state_dict"])
model_ssa.to(device)
model_ssa.eval()

print("Loaded:", CKPT_PATH, "| epoch:", ckpt.get("epoch"), "| best_key:", ckpt.get("best_key"))

def micro_prf(model, loader, exclude_root=False, root_id=0):
    model.eval()
    tp = pred_n = gt_n = 0
    with torch.no_grad():
        for batch in loader:
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()

            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])

            if exclude_root:
                pred_edges = set(e for e in pred_edges if e[0] != root_id)
                gt_edges = set(e for e in gt_edges if e[0] != root_id)

            inter = pred_edges & gt_edges
            tp += len(inter)
            pred_n += len(pred_edges)
            gt_n += len(gt_edges)

    p = tp / pred_n if pred_n else 0.0
    r = tp / gt_n if gt_n else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    return p, r, f1

def build_children_ordered(doc_elem_ids, parent_map, root_id=0):
    order = {eid: i for i, eid in enumerate([root_id] + doc_elem_ids)}
    children = {root_id: []}
    for eid in doc_elem_ids:
        children.setdefault(eid, [])
    for child in doc_elem_ids:
        par = parent_map.get(child, root_id)
        children.setdefault(par, [])
        children[par].append(child)
    for k in children:
        children[k].sort(key=lambda x: order.get(x, 10**9))
    return children

def to_zss_tree(children, node_id, label_map):
    node = zss.Node(label_map.get(node_id, "UNK"))
    for ch in children.get(node_id, []):
        node.addkid(to_zss_tree(children, ch, label_map))
    return node

def semantic_teds_one(doc_id, doc_elem_ids, gt_parent_map, pred_parent_map):
    row = index_df[index_df["doc_id"] == doc_id].iloc[0]
    label_path = row["label_path"]
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    elem_label_map = {c["id"]: c.get("label", "UNK") for c in obj["contents"]}

    gt_children = build_children_ordered(doc_elem_ids, gt_parent_map, root_id=ROOT_ID)
    pr_children = build_children_ordered(doc_elem_ids, pred_parent_map, root_id=ROOT_ID)

    label_map = {ROOT_ID: "ROOT"}
    for eid in doc_elem_ids:
        label_map[eid] = elem_label_map.get(eid, "UNK")

    gt_root = to_zss_tree(gt_children, ROOT_ID, label_map)
    pr_root = to_zss_tree(pr_children, ROOT_ID, label_map)

    def insert_cost(n): return 1
    def remove_cost(n): return 1
    def update_cost(a, b): return 0 if a.label == b.label else 1

    ed = zss.distance(
        gt_root, pr_root,
        get_children=zss.Node.get_children,
        insert_cost=insert_cost,
        remove_cost=remove_cost,
        update_cost=update_cost
    )

    denom = 1 + len(doc_elem_ids)
    return 1.0 - (ed / denom)

def eval_semantic_teds(model, loader):
    model.eval()
    scores = []
    with torch.no_grad():
        for batch in loader:
            doc_id = batch["doc_id"]
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [ROOT_ID] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()
            pred_parent_map = {child: parent_candidates[pidx]
                               for child, pidx in zip(batch["doc_elem_ids"], pred_parent_idx)}

            scores.append(
                semantic_teds_one(doc_id, batch["doc_elem_ids"], batch["parent_map"], pred_parent_map)
            )
    return sum(scores)/max(len(scores), 1)

# Final metrics on test_loader
p_all, r_all, f1_all = micro_prf(model_ssa, test_loader, exclude_root=False, root_id=ROOT_ID)
p_nr,  r_nr,  f1_nr  = micro_prf(model_ssa, test_loader, exclude_root=True,  root_id=ROOT_ID)
teds_sem = eval_semantic_teds(model_ssa, test_loader)

print("\n===== FINAL TEST REPORT =====")
print(f"F1_all     : {f1_all:.4f} (P={p_all:.4f}, R={r_all:.4f})")
print(f"F1_nonroot : {f1_nr:.4f} (P={p_nr:.4f}, R={r_nr:.4f})")
print(f"TEDS_sem   : {teds_sem:.4f}")

  ckpt = torch.load(CKPT_PATH, map_location=device)


Loaded: best_model_ssa_fast.pt | epoch: 1 | best_key: 0.04392875901255789

===== FINAL TEST REPORT =====
F1_all     : 0.1810 (P=0.1810, R=0.1810)
F1_nonroot : 0.0369 (P=0.0404, R=0.0339)
TEDS_sem   : 0.3002


In [28]:
# CELL ABLATION: w/o Page Embedding / w/o Inner-layout Embedding on model_ssa, report F1_all/F1_nonroot/TEDS_sem

import math
import json
from pathlib import Path
import torch
import types
import sys
import subprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
ROOT_ID = 0

# ensure zss
try:
    import zss
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "zss"])
    import zss

# --- (1) patch model_ssa to support toggles ---
# We override _encode_one_chunk to conditionally add page/inner embeddings.
# This assumes your model_ssa has: encoder, inner_emb, page_proj, page_pe_dim.
def sinusoidal_position_encoding(pos, dim):
    pos = pos.float().unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2, device=pos.device).float() * (-math.log(10000.0) / dim))
    pe = torch.zeros(pos.size(0), dim, device=pos.device)
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe

# save original method once (if not saved yet)
if not hasattr(model_ssa, "_encode_one_chunk_orig"):
    model_ssa._encode_one_chunk_orig = model_ssa._encode_one_chunk

def _encode_one_chunk_toggled(self, c):
    # base encoder output
    input_ids = torch.tensor(c["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
    attention_mask = torch.tensor(c["attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
    bbox = torch.tensor(c["bbox"], dtype=torch.long, device=device).unsqueeze(0)

    o = self.encoder(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, return_dict=True)
    hidden = o.last_hidden_state.squeeze(0)  # [T,H]

    # toggles
    use_page = getattr(self, "use_page_emb", True)
    use_inner = getattr(self, "use_inner_emb", True)

    if use_page:
        page_ids = torch.tensor(c["page_id"], dtype=torch.long, device=device)
        pe = sinusoidal_position_encoding(page_ids, self.page_pe_dim)  # [T, page_pe_dim]
        page_e = self.page_proj(pe)                                    # [T, H]
        hidden = hidden + page_e

    if use_inner:
        inner_pos = torch.tensor(c["inner_pos"], dtype=torch.long, device=device)
        inner_pos = torch.clamp(inner_pos, 0, self.inner_emb.num_embeddings - 1)
        inner_e = self.inner_emb(inner_pos)                             # [T, H]
        hidden = hidden + inner_e

    return hidden

# bind patched method
model_ssa._encode_one_chunk = types.MethodType(_encode_one_chunk_toggled, model_ssa)

# --- (2) metrics: F1_all / F1_nonroot / TEDS_semantic ---
def micro_f1(model, loader, exclude_root=False, root_id=0):
    model.eval()
    tp = pred_n = gt_n = 0
    with torch.no_grad():
        for batch in loader:
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [root_id] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()

            pred_edges = set((parent_candidates[pidx], child_eid)
                             for child_eid, pidx in zip(batch["doc_elem_ids"], pred_parent_idx))
            gt_edges = set((batch["parent_map"][eid], eid) for eid in batch["doc_elem_ids"])

            if exclude_root:
                pred_edges = set(e for e in pred_edges if e[0] != root_id)
                gt_edges = set(e for e in gt_edges if e[0] != root_id)

            inter = pred_edges & gt_edges
            tp += len(inter)
            pred_n += len(pred_edges)
            gt_n += len(gt_edges)

    p = tp / pred_n if pred_n else 0.0
    r = tp / gt_n if gt_n else 0.0
    f1 = (2*p*r/(p+r)) if (p+r) else 0.0
    model.train()
    return p, r, f1

def build_children_ordered(doc_elem_ids, parent_map, root_id=0):
    order = {eid: i for i, eid in enumerate([root_id] + doc_elem_ids)}
    children = {root_id: []}
    for eid in doc_elem_ids:
        children.setdefault(eid, [])
    for child in doc_elem_ids:
        par = parent_map.get(child, root_id)
        children.setdefault(par, [])
        children[par].append(child)
    for k in children:
        children[k].sort(key=lambda x: order.get(x, 10**9))
    return children

def to_zss_tree(children, node_id, label_map):
    node = zss.Node(label_map.get(node_id, "UNK"))
    for ch in children.get(node_id, []):
        node.addkid(to_zss_tree(children, ch, label_map))
    return node

def semantic_teds_one(doc_id, doc_elem_ids, gt_parent_map, pred_parent_map):
    row = index_df[index_df["doc_id"] == doc_id].iloc[0]
    label_path = row["label_path"]
    obj = json.loads(Path(label_path).read_text(encoding="utf-8"))
    elem_label_map = {c["id"]: c.get("label", "UNK") for c in obj["contents"]}

    gt_children = build_children_ordered(doc_elem_ids, gt_parent_map, root_id=ROOT_ID)
    pr_children = build_children_ordered(doc_elem_ids, pred_parent_map, root_id=ROOT_ID)

    label_map = {ROOT_ID: "ROOT"}
    for eid in doc_elem_ids:
        label_map[eid] = elem_label_map.get(eid, "UNK")

    gt_root = to_zss_tree(gt_children, ROOT_ID, label_map)
    pr_root = to_zss_tree(pr_children, ROOT_ID, label_map)

    def insert_cost(n): return 1
    def remove_cost(n): return 1
    def update_cost(a, b): return 0 if a.label == b.label else 1

    ed = zss.distance(
        gt_root, pr_root,
        get_children=zss.Node.get_children,
        insert_cost=insert_cost,
        remove_cost=remove_cost,
        update_cost=update_cost
    )

    denom = 1 + len(doc_elem_ids)
    return 1.0 - (ed / denom)

def eval_semantic_teds(model, loader):
    model.eval()
    scores = []
    with torch.no_grad():
        for batch in loader:
            doc_id = batch["doc_id"]
            _, logits = model(
                batch["chunks"],
                batch["doc_elem_ids"],
                batch["doc_elem_positions"],
                batch["parent_map"]
            )
            parent_candidates = [ROOT_ID] + batch["doc_elem_ids"]
            pred_parent_idx = torch.argmax(logits, dim=1).tolist()
            pred_parent_map = {child: parent_candidates[pidx]
                               for child, pidx in zip(batch["doc_elem_ids"], pred_parent_idx)}
            scores.append(semantic_teds_one(doc_id, batch["doc_elem_ids"], batch["parent_map"], pred_parent_map))
    model.train()
    return sum(scores) / max(len(scores), 1)

def run_one(name, use_page, use_inner):
    model_ssa.use_page_emb = use_page
    model_ssa.use_inner_emb = use_inner

    p_all, r_all, f1_all = micro_f1(model_ssa, test_loader, exclude_root=False, root_id=ROOT_ID)
    p_nr,  r_nr,  f1_nr  = micro_f1(model_ssa, test_loader, exclude_root=True,  root_id=ROOT_ID)
    teds = eval_semantic_teds(model_ssa, test_loader)

    print(f"\n=== {name} ===")
    print(f"use_page_emb={use_page} | use_inner_emb={use_inner}")
    print(f"F1_all     : {f1_all:.4f} (P={p_all:.4f}, R={r_all:.4f})")
    print(f"F1_nonroot : {f1_nr:.4f} (P={p_nr:.4f}, R={r_nr:.4f})")
    print(f"TEDS_sem   : {teds:.4f}")
    return {"name": name, "F1_all": f1_all, "F1_nonroot": f1_nr, "TEDS_sem": teds}

# --- (3) run ablations ---
model_ssa.to(device)

res = []
res.append(run_one("baseline (PageE+InnerE)", True, True))
res.append(run_one("w/o PageE", False, True))
res.append(run_one("w/o InnerE", True, False))

print("\n===== SUMMARY (copy to report) =====")
for r in res:
    print(f'{r["name"]}: F1_all={r["F1_all"]:.4f} | F1_nonroot={r["F1_nonroot"]:.4f} | TEDS_sem={r["TEDS_sem"]:.4f}')



=== baseline (PageE+InnerE) ===
use_page_emb=True | use_inner_emb=True
F1_all     : 0.1810 (P=0.1810, R=0.1810)
F1_nonroot : 0.0369 (P=0.0404, R=0.0339)
TEDS_sem   : 0.3002

=== w/o PageE ===
use_page_emb=False | use_inner_emb=True
F1_all     : 0.1896 (P=0.1896, R=0.1896)
F1_nonroot : 0.0000 (P=0.0000, R=0.0000)
TEDS_sem   : 0.6376

=== w/o InnerE ===
use_page_emb=True | use_inner_emb=False
F1_all     : 0.1260 (P=0.1260, R=0.1260)
F1_nonroot : 0.0198 (P=0.0228, R=0.0175)
TEDS_sem   : 0.4121

===== SUMMARY (copy to report) =====
baseline (PageE+InnerE): F1_all=0.1810 | F1_nonroot=0.0369 | TEDS_sem=0.3002
w/o PageE: F1_all=0.1896 | F1_nonroot=0.0000 | TEDS_sem=0.6376
w/o InnerE: F1_all=0.1260 | F1_nonroot=0.0198 | TEDS_sem=0.4121
