In [25]:
# # Fandom Span Identification — Checkpoints & Sanity Checks

# This notebook validates the **span identification (hyperlink detection)** model trained
# on Fandom pages.

# Pipeline validated:
# HTML → char spans → BILOU → token classification → predicted spans

# Repo root:
# `/data/sundeep/Fandom_SI`

# Expected:
# - Model + tokenizer load correctly
# - Predicted spans are readable anchor text
# - Span-level metrics are non-zero


In [26]:
# === CONFIG ===
DOMAIN = "money-heist"
MODEL_DIR = "/data/sundeep/Fandom_SI/experiments/span_id/money-heist/bert-base-uncased/checkpoint-81"
DATA_DIR = "/data/sundeep/Fandom_SI/data/processed_span_identi/money-heist"
MAX_LEN = 256

print("DOMAIN:", DOMAIN)
print("MODEL_DIR:", MODEL_DIR)
print("DATA_DIR:", DATA_DIR)

# EXPECTED OUTPUT:
# DOMAIN: money-heist
# MODEL_DIR: experiments/span_id/money-heist/bert-base-uncased
# DATA_DIR: data/processed_span_identi/money-heist


DOMAIN: money-heist
MODEL_DIR: /data/sundeep/Fandom_SI/experiments/span_id/money-heist/bert-base-uncased/checkpoint-81
DATA_DIR: /data/sundeep/Fandom_SI/data/processed_span_identi/money-heist


In [27]:
from pathlib import Path

p = Path(MODEL_DIR)
assert p.exists(), f"Missing model dir: {p}"

files = {x.name for x in p.glob("*")}
print("Files found:", len(files))
for f in sorted(files):
    print(" -", f)

print("\nChecks:")
print("config.json:", "config.json" in files)
print("model weights:", any(x in files for x in ["pytorch_model.bin", "model.safetensors"]))
print("tokenizer_config.json:", "tokenizer_config.json" in files)
print("special_tokens_map.json:", "special_tokens_map.json" in files)
print("vocab.txt:", "vocab.txt" in files)

# EXPECTED:
# All checks should print True


Files found: 11
 - config.json
 - model.safetensors
 - optimizer.pt
 - rng_state.pth
 - scheduler.pt
 - special_tokens_map.json
 - tokenizer.json
 - tokenizer_config.json
 - trainer_state.json
 - training_args.bin
 - vocab.txt

Checks:
config.json: True
model weights: True
tokenizer_config.json: True
special_tokens_map.json: True
vocab.txt: True


In [28]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR)
model.eval()

print("Tokenizer:", tok.__class__.__name__)
print("Num labels:", model.config.num_labels)
print("id2label:", model.config.id2label)

# EXPECTED:
# Tokenizer: BertTokenizerFast
# Num labels: 5
# id2label: {0:'O',1:'B-LINK',2:'I-LINK',3:'L-LINK',4:'U-LINK'}


Tokenizer: BertTokenizerFast
Num labels: 5
id2label: {0: 'O', 1: 'B', 2: 'I', 3: 'L', 4: 'U'}


In [29]:
import json
from pathlib import Path

