In [1]:
# ======================= Imports =======================
import os
import sys
import time
import json
import unicodedata
import re
import subprocess
from contextlib import contextmanager
import shutil
import traceback
import tiktoken
import pandas as pd
import numpy as np
import requests
import faiss
from typing import List, Dict, Any
from fastembed import TextEmbedding
from rapidfuzz import fuzz as rfuzz
from tabulate import tabulate
from tqdm import tqdm
from collections import defaultdict
from sentence_transformers import SentenceTransformer

# ======================= Global Constants =======================
FLAG_FILE = "Gemini-flash-lite-2.5.flag"
TEMP_FILES = {
    'input':    'temp_input.pkl',
    'combined': 'temp_combined_results.pkl',
    'exact':    'temp_exact_matches.pkl',
    'non_exact':'temp_non_exact_matches.pkl',
    'final':    'temp_final_result.pkl',
    'responses':'responses_backup.pkl'
}

# ======================= Logger Class =======================
class Logger:
    def __init__(self):
        self.printed_messages = set()

    def log(self, msg, once=False):
        """Prints a timestamped message. If once=True, message is only printed once per session."""
        if once:
            # Use a hash of the message to check for duplicates
            msg_hash = hash(msg)
            if msg_hash in self.printed_messages:
                return
            self.printed_messages.add(msg_hash)
        print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}")

logger = Logger()

# ======================= LLM Client =======================
class LLMClient:
    def __init__(self, api_key, base_url, model_name="llama3-groq-70b-8192-tool-use-preview", max_tokens_per_day=500000, max_queries_per_minute=30, temperature=0.7):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.max_tokens_per_day = max_tokens_per_day
        self.max_queries_per_minute = max_queries_per_minute
        self.temperature = temperature
        self.total_tokens_used = 0
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        try:
            self.encoder = tiktoken.encoding_for_model(self.model_name)
        except KeyError:
            self.encoder = tiktoken.get_encoding("cl100k_base")

    def query(self, user_input, system_message):
        # Sends a query to the LLM API and tracks token usage
        tokens_ui = len(self.encoder.encode(user_input))
        tokens_sys = len(self.encoder.encode(system_message))
        estimated = tokens_ui + tokens_sys
        if self.total_tokens_used + estimated > self.max_tokens_per_day:
            raise Exception("Token limit exceeded for the day.")
        time.sleep(60 / self.max_queries_per_minute)
        payload = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user",   "content": user_input}
            ],
            "temperature": self.temperature,
        }
        resp = requests.post(self.base_url, headers=self.headers, json=payload)
        resp.raise_for_status()
        result = resp.json()
        if "usage" in result and "total_tokens" in result["usage"]:
            self.total_tokens_used += result["usage"]["total_tokens"]
        else:
            self.total_tokens_used += estimated
        choices = result.get("choices") or []
        return choices[0].get("message", {}).get("content", "") if choices else ""

# ======================= Prompt Loading =======================
def load_prompts(file_path="system_prompts.json"):
    if not os.path.exists(file_path):
        logger.log(f"Error: Prompt file '{file_path}' not found.")
        sys.exit(1)
    with open(file_path, "r") as f:
        return json.load(f)

prompts = load_prompts()
system_message_I = prompts.get("system_message_I", "")
system_message_II = prompts.get("system_message_II", "")

# ======================= Environment Setup =======================
llm_client = None

def check_and_initialize():
    global llm_client
    if os.path.exists(FLAG_FILE):
        with open(FLAG_FILE, "r") as f:
            config = json.load(f)
        llm_client = LLMClient(
            api_key=config["api_key"],
            base_url=config.get("base_url", ""),
            model_name=config.get("model_name", ""),
            max_tokens_per_day=config.get("max_tokens_per_day", 500000),
            max_queries_per_minute=config.get("max_queries_per_minute", 30),
            temperature=config.get("temperature", 0.7)
        )
        return True
    else:
        initialize_groq_environment()
        return False

def initialize_groq_environment():
    global llm_client
    api_key = input("Enter your API key (required): ").strip()
    base_url = input("Enter base URL (default https://api.groq.com/openai/v1/chat/completions): ") or "https://api.groq.com/openai/v1/chat/completions"
    model_name = input("Enter model name (default llama3-groq-70b-8192-tool-use-preview): ") or "llama3-groq-70b-8192-tool-use-preview"
    max_tokens = input("Max tokens/day (default 500000): ")
    max_tokens = int(max_tokens) if max_tokens else 500000
    max_qpm = input("Max queries/minute (default 30): ")
    max_qpm = int(max_qpm) if max_qpm else 30
    temp = input("Temperature (0.0–1.0, default 0.2): ")
    try:
        temp = float(temp) if temp else 0.2
    except ValueError:
        temp = 0.7
    llm_client = LLMClient(api_key, base_url, model_name, max_tokens, max_qpm, temp)
    with open(FLAG_FILE, "w") as f:
        json.dump({
            "api_key": api_key,
            "base_url": base_url,
            "model_name": model_name,
            "max_tokens_per_day": max_tokens,
            "max_queries_per_minute": max_qpm,
            "temperature": temp
        }, f)

