# Full Extraction (bucketed)

Bucketed, resumable Tier-3 full-text extraction with the schema/prompt and SIMBAD resolution.


## 0) Setup: imports, paths, OpenAI client


In [None]:
import os
import json
import time
import textwrap
from pathlib import Path
import random
import re
from itertools import islice
from collections import defaultdict, Counter

import pandas as pd
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm

try:
    from openai import OpenAI
except Exception:
    OpenAI = None

load_dotenv()

# -----------------------
# Repo / IO paths (public)
# -----------------------
# If running in a notebook under repo/notebooks/, this resolves to repo root.
REPO_ROOT = Path.cwd().resolve()
if (REPO_ROOT / "src").exists() and (REPO_ROOT / "config").exists():
    pass
else:
    # fallback: assume notebook is in repo/notebooks/
    REPO_ROOT = REPO_ROOT.parents[0]

DATA_DIR = Path(os.environ.get("ASTRO_DATA_DIR", str(REPO_ROOT / "data")))
OUT_ROOT = Path(os.environ.get("ASTRO_OUT_DIR", str(REPO_ROOT / "outputs")))

PROMPT_VERSION = os.environ.get("PROMPT_VERSION", "full_extraction")
SCHEMA_VERSION = os.environ.get("SCHEMA_VERSION", "2026-01-18")

RUN_DIR = OUT_ROOT / "full_extraction"
REQ_DIR = RUN_DIR / "requests"
OUT_DIR = RUN_DIR / "outputs"
USAGE_DIR = RUN_DIR / "usage"
CACHE_DIR = RUN_DIR / "cache"

for d in (REQ_DIR, OUT_DIR, USAGE_DIR, CACHE_DIR):
    d.mkdir(parents=True, exist_ok=True)

# -----------------------
# Upstream inputs (private)
# -----------------------
# These are not included in the public release. Users must provide paths if they
# want to run extraction end-to-end.

KG_ROOT = Path(os.environ.get("ASTRO_KG_ROOT", str(REPO_ROOT / "data")))
SUMMARIES_PATH = Path(os.environ.get("ASTRO_SUMMARIES_PATH", str(KG_ROOT / "papers_summaries.jsonl")))
ABSTRACTS_PATH = Path(os.environ.get("ASTRO_ABSTRACTS_PATH", str(KG_ROOT / "abstracts_all.jsonl")))

# Full OCR repository (private; not distributed). If absent, extraction cannot run,
# but downstream post-processing can still run if bucket outputs exist.
OCR_DIR = Path(os.environ.get("ASTRO_OCR_DIR", ""))  # set to enable OCR-based extraction
OCR_PATHS = []
if str(OCR_DIR).strip():
    OCR_DIR = OCR_DIR.expanduser().resolve()
    if OCR_DIR.exists():
        OCR_PATHS = sorted(OCR_DIR.glob("*.jsonl"))
    else:
        print(f"[WARN] OCR_DIR was set but does not exist: {OCR_DIR}")
else:
    print("[INFO] OCR_DIR not set. Skipping OCR discovery (expected for public release).")

# -----------------------
# Extraction controls
# -----------------------
MODEL = os.environ.get("ASTRO_LLM_MODEL", "gpt-5-mini")  # logical model name
MODEL_DEPLOYMENT = os.environ.get("ASTRO_LLM_DEPLOYMENT", MODEL)  # Azure/OpenAI deployment name if applicable

MAX_OCR_CHARS = int(os.environ.get("MAX_OCR_CHARS", "583000"))      # cap to keep prompt size manageable
MAX_EVIDENCE_CHARS = int(os.environ.get("MAX_EVIDENCE_CHARS", "300"))

# -----------------------
# Optional OpenAI/Azure client
# -----------------------
client = None
endpoint = os.environ.get("ASTRO_OPENAI_ENDPOINT", "")  # e.g., https://<resource>.openai.azure.com/openai/v1
api_key = os.environ.get("ASTRO_OPENAI_API_KEY", "")    # or ASTROMLAB_API_KEY; keep generic for release

if OpenAI is None:
    print("[INFO] openai package not available. Install `openai` to run extraction calls.")
elif endpoint and api_key:
    client = OpenAI(base_url=endpoint, api_key=api_key)
    print("[INFO] OpenAI client initialized from ASTRO_OPENAI_ENDPOINT / ASTRO_OPENAI_API_KEY.")
else:
    print("[INFO] OpenAI client not initialized. Set ASTRO_OPENAI_ENDPOINT and ASTRO_OPENAI_API_KEY to enable extraction calls.")

print("REPO_ROOT :", REPO_ROOT)
print("DATA_DIR  :", DATA_DIR)
print("OUT_ROOT  :", OUT_ROOT)
print("RUN_DIR   :", RUN_DIR)
print("SUMMARIES :", SUMMARIES_PATH, "| exists =", SUMMARIES_PATH.exists())
print("ABSTRACTS :", ABSTRACTS_PATH, "| exists =", ABSTRACTS_PATH.exists())
print("OCR_DIR   :", OCR_DIR if str(OCR_DIR).strip() else "<unset>")
print("OCR files :", len(OCR_PATHS))
print("MODEL     :", MODEL, "| DEPLOYMENT :", MODEL_DEPLOYMENT)

## 1) Load paper metadata (title/abstract/summary) and build id_to_meta


In [None]:
# 1) Load paper metadata (title/abstract/summary) and build id_to_meta (release-safe)

def read_jsonl(path: Path):
    with path.open() as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)

def _warn_missing(path: Path, name: str):
    print(f"[WARN] {name} not found: {path}")
    print("       This notebook can still run post-processing steps if you already have bucket outputs,")
    print("       but end-to-end extraction requires these metadata files.")
    return {}

# In the public release, these upstream metadata files may not exist.
# We therefore degrade gracefully rather than asserting.
print("Loading summaries and abstracts...")

summaries_by_arxiv = {}
abstracts_by_arxiv = {}

if SUMMARIES_PATH.exists():
    summaries_by_arxiv = {row.get("arxiv_id"): row.get("summary") for row in read_jsonl(SUMMARIES_PATH) if row.get("arxiv_id")}
else:
    _warn_missing(SUMMARIES_PATH, "Summaries file")

if ABSTRACTS_PATH.exists():
    abstracts_by_arxiv = {row.get("arxiv_id"): (row.get("abstract", "") or "") for row in read_jsonl(ABSTRACTS_PATH) if row.get("arxiv_id")}
else:
    _warn_missing(ABSTRACTS_PATH, "Abstracts file")

rows = []
for aid, summary in summaries_by_arxiv.items():
    if not aid:
        continue

    s_dict = summary or {}
    title_and_author = s_dict.get("title_and_author", "") or ""

    title = ""
    if title_and_author:
        lines = [ln.strip() for ln in title_and_author.splitlines() if ln.strip()]
        if lines:
            title = lines[0].strip("* ").strip()

    summary_text_parts = []
    for k, v in s_dict.items():
        if k == "title_and_author":
            continue
        if isinstance(v, str) and v.strip():
            summary_text_parts.append(v.strip())
    summary_text = "\n\n".join(summary_text_parts)

    abstract_text = abstracts_by_arxiv.get(aid, "")

    if not (title.strip() or summary_text.strip() or abstract_text.strip()):
        continue

    rows.append(
        {
            "arxiv_id": aid,
            "title": title,
            "summary_text": summary_text,
            "abstract_text": abstract_text,
        }
    )

all_papers_df = pd.DataFrame(rows)

id_to_meta = {
    row["arxiv_id"]: {
        "arxiv_id": row["arxiv_id"],
        "title": row.get("title", ""),
        "summary_text": row.get("summary_text", ""),
        "abstract_text": row.get("abstract_text", ""),
    }
    for _, row in all_papers_df.iterrows()
}

ALL_ARXIV_IDS = sorted(id_to_meta.keys())

print("Usable papers with some text:", len(all_papers_df))
print("Unique arxiv_ids in metadata:", len(ALL_ARXIV_IDS))

try:
    display(all_papers_df.head(3))
except Exception:
    print(all_papers_df.head(3).to_string(index=False))

Loading summaries and abstracts...
Usable papers with some text: 408590
Unique arxiv_ids in metadata: 408590


Unnamed: 0,arxiv_id,title,summary_text,abstract_text
0,704.0007,Polymer Quantum Mechanics and its Continuum Limit,Polymer quantum mechanics presents an unconven...,A rather non-standard quantum representation o...
1,704.0009,"The Spitzer c2d Survey of Large, Nearby, Inter...",The Serpens star-forming cloud is one of five ...,We discuss the results from the combined IRAC ...
2,704.0017,Spectroscopic Observations of the Intermediate...,EX Hydrae (EX Hya) is classified as an Interme...,Results from spectroscopic observations of the...


## 2) Define schema + prompt builder


In [None]:
STUDY_MODES = {
    "new_observation",         # authors report newly obtained observations of this object
    "archival_or_reanalysis",  # reanalysis of previously obtained/public data
    "catalog_compilation",     # primarily catalog/survey DB values + crossmatch/statistical compilation
    "simulation_or_theory",    # modeling/simulations/analytic theory; no direct observational analysis of this object here
    "not_applicable",          # object is mentioned only as comparison/calibration/incidental; not analyzed as data subject
    "unknown",                 # truly ambiguous from provided text
}

STUDY_MODE_DEFINITIONS: dict[str, str] = {
    "new_observation": "New observations of this object are presented in this paper (authors obtained new data).",
    "archival_or_reanalysis": "This paper analyzes previously obtained/public data for this object (archive/reanalysis).",
    "catalog_compilation": "This paper primarily uses catalog/survey database values for this object (crossmatch/compilation/statistics).",
    "simulation_or_theory": "This paper discusses this object only in theory/simulation/analytic modeling (no direct observational analysis here).",
    "not_applicable": "Object is only referenced for comparison/calibration/incidental context; not an analyzed data subject here.",
    "unknown": "Cannot determine from the provided text.",
}

STUDY_MODE_PRIORITY = [
    "new_observation",
    "archival_or_reanalysis",
    "catalog_compilation",
    "simulation_or_theory",
    "not_applicable",
    "unknown",
]

study_mode_lines = "\n".join(f"- {m}: {STUDY_MODE_DEFINITIONS[m]}" for m in STUDY_MODE_PRIORITY)
study_mode_enum = ", ".join([f'"{m}"' for m in STUDY_MODE_PRIORITY])
study_mode_priority_str = " > ".join(STUDY_MODE_PRIORITY)


ROLES = {
    "primary_subject",              # the object is a main focus of the paper's results
    "member_of_sample",             # included in a study sample / catalog list used in analysis
    "host_or_counterpart",          # established association / ID / host / counterpart
    "candidate_association",        # tentative counterpart/host/ID (hedged)
    "foreground_or_lens",           # confirmed foreground/lens
    "background_or_lensed_source",  # (optional) explicitly the lensed/background source
    "comparison_or_reference",      # used as benchmark / reference example (not a studied subject here)
    "calibration",                  # used for PSF/flux/astrometric/telluric calibration
    "serendipitous_or_field_source",# incidental neighbor/contaminant in field
    "non_detection_or_upper_limit", # searched/targeted but non-detection/upper limit
    "other",
}


