In [None]:
# --- AKI (t24) LLM score collection: structured outputs + batching + calibration

import os
import json
import math
import time
from typing import List, Dict, Optional

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
from openai import OpenAI


# -----------------------
# 0) Load data + columns
# -----------------------
DATA_PATH = "/content/drive/MyDrive/t60_reg_data.csv"  # <-- change to your local path
ID_COLS = ["pat_id"]
TARGET_COL = "creatinine_t60"   # regression target

df = pd.read_csv(DATA_PATH)

FEATURE_COLS = [c for c in df.columns if c not in ID_COLS + [TARGET_COL]]
print("n_features =", len(FEATURE_COLS))


# -----------------------
# 1) Prompt scaffolding
# -----------------------
DATASET_BACKGROUND = """
We have adult patients who underwent cardiac surgery.
For each patient, we have perioperative and ICU EMR-derived features:
demographics, comorbidities, vitals/hemodynamics, ventilator settings,
fluid balance, labs, procedures, and medication doses.
The modeling goal is sparse linear regression to predict postoperative
serum creatinine at 60 hours (creatinine_t60).
""".strip()

OUTCOME_NAME = "postoperative serum creatinine at 60 hours (creatinine_t60)"

def infer_feature_type(series: pd.Series) -> str:
    """Light heuristic for the LLM (helps it reason about binary flags vs continuous)."""
    if pd.api.types.is_bool_dtype(series):
        return "binary"
    if pd.api.types.is_numeric_dtype(series):
        vals = series.dropna().unique()
        if len(vals) <= 3 and set(vals).issubset({0, 1}):
            return "binary"
        return "numeric"
    return "categorical_or_text"

def build_prompt_for_batch(feature_names: List[str]) -> str:
    # Include a tiny bit of metadata that is cheap but useful:
    payload = []
    for i, name in enumerate(feature_names):
        ftype = infer_feature_type(df[name])
        payload.append({"id": i, "name": name, "type": ftype})

    features_json = json.dumps(payload, indent=2)

    prompt = f"""
You are a cardiothoracic ICU clinician and a biostatistician.

Background:
{DATASET_BACKGROUND}

Task:
We are building a sparse linear regression model to predict {OUTCOME_NAME}.
Each input is a feature (column) from the tabular dataset.

For each feature, rate how a priori clinically relevant it is for predicting {OUTCOME_NAME}
in adult postcardiac-surgery ICU patients. Base your judgment on typical knowledge about:
kidney perfusion/hemodynamics, AKI risk factors, nephrotoxic medications, and kidney-related labs.

Scoring rules (importance):
- Use an integer score from 1 to 5:
  1 = very unlikely to be useful
  2 = weak/indirect relevance
  3 = moderately relevant
  4 = clearly relevant
  5 = directly and strongly related to kidney function or AKI mechanisms

Input features (JSON list):
{features_json}

Return a JSON object with key "scores" containing a list of objects.
Each object must include:
- id (same as input id)
- name (copy exactly)
- importance (integer 1..5)
- reason (1–2 concise sentences)
""".strip()

    return prompt


# -----------------------------------
# 2) Structured output schema (Pydantic)
# -----------------------------------
class ScoreItem(BaseModel):
    id: int
    name: str
    importance: int = Field(ge=1, le=5)
    reason: str

class ScoreBatch(BaseModel):
    scores: List[ScoreItem]


# -----------------------
# 3) OpenAI client + call
# -----------------------
# Set env var: export OPENAI_API_KEY="..."
client = OpenAI(api_key="sk-vJk8GQyC63N0G8cB5t1P-AFz0L62KMvAxmpBhaoldPT3BlbkFJAWwkbEOrbZCTd06GiwrcJBM24oTugZbuNwnlFSFQMA")

def call_llm_for_batch(prompt: str, model: str = "gpt-4o-2024-08-06") -> ScoreBatch:
    """
    Uses Responses API structured parsing, so you don't need to manually parse JSON text.
    """
    resp = client.responses.parse(
        model=model,
        input=[
            {"role": "system", "content": "You output only structured JSON that matches the schema."},
            {"role": "user", "content": prompt},
        ],
        text_format=ScoreBatch,
        temperature=0,
        max_output_tokens=2000,
    )
    return resp.output_parsed