# ======================= Subprocess & Logging =======================
@contextmanager
def managed_subprocess(*args, **kwargs):
    proc = subprocess.Popen(*args, **kwargs)
    try:
        yield proc
    finally:
        proc.terminate()
        proc.wait()

def timestamped_print(msg):
    logger.log(msg)

# ======================= Embeddings & FAISS =======================
PAT = re.compile(r'\(.*?\)')

def clean_text(txt: str) -> str:
    return PAT.sub('', txt or '').replace('_', ' ').lower().strip()

def initialize_embeddings_model(use_sbert: bool = True, sbert_model: str = 'pritamdeka/SapBERT-mnli-snli-scinli-scitail-mednli-stsb', bge_model: str = 'BAAI/bge-small-en-v1.5'):
    try:
        if use_sbert:
            return SentenceTransformer(sbert_model)
        return TextEmbedding(model_name=bge_model)
    except Exception as e:
        logger.log(f"[FATAL] Could not initialize embedding model: {e}")
        sys.exit(1)

def load_vector_db(meta_path: str = 'hpo_meta.json',
                   vec_path:  str = 'hpo_embedded.npz'):
    # ─── Sanity checks ───
    if not os.path.exists(meta_path) or not os.path.exists(vec_path):
        logger.log(f"[FATAL] DB files not found: {meta_path}, {vec_path}")
        sys.exit(1)

    # ─── Load the condensed JSON ───
    try:
        with open(meta_path, 'r', encoding='utf-8') as f:
            combined = json.load(f)
            constants = combined.get('constants', {})
            entries  = combined.get('entries', [])
    except Exception as e:
        logger.log(f"[FATAL] Could not load metadata JSON: {e}")
        sys.exit(1)

    # ─── Load the embeddings ───
    try:
        arr = np.load(vec_path)
        emb_matrix = arr['emb'].astype(np.float32)
    except Exception as e:
        logger.log(f"[FATAL] Could not load embedding npz: {e}")
        sys.exit(1)

    # ─── Warn if lengths mismatch ───
    if len(entries) != emb_matrix.shape[0]:
        logger.log("[WARN] Metadata entries count and embedding rows mismatch "
              f"({len(entries)} vs {emb_matrix.shape[0]})")

    # ─── Reconstruct docs list in the original output format ───
    docs = []
    for entry, vec in zip(entries, emb_matrix):
        hp_id = entry.get('hp_id')
        const = constants.get(hp_id, {})

        doc = {
            'hp_id':          hp_id,
            'info':           entry.get('info'),
            'lineage':        const.get('lineage'),
            'organ_system':   const.get('organ_system'),
            'direction':      entry.get('direction'),
            # preserve these keys even if absent in the new JSON:
            'depth':          const.get('depth'),
            'parent_count':   const.get('parent_count'),
            'child_count':    const.get('child_count'),
            'descendant_count': const.get('descendant_count'),
            'embedding':      vec
        }
        docs.append(doc)

    return docs, emb_matrix
def create_faiss_index(emb_matrix: np.ndarray, metric: str = 'cosine'):
    # Build FAISS index for embeddings
    dim = emb_matrix.shape[1]
    if metric == 'cosine':
        faiss.normalize_L2(emb_matrix)
        index = faiss.IndexFlatIP(dim)
    else:
        index = faiss.IndexFlatL2(dim)
    index.add(emb_matrix)
    return index

def embed_query(text: str, model, metric: str = 'cosine'):
    if hasattr(model, 'encode'):
        vec = model.encode(text, convert_to_numpy=True)
    else:
        vec = np.array(list(model.embed([text]))[0], dtype=np.float32)
    if vec.ndim == 1:
        vec = vec.reshape(1, -1)
    if metric == 'cosine':
        faiss.normalize_L2(vec)
    return vec

def clean_note(text: str) -> str:
    # Fix typical encoding issues
    text = text.encode('latin1', errors='ignore').decode('utf-8', errors='ignore')
    # Normalize unicode (e.g., smart quotes)
    text = unicodedata.normalize("NFKD", text)
    # Remove non-ASCII characters (optional: keep certain ones like µ or – if needed)
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    # Remove multiple spaces and trim
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# ======================= Phenotype Processing =======================
def _collect_metadata_best(
    phrase: str,
    query_vec: np.ndarray,
    index: faiss.Index,
    docs: List[Dict[str, Any]],
    top_k: int = 500,
    similarity_threshold: float = 0.35,
    min_unique: int = 15,
    max_unique: int = 20
) -> List[Dict[str, Any]]:
    """
    Single‐pass hybrid retrieval: token overlap, threshold, fill to min_unique.
    """
    clean_tokens = set(re.findall(r'\w+', phrase.lower()))
    dists, idxs = index.search(query_vec, top_k)
    sims, indices = dists[0], idxs[0]

    seen_hp = set()
    results = []

    for sim, idx in sorted(zip(sims, indices), key=lambda x: x[0], reverse=True):
        if len(results) >= max_unique:
            break
        doc = docs[idx]
        hp = doc.get('hp_id')
        if not hp or hp in seen_hp:
            continue

        info = doc.get('info', '') or ''
        token_overlap = bool(clean_tokens & set(re.findall(r'\w+', info.lower())))

        # accept if token overlap, or above similarity threshold, or to reach min_unique
        if token_overlap or sim >= similarity_threshold or len(results) < min_unique:
            seen_hp.add(hp)
            results.append({
                'hp_id': hp,
                'phrase': info,
                'definition': doc.get('definition'),
                'organ_system': doc.get('organ_system'),
                'similarity': float(sim)
            })

    return results