ROLE_DEFINITIONS: dict[str, str] = {
    # High-signal, primary edges
    "primary_subject": (
        "Primary subject of the paper's results: the object the paper mainly analyzes or draws conclusions about "
        "(observational, archival, catalog-based, or theoretical). Often central to the title/abstract/results."
    ),

    "member_of_sample": (
        "Object is part of a study sample/survey used for statistical analysis or population results "
        "(often listed in tables). Not necessarily deeply analyzed individually."
    ),

    # Association edges
    "host_or_counterpart": (
        "Confirmed/established physical association or identification: host galaxy, optical/IR/radio counterpart, "
        "companion star, etc. Language implies it is known/accepted in the paper."
    ),
    "candidate_association": (
        "Proposed/tentative association or identification: candidate counterpart/host/lens/ID. "
        "Use when paper hedges (possible/likely/may be/consistent with) or not confirmed."
    ),

    # Lensing edges (optional split)
    "foreground_or_lens": (
        "Foreground object affecting the observation, especially a confirmed gravitational lens or foreground star/galaxy."
    ),
    "background_or_lensed_source": (
        "Background source being lensed/magnified (the lensed object), if explicitly discussed as the source behind the lens."
    ),

    # Context / methodology edges (lower signal)
    "comparison_or_reference": (
        "Object used mainly as a comparison/benchmark/reference example (not a main analyzed target), "
        "e.g., classic sources used to compare spectra/lightcurves."
    ),
    "calibration": (
        "Object used operationally to calibrate or characterize data (flux/PSF/astrometric/telluric standard, etc.)."
    ),
    "serendipitous_or_field_source": (
        "Incidental object in the field/nearby/contaminating/blended/background source mentioned for context "
        "rather than as a scientific target."
    ),
    "non_detection_or_upper_limit": (
        "Object was searched/targeted for a signal but the result is a non-detection or upper limit."
    ),
    "other": "Does not fit above roles; use sparingly.",
}
_missing_defs = sorted(ROLES - set(ROLE_DEFINITIONS))
_extra_defs = sorted(set(ROLE_DEFINITIONS) - ROLES)
if _missing_defs or _extra_defs:
    raise ValueError(f"ROLE_DEFINITIONS mismatch. missing={_missing_defs} extra={_extra_defs}")
ROLE_PRIORITY = [
    "non_detection_or_upper_limit",
    "primary_subject",
    "member_of_sample",

    "host_or_counterpart",
    "candidate_association",
    "foreground_or_lens",
    "background_or_lensed_source",

    "calibration",
    "serendipitous_or_field_source",
    "comparison_or_reference",
    "other",
]

role_lines = "\n".join(
    f"- {r}: {ROLE_DEFINITIONS[r]}"
    for r in ROLE_PRIORITY
    if r in ROLES
)
role_enum = ", ".join([f'"{r}"' for r in ROLE_PRIORITY if r in ROLES])
priority_str = " > ".join([r for r in ROLE_PRIORITY if r in ROLES])
EVIDENCE_SOURCES = {"title", "abstract", "body"}

# 2) Define schema + prompt builder (ASTRO OBJECTS ONLY, simplified)
SCHEMA_OBJ = {
  "type": "object",
  "additionalProperties": False,
  "required": ["arxiv_id", "objects"],
  "properties": {
    "arxiv_id": {"type": "string"},
    "objects": {
      "type": "array",
      "items": {
        "type": "object",
        "additionalProperties": False,
        "required": ["name", "role", "study_mode", "evidence_span", "evidence_source"],
        "properties": {
          "name": {"type": "string"},
          "role": {"type": "string", "enum": sorted(ROLES)},
          "study_mode": {"type": "string", "enum": sorted(STUDY_MODES)},
          "evidence_span": {"type": "string"},
          "evidence_source": {"type": "string", "enum": sorted(EVIDENCE_SOURCES)},
        },
      },
    },
  },
}

SCHEMA_JSON = json.dumps(SCHEMA_OBJ, indent=2)
SCHEMA_DOC = SCHEMA_JSON


# ---- helpers ----
FORBIDDEN_NAME_CHARS = [",", "/", "=", ";", "(", ")", "[", "]"]

def _check_enum(val, allowed: set[str], field: str, errors: list[str]):
    if val not in allowed:
        errors.append(f"{field} invalid: {val}")

def _violates_single_designation(name: str) -> bool:
    # enforce “single designation only” cheaply
    if any(ch in name for ch in FORBIDDEN_NAME_CHARS):
        return True
    # disallow obvious multi-name separators
    if " and " in name.lower():
        return True
    return False

def build_prompt_fulltext(meta: dict, full_text: str) -> tuple[str, str]:
    raw_text = full_text or ""
    truncated_text = raw_text if (MAX_OCR_CHARS is None or len(raw_text) <= MAX_OCR_CHARS) else raw_text[:MAX_OCR_CHARS]

    system_msg = textwrap.dedent(
        f"""
        You are an expert astrophysicist helping to build a paper-object knowledge graph.

        You will be given:
        - Paper title
        - Paper abstract
        - Full OCR'd paper text (may be noisy / truncated)

        TASK:
        Extract ONLY named astrophysical objects intended for SIMBAD-style resolution, e.g.:
        - Galaxies: NGC/IC/M/UGC/ESO/SDSS objects (e.g., "NGC 1275", "M 31", "SDSS J1234+5678")
        - Stars: HD/HIP/TYC/2MASS/Gaia/variable-star designations, etc.
        - Clusters / nebulae / SNR / AGN / radio sources / pulsars with standard designations (PSR J..., 3C..., etc.)

        DO NOT include:
        - Instruments, facilities, surveys, catalogs (e.g., ALMA, SDSS, Gaia)
        - Sky regions / fields / pointings
        - Solar system bodies or features (rings/craters/etc.)
        - Generic classes ("galaxies", "supernovae") unless a specific named object is given
        - Raw coordinates unless explicitly labeled as a named object

        ROLE (choose exactly one):
        Allowed roles (use these exact strings):
        [{role_enum}]

        ROLE DEFINITIONS:
        {role_lines}

        ROLE CHOICE RULES (if multiple could apply, choose the highest priority):
        {priority_str}
        Notes:
        - If an object is a primary target BUT the paper reports only an upper limit/non-detection, use non_detection_or_upper_limit.
        - Use candidate_association when the paper is uncertain (possible/likely/may be/consistent with), even if a counterpart name is given.
        - Use calibration only for operational calibration/PSF/standards (not “we compare to Crab”).
        - Use serendipitous_or_field_source for incidental neighbors/contaminants/blends.

        STUDY MODE (choose exactly one per object):
        Allowed study_mode values (use these exact strings):
        [{study_mode_enum}]

        STUDY MODE DEFINITIONS:
        {study_mode_lines}

        STUDY MODE CHOICE RULES (if multiple could apply, choose highest priority):
        {study_mode_priority_str}
        Notes:
        - primary_subject / non_detection_or_upper_limit should almost never have study_mode=not_applicable.
        - comparison_or_reference / calibration / serendipitous_or_field_source should usually have study_mode=not_applicable.
        - Use unknown only if truly ambiguous from the provided text.



        EVIDENCE:
        For each object, include:
        - name
        - role
        - study_mode
        - evidence_span (<= {MAX_EVIDENCE_CHARS} chars)
        - evidence_source


        PRECISION RULES:
        - Be conservative: omit passing mentions and long literature lists without new analysis.
        - Prefer objects central to the paper (main targets, analyzed sample members, key counterparts, major comparison objects).
        - Typical object count <= 25; hard cap at 40.
        - Do not invent names. Use text as written (minor whitespace fixes OK).

        SINGLE DESIGNATION ONLY:
        - name MUST be exactly ONE designation string.
        - name must NOT contain parentheses, commas, slashes, semicolons, '=' or multiple aliases joined together.

        OUTPUT FORMAT:
        Return a single JSON object that strictly matches this schema (no extra keys, valid JSON):
        {SCHEMA_DOC}
        """
    ).strip()

    # Important: do NOT include prompt_version/schema_version here;
    # the model may echo them into JSON as extra keys.
    user_msg = textwrap.dedent(
        f"""
        Paper arXiv ID: {meta['arxiv_id']}

        TITLE:
        {meta.get('title', '').strip()}

        ABSTRACT:
        {meta.get('abstract_text', '').strip()}

        FULL PAPER TEXT (OCR, may be truncated):
        \"\"\"
        {truncated_text}
        \"\"\"

        Return only the JSON object described above.
        """
    ).strip()

    return system_msg, user_msg


def make_request(arxiv_id: str, sys_msg: str, user_msg: str) -> dict:
    custom_id = f"V2-{arxiv_id}"
    return {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": MODEL,
            "response_format": {"type": "json_object"},
            "messages": [
                {"role": "system", "content": sys_msg},
                {"role": "user", "content": user_msg},
            ],
        },
        # keep versions here (safe), not in the prompt-visible JSON schema
        "metadata": {
            "arxiv_id": arxiv_id,
            "prompt_version": PROMPT_VERSION,
            "schema_version": SCHEMA_VERSION,
            "model": MODEL,
        },
    }


def validate_record(rec: dict) -> list[str]:
    errors: list[str] = []
    if not isinstance(rec, dict):
        return ["record is not an object"]

    # Require exact top-level keys (strict)
    required_top = {"arxiv_id", "objects"}
    for k in required_top:
        if k not in rec:
            errors.append(f"missing key: {k}")

    extra_top = set(rec.keys()) - required_top
    if extra_top:
        errors.append(f"extra top-level keys: {sorted(extra_top)}")

    arxiv_id = rec.get("arxiv_id")
    if not isinstance(arxiv_id, str) or not arxiv_id.strip():
        errors.append("arxiv_id missing/empty")

    objects = rec.get("objects")
    if not isinstance(objects, list):
        errors.append("objects is not a list")
        return errors

    if len(objects) > 40:
        errors.append("too many objects (>40)")

    for idx, obj in enumerate(objects):
        prefix = f"objects[{idx}]"
        if not isinstance(obj, dict):
            errors.append(f"{prefix} not an object")
            continue

        required_obj = {"name", "role", "study_mode", "evidence_span", "evidence_source"}

        for k in required_obj:
            if k not in obj:
                errors.append(f"missing key: {prefix}.{k}")
        role = obj.get("role")
        mode = obj.get("study_mode")

        # Roles that are usually "not_applicable" (context-only)
        context_roles = {"comparison_or_reference", "calibration", "serendipitous_or_field_source"}

        # Roles that should almost never be "not_applicable"
        substantive_roles = {
            "primary_subject", "member_of_sample", "host_or_counterpart", "candidate_association",
            "foreground_or_lens", "background_or_lensed_source", "non_detection_or_upper_limit"
        }

        if role in substantive_roles and mode == "not_applicable":
            errors.append(f"{prefix}.study_mode not_applicable inconsistent with role={role}")

        if role in context_roles and mode not in {"not_applicable", "unknown"}:
            errors.append(f"{prefix}.study_mode should usually be not_applicable for role={role} (got {mode})")

        if role == "non_detection_or_upper_limit" and mode == "simulation_or_theory":
            errors.append(f"{prefix}.study_mode simulation_or_theory inconsistent with non_detection_or_upper_limit")


        extra_obj = set(obj.keys()) - required_obj
        if extra_obj:
            errors.append(f"extra keys in {prefix}: {sorted(extra_obj)}")

        name = obj.get("name")
        if not isinstance(name, str) or not name.strip():
            errors.append(f"{prefix}.name missing/empty")
        else:
            if _violates_single_designation(name):
                errors.append(f"{prefix}.name violates single-designation rule")

        _check_enum(obj.get("role"), ROLES, f"{prefix}.role", errors)
        _check_enum(obj.get("study_mode"), STUDY_MODES, f"{prefix}.study_mode", errors)


        ev_span = obj.get("evidence_span", "") or ""
        if not isinstance(ev_span, str) or not ev_span.strip():
            errors.append(f"{prefix}.evidence_span missing/empty")
        elif len(ev_span) > MAX_EVIDENCE_CHARS:
            errors.append(f"{prefix}.evidence_span too long")

        _check_enum(obj.get("evidence_source"), EVIDENCE_SOURCES, f"{prefix}.evidence_source", errors)

    return errors