def get_llm_importance_scores(
    feature_names: List[str],
    batch_size: Optional[int] = None,
    model: str = "gpt-4o-2024-08-06",
    sleep_s: float = 0.25,
    max_retries: int = 6,
    cache_path: str = "aki_llm_scores_raw.json",
) -> Dict[str, Dict]:
    """
    Returns:
      dict[feature_name] = {"importance": int(1..5), "reason": str}
    Caches incrementally so you can resume.
    """

    # default: ceil(sqrt(p)) like the paper’s heuristic
    if batch_size is None:
        batch_size = int(math.ceil(math.sqrt(len(feature_names))))

    # load cache if exists
    if os.path.exists(cache_path):
        with open(cache_path, "r") as f:
            results = json.load(f)
        print(f"[cache] loaded {len(results)} scores from {cache_path}")
    else:
        results = {}

    remaining = [f for f in feature_names if f not in results]
    print(f"Remaining features to score: {len(remaining)} (batch_size={batch_size})")

    for start in range(0, len(remaining), batch_size):
        batch = remaining[start : start + batch_size]
        prompt = build_prompt_for_batch(batch)

        # retry loop (simple exponential backoff)
        attempt = 0
        while True:
            try:
                parsed = call_llm_for_batch(prompt, model=model)
                # map back
                for item in parsed.scores:
                    results[item.name] = {
                        "importance": int(item.importance),
                        "reason": item.reason,
                    }
                # write cache
                with open(cache_path, "w") as f:
                    json.dump(results, f, indent=2)
                print(f"[ok] scored batch {start//batch_size + 1} / {math.ceil(len(remaining)/batch_size)}")
                break
            except Exception as e:
                attempt += 1
                if attempt > max_retries:
                    raise RuntimeError(f"Failed after {max_retries} retries. Last error: {e}") from e
                backoff = (2 ** attempt) * 0.5
                print(f"[retry {attempt}] error={e} | sleeping {backoff:.1f}s")
                time.sleep(backoff)

        time.sleep(sleep_s)

    return results


# -----------------------
# 4) batch calibration (paper-style idea)
# -----------------------
def calibrate_batch_scales(
    feature_names: List[str],
    raw_scores: Dict[str, Dict],
    batch_size: int,
    model: str = "gpt-4o-2024-08-06",
) -> Dict[str, float]:
    """
    Produces per-feature calibrated scores in [0.1, 1.0] (useful for penalty factors).
    Idea:
      - group features by the same batching you used
      - take the max-scored feature per batch
      - ask LLM to rescore just those maxima (global comparison)
      - weight each batch by normalized rescored maximum
      - combine and rescale to [0.1, 1.0]
    """
    # recreate batches
    batches = [feature_names[i:i+batch_size] for i in range(0, len(feature_names), batch_size)]

    # pick max feature per batch
    max_feats = []
    for b in batches:
        best = max(b, key=lambda nm: raw_scores[nm]["importance"])
        max_feats.append(best)

    # rescore maxima globally
    prompt = build_prompt_for_batch(max_feats)
    parsed = call_llm_for_batch(prompt, model=model)

    # batch weights from rescored maxima
    rescored = {it.name: it.importance for it in parsed.scores}
    s = np.array([rescored.get(f, 1) for f in max_feats], dtype=float)
    w = s / (s.sum() + 1e-12)

    # weighted concat
    calibrated = {}
    for bi, b in enumerate(batches):
        for nm in b:
            calibrated[nm] = raw_scores[nm]["importance"] * w[bi]

    # rescale to [0.1, 1.0]
    vals = np.array(list(calibrated.values()), dtype=float)
    vmin, vmax = vals.min(), vals.max()
    for k in calibrated:
        if vmax > vmin:
            z = (calibrated[k] - vmin) / (vmax - vmin)
        else:
            z = 0.0
        calibrated[k] = 0.1 + 0.9 * z

    return calibrated


# -----------------------
# 5) Run + export to aki_weights.csv format
# -----------------------
MODEL = "gpt-4o-2024-08-06"

raw = get_llm_importance_scores(
    FEATURE_COLS,
    batch_size=None,
    model=MODEL,
    cache_path="aki_llm_scores_raw.json",
)

# Save raw weights (1..5)
raw_df = pd.DataFrame({
    "value": list(raw.keys()),
    "importance": [raw[k]["importance"] for k in raw],
    "reason": [raw[k]["reason"] for k in raw],
})
raw_df.to_csv("aki_weights_raw.csv", index=False)

# Optional calibrated weights in [0.1, 1.0]
batch_size_used = int(math.ceil(math.sqrt(len(FEATURE_COLS))))
cal = calibrate_batch_scales(FEATURE_COLS, raw, batch_size=batch_size_used, model=MODEL)

cal_df = pd.DataFrame({
    "value": list(cal.keys()),
    "importance": [cal[k] for k in cal],   # now continuous in [0.1, 1.0]
})
cal_df.to_csv("aki_weights.csv", index=False)