def read_jsonl(p):
    rows=[]
    with open(p,"r",encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

train = read_jsonl(Path(DATA_DIR)/"train.jsonl")
dev   = read_jsonl(Path(DATA_DIR)/"dev.jsonl")
test  = read_jsonl(Path(DATA_DIR)/"test.jsonl")

def stats(name, rows):
    spans=sum(len(r["spans"]) for r in rows)
    print(f"{name}: docs={len(rows)}, spans={spans}")

stats("train", train)
stats("dev", dev)
stats("test", test)

# EXPECTED (approx):
# train: docs=212, spans~5400
# dev: docs=27, spans~700
# test: docs=28, spans~500


train: docs=212, spans=5419
dev: docs=27, spans=757
test: docs=28, spans=541


In [30]:
import random
import numpy as np

ex = random.choice(dev)
text = ex["text"]

enc = tok(
    text,
    return_offsets_mapping=True,
    truncation=True,
    max_length=MAX_LEN,
    return_tensors="pt"
)

with torch.no_grad():
    logits = model(
        input_ids=enc["input_ids"],
        attention_mask=enc["attention_mask"]
    ).logits

pred_ids = logits.argmax(-1)[0].tolist()
offsets = enc["offset_mapping"][0].tolist()

id2label = model.config.id2label
def lab(i):
    return id2label[i] if i in id2label else id2label[str(i)]

print("doc_id:", ex["doc_id"])
print("Gold spans:", len(ex["spans"]))
print("\nPredicted non-O tokens:")

shown=0
for (s,e), pid in zip(offsets, pred_ids):
    if s==0 and e==0: continue
    if lab(pid)!="O":
        print(f"{lab(pid):7} -> {text[s:e]!r}")
        shown+=1
    if shown>=40: break

# EXPECTED:
# B-LINK / I-LINK tokens with readable text like:
# 'Berlin', 'Professor', 'Tokyo'


doc_id: money-heist||159
Gold spans: 72

Predicted non-O tokens:
U       -> 'Denver'
B       -> 'Daniel'
U       -> 'Tokyo'
U       -> 'Moscow'
U       -> 'Benjamín'
U       -> 'Stockholm'
U       -> 'Manila'
U       -> 'Denver'
B       -> 'Miguel'
I       -> 'Fernández'
I       -> 'ani'
L       -> 'lla'
U       -> 'Moscow'
U       -> 'Cincinnati'
U       -> 'Manila'
U       -> 'Vane'
U       -> 'Tokyo'
U       -> 'Stockholm'
U       -> 'Manila'
B       -> 'The'
L       -> 'Professor'
U       -> 'Helsinki'
U       -> 'Tokyo'
U       -> 'Nairobi'
U       -> 'Stockholm'
U       -> 'Rio'
U       -> 'Bogotá'
U       -> 'Palermo'
U       -> 'Berlin'
U       -> 'Oslo'
U       -> 'Lisbon'
B       -> 'Mat'
I       -> 'ías'
I       -> 'Cañ'
L       -> 'o'
U       -> 'Manila'
B       -> 'Miguel'
I       -> 'Fernández'
I       -> 'Tal'
I       -> 'ani'


In [33]:
def bilou_to_spans(tags, offsets):
    """
    Decode B/I/L/U/O OR B-LINK/I-LINK/... into character spans
    """
    spans = []
    cur_start = None
    cur_end = None

    def flush():
        nonlocal cur_start, cur_end
        if cur_start is not None and cur_end is not None and cur_end > cur_start:
            spans.append((cur_start, cur_end))
        cur_start = None
        cur_end = None

    for (s, e), t in zip(offsets, tags):
        if s == 0 and e == 0:
            continue

        t0 = t.split("-", 1)[0]  # <-- IMPORTANT FIX

        if t0 == "U":
            spans.append((s, e))
        elif t0 == "B":
            flush()
            cur_start, cur_end = s, e
        elif t0 == "I":
            if cur_start is None:
                cur_start, cur_end = s, e
            else:
                cur_end = e
        elif t0 == "L":
            if cur_start is None:
                spans.append((s, e))
            else:
                cur_end = e
                flush()
        else:  # O
            flush()

    flush()
    return spans


In [34]:
print("Pred spans:", len(pred_spans))
print("First 10:")
for s,e in pred_spans[:10]:
    print(text[s:e])


Pred spans: 0
First 10:


In [32]:
gold={(s["start"],s["end"]) for s in ex["spans"]}
pred=set(pred_spans)

tp=len(gold & pred)
fp=len(pred - gold)
fn=len(gold - pred)

prec=tp/(tp+fp) if tp+fp else 0
rec=tp/(tp+fn) if tp+fn else 0
f1=2*prec*rec/(prec+rec) if prec+rec else 0

print("TP:",tp,"FP:",fp,"FN:",fn)
print("Precision:",round(prec,3),"Recall:",round(rec,3),"F1:",round(f1,3))

# EXPECTED:
# F1 > 0 for many docs (not necessarily high)


TP: 0 FP: 0 FN: 72
Precision: 0 Recall: 0.0 F1: 0


In [36]:
import json, torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

MODEL_DIR = "/data/sundeep/Fandom_SI/experiments/span_id/money-heist/bert-base-uncased/checkpoint-81"  # <-- set your real checkpoint
DATA_DIR  = "/data/sundeep/Fandom_SI/data/processed_span_identi/money-heist"
MAX_LEN   = 256

tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR)
model.eval()

