<a href="https://colab.research.google.com/github/ArefMahjoubfar/ArefMahjoubfar/blob/main/Extraction_Task_explainers/Extraction_Task_explainers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
===============================================================================
Project: Interpretation and Explainability for LLM Extraction using SHAP
Author: Aref
Part: 1 of N
===============================================================================

This module implements the foundational building blocks for an end-to-end
experimental pipeline that:

1) Loads the CT-RATE dataset from Hugging Face (CT reports with labeled fields).
2) Asks LLMs to extract specific fields (presence, categorical, quantitative).
3) Evaluates extraction quality with common metrics.
4) Computes uncertainty (sample consistency, token-level probabilities)
   and performs discrimination/calibration analyses (in Part 2).
5) Adds SHAP explanations for token-level influence on confidence (in Part 2).

This file contains:
- Imports, logging utilities, reproducibility helpers.
- Pydantic-based configuration models for fields, models, settings, experiment.
- Dataset loader for CT-RATE.
- Prompt builder for clear, JSON-only extraction prompts.
- Hugging Face Causal LM client with token-level log-probabilities (teacher forcing).
- Output parsing, normalization, numeric extraction, and field evaluation helpers.

Installation (suggested baseline):
    pip install datasets pandas numpy scikit-learn matplotlib seaborn sentence-transformers nltk shap transformers torch pydantic

"""

from __future__ import annotations

# -----------------------------
# Standard library imports
# -----------------------------
import os
import re
import json
import time
import random
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

# -----------------------------
# Third-party imports
# -----------------------------
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, validator

from datasets import load_dataset

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)

# Hugging Face Transformers (for open local models, e.g., LLaMA)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


# =============================================================================
# Logging and reproducibility
# =============================================================================

def setup_logging(log_dir: str = "logs", level: int = logging.INFO) -> None:
    """
    Configure Python logging to file + STDOUT.

    Args:
        log_dir: Directory where log files should be written.
        level: Logging verbosity (e.g., logging.INFO).

    Side effects:
        - Creates the log directory if it does not exist.
        - Sets up a new log file with a timestamp.
        - Configures a console handler for real-time feedback.
    """
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f"run_{time.strftime('%Y%m%d_%H%M%S')}.log")

    logging.basicConfig(
        level=level,
        format="%(asctime)s | %(levelname)s | %(message)s",
        handlers=[logging.FileHandler(log_path), logging.StreamHandler()],
    )
    logging.info(f"Logging to {log_path}")


def set_seed(seed: int = 42) -> None:
    """
    Set seeds for Python random, NumPy, and PyTorch for reproducibility.

    Args:
        seed: Integer seed to apply.

    Note:
        - Determinism is not guaranteed across all CUDA kernels.
        - Use this to reduce variance across repeated runs.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# =============================================================================
# Configuration models (Pydantic)
# =============================================================================

class FieldSpec(BaseModel):
    """
    Specification for a single field to extract from a CT report.

    Types:
        - qualitative_presence: binary detection (present/absent)
        - quantitative_value: numeric extraction with tolerance
        - categorical_value: normalized string category matching

    Attributes:
        name: Logical name of the field (e.g., "pulmonary_embolism").
        type: One of {"qualitative_presence","quantitative_value","categorical_value"}.
        target_column: Name of the dataset column containing the label for this field.
        value_tolerance: For quantitative_value, absolute tolerance to accept as correct.
        value_extractor_regex: Optional regex to capture a numeric group for extraction.
        category_map: Optional normalization map (raw -> canonical) for categorical_value.
    """
    name: str
    type: str = Field(regex=r"^(qualitative_presence|quantitative_value|categorical_value)$")
    target_column: str
    value_tolerance: Optional[float] = None
    value_extractor_regex: Optional[str] = None
    category_map: Optional[Dict[str, str]] = None

    @validator("value_tolerance")
    def _check_tolerance_for_quantitative(cls, v, values):
        if values.get("type") == "quantitative_value" and v is None:
            raise ValueError("quantitative_value requires value_tolerance (absolute tolerance).")
        return v


class ModelSpec(BaseModel):
    """
    Model configuration and I/O behavior.

    model_type:
        - open_hf_causal_lm: Local Hugging Face Causal LM (e.g., LLaMA instruct).
        - api_openai: Placeholder for ChatGPT endpoints (implemented later if needed).
        - api_deepseek: Placeholder for DeepSeek R1 (implemented later if needed).

    Attributes:
        name: Short display name for reporting (e.g., "llama3_8b_instruct").
        model_type: One of {"open_hf_causal_lm","api_openai","api_deepseek"}.
        model_id: HF model id for local models (e.g., "meta-llama/Meta-Llama-3-8B-Instruct").
        max_new_tokens: Upper bound on generation length.
        top_p/top_k: Nucleus/top-k sampling controls.
        supports_logprobs: Whether client can return token-level log-probabilities.
        device: "cuda" or "cpu".
        dtype: "float16" (default) or "float32" for model weights.
    """
    name: str
    model_type: str = Field(regex=r"^(open_hf_causal_lm|api_openai|api_deepseek)$")
    model_id: Optional[str] = None
    max_new_tokens: int = 128
    top_p: Optional[float] = None
    top_k: Optional[int] = None
    supports_logprobs: bool = False
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: str = "float16"


class SettingsGrid(BaseModel):
    """
    Grid of generation settings used to probe model behavior.

    Attributes:
        temperature_list: List of temperatures (e.g., [0.0, 0.5, 1.0]).
        verbosity_list: Controls prompt style (["low","high"]).
        n_consistency_samples: Repetitions for sample consistency (phase 2).
        consistency_temperatures: Temperatures used for consistency sampling.
    """
    temperature_list: List[float] = Field(default_factory=lambda: [0.0, 0.5, 1.0])
    verbosity_list: List[str] = Field(default_factory=lambda: ["low", "high"])
    n_consistency_samples: int = 15
    consistency_temperatures: List[float] = Field(default_factory=lambda: [0.5, 1.0])