## 3) Bucket plan (manifest) + done_id detection


In [None]:
# 3) Bucket plan (manifest) + done_id detection 
#
# Notes:
# - This notebook documents the extraction workflow used to generate the released artifacts.
# - This section supports resumable execution by detecting which arXiv IDs already have outputs.

# Bucket controls (override via env without editing the notebook)
BUCKET_SIZE = int(os.environ.get("BUCKET_SIZE", "5000"))
BUCKET_START = int(os.environ.get("BUCKET_START", "0"))
_bucket_end_raw = os.environ.get("BUCKET_END", "").strip()
BUCKET_END = int(_bucket_end_raw) if _bucket_end_raw else None  # set to an int to truncate

selected_ids = ALL_ARXIV_IDS[BUCKET_START:BUCKET_END]
buckets = [selected_ids[i : i + BUCKET_SIZE] for i in range(0, len(selected_ids), BUCKET_SIZE)]

manifest_rows = []
for idx, ids in enumerate(buckets):
    bucket_id = f"{idx:04d}"
    manifest_rows.append(
        {
            "bucket_id": bucket_id,
            "n_papers": len(ids),
            "first_arxiv_id": ids[0] if ids else None,
            "last_arxiv_id": ids[-1] if ids else None,
        }
    )

bucket_manifest_path = RUN_DIR / "bucket_manifest.csv"
pd.DataFrame(manifest_rows).to_csv(bucket_manifest_path, index=False)
print(f"Wrote bucket manifest with {len(manifest_rows)} buckets -> {bucket_manifest_path}")

bucket_index = {row["bucket_id"]: buckets[idx] for idx, row in enumerate(manifest_rows)}


def load_done_arxiv_ids(out_dir: Path) -> set[str]:
    """Collect arXiv IDs that already have a successful parsed record in bucket outputs."""
    if not out_dir.exists():
        return set()

    done = set()
    for path in sorted(out_dir.glob("bucket_*_outputs.jsonl")):
        with path.open() as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    rec = json.loads(line)
                except Exception:
                    continue
                if rec.get("error"):
                    continue
                parsed = rec.get("parsed")
                if isinstance(parsed, dict):
                    aid = (
                        parsed.get("arxiv_id")
                        or rec.get("arxiv_id")
                        or (rec.get("metadata") or {}).get("arxiv_id")
                    )
                    if aid:
                        done.add(aid)
    return done

DONE_ARXIV_IDS = load_done_arxiv_ids(OUT_DIR)
print("Done arxiv_ids from existing outputs:", len(DONE_ARXIV_IDS))
print("Remaining arxiv_ids to process:", len(set(ALL_ARXIV_IDS) - DONE_ARXIV_IDS))

# For release: only skip what has already been completed in outputs.
SKIP_ARXIV_IDS = DONE_ARXIV_IDS

## 4) Build requests for one bucket (stream OCR jsonl)


In [None]:
# After manifest_rows is built
bucket_manifest = pd.DataFrame(manifest_rows)

# Use RUN_DIR from release-safe setup
bucket_manifest_path = RUN_DIR / "bucket_manifest.csv"
bucket_manifest.to_csv(bucket_manifest_path, index=False)

print(f"Wrote bucket manifest with {len(bucket_manifest)} buckets -> {bucket_manifest_path}")

bucket_index = {row["bucket_id"]: buckets[idx] for idx, row in enumerate(manifest_rows)}
BUCKET_IDS = sorted(bucket_index.keys())

print("First bucket IDs:", BUCKET_IDS[:5])

Wrote bucket manifest with 82 buckets -> /Users/jinchuli/projects/astro-llm-tools/data/llm_object_extraction_final/full_extraction_v2/bucket_manifest.csv
First bucket IDs: ['0000', '0001', '0002', '0003', '0004']


In [None]:
# 4) Build requests for one or more buckets (stream OCR jsonl) 
#
# Notes:
# - End-to-end request building requires private OCR inputs (OCR_DIR) and metadata (id_to_meta).
# - In the public release, this cell is safe to import/inspect, but it will no-op unless OCR_DIR is set and exists.

def build_requests_for_bucket(
    bucket_id: str,
    allowed_ids: set[str] | None = None,
    skip_ids: set[str] | None = None,
) -> set[str]:
    if bucket_id not in bucket_index:
        raise ValueError(f"Unknown bucket_id: {bucket_id}")

    # OCR is not included in the public release; require OCR_DIR to be set and valid.
    if not str(OCR_DIR).strip() or not Path(OCR_DIR).exists():
        print(f"[WARN] OCR_DIR is not set or does not exist. Cannot build requests for bucket {bucket_id}.")
        return set()

    target_ids = set(bucket_index[bucket_id])
    if allowed_ids:
        target_ids &= set(allowed_ids)
    if skip_ids:
        target_ids -= set(skip_ids)
    if not target_ids:
        print(f"No target_ids to build for bucket {bucket_id}")
        return set()

    req_path = REQ_DIR / f"bucket_{bucket_id}_requests.jsonl"
    req_path.parent.mkdir(parents=True, exist_ok=True)

    missing_ids = set(target_ids)
    found_ids = set()

    with req_path.open("w") as fout:
        for ocr_path in sorted(Path(OCR_DIR).glob("*.jsonl")):
            if not missing_ids:
                break

            with ocr_path.open() as fin:
                for line in fin:
                    line = line.strip()
                    if not line:
                        continue
                    try:
                        obj = json.loads(line)
                    except Exception:
                        continue

                    aid = obj.get("arxiv_id")
                    if aid not in missing_ids:
                        continue

                    meta = id_to_meta.get(aid)
                    if not meta:
                        # Missing metadata: skip (request would be low quality / incomplete)
                        missing_ids.remove(aid)
                        continue

                    sys_msg, user_msg = build_prompt_fulltext(
                        meta,
                        obj.get("ocr_markdown", "") or "",
                    )
                    req = make_request(aid, sys_msg=sys_msg, user_msg=user_msg)
                    fout.write(json.dumps(req) + "\n")

                    found_ids.add(aid)
                    missing_ids.remove(aid)

    if missing_ids:
        missing_path = RUN_DIR / f"missing_ocr_bucket_{bucket_id}.txt"
        with missing_path.open("w") as m:
            for mid in sorted(missing_ids):
                m.write(mid + "\n")
        print(f"OCR missing for {len(missing_ids)} ids -> {missing_path}")

    print(f"Wrote {len(found_ids)} requests to {req_path}")
    return found_ids


# ---- SAFE controls (default does nothing unless explicitly enabled) ----
# Use env overrides so users don't need to edit the notebook.
TARGET_BUCKET_ID = os.environ.get("TARGET_BUCKET_ID", "").strip()       # e.g. "0" or "0000"
TARGET_BUCKET_IDS = os.environ.get("TARGET_BUCKET_IDS", "").strip()     # e.g. "0000,0001"
BUILD_ALL_BUCKETS = bool(int(os.environ.get("BUILD_ALL_BUCKETS", "0"))) # default OFF in release

if TARGET_BUCKET_IDS:
    target_bucket_ids = [b.strip().zfill(4) for b in TARGET_BUCKET_IDS.split(",") if b.strip()]
elif TARGET_BUCKET_ID:
    target_bucket_ids = [str(TARGET_BUCKET_ID).strip().zfill(4)]
elif BUILD_ALL_BUCKETS:
    target_bucket_ids = BUCKET_IDS
else:
    target_bucket_ids = []

print("Target buckets:", target_bucket_ids)

Target buckets: ['0000', '0001', '0002', '0003', '0004', '0005', '0006', '0007', '0008', '0009', '0010', '0011', '0012', '0013', '0014', '0015', '0016', '0017', '0018', '0019', '0020', '0021', '0022', '0023', '0024', '0025', '0026', '0027', '0028', '0029', '0030', '0031', '0032', '0033', '0034', '0035', '0036', '0037', '0038', '0039', '0040', '0041', '0042', '0043', '0044', '0045', '0046', '0047', '0048', '0049', '0050', '0051', '0052', '0053', '0054', '0055', '0056', '0057', '0058', '0059', '0060', '0061', '0062', '0063', '0064', '0065', '0066', '0067', '0068', '0069', '0070', '0071', '0072', '0073', '0074', '0075', '0076', '0077', '0078', '0079', '0080', '0081']


In [None]:
if not target_bucket_ids:
    print("No buckets selected. Set TARGET_BUCKET_ID or TARGET_BUCKET_IDS (or BUILD_ALL_BUCKETS=True).")
else:
    for bid in target_bucket_ids:
        build_requests_for_bucket(
            bucket_id=bid,
            allowed_ids=set(ALL_ARXIV_IDS),
            skip_ids=SKIP_ARXIV_IDS,
        )

## 5) Execute bucket (resumable) + usage logging


In [None]:

# 5) Execute bucket (resumable) + usage logging

from concurrent.futures import ThreadPoolExecutor, as_completed

def load_existing_results(path: Path) -> dict[str, dict]:
    done = {}
    if not path.exists():
        return done
    with path.open() as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue
            cid = rec.get("custom_id")
            if cid:
                done[cid] = rec
    return done

