# HIPE-2026 People–Place Context — Evidence MIL (Alternative Architecture)
This notebook implements a **different** approach than the cross-encoder baselines:
- Encode each document as a set of overlapping **windows** (evidence units)
- Precompute **frozen window embeddings** with `xlm-roberta-base`
- Precompute **frozen entity embeddings** from the person/place mention strings
- For each `(person, location)` pair, use **multi-instance learning (MIL)**: attention selects the most relevant windows
- Predict two labels per pair:
  - `at`: 3-way (`FALSE/PROBABLE/TRUE`)
  - `isAt`: 2-way (`FALSE/TRUE`)
- Export predictions back to HIPE JSONL format by overwriting `sampled_pairs[*].at` and `sampled_pairs[*].isAt`.

Why this is a different architecture: we do **not** re-encode the entire prompt per pair. The heavy transformer encoder is run **per document window**, cached, and pair scoring is done by a lightweight attention+MLP model.

In [1]:
# 1) Install & verify dependencies
import sys
import subprocess
def _pip_install(pkgs):
    cmd = [sys.executable, '-m', 'pip', 'install', '-q'] + list(pkgs)
    print('Running:', ' '.join(cmd))
    subprocess.check_call(cmd)

# Core ML deps
try:
    import torch  # noqa: F401
except Exception:
    _pip_install(['torch'])

try:
    import transformers  # noqa: F401
except Exception:
    _pip_install(['transformers>=4.35'])

# Utilities
for pkg, import_name in [
    ('tqdm', 'tqdm'),
    ('scikit-learn', 'sklearn'),
    ('pandas', 'pandas'),
    ('numpy', 'numpy'),
]:
    try:
        __import__(import_name)
    except Exception:
        _pip_install([pkg])

import numpy as np
import pandas as pd
import torch
import transformers
from tqdm.auto import tqdm
from sklearn.metrics import recall_score, accuracy_score

print('Python:', sys.version)
print('torch:', torch.__version__)
print('transformers:', transformers.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0))

Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
torch: 2.9.0+cu126
transformers: 4.57.3
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB


In [2]:
# 2) Download/cache HIPE-2026 data repo + locate sandbox directory (same approach as test.ipynb)
from __future__ import annotations
import json
import shutil
import subprocess
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Iterable, Tuple
HIPE_DATA_GITHUB_REPO = 'https://github.com/hipe-eval/HIPE-2026-data'
WORK_DIR = Path.cwd()
CACHE_DIR = WORK_DIR / 'hipe_cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

def _download_file(url: str, dest: Path) -> None:
    import urllib.request
    with urllib.request.urlopen(url) as r, dest.open('wb') as f:
        shutil.copyfileobj(r, f)

def ensure_repo(repo_url: str, dest_dir: Path) -> Path:
    if (dest_dir / '.git').exists() or (dest_dir / 'README.md').exists():
        return dest_dir
    if dest_dir.exists():
        shutil.rmtree(dest_dir)
    if shutil.which('git'):
        print('Cloning repo with git...')
        subprocess.check_call(['git', 'clone', '--depth', '1', repo_url, str(dest_dir)])
        return dest_dir
    print('git not found; downloading zip archive...')
    zip_url = repo_url.rstrip('/') + '/archive/refs/heads/main.zip'
    zip_path = CACHE_DIR / (dest_dir.name + '-main.zip')
    if not zip_path.exists():
        _download_file(zip_url, zip_path)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(CACHE_DIR)
    extracted = CACHE_DIR / (dest_dir.name + '-main')
    if extracted.exists():
        extracted.rename(dest_dir)
        return dest_dir
    alt = CACHE_DIR / (repo_url.rstrip('/').split('/')[-1] + '-main')
    if alt.exists():
        alt.rename(dest_dir)
        return dest_dir
    raise FileNotFoundError('Zip extraction did not produce the expected folder.')

HIPE_DATA_REPO_DIR = ensure_repo(HIPE_DATA_GITHUB_REPO, CACHE_DIR / 'HIPE-2026-data')
print('HIPE_DATA_REPO_DIR =', HIPE_DATA_REPO_DIR)