class ExperimentConfig(BaseModel):
    """
    Top-level configuration for a single experiment run.

    Attributes:
        dataset_name: Hugging Face dataset path/id for CT-RATE.
        dataset_split: Which split to use (e.g., "train", "validation", "test").
        text_column: Column with CT report text.
        fields: List of FieldSpec describing the extraction targets.
        models: List of ModelSpec for the models we will evaluate.
        settings: SettingsGrid for temperature/verbosity/consistency loops.
        output_dir: Where to write artifacts (CSVs, plots, SHAP exports).
        seed: Global seed for reproducibility.
        human_consistency_file: Optional CSV with human agreement for sample consistency.
        embedding_model: Sentence-embedding model name used in phase 2 (provided later).
    """
    dataset_name: str = "CT-RATE"  # replace if the HF path differs (e.g., "org/ct-rate")
    dataset_split: str = "train"
    text_column: str = "text"
    fields: List[FieldSpec]
    models: List[ModelSpec]
    settings: SettingsGrid
    output_dir: str = "artifacts"
    seed: int = 42
    human_consistency_file: Optional[str] = None
    embedding_model: str = "intfloat/e5-small-v2"


# =============================================================================
# Dataset loader (CT-RATE)
# =============================================================================

class CTRateDataset:
    """
    Wrapper around the CT-RATE dataset from Hugging Face.

    Expected schema:
        - text_column contains the CT report paragraph (default "text").
        - For each FieldSpec.target_column, a label column exists in the dataset.
          For qualitative_presence, the label is expected to be 0/1.
          For quantitative/categorical, adapt or create appropriate target columns.

    Usage:
        ds = CTRateDataset(cfg)
        ds.load()
        for sample in ds.iter_samples():
            ...  # sample = {"id": int, "text": str, "labels": {field_name: label}}
    """

    def __init__(self, cfg: ExperimentConfig):
        self.cfg = cfg
        self.ds = None
        self.df = None

    def load(self) -> None:
        """
        Load dataset into memory and enforce presence of required columns.
        """
        logging.info(f"Loading dataset: {self.cfg.dataset_name} [{self.cfg.dataset_split}]")
        self.ds = load_dataset(self.cfg.dataset_name, split=self.cfg.dataset_split)
        self.df = self.ds.to_pandas()

        # Check the text column exists
        assert self.cfg.text_column in self.df.columns, f"Missing text column: {self.cfg.text_column}"

        # Check all target columns exist
        for f in self.cfg.fields:
            assert f.target_column in self.df.columns, f"Missing target column: {f.target_column}"

        logging.info(f"Dataset loaded with shape: {self.df.shape}")

    def iter_samples(self):
        """
        Iterate samples as a lightweight generator.

        Yields:
            dict with:
                - id: integer row index
                - text: CT report paragraph
                - labels: mapping of field_name -> label (from target_column)
        """
        for i, row in self.df.iterrows():
            yield {
                "id": i,
                "text": row[self.cfg.text_column],
                "labels": {f.name: row[f.target_column] for f in self.cfg.fields},
            }


# =============================================================================
# Prompt builder
# =============================================================================

class PromptBuilder:
    """
    Builds consistent prompts for field extraction.

    Verbosity modes:
        - low: concise, direct JSON-only answer, no explanations.
        - high: allows internal reasoning but must still return JSON-only at the end.

    The schema we show the model is small and explicit to encourage structured output.
    """

    @staticmethod
    def make_schema(field: FieldSpec) -> Dict[str, Any]:
        """
        Generate a JSON schema snippet for the requested field.

        Returns:
            A minimal schema for the model to follow.
        """
        if field.type == "qualitative_presence":
            return {"field_name": field.name, "present": "boolean"}
        elif field.type == "quantitative_value":
            return {"field_name": field.name, "present": "boolean", "value": "number"}
        elif field.type == "categorical_value":
            return {"field_name": field.name, "present": "boolean", "value": "string"}
        else:
            raise ValueError(f"Unsupported field type: {field.type}")

    @staticmethod
    def prompt_text(field: FieldSpec, text: str, verbosity: str = "low") -> str:
        """
        Construct a prompt that asks the model to output ONLY a JSON object.

        Args:
            field: FieldSpec describing what to extract.
            text: CT report paragraph.
            verbosity: "low" or "high" affecting style guidance.

        Returns:
            A formatted prompt string.
        """
        schema = PromptBuilder.make_schema(field)
        schema_str = json.dumps(schema, indent=2)

        base_instr = (
            "You are an information extraction assistant for radiology CT reports. "
            "Extract the requested field from the report. Return ONLY a valid JSON object "
            "matching the given schema, with no additional text."
        )
        style_low = "Be concise. Do not add explanations."
        style_high = "Think carefully. You may reason internally, but return ONLY the final JSON."

        style = style_low if verbosity == "low" else style_high

        return (
            f"{base_instr}\n"
            f"{style}\n\n"
            f"Report:\n{text}\n\n"
            f"Field to extract: {field.name}\n"
            f"Output JSON schema:\n{schema_str}\n"
            f"Return ONLY JSON.\n"
        )


# =============================================================================
# Model clients
# =============================================================================