def run_resumable(
    requests_path: Path,
    output_path: Path,
    usage_path: Path,
    max_workers: int = 30,
    max_retries: int = 2,
    retry_delay: float = 3.0,
):
    if client is None:
        print("[WARN] OpenAI client not initialized. Skipping LLM execution.")
        print("       Set ASTRO_OPENAI_ENDPOINT and ASTRO_OPENAI_API_KEY to enable this step.")
        return
    if not requests_path.exists():
        raise FileNotFoundError(f"Requests file not found: {requests_path}")

    all_reqs = [json.loads(line) for line in requests_path.open() if line.strip()]
    print(f"[Run] Loaded {len(all_reqs)} requests from {requests_path}")

    existing = load_existing_results(output_path)
    done_cids = set()
    for cid, rec in existing.items():
        if rec.get("error"):
            continue
        if rec.get("parsed"):
            done_cids.add(cid)

    pending = [r for r in all_reqs if r["custom_id"] not in done_cids]
    print(f"[Run] {len(pending)} requests pending ({len(done_cids)} already done)")

    if not pending:
        print("[Run] Nothing to do, all requests already processed.")
        return

    total_prompt_tokens = 0
    total_completion_tokens = 0

    def call_one(req: dict) -> dict:
        meta = req.get("metadata") or {}
        arxiv_id = meta.get("arxiv_id")
        prompt_version = meta.get("prompt_version", PROMPT_VERSION)
        schema_version = meta.get("schema_version", SCHEMA_VERSION)

        last_err = None
        for attempt in range(max_retries + 1):
            try:
                resp = client.chat.completions.create(
                    model=MODEL_DEPLOYMENT,
                    messages=req["body"]["messages"],
                    response_format=req["body"].get("response_format"),
                )
                content = resp.choices[0].message.content
                try:
                    parsed = json.loads(content)
                    parse_error = None
                except Exception as e:
                    parsed = None
                    parse_error = str(e)

                schema_errors = validate_record(parsed) if parsed else []

                usage = getattr(resp, "usage", None)
                usage_rec = None
                if usage is not None:
                    usage_rec = {
                        "custom_id": req.get("custom_id"),
                        "arxiv_id": arxiv_id,
                        "prompt_tokens": usage.prompt_tokens,
                        "completion_tokens": usage.completion_tokens,
                        "total_tokens": usage.total_tokens,
                    }

                return {
                    "custom_id": req.get("custom_id"),
                    "arxiv_id": arxiv_id,
                    "prompt_version": prompt_version,
                    "schema_version": schema_version,
                    "model": MODEL,
                    "raw_response": content,
                    "parsed": parsed,
                    "error": None,
                    "parse_error": parse_error,
                    "schema_errors": schema_errors,
                    "usage": usage_rec,
                    "metadata": meta,
                }

            except Exception as e:
                last_err = str(e)
                if "rate limit" in last_err.lower() and attempt < max_retries:
                    time.sleep(retry_delay)
                    continue
                break

        return {
            "custom_id": req.get("custom_id"),
            "arxiv_id": arxiv_id,
            "prompt_version": prompt_version,
            "schema_version": schema_version,
            "model": MODEL,
            "raw_response": None,
            "parsed": None,
            "error": last_err or "unknown error",
            "parse_error": None,
            "schema_errors": [],
            "usage": None,
            "metadata": meta,
        }

    with output_path.open("a") as fout, usage_path.open("a") as uout, ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(call_one, r) for r in pending]
        for fut in tqdm(as_completed(futures), total=len(futures), desc=f"LLM run {requests_path.name}"):
            res = fut.result()
            fout.write(json.dumps(res) + "\n")
            if res.get("usage"):
                uout.write(json.dumps(res["usage"]) + "\n")
                total_prompt_tokens += res["usage"].get("prompt_tokens", 0) or 0
                total_completion_tokens += res["usage"].get("completion_tokens", 0) or 0
            else:
                # helpful debugging
                if res.get("error") or res.get("parse_error"):
                    print("Failure:", res.get("custom_id"), res.get("error"), res.get("parse_error"))


    print("[Run] Done. Wrote/updated outputs in", output_path)
    print(f"  Chunk prompt tokens    : {total_prompt_tokens}")
    print(f"  Chunk completion tokens: {total_completion_tokens}")

RUN_BUCKET_IDS = target_bucket_ids
for bid in RUN_BUCKET_IDS:
    bucket_entry = bucket_index.get(bid)
    if bucket_entry is None:
        print(f"Bucket {bid} not found, skipping")
        continue
    req_path = REQ_DIR / f"bucket_{bid}_requests.jsonl"
    out_path = OUT_DIR / f"bucket_{bid}_outputs.jsonl"
    usage_path = USAGE_DIR / f"bucket_{bid}_usage.jsonl"
    run_resumable(req_path, out_path, usage_path, max_workers=15)


## 6) Post-process: parse outputs -> edges CSV


In [None]:
# ============================================================
# Section 6 (post-extraction consolidation -> validated mentions)
#
# INPUTS (expected to already exist):
#   - OUT_DIR / bucket_*_outputs.jsonl
#       Each line is a record with keys like:
#         custom_id, arxiv_id, parsed (or raw_response), error, parse_error, schema_errors, ...
#   - validate_record(parsed_dict) -> list[str]
#       schema validator (already defined earlier in the notebook)
#
# OUTPUTS (written by this single cell):
#   - RUN_DIR / outputs_latest.jsonl
#       One "latest" record per custom_id (paper), with parsed recovered when possible.
#   - RUN_DIR / outputs_latest_summary.csv
#       Paper-level summary: ok/error counts + schema error counts (fatal vs ignored).
#   - RUN_DIR / paper_object_edges_llm_mentions.jsonl
#       Mention-level edges after PER-OBJECT validation/auto-fixes:
#         (arxiv_id, object_name_norm, role, study_mode, evidence_span, ...)
#   - RUN_DIR / paper_object_edges_llm_mentions.jsonl
#       Dropped-mention audit log (why each object mention was dropped).
#   - RUN_DIR / unique_object_names_llm.json
#       Unique normalized object names from *valid mentions* (used for SIMBAD resolution).
#   - RUN_DIR / paper_object_mention_counts_llm.csv
#       Per-paper count of valid mentions emitted.
# ============================================================

from __future__ import annotations

import json
import re
import random
import unicodedata
from pathlib import Path
from collections import defaultdict, Counter

import pandas as pd
from tqdm import tqdm

# -----------------------------
# 0) Discover bucket outputs
# -----------------------------
OUTPUT_GLOB = "bucket_*_outputs.jsonl"
out_paths = sorted(OUT_DIR.glob(OUTPUT_GLOB))
assert out_paths, f"No outputs found in {OUT_DIR} matching {OUTPUT_GLOB}. Did the extraction finish?"
print("Buckets found:", len(out_paths))

# -----------------------------
# 1) Record-level helpers
# -----------------------------
# Treat evidence span length as warning (not a failure).
IGNORABLE_SCHEMA_ERROR_SUFFIXES = (
    "evidence_span too long",  # common exact phrasing
    "evidence_span",           # broad suffix fallback (your choice)
)

def filter_schema_errors(errs: list[str] | None) -> tuple[list[str], list[str]]:
    """Returns (fatal_errors, ignored_errors) at RECORD level."""
    errs = errs or []
    ignored, fatal = [], []
    for e in errs:
        if any(str(e).endswith(sfx) for sfx in IGNORABLE_SCHEMA_ERROR_SUFFIXES):
            ignored.append(e)
        else:
            fatal.append(e)
    return fatal, ignored

def is_ok_record(rec: dict) -> bool:
    """OK record means: no API error, parsed present, and no fatal schema errors."""
    if rec.get("error") is not None:
        return False
    if rec.get("parsed") is None:
        return False
    fatal, _ignored = filter_schema_errors(rec.get("schema_errors") or [])
    return len(fatal) == 0

def iter_jsonl(path: Path):
    with path.open() as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except Exception:
                continue

_JSON_OBJ_RE = re.compile(r"\{.*\}", re.DOTALL)

def try_salvage_parsed(raw_text: str | None):
    """Recover JSON object if the model wrapped JSON with extra text."""
    if not isinstance(raw_text, str) or not raw_text.strip():
        return None
    m = _JSON_OBJ_RE.search(raw_text)
    if not m:
        return None
    cand = m.group(0).strip()
    try:
        return json.loads(cand)
    except Exception:
        return None

def slim_record(rec: dict) -> dict:
    """Keep only what we need downstream (avoid duplicating huge raw responses)."""
    return {
        "custom_id": rec.get("custom_id"),
        "arxiv_id": rec.get("arxiv_id"),
        "prompt_version": rec.get("prompt_version"),
        "schema_version": rec.get("schema_version"),
        "model": rec.get("model"),
        "error": rec.get("error"),
        "parse_error": rec.get("parse_error"),
        "schema_errors": rec.get("schema_errors") or [],
        "metadata": rec.get("metadata") or {},
        "parsed": rec.get("parsed"),
    }

# -----------------------------
# 2) Load buckets -> latest_by_cid (with salvage)
# -----------------------------
latest_by_cid: dict[str, dict] = {}
bucket_line_counts: dict[str, int] = {}
bucket_salvaged_counts: dict[str, int] = {}

for p in out_paths:
    salvaged_here = 0
    n_lines = 0
    for rec in iter_jsonl(p):
        n_lines += 1
        cid = rec.get("custom_id")
        if not cid:
            continue

        # Salvage if parse failed but raw_response exists and error==None
        if (rec.get("parsed") is None) and (rec.get("error") is None) and rec.get("raw_response"):
            salvaged = try_salvage_parsed(rec.get("raw_response"))
            if salvaged is not None:
                errs = validate_record(salvaged)  # list[str]
                fatal, ignored = filter_schema_errors(errs)
                if not fatal:
                    rec["parsed"] = salvaged
                    rec["parse_error"] = None
                    rec["schema_errors"] = errs
                    rec["ignored_schema_errors"] = ignored
                    salvaged_here += 1

        latest_by_cid[cid] = slim_record(rec)

    bucket_line_counts[p.name] = n_lines
    bucket_salvaged_counts[p.name] = salvaged_here

print("Unique custom_ids:", len(latest_by_cid))
print("Total parse salvaged:", sum(bucket_salvaged_counts.values()))

# -----------------------------
# 3) Paper-level summary + sanity checks
# -----------------------------
summary_rows = []
for cid, rec in latest_by_cid.items():
    parsed = rec.get("parsed")
    n_obj = None
    if isinstance(parsed, dict) and isinstance(parsed.get("objects"), list):
        n_obj = len(parsed["objects"])

    ok = is_ok_record(rec)
    fatal, ignored = filter_schema_errors(rec.get("schema_errors") or [])

    summary_rows.append(
        {
            "custom_id": cid,
            "arxiv_id": rec.get("arxiv_id"),
            "ok": ok,
            "n_objects": n_obj,
            "error": rec.get("error"),
            "parse_error": rec.get("parse_error"),
            "n_schema_errors_total": len(rec.get("schema_errors") or []),
            "n_schema_errors_fatal": len(fatal),
            "n_schema_errors_ignored": len(ignored),
            "prompt_version": rec.get("prompt_version"),
            "schema_version": rec.get("schema_version"),
            "model": rec.get("model"),
        }
    )

runs_df = pd.DataFrame(summary_rows)

total = len(runs_df)
ok_cnt = int(runs_df["ok"].sum())
print(f"Total unique papers (custom_ids): {total:,}")
print(f"OK papers (parsed + record-schema-ok): {ok_cnt:,}  ({ok_cnt/total:.1%})")

# Duplicate arxiv_id among OK papers
ok_df = runs_df[runs_df["ok"]].copy()
dups = ok_df["arxiv_id"][ok_df["arxiv_id"].duplicated()].unique().tolist()
print("Duplicate arxiv_id among OK papers:", len(dups))
assert len(dups) == 0, f"Unexpected duplicates among OK arxiv_ids (showing up to 5): {dups[:5]}"

# Error summaries
err_types = Counter((runs_df["error"].fillna("")).tolist())
parse_err_types = Counter((runs_df["parse_error"].fillna("")).tolist())
print("\nTop `error` values (incl empty = ok-ish):")
for k, v in err_types.most_common(8):
    label = k if k else "<none>"
    print(f"  {label[:80]:80s}  {v:,}")

