# HIPE-2026 Peopleâ€“Place Context (Baseline)
This notebook builds a clean baseline for predicting the two labels per (person, location) pair: `at` and `isAt`.

**What you can do here**
- Load HIPE sandbox train/dev JSONL files
- Build pair-level training examples
- Fine-tune a transformer encoder (XLM-R) with two classification heads
- Evaluate on dev + export predictions in HIPE JSONL format

In [6]:
# If needed, install dependencies.
# (In hosted GPU notebooks, torch is often preinstalled.)

import sys
import subprocess

def _pip_install(pkgs):
    cmd = [sys.executable, '-m', 'pip', 'install', '-q'] + list(pkgs)
    print('Running:', ' '.join(cmd))
    subprocess.check_call(cmd)

# Core ML deps
try:
    import torch  # noqa: F401
except Exception:
    _pip_install(['torch'])

try:
    import transformers  # noqa: F401
except Exception:
    _pip_install(['transformers>=4.35'])

# Utilities (optional but convenient)
for pkg, import_name in [
    ('tqdm', 'tqdm'),
    ('scikit-learn', 'sklearn'),
    ('pandas', 'pandas'),
]:
    try:
        __import__(import_name)
    except Exception:
        _pip_install([pkg])

print('Python:', sys.version)
import torch
import transformers
print('torch:', torch.__version__)
print('transformers:', transformers.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0))


Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
torch: 2.9.0+cu126
transformers: 4.57.3
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB


In [7]:
# Download/prepare the HIPE-2026 *data* repo locally inside this notebook runtime.
# This replaces Google Drive mounting when you just want to pull the JSONL files from GitHub.

from __future__ import annotations

import json
import shutil
import subprocess
import zipfile
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

HIPE_DATA_GITHUB_REPO = 'https://github.com/hipe-eval/HIPE-2026-data'
WORK_DIR = Path.cwd()
CACHE_DIR = WORK_DIR / 'hipe_cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)


def _download_file(url: str, dest: Path) -> None:
    import urllib.request

    with urllib.request.urlopen(url) as r, dest.open('wb') as f:
        shutil.copyfileobj(r, f)


def ensure_repo(repo_url: str, dest_dir: Path) -> Path:
    """Ensure a repo is present locally in this runtime; returns repo directory."""
    if (dest_dir / '.git').exists() or (dest_dir / 'README.md').exists():
        return dest_dir

    if dest_dir.exists():
        # In case a previous partial download exists.
        shutil.rmtree(dest_dir)

    # Prefer git clone when available.
    if shutil.which('git'):
        print('Cloning repo with git...')
        subprocess.check_call(['git', 'clone', '--depth', '1', repo_url, str(dest_dir)])
        return dest_dir

    # Fallback: download GitHub zip.
    print('git not found; downloading zip archive...')
    zip_url = repo_url.rstrip('/') + '/archive/refs/heads/main.zip'
    zip_path = CACHE_DIR / (dest_dir.name + '-main.zip')
    if not zip_path.exists():
        _download_file(zip_url, zip_path)

    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(CACHE_DIR)

    extracted = CACHE_DIR / (dest_dir.name + '-main')
    if extracted.exists():
        extracted.rename(dest_dir)
        return dest_dir

    # Some repos extract as <RepoName>-main
    alt = CACHE_DIR / (repo_url.rstrip('/').split('/')[-1] + '-main')
    if alt.exists():
        alt.rename(dest_dir)
        return dest_dir

    raise FileNotFoundError('Zip extraction did not produce the expected folder.')


HIPE_DATA_REPO_DIR = ensure_repo(HIPE_DATA_GITHUB_REPO, CACHE_DIR / 'HIPE-2026-data')
print('HIPE_DATA_REPO_DIR =', HIPE_DATA_REPO_DIR)