class BaseModelClient:
    """
    Abstract interface for model clients.

    Subclasses should implement:
        - generate(): produce text given a prompt and decoding settings.
        - get_token_logprobs(): optional per-token log-probabilities if supported.
    """

    def __init__(self, spec: ModelSpec):
        self.spec = spec

    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Generate a text completion.

        Returns:
            dict with keys:
                - text: generated continuation (ideally JSON per our prompt)
                - raw: full decoded sequence (prompt + continuation), when available
        """
        raise NotImplementedError

    def get_token_logprobs(self, text: str) -> Optional[List[float]]:
        """
        Compute per-token log probabilities for a given text (teacher-forced).
        Return None if not supported by the client.

        Note:
            This is useful as a proxy for confidence (to be calibrated later).
        """
        return None


class HFOpenCausalLMClient(BaseModelClient):
    """
    Hugging Face Causal LM client (e.g., latest LLaMA Instruct).

    - Loads tokenizer and model weights locally.
    - Supports generation with temperature/top_p/top_k controls.
    - Provides token-level log-probabilities for an arbitrary text via teacher forcing.

    Caution:
        - Teacher-forced log-probs reflect how plausible the sequence is under the model,
          not strictly the model's generation-time sampling probabilities (but they correlate).
    """

    def __init__(self, spec: ModelSpec):
        super().__init__(spec)
        assert spec.model_id, "HF client requires model_id for open_hf_causal_lm"
        logging.info(f"Loading HF model: {spec.model_id} on device {spec.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(spec.model_id)
        dtype = torch.float16 if spec.dtype == "float16" else torch.float32

        self.model = AutoModelForCausalLM.from_pretrained(
            spec.model_id,
            torch_dtype=dtype,
            device_map="auto" if spec.device == "cuda" else None
        )
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Generate a continuation for the given prompt.

        Implementation details:
            - Uses do_sample=True when temperature > 0, otherwise greedy-ish (temperature clamped).
            - Decodes output and returns only the continuation (not the prompt).
        """
        if seed is not None:
            set_seed(seed)

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        gen_kwargs = dict(
            temperature=max(0.01, float(temperature)),  # avoid exact 0 for some samplers
            do_sample=temperature > 0,
            max_new_tokens=max_new_tokens or self.spec.max_new_tokens,
        )
        if top_p is not None:
            gen_kwargs["top_p"] = top_p
        if top_k is not None:
            gen_kwargs["top_k"] = top_k

        output_ids = self.model.generate(**inputs, **gen_kwargs)
        full_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Extract continuation by trimming the prompt portion
        prompt_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
        continuation = full_text[len(prompt_text):]

        return {"text": continuation.strip(), "raw": full_text}

    @torch.no_grad()
    def get_token_logprobs(self, text: str) -> Optional[List[float]]:
        """
        Compute per-token log-probabilities for an arbitrary text using teacher forcing.

        Steps:
            - Tokenize the full text.
            - Forward with labels to get next-token logits.
            - Compute log-softmax, then gather log-prob of each gold next token.
            - Returns a list of length (len(tokens)-1), each the log-prob for token t given tokens < t.

        Returns:
            List[float] of per-token log-probs, or None on failure.
        """
        try:
            enc = self.tokenizer(text, return_tensors="pt").to(self.model.device)
            outputs = self.model(**enc, labels=enc["input_ids"])  # labels trigger loss computation
            logits = outputs.logits[:, :-1, :]          # predict token t+1 from positions up to t
            labels = enc["input_ids"][:, 1:]            # gold next tokens

            logprobs = torch.log_softmax(logits, dim=-1)             # [B, T-1, V]
            token_logprobs = logprobs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # [B, T-1]

            return token_logprobs.flatten().tolist()
        except Exception as e:
            logging.warning(f"HF get_token_logprobs failed: {e}")
            return None


# =============================================================================
# Parsing and evaluation helpers
# =============================================================================

class ExtractedRecord(BaseModel):
    """
    Parsed model output for a single field extraction.

    Attributes:
        field_name: Name of the field this JSON refers to.
        present: True/False indicating whether the field is present in the paragraph.
        value: Optional numeric/string value (quantitative/categorical).
    """
    field_name: str
    present: bool
    value: Optional[Union[float, str]] = None


def parse_json_output(raw_text: str) -> Optional[Dict[str, Any]]:
    """
    Parse a JSON object from model output.

    Behavior:
        - Attempts to find the first {...} block (robust against pre/post text).
        - If found, loads that as JSON; otherwise tries to parse entire string as JSON.
        - Returns None on failure.

    Args:
        raw_text: The raw text produced by the model.

    Returns:
        Dict representing the JSON object, or None.
    """
    try:
        m = re.search(r"\{.*\}", raw_text, flags=re.DOTALL)
        if m:
            return json.loads(m.group(0))
        return json.loads(raw_text)
    except Exception:
        return None


def normalize_categorical(value: Optional[str], category_map: Optional[Dict[str, str]]) -> str:
    """
    Normalize categorical string values for robust comparison.

    Strategy:
        - Lowercase + strip whitespace.
        - If category_map is provided and contains the key, map to the canonical form.
        - Otherwise return the lowercased value.

    Args:
        value: Raw predicted or gold categorical label.
        category_map: Optional normalization dictionary.

    Returns:
        Normalized string (possibly empty if input is None).
    """
    if value is None:
        return ""
    v = str(value).strip().lower()
    if category_map and v in category_map:
        return category_map[v]
    return v


def extract_numeric(text: str, regex: Optional[str] = None) -> Optional[float]:
    """
    Extract a numeric value from a text string.

    Behavior:
        - If regex is provided, use it and return the first captured float group.
        - Otherwise, fallback to the first number-like token found.

    Args:
        text: Source string to search.
        regex: Optional capturing regex (must capture the numeric in group 1).

    Returns:
        float or None if no numeric can be parsed.
    """
    if regex:
        m = re.search(regex, text)
        if m:
            try:
                return float(m.group(1))
            except Exception:
                pass

    # Fallback: first number-like token
    m2 = re.search(r"[-+]?\d*\.?\d+", text)
    if m2:
        try:
            return float(m2.group(0))
        except Exception:
            return None
    return None


def evaluate_field(
    field: FieldSpec,
    gold: Any,
    pred: ExtractedRecord
) -> Tuple[int, float]:
    """
    Evaluate correctness for a single field prediction.

    Returns:
        (correct_binary, crude_score)
        - correct_binary: 1 if correct, else 0.
        - crude_score: a simple [0,1] score used in phase 1 (for ROC-AUC, etc.);
                       phase 2 adds calibrated uncertainty -> probability of correctness.

    Logic by type:
        - qualitative_presence:
            correct if bool(gold) == pred.present.
        - quantitative_value:
            correct if pred.present is True AND gold is numeric AND |pred.value - gold| <= tolerance.
            If gold is None and pred.present is False => correct (no value expected).
        - categorical_value:
            correct if pred.present is True AND normalized(pred.value) == normalized(gold).
            If no gold category and pred.present is False => correct.
    """
    if field.type == "qualitative_presence":
        correct = int(bool(gold) == bool(pred.present))
        return correct, float(correct)

    elif field.type == "quantitative_value":
        if pred.present and gold is not None and isinstance(gold, (int, float)) and pred.value is not None:
            try:
                pred_val = float(pred.value)
                tol = field.value_tolerance or 0.0
                correct = int(abs(pred_val - float(gold)) <= tol)
                return correct, float(correct)
            except Exception:
                return 0, 0.0
        # Both absent => correct
        if (gold is None) and (not pred.present):
            return 1, 1.0
        return 0, 0.0

    elif field.type == "categorical_value":
        gold_norm = normalize_categorical(str(gold), field.category_map) if gold is not None else ""
        pred_norm = normalize_categorical(str(pred.value), field.category_map) if pred.value is not None else ""

        if gold_norm and pred.present:
            correct = int(gold_norm == pred_norm)
            return correct, float(correct)

        if (not gold_norm) and (not pred.present):
            return 1, 1.0

        return 0, 0.0

    # Unknown type (should never happen due to validation)
    return 0, 0.0