def split_exact_nonexact(df: pd.DataFrame, hpo_term_col='HPO_Term'):
    # print(f">> DEBUG: about to split, columns: {df.columns.tolist()}")
    if 'category' not in df.columns:
        raise KeyError(f"[FATAL] 'category' missing; columns: {df.columns.tolist()}")
    if hpo_term_col not in df.columns:
        raise KeyError(f"[FATAL] '{hpo_term_col}' missing; columns: {df.columns.tolist()}")

    # only keep Abnormal rows
    df_ab = df[df['category'] == 'Abnormal']
    # print(f">> DEBUG: {len(df_ab)} rows with category=='Abnormal'")

    exact_df     = df_ab.dropna(subset=[hpo_term_col]).copy()
    non_exact_df = df_ab[df_ab[hpo_term_col].isna()].copy()
    # print(f">> DEBUG: exact_df.shape={exact_df.shape}, non_exact_df.shape={non_exact_df.shape}")
    return exact_df, non_exact_df

def process_findings(findings, clinical_note: str, embeddings_model, index, docs,
                    metric: str = 'cosine',
                    keep_top: int = 15):
    """
    Processes findings and returns DataFrame with phrase, category,
    metadata, sentence, patient_id.
    - keep_top: number of unique metadata entries to retrieve.
    """
    # ─── Position A: Split note into sentences for context matching ───
    sentences = [s.strip() for s in clinical_note.split('.') if s.strip()]
    rows = []

    for f in findings:
        phrase = f.get('phrase', '').strip()
        category = f.get('category', '')
        if not phrase:
            continue

        # ─── Position B: Embed the phrase ───
        qv = embed_query(phrase, embeddings_model, metric=metric)

        # ─── Position C: Retrieve best metadata candidates ───
        unique_metadata = _collect_metadata_best(
            phrase=phrase,               # literal text for token-phase
            query_vec=qv,                # embedded vector
            index=index,                 # FAISS index
            docs=docs,                   # list of HPO docs
            top_k=500,                   # FAISS retrieval size
            similarity_threshold=0.35,   # max distance for semantic matches
            min_unique=keep_top,         # ensure at least keep_top entries
            max_unique=keep_top          # cap at keep_top entries
        )

        # ─── Position D: Find the best-matching sentence ───
        fw = set(re.findall(r'\b\w+\b', phrase.lower()))
        best_sent, best_score = None, 0
        for s in sentences:
            sw = set(re.findall(r'\b\w+\b', s.lower()))
            score = len(fw & sw)
            if score > best_score:
                best_score, best_sent = score, s

        # ─── Position E: Collect row ───
        rows.append({
            'phrase':           phrase,
            'category':         category,
            'unique_metadata':  unique_metadata,
            'original_sentence': best_sent,
            'patient_id':       f.get('patient_id')
        })

    return pd.DataFrame(rows)

def clean_and_parse(s: str):
    # Extracts and parses JSON from string
    try:
        m = re.search(r'\{.*\}', s, flags=re.S)
        js_str = m.group(0) if m else s.strip()
        return json.loads(js_str)
    except Exception:
        return None

def extract_findings(response: str) -> list:
    # Extracts findings from LLM response
    if not response:
        return []
    parsed = clean_and_parse(response)
    if not isinstance(parsed, dict):
        return []
    return parsed.get("phenotypes", [])

# ======================= Single-Row Processing =======================
def process_row(clinical_note, system_message, embeddings_model, index, embedded_documents):
    # Clean the clinical note before sending to the LLM
    clinical_note = clean_note(clinical_note)
    # Query the LLM
    raw = llm_client.query(clinical_note, system_message)
    # print(f"LLM response: {raw}...")  # Print first 1000 chars for debugging

    # Try to parse findings robustly
    findings = extract_findings(raw)
    # If findings is empty, try to parse as a list of dicts (sometimes LLMs return just a list)
    if not findings:
        try:
            parsed = json.loads(raw)
            if isinstance(parsed, list):
                findings = parsed
        except Exception:
            pass

    # If still empty, try to extract any dicts with 'phrase' and 'category' keys
    if not findings:
        try:
            matches = re.findall(r'\{[^\}]*\}', raw)
            findings = []
            for m in matches:
                try:
                    d = json.loads(m)
                    if 'phrase' in d and 'category' in d:
                        findings.append(d)
                except Exception:
                    continue
        except Exception:
            pass

    findings = [f for f in findings if isinstance(f, dict) and f.get('category') == 'Abnormal']

    # Return empty DataFrame if no valid findings, but with required columns
    required_cols = ['phrase', 'category', 'unique_metadata', 'original_sentence', 'patient_id']
    if not findings:
        return pd.DataFrame(columns=required_cols)
    # Continue processing
    df = process_findings(findings, clinical_note, embeddings_model, index, embedded_documents)
    # Ensure all required columns are present
    for col in required_cols:
        if col not in df.columns:
            df[col] = np.nan
    return df