def find_hipe_sandbox_dir() -> Path:
    """Locate the HIPE sandbox folder.

    Priority order:
    1) The downloaded HIPE-2026-data repo
    2) A local workspace clone (if you already have it)
    3) Colab/Drive style paths
    """
    candidates = [
        # Inside downloaded data repo
        HIPE_DATA_REPO_DIR / 'data' / 'sandbox',

        # Common local layouts (if you already have the data repo checked out)
        Path.cwd() / 'HIPE-2026-data-main' / 'data' / 'sandbox',
        Path('HIPE-2026-data-main') / 'data' / 'sandbox',
        Path.cwd() / 'HIPE-2026-data' / 'data' / 'sandbox',
        Path('HIPE-2026-data') / 'data' / 'sandbox',

        # Common Colab Google Drive layouts
        Path('/content/drive/My Drive/Colab Notebooks/HIPE-2026-data-main/data/sandbox'),
        Path('/content/drive/My Drive/Jupyter Notebooks/HIPE-2026-data-main/data/sandbox'),
        Path('/content/drive/My Drive/HIPE-2026-data-main/data/sandbox'),
    ]
    for p in candidates:
        if p.exists() and p.is_dir():
            return p

    # Robust fallback: search the downloaded repo for a directory named 'sandbox'
    for p in HIPE_DATA_REPO_DIR.rglob('sandbox'):
        if p.is_dir() and any(p.glob('*.jsonl')):
            return p

    raise FileNotFoundError('Could not find HIPE sandbox dir from common locations.')


SANDBOX_DIR = find_hipe_sandbox_dir()
print('SANDBOX_DIR =', SANDBOX_DIR)
print('Sandbox JSONLs:', sorted([p.name for p in SANDBOX_DIR.glob('*.jsonl')])[:50])

# --- Loading JSONL + building pair-level examples (used by the model cells below) ---
def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows

# Labels: treat null as "FALSE" (matches the official scorer's imputation rule)
AT_LABELS = ['FALSE', 'PROBABLE', 'TRUE']
ISAT_LABELS = ['FALSE', 'TRUE']
AT2ID = {k: i for i, k in enumerate(AT_LABELS)}
ID2AT = {i: k for k, i in AT2ID.items()}
ISAT2ID = {k: i for i, k in enumerate(ISAT_LABELS)}
ID2ISAT = {i: k for k, i in ISAT2ID.items()}

def norm_at(v: Optional[str]) -> str:
    return 'FALSE' if v is None else v

def norm_isat(v: Optional[str]) -> str:
    return 'FALSE' if v is None else v

@dataclass(frozen=True)
class PairExample:
    document_id: str
    pair_index: int
    pers_entity_id: str
    loc_entity_id: str
    text: str
    at: str
    isAt: str


def build_input_text(
    doc_text: str,
    pers_mentions: List[str],
    loc_mentions: List[str],
    *,
    max_mentions: int = 2,
    ) -> str:
    pers = ' | '.join([m.replace('\n', ' ').strip() for m in pers_mentions[:max_mentions]])
    loc = ' | '.join([m.replace('\n', ' ').strip() for m in loc_mentions[:max_mentions]])
    ctx = doc_text.replace('\n', ' ').strip()
    return f"Person: {pers}\nLocation: {loc}\nContext: {ctx}"


def iter_pair_examples(docs: List[Dict[str, Any]]) -> Iterable[PairExample]:
    for doc in docs:
        doc_id = doc['document_id']
        doc_text = doc.get('text', '')
        for idx, pair in enumerate(doc.get('sampled_pairs', [])):
            yield PairExample(
                document_id=doc_id,
                pair_index=idx,
                pers_entity_id=pair['pers_entity_id'],
                loc_entity_id=pair['loc_entity_id'],
                text=build_input_text(
                    doc_text,
                    pair.get('pers_mentions_list', []),
                    pair.get('loc_mentions_list', []),
                ),
                at=norm_at(pair.get('at')),
                isAt=norm_isat(pair.get('isAt')),
            )


def _sandbox_lang_from_filename(p: Path) -> str:
    # e.g., en-train.jsonl -> en
    return p.name.split('-', 1)[0]