def compute_metrics(
    y_true: List[int],
    y_pred: List[int],
    y_score: Optional[List[float]] = None
) -> Dict[str, Any]:
    """
    Compute standard binary classification metrics.

    Metrics:
        - confusion_matrix: TN, FP, FN, TP
        - accuracy, precision, recall, f1
        - sensitivity (== recall)
        - specificity = TN / (TN + FP)
        - roc_auc (if y_score provided and valid)

    Args:
        y_true: Ground truth binary labels.
        y_pred: Binary predictions.
        y_score: Optional probability-like scores for ROC-AUC.

    Returns:
        Dict with all computed metrics. ROC-AUC may be None if not computable.
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0

    out = {
        "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)},
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
        "specificity": specificity,
        "sensitivity": sensitivity,
        "roc_auc": None
    }

    if y_score is not None:
        # ROC-AUC can fail (e.g., only one class present)
        try:
            from sklearn.metrics import roc_auc_score
            out["roc_auc"] = float(roc_auc_score(y_true, y_score))
        except Exception:
            out["roc_auc"] = None

    return out


In [None]:
# =============================================================================
# Part 2: Uncertainty scoring
# =============================================================================

from sentence_transformers import SentenceTransformer
from numpy.linalg import norm

import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

class ConsistencyResult(BaseModel):
    """
    Stores computed consistency statistics across multiple generations for the same prompt.
    """
    samples: List[str]
    agreement_fraction: float
    embedding_cosine_mean: float
    bleu_mean: float
    human_agreement_fraction: Optional[float] = None


class UncertaintyScorer:
    """
    Computes:
        - Sample Consistency metrics:
            * Majority answer agreement fraction
            * Mean embedding cosine similarity
            * Mean BLEU score
            * Optional human annotation agreement
        - Token-level probability metrics:
            * Average token probability
            * Minimum token probability
        - Calibration diagnostics:
            * Expected Calibration Error (ECE)
            * Brier Score
            * Calibration plots
    """

    def __init__(self, cfg: ExperimentConfig):
        self.cfg = cfg
        self.embedder = SentenceTransformer(cfg.embedding_model)
        self.bleu_smooth = SmoothingFunction().method1
        self.human_map = None
        if cfg.human_consistency_file and os.path.exists(cfg.human_consistency_file):
            df = pd.read_csv(cfg.human_consistency_file)
            self.human_map = {
                int(r["sample_id"]): float(r["agree_count"]) / float(r["total"]) if r["total"] > 0 else None
                for _, r in df.iterrows()
            }

    @staticmethod
    def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
        return float(np.dot(a, b) / (norm(a) * norm(b) + 1e-8))

    def sample_consistency(
        self,
        client: BaseModelClient,
        prompt: str,
        n_samples: int = 15,
        temperature: float = 0.7,
        seed_base: int = 1000,
        sample_id: Optional[int] = None,
    ) -> ConsistencyResult:
        """
        Generate multiple outputs for the same prompt and compute agreement metrics.
        """
        gens = []
        for i in range(n_samples):
            res = client.generate(prompt, temperature=temperature, seed=seed_base + i)
            gens.append(res["text"].strip())

        # Majority agreement
        norm_gens = [re.sub(r"\s+", " ", g.lower()).strip() for g in gens]
        counts = pd.Series(norm_gens).value_counts()
        agreement_fraction = counts.iloc[0] / max(1, len(norm_gens))

        # Embedding cosine similarities (mean of all pairs)
        embs = self.embedder.encode(gens, convert_to_numpy=True, normalize_embeddings=True)
        sims = [self.cosine_sim(embs[i], embs[j])
                for i in range(len(embs)) for j in range(i + 1, len(embs))]
        emb_cos_mean = float(np.mean(sims)) if sims else 1.0

        # BLEU mean pairwise
        bleu_scores = [sentence_bleu([gens[i].split()], gens[j].split(),
                                     smoothing_function=self.bleu_smooth)
                       for i in range(len(gens)) for j in range(i + 1, len(gens))]
        bleu_mean = float(np.mean(bleu_scores)) if bleu_scores else 1.0

        human_frac = None
        if self.human_map and sample_id is not None:
            human_frac = self.human_map.get(sample_id)

        return ConsistencyResult(
            samples=gens,
            agreement_fraction=agreement_fraction,
            embedding_cosine_mean=emb_cos_mean,
            bleu_mean=bleu_mean,
            human_agreement_fraction=human_frac
        )

    def token_level_probabilities(self, client: BaseModelClient, text: str) -> Dict[str, float]:
        """
        Compute average and minimum token probabilities for a given text.
        """
        logs = client.get_token_logprobs(text)
        out = {"avg_token_prob": np.nan, "min_token_prob": np.nan}
        if logs is None or len(logs) == 0:
            return out
        probs = np.exp(np.array(logs, dtype=np.float64))
        out["avg_token_prob"] = float(np.mean(probs))
        out["min_token_prob"] = float(np.min(probs))
        return out

    # ---- Calibration helpers ----
    @staticmethod
    def brier_score(y_true: List[int], p_pred: List[float]) -> float:
        y, p = np.array(y_true), np.array(p_pred)
        return float(np.mean((p - y) ** 2))

    @staticmethod
    def expected_calibration_error(y_true: List[int], p_pred: List[float], n_bins: int = 10) -> float:
        y, p = np.array(y_true), np.array(p_pred)
        bins = np.linspace(0, 1, n_bins + 1)
        ece = 0.0
        for i in range(n_bins):
            mask = (p >= bins[i]) & (p < bins[i + 1] if i < n_bins - 1 else p <= bins[i + 1])
            if not np.any(mask):
                continue
            acc = np.mean(y[mask])
            conf = np.mean(p[mask])
            ece += np.mean(mask) * abs(acc - conf)
        return float(ece)

    @staticmethod
    def plot_calibration(y_true: List[int], p_pred: List[float], title: str, out_path: str, n_bins: int = 10) -> None:
        """
        Save a calibration plot comparing predicted confidence vs. observed accuracy.
        """
        y, p = np.array(y_true), np.array(p_pred)
        bins = np.linspace(0, 1, n_bins + 1)
        xs, ys = [], []
        for i in range(n_bins):
            mask = (p >= bins[i]) & (p < bins[i + 1] if i < n_bins - 1 else p <= bins[i + 1])
            if not np.any(mask):
                continue
            xs.append(np.mean(p[mask]))
            ys.append(np.mean(y[mask]))
        plt.figure()
        plt.plot([0, 1], [0, 1], 'k--', label="Perfect calibration")
        plt.scatter(xs, ys, c='b')
        plt.title(title)
        plt.xlabel("Predicted probability")
        plt.ylabel("Observed accuracy")
        plt.legend()
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path)
        plt.close()

    def calibrate_metric_to_prob(self, metric_values: List[float], correct_labels: List[int]) -> callable:
        """
        Fit an isotonic regression mapping from metric to P(correct).
        """
        from sklearn.isotonic import IsotonicRegression
        iso = IsotonicRegression(out_of_bounds="clip")
        iso.fit(np.array(metric_values), np.array(correct_labels))
        return lambda vals: iso.predict(np.array(vals))


In [None]:
# =============================================================================
# Part 3: SHAP Explainer
# =============================================================================

import shap
import warnings

class SHAPExplainer:
    """
    Provides tools to apply SHAP to our extraction pipeline.

    Main modes:
        1) Token-level influence on a confidence score:
            - Wrap a model call and scoring function into f(texts) -> confidences.
            - Use shap.maskers.Text and KernelExplainer to see token contributions.
        2) Selecting "interesting" cases for explanation:
            - Based on calibrated probability-of-correctness (near decision boundary).
    """

    def __init__(self):
        # Quiet down SHAP's verbose logging/warnings
        shap.logger.setLevel(logging.ERROR)
        warnings.filterwarnings("ignore", category=UserWarning)

    @staticmethod
    def make_text_to_confidence_fn(
        client: BaseModelClient,
        field: FieldSpec,
        verbosity: str,
        temperature_for_score: float,
        score_extractor: callable,
    ) -> callable:
        """
        Create a function f(texts) -> confidences for SHAP.

        This function re‑runs the model on each input text, extracts a scalar
        confidence using the provided score_extractor, and clamps it to [0,1].
        """

        def f(texts: List[str]) -> np.ndarray:
            scores = []
            for t in texts:
                prompt = PromptBuilder.prompt_text(field, t, verbosity=verbosity)
                res = client.generate(prompt, temperature=temperature_for_score)
                score = score_extractor(res["text"])
                scores.append(float(np.clip(score, 0.0, 1.0)))
            return np.array(scores, dtype=np.float64)

        return f

    def explain_token_influence(
        self,
        client: BaseModelClient,
        field: FieldSpec,
        text: str,
        verbosity: str,
        temperature_for_score: float,
        score_extractor: callable,
        max_evals: int = 256,
        seed: int = 123,
        plot: bool = True,
        out_path: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Run SHAP KernelExplainer to attribute each token's contribution to the
        confidence score for a single text sample.

        Args:
            client: Model client (must implement generate()).
            field: FieldSpec for which we are explaining the extraction.
            text: The CT report paragraph to explain.
            verbosity: Prompt verbosity ("low"/"high").
            temperature_for_score: Generation temperature for scoring run.
            score_extractor: Callable that maps model output -> confidence in [0,1].
            max_evals: Max SHAP kernel samples.
            seed: RNG seed for SHAP.
            plot: Whether to generate a matplotlib plot.
            out_path: Optional file path to save the plot.

        Returns:
            Dict containing SHAP values for the single example.
        """
        np.random.seed(seed)

        f = self.make_text_to_confidence_fn(client, field, verbosity,
                                            temperature_for_score, score_extractor)
        masker = shap.maskers.Text(tokenizer=None)  # uses default whitespace splitting

        explainer = shap.KernelExplainer(f, masker)
        shap_values = explainer.shap_values([text], nsamples=max_evals, silent=True)

        if plot:
            plt.figure(figsize=(10, 3))
            shap.plots.text(shap_values)
            if out_path:
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                plt.savefig(out_path, bbox_inches="tight")
                plt.close()

        return {"shap_values": shap_values}

    @staticmethod
    def select_cases_for_explanation(
        p_correct: List[float],
        ids: List[Any],
        low_thresh: float = 0.4,
        high_thresh: float = 0.6,
        top_k: int = 20
    ) -> List[Any]:
        """
        Select cases with predicted probability-of-correctness near the decision
        boundary (e.g., between 0.4 and 0.6) for closer inspection.

        Returns:
            List of sample IDs chosen.
        """
        idxs = [i for i, p in enumerate(p_correct) if low_thresh <= p <= high_thresh]
        idxs_sorted = sorted(idxs, key=lambda i: abs(p_correct[i] - 0.5))
        return [ids[i] for i in idxs_sorted[:top_k]]