print("\nTop `parse_error` values (incl empty):")
for k, v in parse_err_types.most_common(8):
    label = k if k else "<none>"
    print(f"  {label[:80]:80s}  {v:,}")

bad_schema = runs_df[runs_df["n_schema_errors_fatal"] > 0]
warn_schema = runs_df[runs_df["n_schema_errors_ignored"] > 0]
print("\nRecords with FATAL schema errors:", len(bad_schema))
print("Records with any IGNORED schema errors:", len(warn_schema))

# Spot-check parsed arxiv_id matches envelope
spot = ok_df.sample(min(25, len(ok_df)), random_state=1)["custom_id"].tolist()
mismatches = []
for cid in spot:
    rec = latest_by_cid[cid]
    parsed = rec.get("parsed") or {}
    if parsed.get("arxiv_id") != rec.get("arxiv_id"):
        mismatches.append((cid, rec.get("arxiv_id"), parsed.get("arxiv_id")))
print("Spot-check mismatched arxiv_id:", len(mismatches))
if mismatches:
    print(mismatches[:5])
    raise AssertionError("Found mismatched arxiv_id between envelope and parsed payload.")

# Persist latest records + summary
CONSOLIDATED_LATEST_PATH = RUN_DIR / "outputs_latest.jsonl"
with CONSOLIDATED_LATEST_PATH.open("w") as f:
    for cid, rec in latest_by_cid.items():
        f.write(json.dumps(rec) + "\n")
print("Wrote:", CONSOLIDATED_LATEST_PATH)

RUN_SUMMARY_PATH = RUN_DIR / "outputs_latest_summary.csv"
runs_df.to_csv(RUN_SUMMARY_PATH, index=False)
print("Wrote:", RUN_SUMMARY_PATH)

# -----------------------------
# 4) Mention-level flattening with PER-OBJECT validation
# -----------------------------
def norm_object_name(s: str) -> str:
    if s is None:
        return ""
    s = unicodedata.normalize("NFKC", str(s))
    s = s.strip()
    s = " ".join(s.split())
    return s

# Parse object-indexed schema errors from validator messages
_OBJ_ERR_RE = re.compile(r"^objects\[(\d+)\]\.(.+)$")

def per_object_errors(schema_errors: list[str]) -> dict[int, list[str]]:
    out: dict[int, list[str]] = defaultdict(list)
    for e in schema_errors or []:
        m = _OBJ_ERR_RE.match(str(e))
        if not m:
            continue
        i = int(m.group(1))
        msg = m.group(2)
        out[i].append(msg)
    return out

# Object-level ignore policy: same spirit as record-level
IGNORABLE_OBJ_ERROR_SUFFIXES = (
    "evidence_span too long",
    "evidence_span",
)

def split_obj_errors(errs: list[str] | None) -> tuple[list[str], list[str]]:
    errs = errs or []
    ignored, fatal = [], []
    for e in errs:
        if any(str(e).endswith(sfx) for sfx in IGNORABLE_OBJ_ERROR_SUFFIXES):
            ignored.append(e)
        else:
            fatal.append(e)
    return fatal, ignored

# Auto-fix common inconsistency patterns
NAME_MULTI_DESIG_TOKEN = "name violates single-designation rule"
MODE_INCONSIST_TOKEN = "study_mode not_applicable inconsistent with role="

def auto_fix_object(obj: dict, obj_errs: list[str]) -> tuple[dict, list[str]]:
    fix_notes = []
    new_obj = dict(obj)

    role = new_obj.get("role")
    mode = new_obj.get("study_mode")

    # If role implies substantive but mode says not_applicable -> unknown
    if any(MODE_INCONSIST_TOKEN in e for e in (obj_errs or [])) and mode == "not_applicable":
        new_obj["study_mode"] = "unknown"
        fix_notes.append("coerced study_mode not_applicable -> unknown (role implies substantive)")

    # If comparison/reference but mode is active -> not_applicable
    if role == "comparison_or_reference" and mode in {"new_observation", "archival_or_reanalysis", "catalog_compilation", "simulation_or_theory"}:
        new_obj["study_mode"] = "not_applicable"
        fix_notes.append("coerced study_mode -> not_applicable for comparison_or_reference")

    return new_obj, fix_notes

MENTIONS_JSONL = RUN_DIR / "paper_object_edges_llm_mentions.jsonl"
MENTIONS_AUDIT_CSV = RUN_DIR / "paper_object_edges_llm_mentions_audit.csv"

if MENTIONS_JSONL.exists():
    MENTIONS_JSONL.unlink()
if MENTIONS_AUDIT_CSV.exists():
    MENTIONS_AUDIT_CSV.unlink()

n_mentions = 0
n_valid_mentions = 0
n_papers_seen = 0
n_papers_emitted = 0

unique_names = set()
per_paper_counts = defaultdict(int)

audit_rows = []
audit_counter = Counter()

with MENTIONS_JSONL.open("w") as f:
    for cid, rec in tqdm(latest_by_cid.items(), total=len(latest_by_cid), desc="Flatten mentions (per-object validate)"):

        parsed = rec.get("parsed")
        if rec.get("error") is not None or parsed is None:
            audit_counter["skip_record_error_or_unparsed"] += 1
            continue

        arxiv_id = rec.get("arxiv_id")
        if not arxiv_id:
            audit_counter["skip_record_missing_arxiv_id"] += 1
            continue

        objs = parsed.get("objects") or []
        if not isinstance(objs, list):
            audit_counter["skip_record_objects_not_list"] += 1
            continue

        n_papers_seen += 1

        schema_errors = rec.get("schema_errors") or []
        obj_err_map = per_object_errors(schema_errors)

        emitted_any = False

        for i, obj in enumerate(objs):
            n_mentions += 1
            if not isinstance(obj, dict):
                audit_counter["skip_obj_not_dict"] += 1
                continue

            obj_errs = obj_err_map.get(i, [])
            obj2, fix_notes = auto_fix_object(obj, obj_errs)

            name_norm = norm_object_name(obj2.get("name", ""))
            if not name_norm:
                audit_counter["skip_obj_empty_name"] += 1
                continue

            fatal0, ignored0 = split_obj_errors(obj_errs)

            # Multi-designation name: drop mention (unless you later implement splitting)
            if any(NAME_MULTI_DESIG_TOKEN in e for e in fatal0):
                audit_counter["skip_obj_multi_designation_name"] += 1
                audit_rows.append({
                    "custom_id": cid,
                    "arxiv_id": arxiv_id,
                    "obj_index": i,
                    "object_name_norm": name_norm,
                    "status": "dropped",
                    "reason": "multi_designation_name",
                    "fatal_errors": fatal0,
                    "ignored_errors": ignored0,
                    "fix_notes": fix_notes,
                })
                continue

            # Any other fatal errors: drop mention
            if fatal0:
                audit_counter["skip_obj_other_fatal_errors"] += 1
                audit_rows.append({
                    "custom_id": cid,
                    "arxiv_id": arxiv_id,
                    "obj_index": i,
                    "object_name_norm": name_norm,
                    "status": "dropped",
                    "reason": "fatal_schema_errors",
                    "fatal_errors": fatal0,
                    "ignored_errors": ignored0,
                    "fix_notes": fix_notes,
                })
                continue

            row = {
                "custom_id": cid,
                "arxiv_id": arxiv_id,
                "object_name": obj2.get("name", ""),
                "object_name_norm": name_norm,
                "role": obj2.get("role"),
                "study_mode": obj2.get("study_mode"),
                "evidence_source": obj2.get("evidence_source"),
                "evidence_span": obj2.get("evidence_span"),
                "prompt_version": rec.get("prompt_version"),
                "schema_version": rec.get("schema_version"),
                "model": rec.get("model"),
                # audit fields
                "obj_index": i,
                "validator_fatal_errors": fatal0,
                "validator_ignored_errors": ignored0,
                "fix_notes": fix_notes,
            }
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

            n_valid_mentions += 1
            emitted_any = True
            unique_names.add(name_norm)
            per_paper_counts[arxiv_id] += 1

        if emitted_any:
            n_papers_emitted += 1
        else:
            audit_counter["papers_with_zero_valid_mentions"] += 1

print("\nMention flatten summary:")
print("  Papers seen (parsed + no API error):", f"{n_papers_seen:,}")
print("  Papers emitted (>=1 valid mention):", f"{n_papers_emitted:,}")
print("  Total mentions (raw objects):", f"{n_mentions:,}")
print("  Valid mentions written:", f"{n_valid_mentions:,}")
print("  Unique object names:", f"{len(unique_names):,}")
print("  Wrote:", MENTIONS_JSONL)

print("\nTop mention audit counters:")
for k, v in audit_counter.most_common(12):
    print(f"  {k:40s} {v:,}")

# Persist unique names for SIMBAD resolution step
UNIQUE_NAMES_PATH = RUN_DIR / "unique_object_names_llm.json"
with UNIQUE_NAMES_PATH.open("w") as f:
    json.dump(sorted(unique_names), f, indent=2, ensure_ascii=False)
print("Wrote:", UNIQUE_NAMES_PATH)

# Per-paper mention counts
counts_df = pd.DataFrame(
    [{"arxiv_id": k, "n_object_mentions": v} for k, v in per_paper_counts.items()]
).sort_values("n_object_mentions", ascending=False)

COUNTS_PATH = RUN_DIR / "paper_object_mention_counts_llm.csv"
counts_df.to_csv(COUNTS_PATH, index=False)
print("Wrote:", COUNTS_PATH)

# Dropped mention audit
if audit_rows:
    audit_df = pd.DataFrame(audit_rows)
    audit_df.to_csv(MENTIONS_AUDIT_CSV, index=False)
    print("Wrote audit:", MENTIONS_AUDIT_CSV, "rows:", len(audit_df))
else:
    print("No dropped mentions logged (audit empty).")

counts_df.head(10)


## 7) Join concepts -> aggregate concept-object edges


In [None]:
# ============================================================
# Section 7: Produce concept–object edges (unresolved name space)
# (release-safe: paths + weights loaded from config/table1.yaml)
#
#   concept_object_edges_unresolved_llm.csv.gz
# ============================================================

import os
import numpy as np
import pandas as pd
from pathlib import Path

# ---- Load YAML config (paths + weights) ----
CONFIG_PATH = Path(os.environ.get("CONFIG_PATH", "config/table1.yaml")).resolve()
print(f"[config] Using CONFIG_PATH = {CONFIG_PATH}")
if not CONFIG_PATH.exists():
    raise FileNotFoundError(
        f"Config not found: {CONFIG_PATH}\n"
        "Double check the path, or set env var CONFIG_PATH=/path/to/table1.yaml"
    )

try:
    import yaml  # pip install pyyaml if missing
except Exception as e:
    raise ImportError(
        "Missing dependency: pyyaml (needed to read config/table1.yaml). "
        "Install with: pip install pyyaml"
    ) from e

with CONFIG_PATH.open() as f:
    cfg = yaml.safe_load(f) or {}

paths_cfg = cfg.get("paths", {}) or {}
weights_cfg = cfg.get("weights", {}) or {}