# ======================= HPO Term Extraction =======================
def _iter_term_hp(item):
    # Yields (term_text, hp_id) pairs from metadata entries
    if isinstance(item, str):
        try:
            d = json.loads(item)
        except Exception:
            return
    elif isinstance(item, dict):
        d = item
    else:
        return
    if "hp_id" in d:
        term_text = d.get("info") or d.get("label")
        hp = d["hp_id"]
        if hp and term_text:
            yield term_text, hp
        return
    for k, v in d.items():
        if isinstance(v, str) and v.startswith("HP:"):
            yield k, v

def build_cluster_index(metadata_list):
    # Builds cluster index for bag-of-words matching
    idx = defaultdict(lambda: defaultdict(list))
    for entry in metadata_list:
        for term, hp in _iter_term_hp(entry):
            ct = clean_text(term)
            if not ct:
                continue
            toks = ct.split()
            sig = " ".join(sorted(toks))
            idx[sig][len(toks)].append(hp)
    return idx

def extract_hpo_term(phrase, metadata_list, cluster_index):
    # Maps phenotype phrase to best HPO term using multiple strategies
    if not metadata_list or (isinstance(metadata_list, float) and pd.isna(metadata_list)):
        return None
    cp = clean_text(phrase)
    if not cp:
        return None
    toks = cp.split()
    sig = " ".join(sorted(toks))
    if sig in cluster_index and len(toks) in cluster_index[sig]:
        return cluster_index[sig][len(toks)][0]
    pairs = []
    for entry in metadata_list:
        for term, hp in _iter_term_hp(entry):
            ct = clean_text(term)
            if ct:
                pairs.append((ct, hp))
    pset = set(toks)
    for ct, hp in pairs:
        if set(ct.split()) == pset:
            return hp
    for ct, hp in pairs:
        if ct == cp:
            return hp
    if len(pset) > 1:
        for ct, hp in pairs:
            if re.search(rf"\b{re.escape(ct)}\b", cp):
                return hp
    best_hp, best_score = None, 0
    for ct, hp in pairs:
        score = rfuzz.token_sort_ratio(cp, ct)
        if score > best_score:
            best_hp, best_score = hp, score
    return best_hp if best_score >= 80 else None

def parse_llm_mapping(resp_text: str, candidate_ids: set) -> (str, str, dict):
    """
    Simplified parsing: strict JSON → key lookup → regex fallback.
    """
    # 1. Strict JSON parse
    try:
        js = json.loads(resp_text)
    except json.JSONDecodeError:
        js = None

    # 2. Look for an HPO ID in known keys
    if isinstance(js, dict):
        candidate = next(
            (js[k].strip().strip('"') for k in ("hpo_id", "HPO_ID", "hp_id", "id")
             if isinstance(js.get(k), str)),
            None
        )
        if candidate:
            low = candidate.lower()
            if low in ("null", "no candidate fit"):
                return None, "null_label", js
            if candidate in candidate_ids:
                return candidate, "ok", js
            return candidate, "hp_not_in_candidates", js

    # 3. Regex fallback for HP:NNNNNNN
    for m in re.findall(r"(HP:\d{6,7})", resp_text):
        if m in candidate_ids:
            return m, "regex_fallback", None

    # 4. Nothing found
    return None, "no_hpo_found", None


def generate_hpo_terms(df_row: pd.DataFrame, system_message: str) -> pd.DataFrame:
    """
    Streamlined LLM + fallback logic with phrase normalization.
    """
    # 1. Extract & normalize inputs
    phrase = df_row['phrase'].iloc[0].strip()
    normalized = phrase.lower().replace('-', ' ').strip()
    category = df_row['category'].iloc[0]
    original = df_row['original_sentence'].iloc[0]
    metadata_list = df_row['unique_metadata'].iloc[0] or []

    # 2. Build candidate list
    candidates = []
    seen = set()
    for m in metadata_list:
        term = m.get('phrase') or m.get('info')
        hp = m.get('hp_id')
        if term and hp and hp not in seen:
            candidates.append({'term': term, 'id': hp})
            seen.add(hp)
    candidate_ids = {c['id'] for c in candidates}

    # 3. Call LLM and parse
    payload = json.dumps({
        'phrase': normalized,
        'category': category,
        'original_sentence': original,
        'candidates': candidates
    })
    resp = llm_client.query(payload, system_message)
    hpo_id, reason, _ = parse_llm_mapping(resp, candidate_ids)

    # 4. Local fallback if needed
    if not hpo_id:
        cluster_idx = build_cluster_index(metadata_list)
        local_id = extract_hpo_term(normalized, metadata_list, cluster_idx)
        if local_id:
            hpo_id, reason = local_id, "fallback_local"

    # 5. Return unified record
    return pd.DataFrame([{
        'HPO_Terms': [{'phrase': phrase, 'HPO_Term': hpo_id}],
        'raw_llm_resp': resp,
        'llm_parse_reason': reason
    }])


def standardize_input(df):
    return validate_input(df)