In [None]:
# =============================================================================
# Part 3: SHAP Explainer
# =============================================================================

import shap
import warnings

class SHAPExplainer:
    """
    Provides tools to apply SHAP to our extraction pipeline.

    Main modes:
        1) Token-level influence on a confidence score:
            - Wrap a model call and scoring function into f(texts) -> confidences.
            - Use shap.maskers.Text and KernelExplainer to see token contributions.
        2) Selecting "interesting" cases for explanation:
            - Based on calibrated probability-of-correctness (near decision boundary).
    """

    def __init__(self):
        # Quiet down SHAP's verbose logging/warnings
        shap.logger.setLevel(logging.ERROR)
        warnings.filterwarnings("ignore", category=UserWarning)

    @staticmethod
    def make_text_to_confidence_fn(
        client: BaseModelClient,
        field: FieldSpec,
        verbosity: str,
        temperature_for_score: float,
        score_extractor: callable,
    ) -> callable:
        """
        Create a function f(texts) -> confidences for SHAP.

        This function re‑runs the model on each input text, extracts a scalar
        confidence using the provided score_extractor, and clamps it to [0,1].
        """

        def f(texts: List[str]) -> np.ndarray:
            scores = []
            for t in texts:
                prompt = PromptBuilder.prompt_text(field, t, verbosity=verbosity)
                res = client.generate(prompt, temperature=temperature_for_score)
                score = score_extractor(res["text"])
                scores.append(float(np.clip(score, 0.0, 1.0)))
            return np.array(scores, dtype=np.float64)

        return f

    def explain_token_influence(
        self,
        client: BaseModelClient,
        field: FieldSpec,
        text: str,
        verbosity: str,
        temperature_for_score: float,
        score_extractor: callable,
        max_evals: int = 256,
        seed: int = 123,
        plot: bool = True,
        out_path: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Run SHAP KernelExplainer to attribute each token's contribution to the
        confidence score for a single text sample.

        Args:
            client: Model client (must implement generate()).
            field: FieldSpec for which we are explaining the extraction.
            text: The CT report paragraph to explain.
            verbosity: Prompt verbosity ("low"/"high").
            temperature_for_score: Generation temperature for scoring run.
            score_extractor: Callable that maps model output -> confidence in [0,1].
            max_evals: Max SHAP kernel samples.
            seed: RNG seed for SHAP.
            plot: Whether to generate a matplotlib plot.
            out_path: Optional file path to save the plot.

        Returns:
            Dict containing SHAP values for the single example.
        """
        np.random.seed(seed)

        f = self.make_text_to_confidence_fn(client, field, verbosity,
                                            temperature_for_score, score_extractor)
        masker = shap.maskers.Text(tokenizer=None)  # uses default whitespace splitting

        explainer = shap.KernelExplainer(f, masker)
        shap_values = explainer.shap_values([text], nsamples=max_evals, silent=True)

        if plot:
            plt.figure(figsize=(10, 3))
            shap.plots.text(shap_values)
            if out_path:
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                plt.savefig(out_path, bbox_inches="tight")
                plt.close()

        return {"shap_values": shap_values}

    @staticmethod
    def select_cases_for_explanation(
        p_correct: List[float],
        ids: List[Any],
        low_thresh: float = 0.4,
        high_thresh: float = 0.6,
        top_k: int = 20
    ) -> List[Any]:
        """
        Select cases with predicted probability-of-correctness near the decision
        boundary (e.g., between 0.4 and 0.6) for closer inspection.

        Returns:
            List of sample IDs chosen.
        """
        idxs = [i for i, p in enumerate(p_correct) if low_thresh <= p <= high_thresh]
        idxs_sorted = sorted(idxs, key=lambda i: abs(p_correct[i] - 0.5))
        return [ids[i] for i in idxs_sorted[:top_k]]


In [None]:
# =============================================================================
# Part 4: ExperimentRunner orchestration
# =============================================================================

import pathlib
import re

def safe_roc_auc(y_true: List[int], scores: List[float]) -> Optional[float]:
    """
    Safe wrapper for roc_auc_score:
        - Returns None if only one class present or scores have no variance.
    """
    from sklearn.metrics import roc_auc_score
    y = np.array(y_true, dtype=int)
    s = np.array(scores, dtype=float)
    if len(np.unique(y)) < 2 or np.allclose(np.std(s), 0.0):
        return None
    try:
        return float(roc_auc_score(y, s))
    except Exception:
        return None


class ExperimentRunner:
    """
    Coordinates all phases:
        1) Phase 1: Extraction & evaluation (per-sample + summary metrics)
        2) Phase 2: Uncertainty computation (sample consistency & token-level probs),
           plus discrimination & calibration.
        3) Phase 3: SHAP explanation on selected borderline-confidence cases.
    """

    def __init__(self, cfg: ExperimentConfig):
        self.cfg = cfg
        self.ds = CTRateDataset(cfg)
        self.uncertainty = UncertaintyScorer(cfg)
        self.shapx = SHAPExplainer()
        os.makedirs(cfg.output_dir, exist_ok=True)

    def run(self):
        set_seed(self.cfg.seed)
        self.ds.load()

        # Accumulators for later export
        rows_eval = []             # Phase 1: per-sample
        rows_summary = []          # Phase 1: per-field summary
        rows_uncert = []           # Phase 2: per-sample uncertainty values
        rows_uncert_summary = []   # Phase 2: per-field/metric summary

        # Loop over models and settings
        for model_spec in self.cfg.models:
            logging.info(f"Starting model: {model_spec.name}")
            client = HFOpenCausalLMClient(model_spec)  # For now: HF client only

            for temperature in self.cfg.settings.temperature_list:
                for verbosity in self.cfg.settings.verbosity_list:
                    logging.info(f"Settings: T={temperature}, Verbosity={verbosity}")

                    # Per-field containers for Phase 1 metrics
                    y_true_map, y_pred_map, y_score_map, ids_map = {}, {}, {}, {}
                    metric_store = {}

                    for field in self.cfg.fields:
                        y_true_map[field.name] = []
                        y_pred_map[field.name] = []
                        y_score_map[field.name] = []
                        ids_map[field.name] = []
                        metric_store[field.name] = {"tlp": []}

                    # ===== Phase 1: One extraction per sample =====
                    for sample in self.ds.iter_samples():
                        sid, text, labels = sample["id"], sample["text"], sample["labels"]

                        for field in self.cfg.fields:
                            prompt = PromptBuilder.prompt_text(field, text, verbosity)
                            res = client.generate(prompt, temperature=temperature,
                                                  top_p=model_spec.top_p, top_k=model_spec.top_k,
                                                  max_new_tokens=model_spec.max_new_tokens,
                                                  seed=self.cfg.seed + sid)
                            raw = res["text"]
                            parsed = parse_json_output(raw) or {}
                            pred = ExtractedRecord(field_name=parsed.get("field_name", field.name),
                                                   present=bool(parsed.get("present", False)),
                                                   value=parsed.get("value", None))

                            gold = labels[field.name]
                            correct, score = evaluate_field(field, gold, pred)
                            pred_bin = int(bool(pred.present)) if field.type == "qualitative_presence" else int(correct)

                            y_true_map[field.name].append(int(bool(gold)) if field.type == "qualitative_presence" else int(correct))
                            y_pred_map[field.name].append(pred_bin)
                            y_score_map[field.name].append(score)
                            ids_map[field.name].append(sid)

                            # Token-level probs
                            tlp = self.uncertainty.token_level_probabilities(client, raw)
                            metric_store[field.name]["tlp"].append({"sample_id": sid,
                                                                    "avg_token_prob": tlp["avg_token_prob"],
                                                                    "min_token_prob": tlp["min_token_prob"]})

                            rows_eval.append({
                                "model": model_spec.name,
                                "temperature": temperature,
                                "verbosity": verbosity,
                                "sample_id": sid,
                                "field": field.name,
                                "gold": gold,
                                "pred_present": pred.present,
                                "pred_value": pred.value,
                                "correct": correct,
                                "score": score,
                                "raw_output": raw
                            })

                    # Summarize Phase 1
                    for field in self.cfg.fields:
                        metrics = compute_metrics(y_true_map[field.name],
                                                  y_pred_map[field.name],
                                                  y_score_map[field.name])
                        rows_summary.append({"model": model_spec.name,
                                             "temperature": temperature,
                                             "verbosity": verbosity,
                                             "field": field.name,
                                             **metrics})

                    # ===== Phase 2: Sample consistency per field/sample =====
                    for field in self.cfg.fields:
                        ids_subset = ids_map[field.name][:100]  # cap for cost
                        for t_cons in self.cfg.settings.consistency_temperatures:
                            for sid in ids_subset:
                                text = self.ds.df.loc[sid, self.cfg.text_column]
                                prompt = PromptBuilder.prompt_text(field, text, verbosity)
                                cres = self.uncertainty.sample_consistency(client, prompt,
                                                                           n_samples=self.cfg.settings.n_consistency_samples,
                                                                           temperature=t_cons,
                                                                           seed_base=self.cfg.seed + 999,
                                                                           sample_id=sid)
                                rows_uncert.append({
                                    "model": model_spec.name,
                                    "temperature": temperature,
                                    "verbosity": verbosity,
                                    "field": field.name,
                                    "sample_id": sid,
                                    "consistency_temp": t_cons,
                                    "agreement_fraction": cres.agreement_fraction,
                                    "embedding_cosine_mean": cres.embedding_cosine_mean,
                                    "bleu_mean": cres.bleu_mean,
                                    "human_agreement_fraction": cres.human_agreement_fraction
                                })

                    # ===== Phase 2: Discrimination + calibration =====
                    df_eval = pd.DataFrame(rows_eval)
                    df_unc = pd.DataFrame(rows_uncert)

                    for field in self.cfg.fields:
                        df_field = df_eval[(df_eval["model"] == model_spec.name) &
                                           (df_eval["temperature"] == temperature) &
                                           (df_eval["verbosity"] == verbosity) &
                                           (df_eval["field"] == field.name)].copy()
                        # Attach token-level probs
                        df_tlp = pd.DataFrame(metric_store[field.name]["tlp"])
                        df_field = df_field.merge(df_tlp, on="sample_id", how="left")

                        # Check each uncertainty metric
                        for metric_name in ["avg_token_prob", "min_token_prob"]:
                            dfm = df_field.dropna(subset=[metric_name, "correct"])
                            if dfm.empty: continue
                            y_true = dfm["correct"].astype(int).tolist()
                            vals = dfm[metric_name].astype(float).tolist()
                            auc = safe_roc_auc(y_true, vals)
                            calibrator = self.uncertainty.calibrate_metric_to_prob(vals, y_true)
                            p_hat = calibrator(vals)
                            ece = self.uncertainty.expected_calibration_error(y_true, p_hat)
                            brier = self.uncertainty.brier_score(y_true, p_hat)

                            plot_path = os.path.join(self.cfg.output_dir, "calibration_plots",
                                                     f"{model_spec.name}_{field.name}_{metric_name}_T{temperature}_V{verbosity}.png")
                            self.uncertainty.plot_calibration(y_true, p_hat, f"{metric_name} calibration", plot_path)

                            rows_uncert_summary.append({
                                "model": model_spec.name, "temperature": temperature, "verbosity": verbosity,
                                "field": field.name, "metric": metric_name, "consistency_temp": None,
                                "roc_auc": auc, "ece": ece, "brier": brier, "n": len(y_true),
                                "calibration_plot": plot_path
                            })

        # ===== Save all outputs =====
        outdir = pathlib.Path(self.cfg.output_dir)
        outdir.mkdir(parents=True, exist_ok=True)
        pd.DataFrame(rows_eval).to_csv(outdir / "phase1_per_sample.csv", index=False)
        pd.DataFrame(rows_summary).to_csv(outdir / "phase1_summary.csv", index=False)
        pd.DataFrame(rows_uncert).to_csv(outdir / "phase2_uncert_per_sample.csv", index=False)
        pd.DataFrame(rows_uncert_summary).to_csv(outdir / "phase2_uncert_summary.csv", index=False)
        logging.info("Run complete — artifacts in %s", self.cfg.output_dir)


# =============================================================================
# Part 5: main() entry point
# =============================================================================

if __name__ == "__main__":
    setup_logging()

    # Minimal example config; replace target_column names with real CT-RATE columns
    fields = [
        FieldSpec(
            name="pulmonary_embolism",
            type="qualitative_presence",
            target_column="pulmonary_embolism"
        )
    ]

    models = [
        ModelSpec(
            name="llama3_8b_instruct",
            model_type="open_hf",  # likely intended to be "open_hf" for Hugging Face
            model_path="meta-llama/Meta-Llama-3-8B-Instruct",
            temperature=0.0,
            max_tokens=512
        )
    ]

    # Define your dataset paths
    dataset_paths = [
        "data/ct_rate/train.csv",
        "data/ct_rate/valid.csv"
    ]

    # Example runner configuration
    runner = ExperimentRunner(
        fields=fields,
        models=models,
        dataset_paths=dataset_paths,
        output_dir="results/ct_rate",
        verbosity="info",
        uncertainty_estimation=False
    )

    runner.run()

    logging.info("Experiment completed successfully.")


In [None]:
"""
api_clients.py
--------------
API client wrappers for OpenAI, Fireworks, and Groq LLMs.