def find_hipe_sandbox_dir() -> Path:
    candidates = [
        HIPE_DATA_REPO_DIR / 'data' / 'sandbox',
        Path.cwd() / 'HIPE-2026-data-main' / 'data' / 'sandbox',
        Path('HIPE-2026-data-main') / 'data' / 'sandbox',
        Path.cwd() / 'HIPE-2026-data' / 'data' / 'sandbox',
        Path('HIPE-2026-data') / 'data' / 'sandbox',
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p
    for p in HIPE_DATA_REPO_DIR.rglob('sandbox'):
        if p.is_dir() and any(p.glob('*.jsonl')):
            return p
    raise FileNotFoundError('Could not find HIPE sandbox dir from common locations.')

SANDBOX_DIR = find_hipe_sandbox_dir()
print('SANDBOX_DIR =', SANDBOX_DIR)
print('Sandbox JSONLs:', sorted([p.name for p in SANDBOX_DIR.glob('*.jsonl')])[:50])

HIPE_DATA_REPO_DIR = /content/hipe_cache/HIPE-2026-data
SANDBOX_DIR = /content/hipe_cache/HIPE-2026-data/data/sandbox
Sandbox JSONLs: ['de-dev.jsonl', 'de-train.jsonl', 'en-dev.jsonl', 'en-train.jsonl', 'fr-dev.jsonl', 'fr-train.jsonl']


In [None]:
# 3) Load JSONL splits + build pair examples + EDA
from dataclasses import dataclass
from collections import Counter, defaultdict
from transformers import AutoTokenizer, AutoModel
import math
import random
from typing import Any, Dict, List, Optional, Tuple


def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def _sandbox_lang_from_filename(p: Path) -> str:
    return p.name.split('-', 1)[0]


def list_sandbox_split_files(split: str, langs: Optional[List[str]] = None) -> List[Path]:
    files = sorted(SANDBOX_DIR.glob(f"*-{split}.jsonl"))
    if langs is None:
        return files
    want = set(langs)
    return [p for p in files if _sandbox_lang_from_filename(p) in want]


def load_docs_from_files(paths: List[Path]) -> List[Dict[str, Any]]:
    docs: List[Dict[str, Any]] = []
    for p in paths:
        docs.extend(read_jsonl(p))
    return docs


# Labels (baseline-compatible): null -> FALSE
AT_LABELS = ['FALSE', 'PROBABLE', 'TRUE']
ISAT_LABELS = ['FALSE', 'TRUE']
AT2ID = {k: i for i, k in enumerate(AT_LABELS)}
ID2AT = {i: k for k, i in AT2ID.items()}
ISAT2ID = {k: i for i, k in enumerate(ISAT_LABELS)}
ID2ISAT = {i: k for k, i in ISAT2ID.items()}


def norm_at(v: Optional[str]) -> str:
    return 'FALSE' if v is None else str(v)

def norm_isat(v: Optional[str]) -> str:
    return 'FALSE' if v is None else str(v)


@dataclass(frozen=True)
class PairExampleMIL:
    document_id: str
    pair_index: int
    language: str
    date: str
    doc_text: str
    pers_entity_id: str
    loc_entity_id: str
    pers_mentions: List[str]
    loc_mentions: List[str]
    pers_mention: str
    loc_mention: str
    at: str
    isAt: str


def _normalize_mentions_list(xs: Any) -> List[str]:
    if not isinstance(xs, list):
        return []
    out: List[str] = []
    seen = set()
    for x in xs:
        if not isinstance(x, str):
            continue
        s = x.strip()
        if not s:
            continue
        if s in seen:
            continue
        seen.add(s)
        out.append(s)
    return out


def iter_pair_examples_mil(docs: List[Dict[str, Any]]) -> List[PairExampleMIL]:
    out: List[PairExampleMIL] = []
    for doc in docs:
        doc_id = str(doc.get('document_id') or doc.get('doc_id') or doc.get('id') or '')
        if not doc_id:
            continue
        lang = str(doc.get('language') or '')
        date = str(doc.get('date') or doc.get('publication_date') or '')
        text = str(doc.get('text') or '')
        pairs = doc.get('sampled_pairs')
        if not text or not isinstance(pairs, list):
            continue
        for idx, pair in enumerate(pairs):
            if not isinstance(pair, dict):
                continue
            pers_id = str(pair.get('pers_entity_id') or '')
            loc_id = str(pair.get('loc_entity_id') or '')
            if not pers_id or not loc_id:
                continue

            pers_ms = _normalize_mentions_list(pair.get('pers_mentions_list'))
            loc_ms = _normalize_mentions_list(pair.get('loc_mentions_list'))
            pers_m = pers_ms[0] if pers_ms else ''
            loc_m = loc_ms[0] if loc_ms else ''

            out.append(PairExampleMIL(
                document_id=doc_id,
                pair_index=idx,
                language=lang,
                date=date,
                doc_text=text,
                pers_entity_id=pers_id,
                loc_entity_id=loc_id,
                pers_mentions=pers_ms,
                loc_mentions=loc_ms,
                pers_mention=pers_m,
                loc_mention=loc_m,
                at=norm_at(pair.get('at')),
                isAt=norm_isat(pair.get('isAt')),
            ))
    return out


# Config: multilingual train, single-language dev
TRAIN_LANGS = ['en', 'de', 'fr']
DEV_LANG = 'en'
train_files = list_sandbox_split_files('train', TRAIN_LANGS)
dev_files = list_sandbox_split_files('dev', [DEV_LANG])
if len(train_files) == 0:
    raise FileNotFoundError(f"No '*-train.jsonl' files found in {SANDBOX_DIR}")
if len(dev_files) != 1:
    raise FileNotFoundError(f"Expected exactly one dev file for DEV_LANG={DEV_LANG}, got: {dev_files}")
print('Train files:', [p.name for p in train_files])
print('Dev file:', dev_files[0].name)

train_docs = load_docs_from_files(train_files)
dev_docs = read_jsonl(dev_files[0])
train_ex = iter_pair_examples_mil(train_docs)
dev_ex = iter_pair_examples_mil(dev_docs)
print('Train docs:', len(train_docs), 'Train pairs:', len(train_ex))
print('Dev docs:', len(dev_docs), 'Dev pairs:', len(dev_ex))


def label_distribution(examples: List[PairExampleMIL]) -> Dict[str, Dict[str, int]]:
    c_at = Counter(ex.at for ex in examples)
    c_is = Counter(ex.isAt for ex in examples)
    return {
        'at': {k: int(c_at.get(k, 0)) for k in AT_LABELS},
        'isAt': {k: int(c_is.get(k, 0)) for k in ISAT_LABELS},
    }

print('Train label distribution:', label_distribution(train_ex))
print('Dev label distribution:', label_distribution(dev_ex))


# Per-language counts (quick EDA)
def counts_by_lang(examples: List[PairExampleMIL]) -> Dict[str, int]:
    c = Counter(ex.language for ex in examples)
    return dict(sorted(c.items(), key=lambda x: (-x[1], x[0])))

print('Train pairs by language:', counts_by_lang(train_ex))
print('Dev pairs by language:', counts_by_lang(dev_ex))


Train files: ['de-train.jsonl', 'en-train.jsonl', 'fr-train.jsonl']
Dev file: en-dev.jsonl
Train docs: 461 Train pairs: 6170
Dev docs: 17 Dev pairs: 151
Train label distribution: {'at': {'FALSE': 3669, 'PROBABLE': 1579, 'TRUE': 922}, 'isAt': {'FALSE': 5493, 'TRUE': 677}}
Dev label distribution: {'at': {'FALSE': 68, 'PROBABLE': 54, 'TRUE': 29}, 'isAt': {'FALSE': 133, 'TRUE': 18}}
Train pairs by language: {'fr': 4450, 'de': 1224, 'en': 496}
Dev pairs by language: {'en': 151}


In [None]:
# 4) Windowing + frozen embedding caches (document windows + entity strings)
from typing import Callable
from transformers import AutoTokenizer, AutoModel
MIL_CACHE_DIR = WORK_DIR / 'mil_cache'
MIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = 'xlm-roberta-base'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
encoder = AutoModel.from_pretrained(MODEL_NAME).to(device)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad_(False)


def set_seed(seed: int = 13) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(13)


def chunk_text_by_tokens(text: str, *, max_tokens: int = 256, stride: int = 128) -> List[str]:
    text = (text or '').replace('\n', ' ').strip()
    if not text:
        return ['']
    enc = tokenizer(text, add_special_tokens=False, return_attention_mask=False, return_tensors=None)
    ids = enc['input_ids']
    if len(ids) <= max_tokens:
        return [text]
    windows: List[str] = []
    start = 0
    while start < len(ids):
        end = min(len(ids), start + max_tokens)
        chunk_ids = ids[start:end]
        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
        chunk_text = chunk_text.strip()
        if chunk_text:
            windows.append(chunk_text)
        if end == len(ids):
            break
        start = max(0, end - stride)
    return windows if windows else [text[:1000]]


@torch.no_grad()
def encode_texts(texts: List[str], *, batch_size: int = 32, max_length: int = 256) -> torch.Tensor:
    vecs: List[torch.Tensor] = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        tok = tokenizer(
            batch,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors='pt',
        )
        tok = {k: v.to(device) for k, v in tok.items()}
        out = encoder(**tok).last_hidden_state
        cls = out[:, 0, :].detach().cpu()
        vecs.append(cls)
    return torch.cat(vecs, dim=0)


def build_doc_index(docs: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
    idx: Dict[str, Dict[str, Any]] = {}
    for doc in docs:
        doc_id = str(doc.get('document_id') or doc.get('doc_id') or doc.get('id') or '')
        if not doc_id:
            continue
        idx[doc_id] = doc
    return idx


train_doc_by_id = build_doc_index(train_docs)
dev_doc_by_id = build_doc_index(dev_docs)


def precompute_doc_windows_and_embeds(
    *,
    split_name: str,
    docs_by_id: Dict[str, Dict[str, Any]],
    max_tokens: int = 256,
    stride: int = 128,
    max_length: int = 256,
    recompute: bool = False,
) -> Tuple[Dict[str, List[str]], Dict[str, torch.Tensor]]:
    win_path = MIL_CACHE_DIR / f'{split_name}_windows.pt'
    emb_path = MIL_CACHE_DIR / f'{split_name}_winemb.pt'
    if (not recompute) and win_path.exists() and emb_path.exists():
        print(f'Loading cached windows/embeddings for {split_name}...')
        doc_windows = torch.load(win_path)
        doc_embeds = torch.load(emb_path)
        return doc_windows, doc_embeds

    doc_windows: Dict[str, List[str]] = {}
    doc_embeds: Dict[str, torch.Tensor] = {}
    for doc_id, doc in tqdm(list(docs_by_id.items()), desc=f'Build windows [{split_name}]'):
        text = str(doc.get('text') or '')
        windows = chunk_text_by_tokens(text, max_tokens=max_tokens, stride=stride)
        doc_windows[doc_id] = windows
    for doc_id, windows in tqdm(list(doc_windows.items()), desc=f'Encode windows [{split_name}]'):
        embeds = encode_texts(windows, batch_size=32, max_length=max_length)
        doc_embeds[doc_id] = embeds  # [W, H] on CPU
    torch.save(doc_windows, win_path)
    torch.save(doc_embeds, emb_path)
    print('Saved cache:', win_path.name, 'and', emb_path.name)
    return doc_windows, doc_embeds


DOC_MAX_TOKENS = 256
DOC_STRIDE = 128
WIN_MAX_LENGTH = 256
RECOMPUTE_EMBEDS = False

train_doc_windows, train_doc_embeds = precompute_doc_windows_and_embeds(
    split_name='train',
    docs_by_id=train_doc_by_id,
    max_tokens=DOC_MAX_TOKENS,
    stride=DOC_STRIDE,
    max_length=WIN_MAX_LENGTH,
    recompute=RECOMPUTE_EMBEDS,
)
dev_doc_windows, dev_doc_embeds = precompute_doc_windows_and_embeds(
    split_name='dev',
    docs_by_id=dev_doc_by_id,
    max_tokens=DOC_MAX_TOKENS,
    stride=DOC_STRIDE,
    max_length=WIN_MAX_LENGTH,
    recompute=RECOMPUTE_EMBEDS,
)


# Entity embedding cache (encode multiple mention variants)
def build_entity_text(ex: PairExampleMIL) -> Tuple[str, str]:
    # Using multiple variants helps robustness to OCR/newlines.
    p_ms = [m.strip().replace('\n', ' ') for m in (ex.pers_mentions or []) if isinstance(m, str) and m.strip()]
    l_ms = [m.strip().replace('\n', ' ') for m in (ex.loc_mentions or []) if isinstance(m, str) and m.strip()]
    p = ' ; '.join(p_ms[:3]) if p_ms else (ex.pers_mention.strip() if ex.pers_mention else '')
    l = ' ; '.join(l_ms[:3]) if l_ms else (ex.loc_mention.strip() if ex.loc_mention else '')
    return f'Person: {p}', f'Location: {l}'


@torch.no_grad()
def precompute_entity_embeds(examples: List[PairExampleMIL], *, max_length: int = 48) -> Dict[str, torch.Tensor]:
    # key by entity id; store [H] on CPU
    ent_texts: Dict[str, str] = {}
    for ex in examples:
        p_text, l_text = build_entity_text(ex)
        if ex.pers_entity_id and ex.pers_entity_id not in ent_texts:
            ent_texts[ex.pers_entity_id] = p_text
        if ex.loc_entity_id and ex.loc_entity_id not in ent_texts:
            ent_texts[ex.loc_entity_id] = l_text
    ids = list(ent_texts.keys())
    texts = [ent_texts[i] for i in ids]
    vec = encode_texts(texts, batch_size=64, max_length=max_length)
    return {k: vec[i] for i, k in enumerate(ids)}


ent_cache_path = MIL_CACHE_DIR / 'entity_embeds.pt'
if ent_cache_path.exists() and (not RECOMPUTE_EMBEDS):
    entity_embeds = torch.load(ent_cache_path)
    print('Loaded entity embedding cache:', ent_cache_path.name, 'n=', len(entity_embeds))
else:
    entity_embeds = precompute_entity_embeds(train_ex + dev_ex, max_length=48)
    torch.save(entity_embeds, ent_cache_path)
    print('Saved entity embedding cache:', ent_cache_path.name, 'n=', len(entity_embeds))


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Build windows [train]:   0%|          | 0/461 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (550 > 512). Running this sequence through the model will result in indexing errors


Encode windows [train]:   0%|          | 0/461 [00:00<?, ?it/s]

Saved cache: train_windows.pt and train_winemb.pt


Build windows [dev]:   0%|          | 0/17 [00:00<?, ?it/s]

Encode windows [dev]:   0%|          | 0/17 [00:00<?, ?it/s]

Saved cache: dev_windows.pt and dev_winemb.pt
Saved entity embedding cache: entity_embeds.pt n= 5785


In [None]:
# 5) Build MIL instances (candidate windows per pair) + DataLoaders
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from typing import NamedTuple
import re


class MILInstance(NamedTuple):
    document_id: str
    pair_index: int
    pers_entity_id: str
    loc_entity_id: str
    cand_win_idx: List[int]
    y_at: int
    y_isat: int


_ws_re = re.compile(r"\s+")
_nonword_re = re.compile(r"[^\w\s]")

def _norm_for_match(s: str) -> str:
    # Robust to OCR/newlines/punctuation differences.
    s = (s or '').replace('\n', ' ').casefold()
    s = _ws_re.sub(' ', s).strip()
    s = _nonword_re.sub(' ', s)
    s = _ws_re.sub(' ', s).strip()
    return s


def _norm_mentions(ms: List[str]) -> List[str]:
    out: List[str] = []
    seen = set()
    for m in ms or []:
        if not isinstance(m, str):
            continue
        nm = _norm_for_match(m)
        if not nm or nm in seen:
            continue
        seen.add(nm)
        out.append(nm)
    # Prefer longer mentions first (more specific)
    out.sort(key=len, reverse=True)
    return out


def _evenly_spaced_indices(n: int, k: int) -> List[int]:
    if n <= 0:
        return []
    if n <= k:
        return list(range(n))
    # include start, middle, end-ish
    xs = np.linspace(0, n - 1, num=k)
    idx = sorted({int(round(x)) for x in xs})
    if len(idx) < k:
        # fill deterministically
        for i in range(n):
            if i not in idx:
                idx.append(i)
            if len(idx) == k:
                break
        idx.sort()
    return idx[:k]


def find_candidate_windows(
    windows: List[str],
    pers_mentions: List[str],
    loc_mentions: List[str],
    *,
    max_keep: int = 32,
) -> List[int]:
    # Score each window by lexical matches against ANY mention surface form.
    p_ms = _norm_mentions(pers_mentions)
    l_ms = _norm_mentions(loc_mentions)

    scores: List[Tuple[int, int]] = []  # (score, idx)
    for i, w in enumerate(windows):
        wl = _norm_for_match(w)
        sp = 0
        sl = 0
        for m in p_ms[:10]:
            if m and m in wl:
                sp = 1
                break
        for m in l_ms[:10]:
            if m and m in wl:
                sl = 1
                break
        s = 0
        if sp and sl:
            s = 3
        elif sp or sl:
            s = 1
        if s > 0:
            scores.append((s, i))

    if not scores:
        # Fallback: avoid always selecting the beginning of the document.
        return _evenly_spaced_indices(len(windows), max_keep)

    # Prefer windows that contain BOTH entities, then partial matches.
    scores.sort(key=lambda x: (x[0], -x[1]), reverse=True)
    top = [i for _, i in scores[:max_keep]]
    top = sorted(set(top))
    return top


def build_instances(
    examples: List[PairExampleMIL],
    doc_windows: Dict[str, List[str]],
) -> List[MILInstance]:
    out: List[MILInstance] = []
    for ex in tqdm(examples, desc='Build pair instances'):
        wins = doc_windows.get(ex.document_id)
        if not wins:
            continue
        cand = find_candidate_windows(wins, ex.pers_mentions, ex.loc_mentions)
        if ex.at not in AT2ID or ex.isAt not in ISAT2ID:
            continue
        out.append(MILInstance(
            document_id=ex.document_id,
            pair_index=ex.pair_index,
            pers_entity_id=ex.pers_entity_id,
            loc_entity_id=ex.loc_entity_id,
            cand_win_idx=cand,
            y_at=AT2ID[ex.at],
            y_isat=ISAT2ID[ex.isAt],
        ))
    return out


train_inst = build_instances(train_ex, train_doc_windows)
dev_inst = build_instances(dev_ex, dev_doc_windows)
print('MIL train instances:', len(train_inst), 'dev instances:', len(dev_inst))


class MILDataset(Dataset):
    def __init__(self, instances: List[MILInstance]):
        self.instances = instances

    def __len__(self) -> int:
        return len(self.instances)

    def __getitem__(self, i: int) -> MILInstance:
        return self.instances[i]


def mil_collate(batch: List[MILInstance]) -> Dict[str, Any]:
    # Build padded window-embedding tensors per example
    doc_ids = [b.document_id for b in batch]
    p_ids = [b.pers_entity_id for b in batch]
    l_ids = [b.loc_entity_id for b in batch]
    y_at = torch.tensor([b.y_at for b in batch], dtype=torch.long)
    y_is = torch.tensor([b.y_isat for b in batch], dtype=torch.long)

    # Window embeddings
    embeds_list: List[torch.Tensor] = []
    max_w = 1
    H = int(encoder.config.hidden_size)
    for b in batch:
        if b.document_id in train_doc_embeds:
            e = train_doc_embeds[b.document_id]
        elif b.document_id in dev_doc_embeds:
            e = dev_doc_embeds[b.document_id]
        else:
            e = torch.zeros((1, H), dtype=torch.float32)
        embeds_list.append(e)
        max_w = max(max_w, e.shape[0])

    win_emb = torch.zeros((len(batch), max_w, H), dtype=torch.float32)
    win_valid = torch.zeros((len(batch), max_w), dtype=torch.bool)
    win_cand = torch.zeros((len(batch), max_w), dtype=torch.bool)
    for i, (b, e) in enumerate(zip(batch, embeds_list)):
        w = e.shape[0]
        win_emb[i, :w] = e
        win_valid[i, :w] = True
        for j in b.cand_win_idx:
            if j < w:
                win_cand[i, j] = True
        if not win_cand[i, :w].any():
            win_cand[i, :w] = True

    # Entity embeddings
    zero = torch.zeros((H,), dtype=torch.float32)
    p_emb = torch.stack([entity_embeds.get(pid, zero) for pid in p_ids], dim=0).float()
    l_emb = torch.stack([entity_embeds.get(lid, zero) for lid in l_ids], dim=0).float()
    return {
        'doc_ids': doc_ids,
        'pair_index': [b.pair_index for b in batch],
        'p_emb': p_emb,
        'l_emb': l_emb,
        'win_emb': win_emb,
        'win_valid': win_valid,
        'win_cand': win_cand,
        'y_at': y_at,
        'y_isAt': y_is,
    }


# Weighted sampling to fight label imbalance (balance joint (at,isAt) combos)
combo_counts = Counter((b.y_at, b.y_isat) for b in train_inst)
weights = np.array([1.0 / combo_counts[(b.y_at, b.y_isat)] for b in train_inst], dtype=np.float64)
weights = weights / weights.mean()
sampler = WeightedRandomSampler(weights=torch.tensor(weights, dtype=torch.double), num_samples=len(train_inst), replacement=True)

BATCH_SIZE = 64
train_loader = DataLoader(MILDataset(train_inst), batch_size=BATCH_SIZE, sampler=sampler, shuffle=False, collate_fn=mil_collate)
dev_loader = DataLoader(MILDataset(dev_inst), batch_size=BATCH_SIZE, shuffle=False, collate_fn=mil_collate)
print('Train batches:', len(train_loader), 'Dev batches:', len(dev_loader))
print('Train combo counts:', dict(sorted(combo_counts.items())))


Build pair instances:   0%|          | 0/6170 [00:00<?, ?it/s]

Build pair instances:   0%|          | 0/151 [00:00<?, ?it/s]

MIL train instances: 6170 dev instances: 151
Train batches: 97 Dev batches: 3


In [None]:
# 6) MIL model: attention over windows conditioned on (person, location)
import torch.nn as nn
import torch.nn.functional as F


class EvidenceMIL(nn.Module):
    def __init__(self, hidden_size: int, *, dropout: float = 0.2):
        super().__init__()
        self.hidden = hidden_size
        self.dropout = nn.Dropout(dropout)

        # Pair -> query
        self.pair_mlp = nn.Sequential(
            nn.Linear(hidden_size * 4, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
        )

        # Attention projection for windows
        self.win_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.attn_temp = nn.Parameter(torch.tensor(1.0))

        # Shared trunk before heads
        # features: p,l,attn_pool,max_pool,|p-l|,p*l
        feat_dim = hidden_size * 6
        self.trunk = nn.Sequential(
            nn.Linear(feat_dim, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.head_at = nn.Linear(hidden_size, 3)
        self.head_isat = nn.Linear(hidden_size, 2)

    def forward(
        self,
        p_emb: torch.Tensor,      # [B,H]
        l_emb: torch.Tensor,      # [B,H]
        win_emb: torch.Tensor,    # [B,W,H]
        win_valid: torch.Tensor,  # [B,W] bool
        win_cand: torch.Tensor,   # [B,W] bool
    ) -> Dict[str, torch.Tensor]:
        B, W, H = win_emb.shape
        p = self.dropout(p_emb)
        l = self.dropout(l_emb)

        # query uses interactions but not pooled evidence
        x = torch.cat([p, l, torch.abs(p - l), p * l], dim=-1)
        q = self.pair_mlp(x)  # [B,H]

        # window scores: (Wproj(v_w)) dot q
        v = self.win_proj(win_emb)  # [B,W,H]
        scores = torch.einsum('bwh,bh->bw', v, q)

        # Mask out invalid/padded windows and non-candidate windows
        mask = win_valid & win_cand
        scores = scores / (torch.clamp(self.attn_temp, min=0.2, max=5.0))
        scores = scores.masked_fill(~mask, -1e9)

        attn = torch.softmax(scores, dim=-1)  # [B,W]
        v_attn = torch.einsum('bw,bwh->bh', attn, win_emb)

        # Max pool over candidate windows (complementary to attention)
        win_emb_masked = win_emb.masked_fill(~mask.unsqueeze(-1), float('-inf'))
        v_max = win_emb_masked.max(dim=1).values
        v_max = torch.nan_to_num(v_max, nan=0.0, neginf=0.0, posinf=0.0)

        feats = torch.cat([p, l, v_attn, v_max, torch.abs(p - l), p * l], dim=-1)
        h = self.trunk(feats)
        return {
            'logits_at': self.head_at(h),
            'logits_isAt': self.head_isat(h),
            'attn': attn,
        }


def inv_sqrt_class_weights(y: np.ndarray, num_classes: int) -> torch.Tensor:
    counts = np.bincount(y, minlength=num_classes).astype(np.float32)
    w = 1.0 / np.sqrt(np.maximum(counts, 1.0))
    w = w * (num_classes / w.sum())
    return torch.tensor(w, dtype=torch.float32)


# Class weights (computed from training labels)
y_at_train = np.array([b.y_at for b in train_inst], dtype=np.int64)
y_is_train = np.array([b.y_isat for b in train_inst], dtype=np.int64)
w_at = inv_sqrt_class_weights(y_at_train, 3).to(device)
w_is = inv_sqrt_class_weights(y_is_train, 2).to(device)
print('Class weights at:', w_at.detach().cpu().numpy())
print('Class weights isAt:', w_is.detach().cpu().numpy())

mil_model = EvidenceMIL(hidden_size=int(encoder.config.hidden_size), dropout=0.25).to(device)
print('MIL model params:', sum(p.numel() for p in mil_model.parameters())/1e6, 'M')


Class weights at: [0.6638366 1.0119148 1.3242487]
Class weights isAt: [0.5196881 1.480312 ]
MIL model params: 6.494213 M


In [7]:
# 7) Training + evaluation (macro recall) + early stopping
def macro_recall(y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -> float:
    return float(recall_score(y_true, y_pred, average='macro', labels=labels, zero_division=0))

@torch.no_grad()
def evaluate_mil(model: EvidenceMIL, loader: DataLoader) -> Dict[str, float]:
    model.eval()
    ys_at, ps_at = [], []
    ys_is, ps_is = [], []
    total_loss = 0.0
    n_batches = 0
    for batch in loader:
        p = batch['p_emb'].to(device)
        l = batch['l_emb'].to(device)
        win_emb = batch['win_emb'].to(device)
        win_valid = batch['win_valid'].to(device)
        win_cand = batch['win_cand'].to(device)
        y_at = batch['y_at'].to(device)
        y_is = batch['y_isAt'].to(device)
        out = model(p, l, win_emb, win_valid, win_cand)
        loss_at = F.cross_entropy(out['logits_at'], y_at, weight=w_at)
        loss_is = F.cross_entropy(out['logits_isAt'], y_is, weight=w_is)
        loss = loss_at + loss_is
        total_loss += float(loss.item())
        n_batches += 1
        pa = out['logits_at'].argmax(dim=-1).detach().cpu().numpy()
        pi = out['logits_isAt'].argmax(dim=-1).detach().cpu().numpy()
        ys_at.append(y_at.detach().cpu().numpy())
        ps_at.append(pa)
        ys_is.append(y_is.detach().cpu().numpy())
        ps_is.append(pi)
    y_at_all = np.concatenate(ys_at)
    p_at_all = np.concatenate(ps_at)
    y_is_all = np.concatenate(ys_is)
    p_is_all = np.concatenate(ps_is)
    mr_at = macro_recall(y_at_all, p_at_all, labels=list(range(3)))
    mr_is = macro_recall(y_is_all, p_is_all, labels=list(range(2)))
    acc_at = float(accuracy_score(y_at_all, p_at_all))
    acc_is = float(accuracy_score(y_is_all, p_is_all))
    return {
        'val_loss': total_loss / max(1, n_batches),
        'acc_at': acc_at,
        'acc_isAt': acc_is,
        'macro_recall_at': mr_at,
        'macro_recall_isAt': mr_is,
        'avg_macro_recall': 0.5 * (mr_at + mr_is),
    }

LR = 2e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 30
PATIENCE = 5
optimizer = torch.optim.AdamW(mil_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

best = {'avg_macro_recall': -1.0, 'epoch': -1}
best_path = MIL_CACHE_DIR / 'best_mil_model.pt'
bad_epochs = 0
history: List[Dict[str, float]] = []

for epoch in range(1, EPOCHS + 1):
    mil_model.train()
    total = 0.0
    n = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}')
    for batch in pbar:
        p = batch['p_emb'].to(device)
        l = batch['l_emb'].to(device)
        win_emb = batch['win_emb'].to(device)
        win_valid = batch['win_valid'].to(device)
        win_cand = batch['win_cand'].to(device)
        y_at = batch['y_at'].to(device)
        y_is = batch['y_isAt'].to(device)
        optimizer.zero_grad(set_to_none=True)
        out = mil_model(p, l, win_emb, win_valid, win_cand)
        loss_at = F.cross_entropy(out['logits_at'], y_at, weight=w_at)
        loss_is = F.cross_entropy(out['logits_isAt'], y_is, weight=w_is)
        loss = loss_at + loss_is
        loss.backward()
        torch.nn.utils.clip_grad_norm_(mil_model.parameters(), 1.0)
        optimizer.step()
        total += float(loss.item())
        n += 1
        if n % 10 == 0:
            pbar.set_postfix(train_loss=total / max(1, n))
    metrics = evaluate_mil(mil_model, dev_loader)
    row = {'epoch': epoch, 'train_loss': total / max(1, n), **metrics}
    history.append(row)
    print(row)
    if metrics['avg_macro_recall'] > best['avg_macro_recall']:
        best = {'avg_macro_recall': metrics['avg_macro_recall'], 'epoch': epoch}
        torch.save({'model': mil_model.state_dict()}, best_path)
        print('Saved best ->', best_path)
        bad_epochs = 0
    else:
        bad_epochs += 1
        if bad_epochs >= PATIENCE:
            print(f'Early stopping: no improvement for {PATIENCE} epochs')
            break

print('Best:', best)
if best_path.exists():
    mil_model.load_state_dict(torch.load(best_path, map_location=device)['model'])
    print('Loaded best checkpoint')

df = pd.DataFrame(history)
if len(df):
    display(df)

Epoch 1/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 1, 'train_loss': 2.229541878110355, 'val_loss': 1.7157322963078816, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}
Saved best -> /content/mil_cache/best_mil_model.pt


Epoch 2/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 2, 'train_loss': 1.6378914437343164, 'val_loss': 1.6631971995035808, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}


Epoch 3/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 3, 'train_loss': 1.6447333380119087, 'val_loss': 1.6629846890767415, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}


Epoch 4/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 4, 'train_loss': 1.6260318866710073, 'val_loss': 1.6739999453226726, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}


Epoch 5/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 5, 'train_loss': 1.6386383120546635, 'val_loss': 1.6816920439402263, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}


Epoch 6/30:   0%|          | 0/97 [00:00<?, ?it/s]

{'epoch': 6, 'train_loss': 1.6281058284425245, 'val_loss': 1.6890898545583088, 'acc_at': 0.4503311258278146, 'acc_isAt': 0.8807947019867549, 'macro_recall_at': 0.3333333333333333, 'macro_recall_isAt': 0.5, 'avg_macro_recall': 0.41666666666666663}
Early stopping: no improvement for 5 epochs
Best: {'avg_macro_recall': 0.41666666666666663, 'epoch': 1}
Loaded best checkpoint


Unnamed: 0,epoch,train_loss,val_loss,acc_at,acc_isAt,macro_recall_at,macro_recall_isAt,avg_macro_recall
0,1,2.229542,1.715732,0.450331,0.880795,0.333333,0.5,0.416667
1,2,1.637891,1.663197,0.450331,0.880795,0.333333,0.5,0.416667
2,3,1.644733,1.662985,0.450331,0.880795,0.333333,0.5,0.416667
3,4,1.626032,1.674,0.450331,0.880795,0.333333,0.5,0.416667
4,5,1.638638,1.681692,0.450331,0.880795,0.333333,0.5,0.416667
5,6,1.628106,1.68909,0.450331,0.880795,0.333333,0.5,0.416667


In [8]:
# 8) Predict + export back to HIPE JSONL
@torch.no_grad()
def predict_mil(model: EvidenceMIL, loader: DataLoader) -> Dict[Tuple[str, int], Tuple[str, str]]:
    model.eval()
    out_map: Dict[Tuple[str, int], Tuple[str, str]] = {}
    for batch in tqdm(loader, desc='Predict'):
        p = batch['p_emb'].to(device)
        l = batch['l_emb'].to(device)
        win_emb = batch['win_emb'].to(device)
        win_valid = batch['win_valid'].to(device)
        win_cand = batch['win_cand'].to(device)
        out = model(p, l, win_emb, win_valid, win_cand)
        pa = out['logits_at'].argmax(dim=-1).detach().cpu().numpy()
        pi = out['logits_isAt'].argmax(dim=-1).detach().cpu().numpy()
        for doc_id, pair_idx, a, i_ in zip(batch['doc_ids'], batch['pair_index'], pa, pi):
            out_map[(str(doc_id), int(pair_idx))] = (ID2AT[int(a)], ID2ISAT[int(i_)])
    return out_map

pred_map = predict_mil(mil_model, dev_loader)
print('Predictions for pairs:', len(pred_map))

def write_predictions_jsonl(
    docs: List[Dict[str, Any]],
    pred_by_pair: Dict[Tuple[str, int], Tuple[str, str]],
    out_path: Path,
) -> None:
    with out_path.open('w', encoding='utf-8') as f:
        for doc in docs:
            doc_id = str(doc.get('document_id') or doc.get('doc_id') or doc.get('id') or '')
            new_doc = dict(doc)
            new_pairs = []
            for idx, pair in enumerate(doc.get('sampled_pairs', [])):
                new_pair = dict(pair)
                at_pred, isat_pred = pred_by_pair.get((doc_id, idx), ('FALSE', 'FALSE'))
                new_pair['at'] = at_pred
                new_pair['isAt'] = isat_pred
                new_pairs.append(new_pair)
            new_doc['sampled_pairs'] = new_pairs
            f.write(json.dumps(new_doc, ensure_ascii=False) + '\n')

dev_lang = DEV_LANG
out_path = Path.cwd() / f'{dev_lang}_dev_predictions_mil.jsonl'
write_predictions_jsonl(dev_docs, pred_map, out_path)
print('Wrote:', out_path)

Predict:   0%|          | 0/3 [00:00<?, ?it/s]

Predictions for pairs: 151
Wrote: /content/en_dev_predictions_mil.jsonl


In [9]:
# 9) Sanity checks on exported JSONL
reloaded = read_jsonl(out_path)
print('Reloaded docs:', len(reloaded), 'expected:', len(dev_docs))
assert len(reloaded) == len(dev_docs)
n_pairs = 0
for d in reloaded[:5]:
    doc_id = d.get('document_id')
    pairs = d.get('sampled_pairs', [])
    print('doc_id:', doc_id, '| pairs:', len(pairs))
    for j, p in enumerate(pairs[:3]):
        print('  ', j, p.get('at'), p.get('isAt'))
    n_pairs += len(pairs)
print('Total pairs in file (first 5 docs counted separately):', n_pairs)

# Validate that every sampled_pairs entry has at/isAt strings
missing = 0
for d in reloaded:
    for p in d.get('sampled_pairs', []):
        if ('at' not in p) or ('isAt' not in p):
            missing += 1
print('Pairs missing at/isAt:', missing)
assert missing == 0

Reloaded docs: 17 expected: 17
doc_id: sn84026272-1800-12-10-a-i0004 | pairs: 4
   0 FALSE FALSE
   1 FALSE FALSE
   2 FALSE FALSE
doc_id: sn88085488-1910-09-23-a-i0001 | pairs: 10
   0 FALSE FALSE
   1 FALSE FALSE
   2 FALSE FALSE
doc_id: sn85042404-1880-12-21-a-i0002 | pairs: 1
   0 FALSE FALSE
doc_id: sn82014385-1810-04-04-a-i0003 | pairs: 16
   0 FALSE FALSE
   1 FALSE FALSE
   2 FALSE FALSE
doc_id: sn83026170-1820-01-15-a-i0002 | pairs: 6
   0 FALSE FALSE
   1 FALSE FALSE
   2 FALSE FALSE
Total pairs in file (first 5 docs counted separately): 37
Pairs missing at/isAt: 0