# ======================= Input Validation & State =======================
def validate_input(df):
    if 'clinical_note' not in df.columns:
        raise KeyError("Missing required column: 'clinical_note'.")
    df = df.dropna(subset=['clinical_note']).copy()
    df['clinical_note'] = df['clinical_note'].astype(str)
    # Clean all clinical notes to prevent encoding-related bugs
    df['clinical_note'] = df['clinical_note'].apply(clean_note)
    if 'patient_id' not in df.columns:
        df = df.reset_index(drop=True)
        df['patient_id'] = df.index + 1
    else:
        df['patient_id'] = df['patient_id'].astype(int)
    return df

def load_state(temp_files):
    # Loads pipeline state from temp files
    state = {}
    for key, path in temp_files.items():
        try:
            if os.path.exists(path):
                df = pd.read_pickle(path)
                state[key] = df
                # This print is now handled in main
            else:
                state[key] = pd.DataFrame()
        except Exception as e:
            state[key] = pd.DataFrame()
            print(f"Warning loading '{key}': {e}. Starting fresh for this key.")
    return state

def process_results(final_df):
    # Handles output of final results (CSV or display)
    if final_df.empty:
        logger.log("No final results to process.")
        return
    choice = input("Save results as CSV or display? (save/display): ").strip().lower()
    if choice == 'save':
        fname = input("Output CSV filename (e.g., results.csv): ").strip()
        if not fname:
            print("Filename cannot be empty. Skipping save.")
            return
        rows = []
        for idx, r in final_df.iterrows():
            pid = r.get('patient_id', idx)
            for term in r.get('HPO_Terms', []):
                ph = term.get('phrase', '').strip()
                cat = term.get('category', '')
                hp = term.get('HPO_Term') or ''
                if isinstance(hp, str):
                    hp = hp.replace("HP:HP:", "HP:")
                if hp:
                    rows.append({
                        'Patient ID': pid,
                        'Category': cat,
                        'Phenotype name': ph,
                        'HPO ID': hp
                    })
                else:
                    print(f"Warning: Blank HPO_Term for patient {pid}, phrase '{ph}', category '{cat}' - not included in CSV.")
        if rows:
            output_df = pd.DataFrame(rows)
            output_df.to_csv(fname, index=False)
            logger.log(f"Saved tabular results to {fname}")
        else:
            logger.log("No valid HPO terms to save in tabular format.")
        final_df.to_csv(f"{os.path.splitext(fname)[0]}_json_raw.csv", index=False)
        logger.log(f"Saved raw JSON results to {os.path.splitext(fname)[0]}_json_raw.csv")
    elif choice == 'display':
        tbl = []
        for idx, r in final_df.iterrows():
            pid = r.get('patient_id', idx)
            for term in r.get('HPO_Terms', []):
                ph = term.get('phrase', '').strip()
                cat = term.get('category', '')
                hp = term.get('HPO_Term') or ''
                if isinstance(hp, str):
                    hp = hp.replace("HP:HP:", "HP:")
                tbl.append({
                    'Case': f"Case {pid}",
                    'Category': cat,
                    'Phenotype name': ph,
                    'HPO ID': hp
                })
        if tbl:
            print(tabulate(pd.DataFrame(tbl), headers='keys', tablefmt='psql'))
        else:
            logger.log("No terms to display.")
    else:
        print("Invalid choice; please enter 'save' or 'display'.")

def cleanup(temp_files, success):
    # Removes temp files if pipeline succeeded
    if success:
        logger.log("Pipeline succeeded. Cleaning up temporary files...")
        for path in temp_files.values():
            try:
                if os.path.exists(path):
                    os.remove(path)
                    # logger.log(f"Removed temp file: {path}")
            except OSError as e:
                logger.log(f"Error removing temp file {path}: {e}")
    else:
        logger.log("Pipeline failed. Keeping temporary files for debugging/resume.")

def save_state_checkpoint(state, temp_files, keys=('input','combined', 'exact', 'non_exact', 'final')):
    # Saves pipeline state to disk
    for key in keys:
        df = state.get(key)
        if df is None or df.empty:
            continue
        rel_path = temp_files.get(key)
        if not rel_path:
            logger.log(f"Warning: No path configured for '{key}'. Skipping.")
            continue
        abs_path = os.path.abspath(rel_path)
        tmp_path = abs_path + ".tmp"
        try:
            df.to_pickle(tmp_path)
            shutil.move(tmp_path, abs_path)
            logger.log(f"[SAVE] Checkpointed '{key}' ({len(df)} rows) -> {abs_path}", once=True)
        except Exception as e:
            logger.log(f"Error saving '{key}': {e}")
            if os.path.exists(tmp_path): os.remove(tmp_path)