Reads API keys from a `.env` file using `python-dotenv`.

Usage:
    from api_clients import OpenAIClient, FireworksClient, GroqClient
    openai_client = OpenAIClient(model="gpt-4o-mini")
    resp = openai_client.generate("Say hello")
"""

import os
from typing import Optional, Dict, Any, List
from dotenv import load_dotenv

# Load environment variables from .env
# .env should contain:
#   OPENAI_API_KEY=...
#   FIREWORKS_API_KEY=...
#   GROQ_API_KEY=...
load_dotenv()

# Optional: install these packages if not present
# pip install openai fireworks-ai groq python-dotenv

# ----------------------------
# Base class
# ----------------------------
class BaseAPIClient:
    def __init__(self, model: str):
        self.model = model

    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        max_tokens: int = 512,
        **kwargs
    ) -> Dict[str, Any]:
        raise NotImplementedError


# ----------------------------
# OpenAI API client
# ----------------------------
class OpenAIClient(BaseAPIClient):
    def __init__(self, model: str):
        super().__init__(model)
        try:
            import openai
        except ImportError:
            raise ImportError("Please install the openai package: pip install openai")
        self.openai = openai
        self.openai.api_key = os.getenv("OPENAI_API_KEY")
        if not self.openai.api_key:
            raise ValueError("OPENAI_API_KEY not found in environment.")

    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        max_tokens: int = 512,
        **kwargs
    ) -> Dict[str, Any]:
        resp = self.openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
            max_tokens=max_tokens,
            **kwargs
        )
        content = resp.choices[0].message["content"]
        return {"text": content, "raw": resp}


# ----------------------------
# Fireworks AI client
# ----------------------------
class FireworksClient(BaseAPIClient):
    def __init__(self, model: str):
        super().__init__(model)
        try:
            import fireworks.client
        except ImportError:
            raise ImportError("Please install the fireworks-ai package: pip install fireworks-ai")
        self.fireworks = fireworks.client
        self.fireworks.api_key = os.getenv("FIREWORKS_API_KEY")
        if not self.fireworks.api_key:
            raise ValueError("FIREWORKS_API_KEY not found in environment.")

    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        max_tokens: int = 512,
        **kwargs
    ) -> Dict[str, Any]:
        resp = self.fireworks.Completion.create(
            model=self.model,
            prompt=prompt,
            temperature=temperature,
            max_tokens=max_tokens,
            **kwargs
        )
        content = resp["choices"][0]["text"]
        return {"text": content, "raw": resp}


# ----------------------------
# Groq client
# ----------------------------
class GroqClient(BaseAPIClient):
    def __init__(self, model: str):
        super().__init__(model)
        try:
            from groq import Groq
        except ImportError:
            raise ImportError("Please install the groq package: pip install groq")
        api_key = os.getenv("GROQ_API_KEY")
        if not api_key:
            raise ValueError("GROQ_API_KEY not found in environment.")
        self.client = Groq(api_key=api_key)

    def generate(
        self,
        prompt: str,
        temperature: float = 0.0,
        max_tokens: int = 512,
        **kwargs
    ) -> Dict[str, Any]:
        # Groq Python SDK interface
        resp = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
            max_tokens=max_tokens,
            **kwargs
        )
        content = resp.choices[0].message.content
        return {"text": content, "raw": resp}


In [None]:
.env file
OPENAI_API_KEY=sk-...
FIREWORKS_API_KEY=fk-...
GROQ_API_KEY=gk-...


In [None]:
# Example usage in the experiment
from api_clients import OpenAIClient, FireworksClient, GroqClient

openai_client = OpenAIClient(model="gpt-4o-mini")
fw_client = FireworksClient(model="accounts/fireworks/models/llama-v3-70b-instruct")
groq_client = GroqClient(model="mixtral-8x7b-32768")

print(openai_client.generate("Hello from OpenAI")["text"])


In [None]:
"""
config_settings.py
------------------
Centralized configuration for experiment runs.
Keeps dataset paths, model parameters, and other settings in one place.