DATA_DIR_CFG = Path(paths_cfg.get("data_dir", "data"))
OUT_DIR_CFG  = Path(paths_cfg.get("out_dir", str(RUN_DIR if "RUN_DIR" in globals() else "outputs")))

def _resolve_path(base_dir: Path, p: str | None) -> Path | None:
    if p is None:
        return None
    p = str(p).strip()
    if not p:
        return None
    pp = Path(p)
    return pp if pp.is_absolute() else (base_dir / pp).resolve()

# Key input paths from YAML
PAPER_CONCEPT_EDGES_PATH = _resolve_path(DATA_DIR_CFG, paths_cfg.get("concept_map", "papers_concepts_mapping.csv"))
PO_MENTIONS_PATH = _resolve_path(OUT_DIR_CFG, paths_cfg.get("po_mentions", "paper_object_edges_llm_mentions.jsonl"))

print(f"[config] PAPER_CONCEPT_EDGES_PATH = {PAPER_CONCEPT_EDGES_PATH} | exists = {PAPER_CONCEPT_EDGES_PATH.exists() if PAPER_CONCEPT_EDGES_PATH else False}")
print(f"[config] PO_MENTIONS_PATH         = {PO_MENTIONS_PATH} | exists = {PO_MENTIONS_PATH.exists() if PO_MENTIONS_PATH else False}")

if PAPER_CONCEPT_EDGES_PATH is None or not PAPER_CONCEPT_EDGES_PATH.exists():
    raise FileNotFoundError(
        f"Missing paper–concept mapping CSV at: {PAPER_CONCEPT_EDGES_PATH}\n"
        "Double check cfg.paths.concept_map and cfg.paths.data_dir in table1.yaml."
    )

if PO_MENTIONS_PATH is None or not PO_MENTIONS_PATH.exists():
    raise FileNotFoundError(
        f"Missing mentions JSONL at: {PO_MENTIONS_PATH}\n"
        "Double check cfg.paths.po_mentions and cfg.paths.out_dir in table1.yaml."
    )

# Weights from YAML
ROLE_WEIGHT = (weights_cfg.get("role_weight", {}) or {})
STUDY_MODE_MULT = (weights_cfg.get("study_mode_mult", {}) or {})
CONTEXT_ROLES = set(weights_cfg.get("context_roles", []) or [])

if not ROLE_WEIGHT or not STUDY_MODE_MULT:
    raise ValueError(
        "weights.role_weight and/or weights.study_mode_mult missing or empty in table1.yaml. "
        "Double check your config."
    )

print(f"[config] Loaded ROLE_WEIGHT keys: {len(ROLE_WEIGHT)}")
print(f"[config] Loaded STUDY_MODE_MULT keys: {len(STUDY_MODE_MULT)}")
print(f"[config] Loaded CONTEXT_ROLES: {sorted(CONTEXT_ROLES)}")

# -----------------------------
# 7.1) Load paper–concept edges
# -----------------------------
paper_concept_df = pd.read_csv(PAPER_CONCEPT_EDGES_PATH, dtype={"arxiv_id": "string", "label": "int32"})
expected_cols = {"arxiv_id", "label"}
if set(paper_concept_df.columns) != expected_cols:
    raise ValueError(f"Unexpected columns: {paper_concept_df.columns.tolist()} (expected {sorted(expected_cols)})")

paper_concept_df = paper_concept_df.dropna(subset=["arxiv_id", "label"]).copy()
paper_concept_df["arxiv_id"] = paper_concept_df["arxiv_id"].str.strip()
paper_concept_df = paper_concept_df[paper_concept_df["arxiv_id"].str.len() > 0]

paper_concept_df["weight"] = np.float32(1.0)

before = len(paper_concept_df)
paper_concept_df = paper_concept_df.drop_duplicates(subset=["arxiv_id", "label"]).reset_index(drop=True)
print("paper_concept_df:", paper_concept_df.shape, "dropped dups:", before - len(paper_concept_df))
paper_concept_df.head()

# -----------------------------
# 7.1b) Load mentions + coverage checks
# -----------------------------
po_mentions = pd.read_json(PO_MENTIONS_PATH, lines=True, dtype={"arxiv_id": "string"})
print("po_mentions:", po_mentions.shape)

pc_papers = set(paper_concept_df["arxiv_id"].unique().tolist())
po_papers = set(po_mentions["arxiv_id"].unique().tolist())

print("Unique papers in paper_concept_df:", len(pc_papers))
print("Unique papers in po_mentions:", len(po_papers))
print("Overlap papers:", len(pc_papers & po_papers))
print("Papers with objects but no concepts:", len(po_papers - pc_papers))
print("Papers with concepts but no objects:", len(pc_papers - po_papers))

# -----------------------------
# 7.2) Aggregate paper–object edges (legacy tier3 semantics)
# -----------------------------
po_mentions = po_mentions.copy()
po_mentions["role_weight"] = po_mentions["role"].map(ROLE_WEIGHT).fillna(np.float32(0.75)).astype("float32")
po_mentions["mode_mult"] = po_mentions["study_mode"].map(STUDY_MODE_MULT).fillna(np.float32(0.50)).astype("float32")
po_mentions["mention_weight"] = (po_mentions["role_weight"] * po_mentions["mode_mult"]).astype("float32")

# Drop context-only mentions via weight==0 (typically study_mode=not_applicable)
po_mentions_f = po_mentions[po_mentions["mention_weight"] > 0].copy()
print("Mentions kept (mention_weight>0):", len(po_mentions_f), "/", len(po_mentions))

def _uniq_sorted(xs):
    xs = [x for x in xs if isinstance(x, str) and x]
    return sorted(set(xs))

def _pick_examples(xs, k=3):
    xs = [x for x in xs if isinstance(x, str) and x.strip()]
    xs = sorted(xs, key=len)[:k]
    return xs

po_agg = (
    po_mentions_f
    .groupby(["arxiv_id", "object_name_norm"], as_index=False)
    .agg(
        mention_count=("object_name_norm", "size"),
        obj_weight=("mention_weight", "sum"),
        roles=("role", _uniq_sorted),
        study_modes=("study_mode", _uniq_sorted),
        evidence_sources=("evidence_source", _uniq_sorted),
        example_evidence=("evidence_span", _pick_examples),
    )
)
po_agg["obj_weight"] = po_agg["obj_weight"].astype("float32")

print("po_agg:", po_agg.shape)

PO_AGG_PATH = _resolve_path(OUT_DIR_CFG, paths_cfg.get("po_agg", "paper_object_edges_llm_agg.parquet")) or (RUN_DIR / "paper_object_edges_llm_agg.parquet")
po_agg.to_parquet(PO_AGG_PATH, index=False)
print("Wrote:", PO_AGG_PATH)

# Legacy tier3 filter: exclude rows that are purely context roles
po_used = po_agg[po_agg["roles"].apply(lambda rs: not set(rs).issubset(CONTEXT_ROLES))].copy()
print("paper–object edges kept after context-only exclusion:", len(po_used), "/", len(po_agg))

# -----------------------------
# 7.3) Produce concept–object edges (UNRESOLVED)
# -----------------------------
def build_concept_object_edges(paper_concept: pd.DataFrame, paper_object: pd.DataFrame) -> pd.DataFrame:
    """
    paper_concept: [arxiv_id, label, weight]
    paper_object:  [arxiv_id, object_name_norm, obj_weight]
    Returns: [label, object_name_norm, n_papers, total_weight]
    """
    merged = paper_concept.merge(
        paper_object[["arxiv_id", "object_name_norm", "obj_weight"]],
        on="arxiv_id",
        how="inner",
    )
    merged["edge_weight"] = (merged["weight"].astype("float32") * merged["obj_weight"].astype("float32")).astype("float32")

    co = (
        merged
        .groupby(["label", "object_name_norm"], as_index=False)
        .agg(
            n_papers=("arxiv_id", "nunique"),
            total_weight=("edge_weight", "sum"),
        )
    )
    co["label"] = co["label"].astype("int32")
    co["n_papers"] = co["n_papers"].astype("int32")
    co["total_weight"] = co["total_weight"].astype("float32")
    return co.sort_values(["label", "total_weight"], ascending=[True, False]).reset_index(drop=True)

co_unres = build_concept_object_edges(paper_concept_df, po_used)
print("concept_object_edges_unresolved:", co_unres.shape)

# -----------------------------
# 7.3b) Sanity checks
# -----------------------------
assert co_unres["total_weight"].min() > 0, "non-positive total_weight"
dups = co_unres.duplicated(subset=["label", "object_name_norm"]).sum()
assert dups == 0, "duplicate (label, object_name_norm) rows"
print(f"ok | edges={len(co_unres):,} | unique labels={co_unres['label'].nunique():,} | unique objects={co_unres['object_name_norm'].nunique():,}")

# -----------------------------
# 7.3c) Write unresolved deliverable (renamed from legacy tier3)
# -----------------------------
OUT_UNRES = RUN_DIR / "concept_object_edges_unresolved_llm.csv.gz"
co_unres.to_csv(OUT_UNRES, index=False, compression="gzip")

print("Wrote:", OUT_UNRES)

co_unres.head(10)

## 8) SIMBAD resolution for astro_object + final deliverables


In [None]:
# ============================================================
# Section 8 (fixed): Resolve concept-object "unresolved" names to SIMBAD main_id
# using existing caches + negative-cache failures, with ONE-TIME cleaning.
#
# INPUTS:
#   - RUN_DIR / concept_object_edges_tier3_unresolved_llm.csv.gz
#       columns: [label, object_name_norm, n_papers, total_weight]
#   - cache dir files:
#       - simbad_alias_cache.jsonl
#       - simbad_name_resolution_cache_llm_objects_with_otype_and_errors_clean.jsonl
#       - simbad_resolution_failures.jsonl
#
# OUTPUTS (appended):
#   - NAME_CACHE_PATH (append new SIMBAD results)
#   - FAILURES_PATH   (append new no_match failures only)
#
# KEY FIXES:
#   - Create object_name_clean ONCE and aggregate on it (prevents hidden-char duplicates)
#   - Use clean keys consistently for cache lookups and failure short-circuit
#   - Do NOT skip SIMBAD query just because key exists in name2main; only skip if mapped (non-empty main_id)
# ============================================================

from __future__ import annotations

import sys
import json
import time
import re
import unicodedata
import warnings
from pathlib import Path

import pandas as pd
from tqdm import tqdm
from astroquery.simbad import Simbad
from astroquery.exceptions import NoResultsWarning



# Silence spammy "No results" warnings; we handle no-match explicitly
warnings.simplefilter("ignore", NoResultsWarning)

# -----------------------------
# Paths (cache dir)
# -----------------------------
# -----------------------------
# Config-driven paths (release-safe)
# -----------------------------
import os
from pathlib import Path

CONFIG_PATH = Path(os.environ.get("CONFIG_PATH", "config/table1.yaml")).resolve()
print(f"[config] Using CONFIG_PATH = {CONFIG_PATH}")
if not CONFIG_PATH.exists():
    raise FileNotFoundError(
        f"Config not found: {CONFIG_PATH}\n"
        "Double check the path, or set env var CONFIG_PATH=/path/to/table1.yaml"
    )

try:
    import yaml
except Exception as e:
    raise ImportError("Missing dependency: pyyaml. Install with: pip install pyyaml") from e