# ======================= Main Pipeline =======================
def main():
    # 1) initialize environment
    if not check_and_initialize():
        print("Environment initialized.")
    timestamped_print("Starting HPO extraction pipeline...")
    start_time = time.time()

    # 2) load checkpointed state
    state = load_state(TEMP_FILES)
    for key, df in state.items():
        if not df.empty:
            timestamped_print(f"Loaded checkpoint '{key}' ({len(df)} rows)")
        else:
            timestamped_print(f"No checkpoint found for '{key}', starting fresh.")

    # 3) on first run, ingest notes
    if state['combined'].empty:
        if input("Manual notes? (yes/no): ").strip().lower() == 'yes':
            notes = []
            while True:
                note = input("Note (or 'done'): ")
                if note.lower() == 'done':
                    break
                notes.append(note)
            df_input = pd.DataFrame({'clinical_note': notes})
        else:
            while True:
                fname = input("CSV filename: ")
                try:
                    raw = pd.read_csv(fname)
                    df_input = standardize_input(raw)
                    break
                except Exception as e:
                    print(f"Error loading CSV: {e}")
        state['input'] = validate_input(df_input)
        save_state_checkpoint(state, TEMP_FILES, keys=['input'])
    else:
        timestamped_print("Resuming from existing 'combined' checkpoint.")

    # 4) initialize models and indices
    emb_model    = initialize_embeddings_model(use_sbert=True)
    docs, emb_matrix = load_vector_db(meta_path='hpo_meta.json', vec_path='hpo_embedded.npz')
    index        = create_faiss_index(emb_matrix, metric='cosine')
    cluster_index = build_cluster_index(docs)

    success = False
    try:
        # 5) process raw clinical notes if needed
        if state['combined'].empty:
            timestamped_print("Processing clinical notes...")
            combined = pd.DataFrame()
            pids     = sorted(state['input']['patient_id'].unique())
            total    = len(pids)
            for i, pid in enumerate(tqdm(pids, desc="Processing Notes", unit="note")):
                note = state['input'].loc[state['input']['patient_id'] == pid, 'clinical_note'].iloc[0]
                res  = process_row(note, system_message_I, emb_model, index, docs)
                if not res.empty:
                    res['patient_id'] = pid
                    combined = pd.concat([combined, res], ignore_index=True)
                # checkpoint every 20 notes
                if (i + 1) % 20 == 0 or (i + 1) == total:
                    state['combined'] = combined.copy()
                    save_state_checkpoint(state, TEMP_FILES, keys=['combined'])
            state['combined'] = combined.copy()
            save_state_checkpoint(state, TEMP_FILES, keys=['combined'])
        else:
            timestamped_print("Skipped clinical processing; checkpoint exists.")

        # 6) split into exact vs non-exact
        if state['exact'].empty or state['non_exact'].empty:
            timestamped_print("Splitting exact vs non-exact…")
            df = state['combined'].copy()

            # ensure HPO_Term column is present
            if 'HPO_Term' not in df.columns:
                df['HPO_Term'] = np.nan

            # compute HPO_Term for any missing
            if not df.empty:
                df['HPO_Term'] = (
                    df.apply(
                        lambda r: extract_hpo_term(r['phrase'], r['unique_metadata'], cluster_index)
                                if pd.isna(r['HPO_Term']) else r['HPO_Term'],
                        axis=1
                    )
                    .astype(object)
                    .where(lambda x: pd.notna(x), np.nan)
                )

            exact_df, non_exact_df = split_exact_nonexact(df, hpo_term_col='HPO_Term')
            state['exact']     = exact_df
            state['non_exact'] = non_exact_df
            save_state_checkpoint(state, TEMP_FILES, keys=['exact', 'non_exact'])
        else:
            timestamped_print("Skipped splitting; checkpoints exist.")

        # 7) post-process non-exact entries
        non_ex = state['non_exact'].copy()
        for col in ("llm_parse_reason", "raw_llm_resp"):
            non_ex[col] = non_ex.get(col, pd.Series(dtype="object")).astype("object")

        idxs = non_ex[
            (non_ex['category'] == 'Abnormal') &
            (non_ex['HPO_Term'].isna()) &
            (non_ex['HPO_Term'] != "No Candidate Fit")
        ].index

        if not idxs.empty:
            timestamped_print(f"Generating HPO for {len(idxs)} entries...")
            counter = 0
            for i, idx in enumerate(tqdm(idxs, desc="Generating HPO", unit="entry")):
                row_df = non_ex.loc[[idx]]
                out_df = generate_hpo_terms(row_df, system_message_II)
                hp     = (out_df.at[0, 'HPO_Terms'][0]['HPO_Term']
                          if not out_df.empty else None)
                # attach parse reason & raw response
                if 'llm_parse_reason' in out_df.columns:
                    non_ex.at[idx, 'llm_parse_reason'] = out_df.at[0, 'llm_parse_reason']
                if 'raw_llm_resp' in out_df.columns:
                    non_ex.at[idx, 'raw_llm_resp']       = out_df.at[0, 'raw_llm_resp']
                non_ex.at[idx, 'HPO_Term'] = hp or "No Candidate Fit"

                counter += 1
                if counter >= 50 or (i + 1) == len(idxs):
                    state['non_exact'] = non_ex.copy()
                    save_state_checkpoint(state, TEMP_FILES, keys=['non_exact'])
                    counter = 0
        else:
            timestamped_print("No non-exact entries to process.")

        # 8) compile final results
        if state['final'].empty:
            timestamped_print("Compiling final results...")
            merged = pd.concat([state['exact'], state['non_exact']], ignore_index=True)
            merged = merged.dropna(subset=['HPO_Term'])
            if not merged.empty:
                grouped = (
                    merged
                      .groupby('patient_id')[['phrase','category','HPO_Term']]
                      .apply(lambda g: g.to_dict('records'))
                      .reset_index(name='HPO_Terms')
                )
                state['final'] = grouped.copy()
            else:
                state['final'] = pd.DataFrame(columns=['patient_id','HPO_Terms'])
            save_state_checkpoint(state, TEMP_FILES, keys=['final'])
        else:
            timestamped_print("Skipped final compilation; checkpoint exists.")

        # 9) output
        process_results(state['final'])
        success = True
        logger.log(f"Pipeline completed in {time.time() - start_time:.2f}s")

    except Exception as e:
        logger.log(f"Pipeline error: {e}. Saving progress and exiting.")
        save_state_checkpoint(state, TEMP_FILES)
        traceback.print_exc()
        # sys.exit(1)

    finally:
        cleanup(TEMP_FILES, success)