def list_sandbox_split_files(split: str, langs: Optional[List[str]] = None) -> List[Path]:
    files = sorted(SANDBOX_DIR.glob(f"*-{split}.jsonl"))
    if langs is None:
        return files
    want = set(langs)
    return [p for p in files if _sandbox_lang_from_filename(p) in want]


def load_docs_from_files(paths: List[Path]) -> List[Dict[str, Any]]:
    docs: List[Dict[str, Any]] = []
    for p in paths:
        docs.extend(read_jsonl(p))
    return docs


def label_distribution(examples: List[PairExample]) -> Dict[str, Dict[str, int]]:
    c_at = Counter(ex.at for ex in examples)
    c_isat = Counter(ex.isAt for ex in examples)
    return {
        'at': {k: int(c_at.get(k, 0)) for k in AT_LABELS},
        'isAt': {k: int(c_isat.get(k, 0)) for k in ISAT_LABELS},
    }


# By default: train on *all* sandbox train files; evaluate/export on one dev language.
TRAIN_LANGS = ['en', 'de', 'fr']
DEV_LANG = 'en'

train_files = list_sandbox_split_files('train', TRAIN_LANGS)
dev_files = list_sandbox_split_files('dev', [DEV_LANG])
if len(train_files) == 0:
    raise FileNotFoundError(f"No '*-train.jsonl' files found in sandbox dir: {SANDBOX_DIR}")
if len(dev_files) != 1:
    raise FileNotFoundError(f"Expected exactly one dev file for DEV_LANG={DEV_LANG}, got: {dev_files}")

print('Train files:', [p.name for p in train_files])
print('Dev file:', dev_files[0].name)

train_docs = load_docs_from_files(train_files)
dev_path = dev_files[0]
dev_docs = read_jsonl(dev_path)

train_ex = list(iter_pair_examples(train_docs))
dev_ex = list(iter_pair_examples(dev_docs))

print('Train docs:', len(train_docs), 'Train pairs:', len(train_ex))
print('Dev docs:', len(dev_docs), 'Dev pairs:', len(dev_ex))
print('Train label distribution:', label_distribution(train_ex))
print('Dev label distribution:', label_distribution(dev_ex))


HIPE_DATA_REPO_DIR = /content/hipe_cache/HIPE-2026-data
SANDBOX_DIR = /content/hipe_cache/HIPE-2026-data/data/sandbox
Sandbox JSONLs: ['de-dev.jsonl', 'de-train.jsonl', 'en-dev.jsonl', 'en-train.jsonl', 'fr-dev.jsonl', 'fr-train.jsonl']
Train files: ['de-train.jsonl', 'en-train.jsonl', 'fr-train.jsonl']
Dev file: en-dev.jsonl
Train docs: 461 Train pairs: 6170
Dev docs: 17 Dev pairs: 151
Train label distribution: {'at': {'FALSE': 3669, 'PROBABLE': 1579, 'TRUE': 922}, 'isAt': {'FALSE': 5493, 'TRUE': 677}}
Dev label distribution: {'at': {'FALSE': 68, 'PROBABLE': 54, 'TRUE': 29}, 'isAt': {'FALSE': 133, 'TRUE': 18}}


## Model + training (next)
Next cells will add:
- Tokenization with `xlm-roberta-base`
- A multi-task model with 2 heads:
  - `at`: 3-way classification (`FALSE/PROBABLE/TRUE`)
  - `isAt`: 2-way classification (`FALSE/TRUE`)
- Training loop + dev evaluation
- Export predictions to HIPE JSONL format

In [8]:
import math
from typing import Sequence

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup

MODEL_NAME = 'xlm-roberta-base'
MAX_LEN = 182
BATCH_SIZE = 32
LR = 2e-5
EPOCHS = 6
WARMUP_RATIO = 0.06
SEED = 42
DROPOUT = 0.2