print("Tokenizer class:", tok.__class__.__name__)
print("tok.is_fast:", getattr(tok, "is_fast", False))

ex = json.loads(open(f"{DATA_DIR}/dev.jsonl","r",encoding="utf-8").readline())
text = ex["text"]

enc = tok(text, return_offsets_mapping=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
offsets = enc["offset_mapping"][0].tolist()

# count how many tokens have real offsets
real = [(s,e) for (s,e) in offsets if not (s==0 and e==0)]
print("Total tokens:", len(offsets))
print("Real-offset tokens:", len(real))
print("First 20 offsets:", offsets[:20])

with torch.no_grad():
    logits = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"]).logits
pred_ids = logits.argmax(-1)[0].tolist()

id2label = model.config.id2label
def lab(i):
    if i in id2label: return id2label[i]
    if str(i) in id2label: return id2label[str(i)]
    return str(i)

tags = [lab(i) for i in pred_ids]
print("First 30 tags:", tags[:30])
print("Tag counts:", {t: tags.count(t) for t in sorted(set(tags))})


Tokenizer class: BertTokenizerFast
tok.is_fast: True
Total tokens: 239
Real-offset tokens: 237
First 20 offsets: [[0, 0], [0, 4], [4, 6], [7, 9], [9, 14], [15, 27], [28, 39], [40, 46], [47, 53], [54, 60], [61, 66], [67, 80], [81, 87], [88, 90], [90, 94], [95, 97], [97, 101], [101, 102], [103, 104], [104, 112]]
First 30 tags: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'I', 'I', 'I', 'L', 'O', 'O', 'O', 'B', 'I', 'I', 'L', 'O', 'O', 'O', 'B', 'I']
Tag counts: {'B': 11, 'I': 24, 'L': 13, 'O': 188, 'U': 3}


In [37]:
# --- guaranteed BILOU decoder for tags: B/I/L/U/O (or B-LINK etc.) ---
def bilou_to_spans(tags, offsets):
    spans = []
    cur_start = None
    cur_end = None

    def flush():
        nonlocal cur_start, cur_end
        if cur_start is not None and cur_end is not None and cur_end > cur_start:
            spans.append((cur_start, cur_end))
        cur_start = None
        cur_end = None

    for (s, e), t in zip(offsets, tags):
        if s == 0 and e == 0:   # special tokens
            continue
        if e <= s:
            continue

        t0 = t.split("-", 1)[0]  # handles 'B-LINK' too

        if t0 == "U":
            spans.append((s, e))
            cur_start = None
            cur_end = None
        elif t0 == "B":
            flush()
            cur_start, cur_end = s, e
        elif t0 == "I":
            if cur_start is None:
                cur_start, cur_end = s, e
            else:
                cur_end = e
        elif t0 == "L":
            if cur_start is None:
                spans.append((s, e))  # broken seq, treat as single
            else:
                cur_end = e
                flush()
        else:  # O
            flush()

    flush()
    return spans

# build spans
pred_spans = bilou_to_spans(tags, offsets)

print("Pred spans:", len(pred_spans))
print("First 10:")
for a, b in pred_spans[:10]:
    print(f"[{a}:{b}] -> {text[a:b]!r}")


Pred spans: 18
First 10:
[88:102] -> 'Raquel Murillo'
[114:127] -> 'Laura Murillo'
[139:151] -> 'Paula Vicuña'
[152:159] -> 'Murillo'
[176:190] -> 'Alberto Vicuña'
[234:247] -> 'The Professor'
[279:280] -> '2'
[283:284] -> '3'
[287:288] -> '4'
[302:305] -> 'Kit'


In [38]:
for a,b in pred_spans:
    s = text[a:b]
    if len(s) <= 2 or s.isdigit():
        print("JUNK?", (a,b), repr(s))

JUNK? (279, 280) '2'
JUNK? (283, 284) '3'
JUNK? (287, 288) '4'
JUNK? (501, 503) 'Ra'


In [39]:
# visible char window from offsets
max_visible_char = max(e for (s,e) in offsets if not (s==0 and e==0))