if __name__ == '__main__':
    main()

2025-08-26 02:34:13 - Starting HPO extraction pipeline...
2025-08-26 02:34:13 - No checkpoint found for 'input', starting fresh.
2025-08-26 02:34:13 - No checkpoint found for 'combined', starting fresh.
2025-08-26 02:34:13 - No checkpoint found for 'exact', starting fresh.
2025-08-26 02:34:13 - No checkpoint found for 'non_exact', starting fresh.
2025-08-26 02:34:13 - No checkpoint found for 'final', starting fresh.
2025-08-26 02:34:13 - No checkpoint found for 'responses', starting fresh.


Manual notes? (yes/no):  yes
Note (or 'done'):  A syndrome of brachydactyly (absence of some middle or distal phalanges), aplastic or hypoplastic nails, symphalangism (ankylois of proximal interphalangeal joints), synostosis of some carpal and tarsal bones, craniosynostosis, and dysplastic hip joints is reported in five members of an Italian family. It may represent a previously undescribed autosomal dominant trait.
Note (or 'done'):  Townes-Brocks syndrome (TBS) is an autosomal dominant disorder with multiple malformations and variable expression. Major findings include external ear anomalies, hearing loss, preaxial polydactyly and triphalangeal thumbs, imperforate anus, and renal malformations. Most patients with Townes-Brocks syndrome have normal intelligence, although mental retardation has been noted in a few.
Note (or 'done'):  done


2025-08-26 02:34:40 - [SAVE] Checkpointed 'input' (2 rows) -> /mnt/AI/wgs-database/phenotype_extractors/RAG-HPO/temp_input.pkl
2025-08-26 02:34:48 - Processing clinical notes...


Processing Notes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.84s/note]


2025-08-26 02:34:56 - [SAVE] Checkpointed 'combined' (13 rows) -> /mnt/AI/wgs-database/phenotype_extractors/RAG-HPO/temp_combined_results.pkl
2025-08-26 02:34:56 - Splitting exact vs non-exact…
2025-08-26 02:34:56 - [SAVE] Checkpointed 'exact' (7 rows) -> /mnt/AI/wgs-database/phenotype_extractors/RAG-HPO/temp_exact_matches.pkl
2025-08-26 02:34:56 - [SAVE] Checkpointed 'non_exact' (6 rows) -> /mnt/AI/wgs-database/phenotype_extractors/RAG-HPO/temp_non_exact_matches.pkl
2025-08-26 02:34:56 - Generating HPO for 6 entries...


Generating HPO: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:16<00:00,  2.74s/entry]

2025-08-26 02:35:13 - Compiling final results...
2025-08-26 02:35:13 - [SAVE] Checkpointed 'final' (2 rows) -> /mnt/AI/wgs-database/phenotype_extractors/RAG-HPO/temp_final_result.pkl





Save results as CSV or display? (save/display):  save
Output CSV filename (e.g., results.csv):  gemini_flash_result.csv


2025-08-26 02:35:39 - Saved tabular results to gemini_flash_result.csv
2025-08-26 02:35:39 - Saved raw JSON results to gemini_flash_result_json_raw.csv
2025-08-26 02:35:39 - Pipeline completed in 85.89s
2025-08-26 02:35:39 - Pipeline succeeded. Cleaning up temporary files...


In [2]:
# ======================= Batch HPO Term Replacement =======================
import csv
import requests
from tqdm import tqdm
import time

def get_ols_term_status(hpo_id, max_retries=3, sleep_seconds=1):
    # Checks OLS for HPO term status and replacement
    if not hpo_id.startswith("HP:"):
        print(f"[SKIP] Invalid HPO ID format: {hpo_id}")
        return False, None, "invalid_format"
    iri = f"http://purl.obolibrary.org/obo/{hpo_id.replace(':', '_')}"
    url = f"https://www.ebi.ac.uk/ols4/api/ontologies/hp/terms?iri={requests.utils.quote(iri)}"
    for attempt in range(1, max_retries + 1):
        try:
            r = requests.get(url, timeout=10)
            r.raise_for_status()
            data = r.json()
            term = data.get("_embedded", {}).get("terms", [{}])[0]
            is_obsolete = term.get("is_obsolete", False)
            replacement_id = term.get("term_replaced_by")
            return is_obsolete, replacement_id, None
        except requests.exceptions.RequestException as e:
            print(f"[WARN] {hpo_id} failed (attempt {attempt}): {e}")
            if attempt < max_retries:
                time.sleep(sleep_seconds * attempt)
            else:
                return False, None, str(e)
    return False, None, "unknown_error"