Import this module in your main script:
    from config_settings import FIELDS, MODELS, DATASET_PATHS, OUTPUT_DIR
"""

import os
from pathlib import Path

# -----------------------
# Project structure setup
# -----------------------
BASE_DIR = Path(__file__).resolve().parent

# -----------------------
# Dataset settings
# -----------------------
DATASET_PATHS = [
    BASE_DIR / "data" / "ct_rate" / "train.csv",
    BASE_DIR / "data" / "ct_rate" / "valid.csv",
]

# -----------------------
# Output settings
# -----------------------
OUTPUT_DIR = BASE_DIR / "results" / "ct_rate"

# -----------------------
# Experiment parameters
# -----------------------
TEMPERATURE = 0.0
MAX_TOKENS = 512
VERBOSITY = "info"
UNCERTAINTY_ESTIMATION = False

# -----------------------
# Field and model specs
# -----------------------
from experiment_specs import FieldSpec, ModelSpec
# If FieldSpec & ModelSpec are in another module, update this import path accordingly

FIELDS = [
    FieldSpec(
        name="pulmonary_embolism",
        type="qualitative_presence",
        target_column="pulmonary_embolism"
    )
]

MODELS = [
    ModelSpec(
        name="llama3_8b_instruct",
        model_type="open_hf",  # change if using API clients: "openai", "fireworks", "groq"
        model_path="meta-llama/Meta-Llama-3-8B-Instruct",
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS
    )
]

# -----------------------
# Utility function
# -----------------------
def get_config_summary() -> str:
    return (
        f"Datasets: {DATASET_PATHS}\n"
        f"Output dir: {OUTPUT_DIR}\n"
        f"Models: {[m.name for m in MODELS]}\n"
        f"Temperature: {TEMPERATURE}, Max tokens: {MAX_TOKENS}\n"
        f"Verbosity: {VERBOSITY}, Uncertainty: {UNCERTAINTY_ESTIMATION}"
    )


In [None]:
# RUN your Experiment
from config_settings import FIELDS, MODELS, DATASET_PATHS, OUTPUT_DIR, VERBOSITY, UNCERTAINTY_ESTIMATION
from utils import setup_logging
from experiment_runner import ExperimentRunner
import logging

if __name__ == "__main__":
    setup_logging()

    runner = ExperimentRunner(
        fields=FIELDS,
        models=MODELS,
        dataset_paths=DATASET_PATHS,
        output_dir=OUTPUT_DIR,
        verbosity=VERBOSITY,
        uncertainty_estimation=UNCERTAINTY_ESTIMATION
    )

    runner.run()
    logging.info("Experiment completed successfully.")