def set_seed(seed: int = 42) -> None:
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print('Tokenizer loaded:', MODEL_NAME)

class PairDataset(Dataset):
    def __init__(self, examples: Sequence[PairExample]):
        self.examples = list(examples)

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ex = self.examples[idx]
        enc = tokenizer(
            ex.text,
            truncation=True,
            padding=False,
            max_length=MAX_LEN,
            return_tensors=None,
        )
        return {
            'input_ids': enc['input_ids'],
            'attention_mask': enc['attention_mask'],
            'label_at': AT2ID[ex.at],
            'label_isAt': ISAT2ID[ex.isAt],
        }

def collate_batch(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    input_ids = [torch.tensor(x['input_ids'], dtype=torch.long) for x in batch]
    attention_mask = [torch.tensor(x['attention_mask'], dtype=torch.long) for x in batch]
    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)

    label_at = torch.tensor([x['label_at'] for x in batch], dtype=torch.long)
    label_isAt = torch.tensor([x['label_isAt'] for x in batch], dtype=torch.long)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'label_at': label_at,
        'label_isAt': label_isAt,
    }

class MultiTaskXLMR(nn.Module):
    def __init__(self, model_name: str = MODEL_NAME, dropout: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.head_at = nn.Linear(hidden, len(AT_LABELS))
        self.head_isAt = nn.Linear(hidden, len(ISAT_LABELS))

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # XLM-R has no pooler by default; use CLS token representation.
        cls = out.last_hidden_state[:, 0, :]
        x = self.dropout(cls)
        return {
            'logits_at': self.head_at(x),
            'logits_isAt': self.head_isAt(x),
        }

model = MultiTaskXLMR(MODEL_NAME, dropout=DROPOUT).to(device)
print('Model loaded:', MODEL_NAME)
print('DROPOUT =', DROPOUT)
print('EPOCHS =', EPOCHS)


Using device: cuda
Tokenizer loaded: xlm-roberta-base
Model loaded: xlm-roberta-base
DROPOUT = 0.2
EPOCHS = 6


In [None]:
from typing import Tuple
from collections import Counter
import copy
import random
import pandas as pd
from sklearn.metrics import accuracy_score, recall_score
from tqdm.auto import tqdm
from torch.utils.data import WeightedRandomSampler

@torch.no_grad()
def predict_batches(model: nn.Module, loader: DataLoader) -> Tuple[List[int], List[int]]:
    model.eval()
    preds_at: List[int] = []
    preds_isAt: List[int] = []
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        out = model(input_ids=input_ids, attention_mask=attention_mask)
        preds_at.extend(out['logits_at'].argmax(dim=-1).tolist())
        preds_isAt.extend(out['logits_isAt'].argmax(dim=-1).tolist())
    return preds_at, preds_isAt

def evaluate_simple(model: nn.Module, loader: DataLoader) -> Dict[str, float]:
    y_at: List[int] = []
    y_isAt: List[int] = []
    for batch in loader:
        y_at.extend(batch['label_at'].tolist())
        y_isAt.extend(batch['label_isAt'].tolist())

    p_at, p_isAt = predict_batches(model, loader)

    acc_at = accuracy_score(y_at, p_at)
    acc_isAt = accuracy_score(y_isAt, p_isAt)

    # Macro recall across classes (handles class imbalance better than accuracy).
    # Force label set so missing classes don't crash; zero_division=0 avoids warnings.
    rec_at_macro = recall_score(
        y_at,
        p_at,
        average='macro',
        labels=list(range(len(AT_LABELS))),
        zero_division=0,
    )
    rec_isAt_macro = recall_score(
        y_isAt,
        p_isAt,
        average='macro',
        labels=list(range(len(ISAT_LABELS))),
        zero_division=0,
    )

    return {
        'acc_at': float(acc_at),
        'acc_isAt': float(acc_isAt),
        'acc_avg': float(0.5 * (acc_at + acc_isAt)),
        'recall_at_macro': float(rec_at_macro),
        'recall_isAt_macro': float(rec_isAt_macro),
        'recall_avg_macro': float(0.5 * (rec_at_macro + rec_isAt_macro)),
    }

@torch.no_grad()
def compute_val_loss(
    model: nn.Module,
    loader: DataLoader,
    *,
    criterion_at: nn.Module,
    criterion_isAt: nn.Module,
) -> float:
    model.eval()
    total = 0.0
    batches = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        label_at = batch['label_at'].to(device)
        label_isAt = batch['label_isAt'].to(device)

        out = model(input_ids=input_ids, attention_mask=attention_mask)
        loss_at = criterion_at(out['logits_at'], label_at)
        loss_isAt = criterion_isAt(out['logits_isAt'], label_isAt)
        loss = loss_at + loss_isAt
        total += float(loss.item())
        batches += 1
    return float(total / max(1, batches))

def parse_prompt_text(text: str) -> Tuple[List[str], List[str], str]:
    """Extract mentions + context from our formatted prompt string."""
    pers_mentions: List[str] = []
    loc_mentions: List[str] = []
    ctx = ""
    for raw in text.splitlines():
        line = raw.strip()
        if line.startswith('Person:'):
            pers_mentions = [m.strip() for m in line[len('Person:'):].split('|') if m.strip()]
        elif line.startswith('Location:'):
            loc_mentions = [m.strip() for m in line[len('Location:'):].split('|') if m.strip()]
        elif line.startswith('Context:'):
            ctx = line[len('Context:'):].strip()
    return pers_mentions, loc_mentions, ctx

def crop_context(ctx: str, pers_mentions: List[str], loc_mentions: List[str], *, window_chars: int = 900) -> str:
    if len(ctx) <= window_chars:
        return ctx
    # Try to center around first occurrence of a person/location mention
    needles = [m for m in (pers_mentions[:1] + loc_mentions[:1]) if m]
    centers: List[int] = []
    for needle in needles:
        pos = ctx.lower().find(needle.lower())
        if pos >= 0:
            centers.append(pos)
    if centers:
        center = min(centers)
        start = max(0, center - window_chars // 2)
    else:
        start = random.randint(0, max(0, len(ctx) - window_chars))
    end = min(len(ctx), start + window_chars)
    snippet = ctx[start:end].strip()
    return snippet

def augment_prompt_text(text: str, *, window_chars: int = 900, drop_prob: float = 0.35, shuffle_mentions: bool = True) -> str:
    pers_mentions, loc_mentions, ctx = parse_prompt_text(text)
    # Mention dropout (keep at least 1)
    def _drop(ms: List[str]) -> List[str]:
        if len(ms) <= 1:
            return ms
        kept = [m for m in ms if random.random() > drop_prob]
        return kept if kept else [random.choice(ms)]
    pers_mentions = _drop(pers_mentions)
    loc_mentions = _drop(loc_mentions)
    if shuffle_mentions:
        random.shuffle(pers_mentions)
        random.shuffle(loc_mentions)
    ctx = crop_context(ctx, pers_mentions, loc_mentions, window_chars=window_chars)
    pers = ' | '.join(pers_mentions)
    loc = ' | '.join(loc_mentions)
    return f"Person: {pers}\nLocation: {loc}\nContext: {ctx}"

class AugmentedPairDataset(Dataset):
    def __init__(self, examples: Sequence[PairExample], *, augment: bool, aug_prob: float, window_chars: int):
        self.examples = list(examples)
        self.augment = bool(augment)
        self.aug_prob = float(aug_prob)
        self.window_chars = int(window_chars)

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ex = self.examples[idx]
        text = ex.text
        if self.augment and random.random() < self.aug_prob:
            text = augment_prompt_text(text, window_chars=self.window_chars)
        enc = tokenizer(
            text,
            truncation=True,
            padding=False,
            max_length=MAX_LEN,
            return_tensors=None,
        )
        return {
            'input_ids': enc['input_ids'],
            'attention_mask': enc['attention_mask'],
            'label_at': AT2ID[ex.at],
            'label_isAt': ISAT2ID[ex.isAt],
        }

def make_train_sampler(examples: List[PairExample]) -> WeightedRandomSampler:
    """Oversample rarer labels (both tasks) using a per-example weight."""
    at_counts = Counter(ex.at for ex in examples)
    isat_counts = Counter(ex.isAt for ex in examples)
    # Inverse-frequency weights; sqrt dampens extreme oversampling.
    inv_at = {k: (len(examples) / max(1, at_counts.get(k, 0))) for k in AT_LABELS}
    inv_isat = {k: (len(examples) / max(1, isat_counts.get(k, 0))) for k in ISAT_LABELS}
    weights = []
    for ex in examples:
        w = (inv_at[ex.at] * inv_isat[ex.isAt]) ** 0.5
        weights.append(float(w))
    w_t = torch.tensor(weights, dtype=torch.double)
    return WeightedRandomSampler(weights=w_t, num_samples=len(examples), replacement=True)

# --- Augmentation toggles (train only) ---
USE_AUGMENTATION = True
AUG_PROB = 0.50
AUG_CONTEXT_WINDOW_CHARS = 900

train_ds = AugmentedPairDataset(train_ex, augment=USE_AUGMENTATION, aug_prob=AUG_PROB, window_chars=AUG_CONTEXT_WINDOW_CHARS)
dev_ds = AugmentedPairDataset(dev_ex, augment=False, aug_prob=0.0, window_chars=AUG_CONTEXT_WINDOW_CHARS)

# --- Imbalance + regularization toggles ---
USE_OVERSAMPLING = True
WEIGHT_DECAY = 0.01

if USE_OVERSAMPLING:
    sampler = make_train_sampler(train_ex)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, shuffle=False, collate_fn=collate_batch)
    print('Train sampling: WeightedRandomSampler (oversampling enabled)')
else:
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    print('Train sampling: shuffle=True')

dev_loader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

criterion_at = nn.CrossEntropyLoss()
criterion_isAt = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

num_steps = EPOCHS * len(train_loader)
warmup_steps = int(WARMUP_RATIO * num_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_steps)

# --- Early stopping / best checkpoint ---
EARLY_STOPPING = True
EARLY_STOP_PATIENCE = 4
EARLY_STOP_MIN_DELTA = 0.0
BEST_MODEL_METRIC = 'recall_avg_macro'  # or 'val_loss' if you prefer
_best_state = None
_best_epoch = None
_best_metric_val = None
_best_val_loss = None
_no_improve = 0

print('Augmentation:', USE_AUGMENTATION, '| AUG_PROB:', AUG_PROB, '| ctx_window_chars:', AUG_CONTEXT_WINDOW_CHARS)
print('Oversampling:', USE_OVERSAMPLING, '| weight_decay:', WEIGHT_DECAY)
print('Early stopping:', EARLY_STOPPING, '| patience:', EARLY_STOP_PATIENCE, '| best metric:', BEST_MODEL_METRIC)
print('Train batches:', len(train_loader), 'Dev batches:', len(dev_loader))
print('Total steps:', num_steps, 'Warmup:', warmup_steps)

history: List[Dict[str, float]] = []

# Training loop
for epoch in range(1, EPOCHS + 1):
    model.train()
    running = 0.0

    pbar = tqdm(
        enumerate(train_loader, start=1),
        total=len(train_loader),
        desc=f"Epoch {epoch}/{EPOCHS}",
        leave=True,
    )
    for step, batch in pbar:
        optimizer.zero_grad(set_to_none=True)

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        label_at = batch['label_at'].to(device)
        label_isAt = batch['label_isAt'].to(device)

        out = model(input_ids=input_ids, attention_mask=attention_mask)
        loss_at = criterion_at(out['logits_at'], label_at)
        loss_isAt = criterion_isAt(out['logits_isAt'], label_isAt)
        loss = loss_at + loss_isAt

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        running += float(loss.item())
        if step % 5 == 0 or step == len(train_loader):
            pbar.set_postfix(
                train_loss=f"{running / step:.4f}",
                lr=f"{optimizer.param_groups[0]['lr']:.2e}",
            )

    metrics = evaluate_simple(model, dev_loader)
    val_loss = compute_val_loss(model, dev_loader, criterion_at=criterion_at, criterion_isAt=criterion_isAt)
    lr_now = float(optimizer.param_groups[0]['lr'])
    row = {
        'epoch': float(epoch),
        'train_loss': float(running / max(1, len(train_loader))),
        'val_loss': float(val_loss),
        'lr': lr_now,
        **metrics,
    }
    history.append(row)

    # Track best checkpoint (defaults to macro recall avg)
    metric_val = float(row[BEST_MODEL_METRIC])
    if _best_metric_val is None or metric_val > _best_metric_val:
        _best_metric_val = metric_val
        _best_epoch = epoch
        _best_state = copy.deepcopy(model.state_dict())

    # Early stop based on validation loss (more stable than recall)
    if _best_val_loss is None or val_loss < (_best_val_loss - EARLY_STOP_MIN_DELTA):
        _best_val_loss = float(val_loss)
        _no_improve = 0
    else:
        _no_improve += 1

    df = pd.DataFrame(history)
    # Pretty print with stable column order
    cols = [
        'epoch',
        'train_loss',
        'val_loss',
        'lr',
        'acc_at',
        'acc_isAt',
        'acc_avg',
        'recall_at_macro',
        'recall_isAt_macro',
        'recall_avg_macro',
    ]
    cols = [c for c in cols if c in df.columns]
    df_show = df[cols].copy()
    if 'lr' in df_show.columns:
        df_show['lr'] = df_show['lr'].map(lambda x: f"{float(x):.2e}")
    print('\nProgress so far:')
    print(df_show.round(4).to_string(index=False))

    if EARLY_STOPPING and _no_improve >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping: val_loss did not improve for {EARLY_STOP_PATIENCE} epoch(s).")
        break

# Restore best checkpoint (by BEST_MODEL_METRIC) and report
if _best_state is not None:
    model.load_state_dict(_best_state)
    best_metrics = evaluate_simple(model, dev_loader)
    best_val_loss = compute_val_loss(model, dev_loader, criterion_at=criterion_at, criterion_isAt=criterion_isAt)
    print(f"\nRestored best model from epoch {_best_epoch} (best {BEST_MODEL_METRIC}={_best_metric_val:.4f}).")
    print('Best model dev metrics:', {**best_metrics, 'val_loss': float(best_val_loss)})


Train sampling: WeightedRandomSampler (oversampling enabled)
Augmentation: True | AUG_PROB: 0.5 | ctx_window_chars: 900
Oversampling: True | weight_decay: 0.01
Early stopping: True | patience: 2 | best metric: recall_avg_macro
Train batches: 193 Dev batches: 5
Total steps: 1158 Warmup: 69


Epoch 1/6:   0%|          | 0/193 [00:00<?, ?it/s]


Progress so far:
 epoch  train_loss  val_loss       lr  acc_at  acc_isAt  acc_avg  recall_at_macro  recall_isAt_macro  recall_avg_macro
   1.0      1.7679    2.0002 1.77e-05  0.2053    0.3377   0.2715             0.33             0.4559            0.3929


Epoch 2/6:   0%|          | 0/193 [00:00<?, ?it/s]


Progress so far:
 epoch  train_loss  val_loss       lr  acc_at  acc_isAt  acc_avg  recall_at_macro  recall_isAt_macro  recall_avg_macro
   1.0      1.7679    2.0002 1.77e-05  0.2053    0.3377   0.2715           0.3300             0.4559            0.3929
   2.0      1.5160    1.7575 1.42e-05  0.4437    0.6887   0.5662           0.4753             0.5831            0.5292


Epoch 3/6:   0%|          | 0/193 [00:00<?, ?it/s]


Progress so far:
 epoch  train_loss  val_loss       lr  acc_at  acc_isAt  acc_avg  recall_at_macro  recall_isAt_macro  recall_avg_macro
   1.0      1.7679    2.0002 1.77e-05  0.2053    0.3377   0.2715           0.3300             0.4559            0.3929
   2.0      1.5160    1.7575 1.42e-05  0.4437    0.6887   0.5662           0.4753             0.5831            0.5292
   3.0      1.2859    2.1292 1.06e-05  0.3510    0.5629   0.4570           0.4039             0.5838            0.4938


Epoch 4/6:   0%|          | 0/193 [00:00<?, ?it/s]


Progress so far:
 epoch  train_loss  val_loss       lr  acc_at  acc_isAt  acc_avg  recall_at_macro  recall_isAt_macro  recall_avg_macro
   1.0      1.7679    2.0002 1.77e-05  0.2053    0.3377   0.2715           0.3300             0.4559            0.3929
   2.0      1.5160    1.7575 1.42e-05  0.4437    0.6887   0.5662           0.4753             0.5831            0.5292
   3.0      1.2859    2.1292 1.06e-05  0.3510    0.5629   0.4570           0.4039             0.5838            0.4938
   4.0      1.0950    2.0999 7.09e-06  0.4437    0.6887   0.5662           0.4581             0.6552            0.5566

Early stopping: val_loss did not improve for 2 epoch(s).

Restored best model from epoch 4 (best recall_avg_macro=0.5566).
Best model dev metrics: {'acc_at': 0.44370860927152317, 'acc_isAt': 0.6887417218543046, 'acc_avg': 0.5662251655629139, 'recall_at_macro': 0.4580547416923347, 'recall_isAt_macro': 0.6551796157059315, 'recall_avg_macro': 0.5566171786991331, 'val_loss': 2.0998814582

In [10]:
def write_predictions_jsonl(
    docs: List[Dict[str, Any]],
    examples: List[PairExample],
    pred_at_ids: List[int],
    pred_isat_ids: List[int],
    out_path: Path,
) -> None:
    """Write a HIPE-format JSONL file.

    We preserve the original doc fields, and overwrite each pair's `at` and `isAt` with predictions.
    """
    assert len(examples) == len(pred_at_ids) == len(pred_isat_ids)

    # Group predictions by doc_id and pair_index
    by_doc: Dict[str, Dict[int, Tuple[str, str]]] = {}
    for ex, a_id, i_id in zip(examples, pred_at_ids, pred_isat_ids):
        by_doc.setdefault(ex.document_id, {})[ex.pair_index] = (ID2AT[int(a_id)], ID2ISAT[int(i_id)])

    with out_path.open('w', encoding='utf-8') as f:
        for doc in docs:
            doc_id = doc['document_id']
            new_doc = dict(doc)
            new_pairs = []
            for idx, pair in enumerate(doc.get('sampled_pairs', [])):
                new_pair = dict(pair)
                at_pred, isat_pred = by_doc.get(doc_id, {}).get(idx, ('FALSE', 'FALSE'))
                new_pair['at'] = at_pred
                new_pair['isAt'] = isat_pred
                new_pairs.append(new_pair)
            new_doc['sampled_pairs'] = new_pairs
            f.write(json.dumps(new_doc, ensure_ascii=False) + '\n')

# Export dev predictions (for DEV_LANG)
p_at_dev, p_isAt_dev = predict_batches(model, dev_loader)
dev_lang = globals().get('DEV_LANG', 'dev')
output_path = Path.cwd() / f'{dev_lang}_dev_predictions.jsonl'
write_predictions_jsonl(dev_docs, dev_ex, p_at_dev, p_isAt_dev, output_path)
print('Wrote:', output_path)


Wrote: /content/en_dev_predictions.jsonl
