In [1]:
from src.config import (
    GPT_KEY, CLAUDE_KEY, LLMAPI_KEY,
    GPT_MODEL, CLAUDE_MODEL, LLMAPI_MODEL,
    COMP_DIR, ASSR_DIR,
    MAX_TOKENS, TEMPERATURE
)
from src.llm_clients import GPTClient, ClaudeClient, LlamaAPIClient
from src.schemas import Assertion, AssertionDoc
from src.prompts import get_prompt
from src.json_utils import parse_or_fix
from src.utils import ensure_dir_exists

import pandas as pd
import json
import tqdm.auto as tqdm

In [2]:
# Prepare output repository
ensure_dir_exists(ASSR_DIR)

# Load the second-stage output (completion).
def load_completions(model_tag: str) -> pd.DataFrame:
    rows = []
    with open(f"{COMP_DIR}/{model_tag}.jsonl", "r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            for s in obj["sentences"]:
                rows.append({
                    "pmid": obj["pmid"],
                    "id": s["id"],
                    "sentence": s["resolved"]
                })
    return pd.DataFrame(rows)

def load_completions(tag: str) -> pd.DataFrame:
    rows = []
    with open(f"{COMP_DIR}/{tag}.jsonl", "r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            for s in obj["sentences"]:
                rows.append({
                    "pmid": obj["pmid"],
                    "id": s["id"],
                    "sentence": s["resolved"]
                })
    return pd.DataFrame(rows)

# Initialize LLM Client
clients = {
    "gpt4o": GPTClient(model=GPT_MODEL, key=GPT_KEY),
    "claude": ClaudeClient(model=CLAUDE_MODEL, key=CLAUDE_KEY),
    "llama": LlamaAPIClient(model=LLMAPI_MODEL, key=LLMAPI_KEY)
}

In [3]:
# Build prompt block
system_prompt, fewshot = get_prompt("assertion")

def build_msgs(sent_id: int, sentence: str) -> list[dict]:
    user = f"Sentence [{sent_id}]: {sentence}\nReturn JSON:"
    return fewshot + [{"role": "user", "content": user}]

In [4]:
# Execute assertion triplet extraction
def run_assertion(tag: str):
    cli = clients[tag]
    out_path = f"{ASSR_DIR}/{tag}.jsonl"
    fail_path = f"{ASSR_DIR}/{tag}.fail.txt"

    df = load_completions(tag)
    grouped = df.groupby("pmid")

    with open(out_path, "w", encoding="utf-8") as fw, \
         open(fail_path, "w", encoding="utf-8") as fail_log:

        for pmid, group in tqdm.tqdm(grouped, desc=f"{tag} extraction"):
            assertions = []

            for _, row in group.iterrows():
                msgs = build_msgs(row["id"], row["sentence"])
                try:
                    raw = cli.run(msgs, task_id=f"{pmid}-{row['id']}")
                    triple = parse_or_fix(raw, cli, msgs, target_class=Assertion)
                    assertions.append(triple)
                except Exception as err:
                    print(f"[{tag}][{pmid}][{row['id']}] failed → {err}")
                    fail_log.write(f"{pmid}\t{row['id']}\t{err}\n")
                    continue

            doc = AssertionDoc(pmid=pmid, assertion=assertions)
            fw.write(doc.model_dump_json(ensure_ascii=False) + "\n")

    print(f"{tag} extract assertion -> {out_path}")

In [5]:
# Run models
run_assertion("gpt4o")
run_assertion("claude")
run_assertion("llama")

gpt4o extraction:   0%|          | 0/83 [00:00<?, ?it/s]

gpt4o extract assertion -> assertion/gpt4o.jsonl


claude extraction:   0%|          | 0/88 [00:00<?, ?it/s]

[Retry 1] Invalid JSON: 1 validation error for Assertion
object
  Field required [type=missing, input_value={'id': 5, 'subject': 'sta...he simulated sequences'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.11/v/missing
-> Raw output: {"id": 5, "subject": "statistical models", "predicate": "perform well", "condition": "when there was no or limited purifying selection in the simulated sequences"} 
[Retry 1] Invalid JSON: 1 validation error for Assertion
object
  Field required [type=missing, input_value={'id': 12, 'subject': 'ex...predicate': 'increases'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.11/v/missing
-> Raw output: {"id": 12, "subject": "expression level of the apoptosis-related proteins Caspase 3 and PARP", "predicate": "increases"} 
[Retry 1] Invalid JSON: 1 validation error for Assertion
object
  Field required [type=missing, input_value={'id': 7, 'subject': 'the... the fourth experiment'}, input_typ

llama extraction:   0%|          | 0/73 [00:00<?, ?it/s]

llama extract assertion -> assertion/llama.jsonl