with CONFIG_PATH.open() as f:
    cfg = yaml.safe_load(f) or {}

paths_cfg = cfg.get("paths", {}) or {}

DATA_DIR_CFG = Path(paths_cfg.get("data_dir", "data"))
OUT_DIR_CFG  = Path(paths_cfg.get("out_dir", str(RUN_DIR if "RUN_DIR" in globals() else "outputs")))

def _resolve_path(base_dir: Path, p):
    if p is None:
        return None
    p = str(p).strip()
    if not p:
        return None
    pp = Path(p)
    return pp if pp.is_absolute() else (base_dir / pp)

# SIMBAD caches (from YAML)
NAME_CACHE_PATH  = _resolve_path(DATA_DIR_CFG, paths_cfg.get("simbad_name_cache"))
ALIAS_CACHE_PATH = _resolve_path(DATA_DIR_CFG, paths_cfg.get("simbad_alias_cache"))  # may be null
if NAME_CACHE_PATH is None:
    raise ValueError("paths.simbad_name_cache is missing/null in table1.yaml")

CACHE_DIR = NAME_CACHE_PATH.parent
FAILURES_PATH = CACHE_DIR / "simbad_resolution_failures.jsonl"

print(f"[paths] DATA_DIR_CFG       = {DATA_DIR_CFG.resolve()}")
print(f"[paths] OUT_DIR_CFG        = {OUT_DIR_CFG.resolve()}")
print(f"[paths] NAME_CACHE_PATH    = {NAME_CACHE_PATH.resolve()} | exists = {NAME_CACHE_PATH.exists()}")
print(f"[paths] ALIAS_CACHE_PATH   = {ALIAS_CACHE_PATH.resolve() if ALIAS_CACHE_PATH else None} | exists = {ALIAS_CACHE_PATH.exists() if ALIAS_CACHE_PATH else False}")
print(f"[paths] FAILURES_PATH      = {FAILURES_PATH.resolve()} | exists = {FAILURES_PATH.exists()}")

# Ensure cache dir exists
CACHE_DIR.mkdir(parents=True, exist_ok=True)
if not FAILURES_PATH.exists():
    FAILURES_PATH.write_text("")

# -----------------------------
# Load unresolved concept-object edges (no tier naming)
# -----------------------------
# Section 7 now writes: RUN_DIR / concept_object_edges_unresolved_llm.csv.gz
RUN_DIR = Path(os.environ.get("ASTRO_OUT_DIR", str(OUT_DIR_CFG))) / "full_extraction_"
co_path = RUN_DIR / "concept_object_edges_unresolved_llm.csv.gz"

print(f"[paths] co_path            = {co_path.resolve()} | exists = {co_path.exists()}")

if not co_path.exists():
    raise FileNotFoundError(
        f"Missing unresolved concept-object edges: {co_path}\n"
        "Double check where Section 7 wrote the file, and update the path here accordingly."
    )