def process_and_replace_all(infile, outfile):
    # Reads CSV, replaces obsolete HPO terms, writes updated CSV
    failures = []
    replaced_count = 0
    skipped_count = 0
    with open(infile, newline='') as rf:
        reader = list(csv.DictReader(rf))
        fieldnames = reader[0].keys() | {"was_replaced", "original_term"}
        with open(outfile, 'w', newline='') as wf:
            writer = csv.DictWriter(wf, fieldnames=fieldnames)
            writer.writeheader()
            for row in tqdm(reader, desc="Processing HPO terms", unit="row"):
                hpo_field = row.get("hpo_term", "").strip()
                if not hpo_field or hpo_field.lower() == "none":
                    row["original_term"] = hpo_field
                    row["was_replaced"] = "False"
                    skipped_count += 1
                    writer.writerow(row)
                    continue
                hpo_ids = [h.strip() for h in hpo_field.split(",") if h.strip().startswith("HP:")]
                replaced_ids = []
                was_any_replaced = False
                for hpo_id in hpo_ids:
                    is_obs, replacement, error = get_ols_term_status(hpo_id)
                    if error:
                        failures.append({"hpo_id": hpo_id, "error": error, "row": row})
                        replaced_ids.append(hpo_id)
                        continue
                    if is_obs and replacement:
                        replaced_ids.append(replacement)
                        was_any_replaced = True
                        replaced_count += 1
                        print(f"[REPLACED] {hpo_id} → {replacement}")
                    elif is_obs and not replacement:
                        print(f"[OBSOLETE w/o replacement] {hpo_id}")
                        replaced_ids.append(hpo_id)
                    else:
                        replaced_ids.append(hpo_id)
                row["original_term"] = hpo_field
                row["hpo_term"] = ", ".join(replaced_ids)
                row["was_replaced"] = str(was_any_replaced)
                writer.writerow(row)
    print(f"\n✅ Batch complete. Updated file saved to:\n{outfile}")
    print(f"🟢 Terms replaced: {replaced_count}")
    print(f"🟡 Skipped (empty or 'none'): {skipped_count}")
    print(f"🔴 Failed API lookups: {len(failures)}")
    if failures:
        failfile = outfile.replace(".csv", "_failures.csv")
        with open(failfile, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=["hpo_id", "error", "row"])
            writer.writeheader()
            writer.writerows(failures)
        print(f"🗂️ Failures saved to: {failfile}")

if __name__ == "__main__":
    input_file = "gemini_flash_result.csv"
    output_file = "gemini_flash_result_updated.csv"
    process_and_replace_all(input_file, output_file)

Processing HPO terms: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 75730.49row/s]


✅ Batch complete. Updated file saved to:
gemini_flash_result_updated.csv
🟢 Terms replaced: 0
🟡 Skipped (empty or 'none'): 13
🔴 Failed API lookups: 0





In [3]:
import pandas as pd
import re  # [Added] for splitting alt_ids

def enrich_results_with_hpo_info(results_csv: str,
                                 hpo_full_csv: str,
                                 output_csv: str):
    """
    Read your pipeline results (with column 'hpo_term') and the full HPO terms file,
    replace any alt_id in 'hpo_term' with its canonical hp_id, then merge in
    alt_ids, snomedct, and umls by matching hpo_term → hp_id, and save.
    """
    # 1. Load the results file, treating all columns as strings
    results_df = pd.read_csv(results_csv, dtype=str).fillna('')

    # 2. Load the full HPO terms file and extract unique info rows
    hpo_full_df = pd.read_csv(hpo_full_csv, dtype=str).fillna('')
    hpo_info_df = (
        hpo_full_df
        .loc[:, ['hp_id', 'alt_ids', 'snomedct', 'umls']]
        .drop_duplicates(subset=['hp_id'])
    )

    # 3. Build alt_id → canonical hp_id mapping                 # [Added]
    alt_map = {}
    for hp_id, alt_ids in zip(hpo_info_df['hp_id'], hpo_info_df['alt_ids']):
        if not alt_ids:
            continue
        for alt in re.split(r'[;,]\s*', alt_ids):
            alt = alt.strip()
            if alt:
                alt_map[alt] = hp_id

    # 4. Replace alt_ids in results_df['hpo_term']                # [Modified]
    original_terms = results_df['hpo_term'].copy()               # [Added]
    results_df['hpo_term'] = results_df['hpo_term'].apply(
        lambda x: alt_map.get(x, x)
    )
    num_replaced = (results_df['hpo_term'] != original_terms).sum()  # [Added]
    print(f"[INFO] Replaced {num_replaced} alt_id entries with canonical hp_id")  # [Added]

    # 5. Merge in HPO info
    enriched_df = results_df.merge(
        hpo_info_df,
        how='left',
        left_on='hpo_term',
        right_on='hp_id'
    )

    # 6. Drop redundant hp_id column
    enriched_df.drop(columns=['hp_id'], inplace=True)

    # 7. Write out enriched DataFrame
    enriched_df.to_csv(output_csv, index=False)

# Example invocation:
enrich_results_with_hpo_info(
    results_csv='gemini_flash_result_updated.csv',
    hpo_full_csv='hpo_terms_full.csv',
    output_csv='gemini_flash_result_updated_enriched.csv'
)

KeyError: 'hpo_term'