gold = set()
for s in ex["spans"]:
    a=int(s["start"]); b=int(s["end"])
    if 0 <= a < b <= max_visible_char:
        gold.add((a,b))

pred = set((a,b) for (a,b) in pred_spans if b <= max_visible_char)

tp = len(gold & pred)
fp = len(pred - gold)
fn = len(gold - pred)

prec = tp/(tp+fp) if tp+fp else 0
rec  = tp/(tp+fn) if tp+fn else 0
f1   = 2*prec*rec/(prec+rec) if prec+rec else 0

print("Visible window:", max_visible_char)
print("Gold spans (visible):", len(gold))
print("Pred spans (visible):", len(pred))
print("TP FP FN:", tp, fp, fn)
print("P/R/F1:", round(prec,3), round(rec,3), round(f1,3))


Visible window: 999
Gold spans (visible): 15
Pred spans (visible): 18
TP FP FN: 10 8 5
P/R/F1: 0.556 0.667 0.606


In [40]:
import numpy as np
import torch

MAX_LEN = 256

# ---- build gold token labels (BILOU per token) ----
label_map = {"O":0, "B":1, "I":2, "L":3, "U":4}   # must match training label ids order
id2lab_simple = {v:k for k,v in label_map.items()}

def spans_to_token_ids(offsets, gold_spans, ignore_id=-100):
    """
    offsets: list of [start,end] for each token
    gold_spans: list of dicts with {"start":int,"end":int} or tuples
    returns: list[int] token label ids aligned to offsets
    """
    # convert gold spans to tuples
    gs = []
    for s in gold_spans:
        if isinstance(s, dict):
            gs.append((int(s["start"]), int(s["end"])))
        else:
            gs.append((int(s[0]), int(s[1])))

    labels = []
    for (s,e) in offsets:
        if s==0 and e==0:
            labels.append(ignore_id)
            continue

        # check overlap with any gold span (token-level)
        inside = None
        for (a,b) in gs:
            if s >= a and e <= b and b > a:
                inside = (a,b)
                break

        if inside is None:
            labels.append(label_map["O"])
        else:
            a,b = inside
            # BILOU for token inside this span
            if s == a and e == b:
                labels.append(label_map["U"])
            elif s == a:
                labels.append(label_map["B"])
            elif e == b:
                labels.append(label_map["L"])
            else:
                labels.append(label_map["I"])
    return labels

# ---- encode ----
text = ex["text"]
enc = tok(text, return_offsets_mapping=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
offsets = enc["offset_mapping"][0].tolist()
attn = enc["attention_mask"][0].tolist()

# gold token labels
gold_ids = spans_to_token_ids(offsets, ex["spans"], ignore_id=-100)

# pred token labels
with torch.no_grad():
    logits = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"]).logits
pred_ids = logits.argmax(-1)[0].tolist()

# mask: only real tokens (attn=1) and not specials (offset != (0,0))
mask = [(attn[i]==1 and not (offsets[i][0]==0 and offsets[i][1]==0) and gold_ids[i]!=-100) for i in range(len(offsets))]

gold = np.array([gold_ids[i] for i,m in enumerate(mask) if m], dtype=int)
pred = np.array([pred_ids[i] for i,m in enumerate(mask) if m], dtype=int)

# token accuracy
acc = float((gold==pred).mean()) if len(gold) else 0.0

# token F1 for "LINK" tokens vs O (binary)
gold_pos = (gold != label_map["O"])
pred_pos = (pred != label_map["O"])

tp = int((gold_pos & pred_pos).sum())
fp = int((~gold_pos & pred_pos).sum())
fn = int((gold_pos & ~pred_pos).sum())

prec = tp/(tp+fp) if tp+fp else 0.0
rec  = tp/(tp+fn) if tp+fn else 0.0
f1   = 2*prec*rec/(prec+rec) if prec+rec else 0.0

print("Token-level accuracy:", round(acc,4))
print("Token-level LINK P/R/F1:", round(prec,4), round(rec,4), round(f1,4))
print("Counts TP/FP/FN:", tp, fp, fn)


Token-level accuracy: 0.9494
Token-level LINK P/R/F1: 0.8824 0.9 0.8911
Counts TP/FP/FN: 45 6 5