# -----------------------------
# Helpers
# -----------------------------
def read_jsonl(path: Path):
    if not path.exists():
        return
    with path.open("r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except Exception:
                continue

def alias_norm(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    return " ".join(s.split()).lower()

# Very safe cleaning (ONLY strip control chars + normalize + whitespace)
_CTRL_RE = re.compile(r"[\u0000-\u001F\u007F-\u009F]")

def safe_clean_name(s: str) -> str:
    if s is None:
        return ""
    s = str(s)
    s = unicodedata.normalize("NFKC", s)
    s = _CTRL_RE.sub("", s)
    s = " ".join(s.strip().split())
    return s

def key_norm(s: str) -> str:
    # key used for ALL dict/set lookups
    return alias_norm(safe_clean_name(s))

def load_no_match_failures(path: Path) -> tuple[set[str], set[str]]:
    """
    Only load rows where status == "no_match".
    Skip iff our query string exactly matches either:
      - query_name (after safe_clean_name)
      - alias_norm (after key_norm)
    """
    no_match_query_names: set[str] = set()
    no_match_alias_norms: set[str] = set()
    if not path.exists():
        return no_match_query_names, no_match_alias_norms

    for row in read_jsonl(path):
        if row.get("status") != "no_match":
            continue
        qn = row.get("query_name")
        an = row.get("alias_norm")
        if isinstance(qn, str) and qn:
            no_match_query_names.add(safe_clean_name(qn))   # exact match against obj_clean
        if isinstance(an, str) and an:
            no_match_alias_norms.add(key_norm(an))          # exact match against key
    return no_match_query_names, no_match_alias_norms

def to_text(x):
    if x is None:
        return None
    if isinstance(x, bytes):
        try:
            return x.decode("utf-8", errors="replace")
        except Exception:
            return str(x)
    return str(x)

# -----------------------------
# Load caches
# -----------------------------
alias2main: dict[str, str] = {}
if ALIAS_CACHE_PATH and ALIAS_CACHE_PATH.exists():
    for row in read_jsonl(ALIAS_CACHE_PATH) or []:
        main_id = row.get("main_id")
        if not isinstance(main_id, str) or not main_id.strip():
            continue
        aliases = row.get("aliases") or [main_id]
        if not isinstance(aliases, list):
            aliases = [main_id]
        for a in aliases:
            if not isinstance(a, str) or not a.strip():
                continue
            k = key_norm(a)
            if k:
                alias2main[k] = main_id
    print("Loaded", len(alias2main), "alias→main_id entries from simbad_alias_cache")
else:
    print("[INFO] alias cache not provided/found; proceeding without it.")

# -----------------------------
# Load name→main_id cache
# -----------------------------
name2main: dict[str, str | None] = {}
if NAME_CACHE_PATH.exists():
    for row in read_jsonl(NAME_CACHE_PATH) or []:
        # Prefer alias_norm field, else fallback to query_name
        k = row.get("alias_norm")
        if not isinstance(k, str) or not k.strip():
            qn = row.get("query_name")
            if isinstance(qn, str) and qn.strip():
                k = qn
            else:
                continue
        k = key_norm(k)
        name2main[k] = row.get("main_id")
    print("Loaded", len(name2main), "name→main_id entries from", NAME_CACHE_PATH.name)
else:
    print("[WARN] name-resolution cache not found:", NAME_CACHE_PATH)

otype_by_main = {}

for row in read_jsonl(NAME_CACHE_PATH) or []:
    main_id = row.get("main_id")
    otype = row.get("otype")
    if isinstance(main_id, str) and main_id.strip():
        otype_by_main[main_id] = otype

no_match_query_names, no_match_alias_norms = load_no_match_failures(FAILURES_PATH)
print("Loaded no_match failures:",
      "query_name:", len(no_match_query_names),
      "| alias_norm:", len(no_match_alias_norms))

# -----------------------------
# Load unresolved tier3 edges
# -----------------------------
assert co_path.exists(), f"Missing: {co_path}"

concept_object_edges = pd.read_csv(
    co_path,
    dtype={"label": "int32", "object_name_norm": "string", "n_papers": "int32", "total_weight": "float32"},
    low_memory=False,
)
required_cols = {"label", "object_name_norm", "n_papers", "total_weight"}
assert required_cols.issubset(concept_object_edges.columns), f"Unexpected columns: {concept_object_edges.columns.tolist()}"

# -----------------------------
# CRITICAL FIX: clean ONCE, then aggregate / resolve using clean names
# -----------------------------
concept_object_edges["object_name_clean"] = (
    concept_object_edges["object_name_norm"]
    .astype("string")
    .fillna("")
    .map(lambda s: safe_clean_name(s))
)

# Drop empties after cleaning
concept_object_edges = concept_object_edges[concept_object_edges["object_name_clean"].str.len() > 0].copy()

print("Unique object_name_norm (raw):", concept_object_edges["object_name_norm"].nunique())
print("Unique object_name_clean     :", concept_object_edges["object_name_clean"].nunique())

# Frequency table (CLEAN names)
object_freq = (
    concept_object_edges
    .groupby("object_name_clean", as_index=False)["n_papers"]
    .sum()
    .rename(columns={"n_papers": "total_n_papers"})
    .sort_values("total_n_papers", ascending=False)
    .reset_index(drop=True)
)
print("object_freq shape:", object_freq.shape)
print(f"Total unique CLEAN object names (tier3 unresolved): {len(object_freq):,}")

# Precompute lookup key for speed
object_freq["object_key"] = object_freq["object_name_clean"].map(key_norm)

# -----------------------------
# Coverage using caches + failure short-circuit (CLEAN names)
# -----------------------------
mapped: dict[str, str] = {}     # clean_name -> main_id
unmapped: list[str] = []        # clean_name
skipped_due_to_failures = 0

for obj_clean, k in zip(object_freq["object_name_clean"].tolist(), object_freq["object_key"].tolist()):
    # exact-match skip for known no_match failures
    if (obj_clean in no_match_query_names) or (k in no_match_alias_norms):
        skipped_due_to_failures += 1
        continue

    main_id = alias2main.get(k) or name2main.get(k)
    if isinstance(main_id, str) and main_id.strip():
        mapped[obj_clean] = main_id
    else:
        unmapped.append(obj_clean)

print("Skipped due to known no_match failures:", skipped_due_to_failures)
print(f"Mapped via existing SIMBAD data: {len(mapped):,}")
print(f"Unmapped                       : {len(unmapped):,}")
print(f"Coverage (mapped / total clean): {len(mapped) / max(1,len(object_freq)):.3%}")

# Unmapped prioritized by freq (clean)
unmapped_set = set(unmapped)
unmapped_df = (
    object_freq[object_freq["object_name_clean"].isin(unmapped_set)]
    .copy()
    .sort_values("total_n_papers", ascending=False)
    .reset_index(drop=True)
)
print("Unmapped objects with frequencies:", unmapped_df.shape)
print(unmapped_df[["object_name_clean", "total_n_papers"]].head(20))

# -----------------------------
# Query SIMBAD for remaining unmapped, append to NAME cache + failures
# -----------------------------

# create _simbad and run query loop
_simbad = Simbad()
_simbad.TIMEOUT = 30
_simbad.add_votable_fields("otype")
def query_simbad_with_meta(query_name: str) -> dict:
    """
    Returns:
    status: 'ok' | 'no_match' | 'error'
    main_id: str | None
    otype: str | None
    error: str | None
    """
    try:
        t = _simbad.query_object(query_name)
        if t is None or len(t) == 0:
            return {"status": "no_match", "main_id": None, "otype": None, "error": None}
        row = t[0]
        raw_main = row["MAIN_ID"] if "MAIN_ID" in row.colnames else row[row.colnames[0]]
        main_id = " ".join(to_text(raw_main).split()) if raw_main is not None else None
        raw_otype = None
        if "OTYPE" in row.colnames:
            raw_otype = row["OTYPE"]
        elif "OTYPE_S" in row.colnames:
            raw_otype = row["OTYPE_S"]
        otype = to_text(raw_otype) if raw_otype is not None else None
        if not main_id:
            return {"status": "error", "main_id": None, "otype": None, "error": "empty_MAIN_ID"}
        return {"status": "ok", "main_id": main_id, "otype": otype, "error": None}
    except Exception as e:
        return {"status": "error", "main_id": None, "otype": None, "error": str(e)[:300]}
MAX_TO_RESOLVE = 200_000
SLEEP_SECONDS  = 0.05
to_resolve = unmapped_df["object_name_clean"].head(MAX_TO_RESOLVE).tolist()
print("Attempting SIMBAD queries for:", len(to_resolve), "CLEAN names (top by total_n_papers)")
ok_cnt = no_match_cnt = err_cnt = 0
new_rows = 0
new_failure_rows = 0
# Optional quick visibility (first few queries)
print("First 20 to_resolve:", to_resolve[:20])
with NAME_CACHE_PATH.open("a") as fout, FAILURES_PATH.open("a") as ff:
    for obj_clean in tqdm(to_resolve, desc="Resolving via SIMBAD"):
        if not obj_clean:
            continue
        k = key_norm(obj_clean)
        # Skip if known no_match failures (exact match)
        if (obj_clean in no_match_query_names) or (k in no_match_alias_norms):
            continue
        # Skip only if ALREADY RESOLVED (non-empty main_id). Do NOT skip negative/None placeholders.
        existing_main = alias2main.get(k) or name2main.get(k)
        if isinstance(existing_main, str) and existing_main.strip():
            continue
        res = query_simbad_with_meta(obj_clean)
        if res["status"] == "ok":
            ok_cnt += 1
        elif res["status"] == "no_match":
            no_match_cnt += 1
        else:
            err_cnt += 1
        record = {
            "query_name": obj_clean,
            "alias_norm": k,
            "status": res["status"],
            "main_id": res["main_id"],
            "otype": res["otype"],
            "error": res["error"],
        }
        # Update in-memory cache (may be None; that's OK)
        name2main[k] = res["main_id"]
        fout.write(json.dumps(record, ensure_ascii=False) + "\n")
        new_rows += 1
        # Only write failures when it's a true no_match
        if res["status"] == "no_match":
            fail_rec = {
                "query_name": obj_clean,
                "alias_norm": k,
                "status": "no_match",
                "main_id": None,
                "otype": None,
                "error": None,
            }
            ff.write(json.dumps(fail_rec, ensure_ascii=False) + "\n")
            no_match_query_names.add(obj_clean)
            no_match_alias_norms.add(k)
            new_failure_rows += 1
        # Light throttle
        if SLEEP_SECONDS:
            time.sleep(SLEEP_SECONDS)
print("SIMBAD query summary:")
print("  ok      :", ok_cnt)
print("  no_match:", no_match_cnt)
print("  error   :", err_cnt)
print("  rows appended to NAME_CACHE_PATH:", new_rows)
print("  new failures appended:", new_failure_rows)
# Optional: recompute coverage over CLEAN names
mapped2 = 0
for obj_clean, k in zip(object_freq["object_name_clean"].tolist(), object_freq["object_key"].tolist()):
    main_id = alias2main.get(k) or name2main.get(k)
    if isinstance(main_id, str) and main_id.strip():
        mapped2 += 1
print(f"Post-query coverage (clean): {mapped2}/{len(object_freq)} = {mapped2 / max(1,len(object_freq)):.3%}")


Loaded 176014 alias→main_id entries from simbad_alias_cache
Loaded 518655 name→main_id entries from simbad_name_resolution_cache_llm_objects_with_otype_and_errors_clean.jsonl
Loaded no_match failures: query_name: 355519 | alias_norm: 355527
Unique object_name_norm (raw): 339495
Unique object_name_clean     : 339494
object_freq shape: (339494, 2)
Total unique CLEAN object names (tier3 unresolved): 339,494
Skipped due to known no_match failures: 155958
Mapped via existing SIMBAD data: 140,991
Unmapped                       : 42,545
Coverage (mapped / total clean): 41.530%
Unmapped objects with frequencies: (42545, 3)
   object_name_clean  total_n_papers
0     PN G035.9-01.1              17
1          UDS-10246              10
2          UDS 90845              10
3          UDS 46645              10
4            Mrk 945              10
5            Mrk0004              10
6          UDS 408.0              10
7          UDS 37091              10
8            Mrk0007              10
9      

Resolving via SIMBAD: 100%|██████████| 42545/42545 [2:29:29<00:00,  4.74it/s]  


SIMBAD query summary:
  ok      : 17855
  no_match: 24657
  error   : 1
  rows appended to NAME_CACHE_PATH: 42513
  new failures appended: 24657
Post-query coverage (clean): 158874/339494 = 46.797%


In [None]:
# ============================================================
# Section 8b: Materialize resolved concept–object edges (SIMBAD main_id)
#
# INPUTS:
#   - RUN_DIR / concept_object_edges_unresolved_llm.csv.gz
#       columns: [label, object_name_norm, n_papers, total_weight]
#   - In-memory caches (post-resolution):
#       alias2main (key -> main_id)   [optional]
#       name2main  (key -> main_id or None)
#   - (Optional) otype_by_main (main_id -> otype) if want noReg filtering
#
# OUTPUTS:
#   - RUN_DIR / concept_object_edges_resolved_simbad_only_llm.csv.gz
#   - RUN_DIR / concept_object_edges_resolved_simbad_only_llm_noReg.csv.gz  (optional)
#   - RUN_DIR / simbad_object_catalog_llm.csv.gz                             (optional)
# ============================================================

import pandas as pd
import numpy as np

# ---- Safety checks: required in-memory state from Section 8 ----
if "name2main" not in globals() or not isinstance(name2main, dict):
    raise RuntimeError("name2main not found. Run Section 8 (SIMBAD resolution) before running Section 8b.")

if "key_norm" not in globals() or "safe_clean_name" not in globals():
    raise RuntimeError("Missing helpers (key_norm/safe_clean_name). Run Section 8 before running Section 8b.")

# alias2main is optional (may be empty if alias cache not provided)
if "alias2main" not in globals() or not isinstance(alias2main, dict):
    alias2main = {}
    print("[INFO] alias2main not available; proceeding with name2main only.")

# 1) Load unresolved concept–object edges
co_unres_path = RUN_DIR / "concept_object_edges_unresolved_llm.csv.gz"
if not co_unres_path.exists():
    raise FileNotFoundError(f"Missing unresolved edges file: {co_unres_path}")

co = pd.read_csv(
    co_unres_path,
    dtype={"label": "int32", "object_name_norm": "string", "n_papers": "int32", "total_weight": "float32"},
    low_memory=False,
)

# 2) Clean + key
co["object_name_clean"] = (
    co["object_name_norm"]
    .astype("string")
    .fillna("")
    .map(lambda s: safe_clean_name(s))
)
co = co[co["object_name_clean"].str.len() > 0].copy()
co["object_key"] = co["object_name_clean"].map(key_norm)

# 3) Map to SIMBAD main_id (object_id)
def map_main_id(k: str) -> str | None:
    mid = alias2main.get(k)
    if isinstance(mid, str) and mid.strip():
        return mid
    mid = name2main.get(k)
    if isinstance(mid, str) and mid.strip():
        return mid
    return None

co["object_id"] = co["object_key"].map(map_main_id).astype("string")

n_total_rows = len(co)
n_resolved_rows = int(co["object_id"].notna().sum())

print("Rows total:", f"{n_total_rows:,}")
print("Rows with resolved object_id:", f"{n_resolved_rows:,}", f"({n_resolved_rows/max(1,n_total_rows):.2%})")

# 4) Keep SIMBAD-only rows
co_res = co.dropna(subset=["object_id"]).copy()

# 5) Aggregate by (label, object_id) to avoid duplicates from aliasing
co_res_agg = (
    co_res
    .groupby(["label", "object_id"], as_index=False)
    .agg(
        n_papers=("n_papers", "sum"),
        total_weight=("total_weight", "sum"),
        raw_name_count=("object_name_clean", "nunique"),
    )
)

co_res_agg["n_papers"] = co_res_agg["n_papers"].astype("int32")
co_res_agg["total_weight"] = co_res_agg["total_weight"].astype("float32")
co_res_agg["raw_name_count"] = co_res_agg["raw_name_count"].astype("int16")

print("Resolved aggregated edges:", co_res_agg.shape)
print("Unique objects:", co_res_agg["object_id"].nunique())

# 6) Write resolved edges (renamed; no tier wording)
OUT_RES = RUN_DIR / "concept_object_edges_resolved_simbad_only_llm.csv.gz"
co_res_agg.to_csv(OUT_RES, index=False, compression="gzip")
print("Wrote:", OUT_RES)

# 7) Optional: noReg filtering if otype_by_main exists
def is_region_like(main_id: str, otype: str | None) -> bool:
    s = (main_id or "").upper()
    t = (otype or "").upper()
    # conservative: exclude SIMBAD object types containing REGION/FIELD
    if "REG" in t or "FIELD" in t:
        return True
    # also exclude main IDs that are clearly NAME Field/Region
    if s.startswith("NAME ") and ("FIELD" in s or "REGION" in s):
        return True
    return False

if "otype_by_main" in globals() and isinstance(otype_by_main, dict) and len(otype_by_main) > 0:
    mask_bad = co_res_agg["object_id"].map(lambda mid: is_region_like(mid, otype_by_main.get(mid))).astype(bool)
    co_res_noreg = co_res_agg[~mask_bad].copy()

    OUT_NOREG = RUN_DIR / "concept_object_edges_resolved_simbad_only_llm_noReg.csv.gz"
    co_res_noreg.to_csv(OUT_NOREG, index=False, compression="gzip")
    print("Wrote:", OUT_NOREG, "| dropped:", int(mask_bad.sum()))
else:
    print("[SKIP] otype_by_main not available; skipping noReg output")

# 8) Optional: object catalog (object_id -> otype) if available
if "otype_by_main" in globals() and isinstance(otype_by_main, dict) and len(otype_by_main) > 0:
    obj_ids = sorted(set(co_res_agg["object_id"].astype("string").tolist()))
    obj_df = pd.DataFrame({
        "object_id": obj_ids,
        "otype": [otype_by_main.get(mid) for mid in obj_ids],
    })

    OUT_OBJ = RUN_DIR / "simbad_object_catalog_llm.csv.gz"
    obj_df.to_csv(OUT_OBJ, index=False, compression="gzip")
    print("Wrote:", OUT_OBJ, "| rows:", len(obj_df))
else:
    print("[SKIP] otype_by_main not available; skipping object catalog")

Rows total: 6,816,113
Rows with resolved object_id: 4,258,978 (62.48%)
Resolved aggregated edges: (3731041, 5)
Unique objects: 103214
Wrote: /Users/jinchuli/projects/astro-llm-tools/data/llm_object_extraction_final/full_extraction_v2/concept_object_edges_tier3_resolved_simbad_only_llm_v2.csv.gz
Wrote: /Users/jinchuli/projects/astro-llm-tools/data/llm_object_extraction_final/full_extraction_v2/concept_object_edges_tier3_resolved_simbad_only_llm_v2_noReg.csv.gz | dropped: 6097
Wrote: /Users/jinchuli/projects/astro-llm-tools/data/llm_object_extraction_final/full_extraction_v2/simbad_object_catalog_llm_v2.csv.gz | rows: 103214


author note: the above shows 103,214 objects in the dataset instead of 100,560 is because our simbad resolution cache contains some objects not present in the final run. They are subsequently discarded in the matrix building phase and the effective object count is 100,560.