In [17]:
import pandas as pd
import requests
import json
from typing import Tuple, List
from tqdm import tqdm
import os
import sys
from openai import OpenAI
from dotenv import load_dotenv

# Load environment variables
load_dotenv(override=True)
api_key = os.getenv('OPENAI_API_KEY')

# API key validation
if not api_key:
    print("❌ No API key found – check your .env file.")
elif not api_key.startswith("sk-proj-"):
    print("⚠️ API key found, but it doesn't start with 'sk-proj-' – check you're using the correct project key.")
elif api_key.strip() != api_key:
    print("⚠️ API key contains extra spaces or tabs – clean it up in your .env file.")
else:
    print("✅ API key loaded successfully.")

# OpenAI client setup
openai = OpenAI(api_key=api_key)

# ========== CONFIG ==========
INPUT_FILE = "health_dataset_full.csv"
OUTPUT_FILE = "health-data-processed.csv"
CHECKPOINT_FILE = "checkpoint.json"
CHECKPOINT_INTERVAL = 20
GPT_MODEL = "gpt-4o-mini"

# ========== CHECKPOINT HANDLING ==========
def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, 'r') as f:
                checkpoint = json.load(f)
                return checkpoint.get("last_row_index", 0)
        except json.JSONDecodeError:
            print("⚠️ Invalid checkpoint file. Starting from scratch.", file=sys.stderr)
    return 0

def save_checkpoint(last_row_index: int):
    checkpoint = {"last_row_index": last_row_index}
    try:
        with open(CHECKPOINT_FILE, 'w') as f:
            json.dump(checkpoint, f, indent=2)
    except Exception as e:
        print(f"❌ Error saving checkpoint: {e}", file=sys.stderr)

# ========== PROMPT BUILDER ==========
def build_prompt(row) -> str:
    try:
        full_text = row["text"]
        input_part = full_text.split("###Output")[0].replace("###Input :", "").strip()
        output_part = full_text.split("###Output :")[1].strip()
    except Exception:
        input_part = "MISSING"
        output_part = "MISSING"

    return f"""
You are a medical assistant tasked with interpreting patient Q&A records. For each case, your goal is to produce:

1. A clinically accurate and concise summary of both the patient's concern and the recommended medical advice. Your summary should preserve important details (e.g., age, symptoms, timelines, test results, vital signs, red flags), but keep it under **1000 characters**.
2. A list of one or more functional tags that describe the main medical domains involved.

Available Functional Tags:
- cardiology → For cardiac issues like arrhythmias, chest pain, blood pressure, or heart rate anomalies.
- neurology → For neurological concerns such as seizures, strokes, or chronic neurological disorders.
- autoimmune → For systemic or multi-system immune conditions like lupus, MS, or inflammatory syndromes.
- pharmacology → For cases involving medication concerns, drug interactions, or prescribing decisions.
- diagnostic_uncertainty → For unclear or undiagnosed symptoms requiring deeper diagnostic reasoning.
- patient_education → For cases requiring emotional support, plain-language explanation, or health literacy enhancement.

---
### Input Record:
###Input :{input_part}

###Output :{output_part}
---
### Output (structured):
Summary: <Summarize patient description and clinical recommendation in ≤ 1000 characters>
Tags: <Comma-separated list of relevant functional tags from above>
"""

# ========== GPT CALL ==========
def query_gpt(prompt: str) -> str:
    try:
        response = openai.chat.completions.create(
            model=GPT_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"❌ GPT API error: {e}", file=sys.stderr)
        return ""

# ========== RESPONSE PARSER ==========
def extract_interpretation(text: str) -> Tuple[str, List[str]]:
    text = text.strip()
    summary = ""
    tags = []
    try:
        for line in text.splitlines():
            if line.lower().startswith("summary:"):
                summary = line.split(":", 1)[1].strip()
            elif line.lower().startswith("tags:"):
                tags = [t.strip() for t in line.split(":", 1)[1].split(",") if t.strip()]
    except Exception:
        pass
    return summary or "ERROR", tags

# ========== MAIN PROCESSING ==========
def main():
    try:
        df = pd.read_csv(INPUT_FILE)
    except FileNotFoundError:
        print(f"❌ Error: {INPUT_FILE} not found.", file=sys.stderr)
        return
    except pd.errors.EmptyDataError:
        print(f"❌ Error: {INPUT_FILE} is empty.", file=sys.stderr)
        return

    # Load checkpoint
    last_row_index = load_checkpoint()
    rows_to_process = df.iloc[last_row_index:]
    print(f"🔄 Starting from row {last_row_index}, total rows: {len(rows_to_process)}", flush=True)

    processed_count = 0
    for idx, row in tqdm(rows_to_process.iterrows(),
                         total=len(rows_to_process),
                         desc="Processing rows",
                         unit="row",
                         dynamic_ncols=True,
                         mininterval=0.1):

        try:
            prompt = build_prompt(row)
            response = query_gpt(prompt)
            interpretation, tags = extract_interpretation(response)
        except Exception as e:
            print(f"❌ Row {idx} failed: {e}", file=sys.stderr)
            interpretation, tags = "ERROR", []

        df.at[idx, "interpretation"] = interpretation
        df.at[idx, "functional_tags"] = ", ".join(tags)

        processed_count += 1

        if processed_count % CHECKPOINT_INTERVAL == 0:
            header = not os.path.exists(OUTPUT_FILE)
            df.iloc[last_row_index:last_row_index + processed_count].to_csv(
                OUTPUT_FILE, mode="a", index=False, header=header
            )
            save_checkpoint(idx)
            last_row_index = idx + 1
            processed_count = 0

    # Save remaining rows
    if processed_count > 0:
        header = not os.path.exists(OUTPUT_FILE)
        df.iloc[last_row_index:].to_csv(
            OUTPUT_FILE, mode="a", index=False, header=header
        )
        save_checkpoint(idx)

    print(f"✅ Done. Processed data saved to {OUTPUT_FILE}")

# ========== RUN ==========
if __name__ == "__main__":
    main()

✅ API key loaded successfully.
🔄 Starting from row 1002, total rows: 1


Processing rows: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.05s/row]

✅ Done. Processed data saved to health-data-processed.csv





In [26]:
def filter_exact_error_rows(input_csv, output_clean_csv, output_error_csv):
    df = pd.read_csv(input_csv)

    # Strip spaces and filter exact matches for "ERROR"
    error_rows = df[df['interpretation'].str.strip().eq("ERROR")]
    clean_rows = df[~df['interpretation'].str.strip().eq("ERROR")]

    # Save filtered data
    error_rows.to_csv(output_error_csv, index=False)
    clean_rows.to_csv(output_clean_csv, index=False)

    print(f"Saved {len(error_rows)} error rows to '{output_error_csv}'")
    print(f"Saved {len(clean_rows)} clean rows to '{output_clean_csv}'")

In [27]:
filter_exact_error_rows(
    input_csv="health-data-processed.csv",
    output_clean_csv="health-data-clean.csv",
    output_error_csv="health-data-errors.csv"
)

Saved 18 error rows to 'health-data-errors.csv'
Saved 986 clean rows to 'health-data-clean.csv'


In [28]:
clean = pd.read_csv("health-data-clean.csv")
clean.head()

Unnamed: 0,text,interpretation,functional_tags
0,"###Input :My daughter ( F, 18 y/o, 5'5', 165lb...",An 18-year-old female has been feeling poorly ...,"cardiology, diagnostic_uncertainty"
1,###Input :Im a 37 y.o. transgender man with pr...,A 37-year-old transgender man with pre-diabete...,"cardiology, pharmacology, patient_education, d..."
2,###Input :Male 35 physically active no issues ...,"A 35-year-old physically active male, previous...","diagnostic_uncertainty, cardiology, neurology,..."
3,###Input :Appreciate it al labs have come back...,The patient reports that all lab tests returne...,diagnostic_uncertainty
4,"###Input :32F, 130-140lbs, have asthma and his...",A 32-year-old female with asthma and a history...,"cardiology, patient_education"


In [29]:
print(clean["interpretation"][20])

The patient's son has adult-onset Still's disease (AOSD), previously known as systemic-onset juvenile idiopathic arthritis (SoJIA) when under 16. His symptoms include a rash that resolves with fever, muscle aches, joint pain, and enlarged spleen and liver, with elevated CRP, ESR, ASO, and ferritin levels. Liver enzymes are also elevated due to swelling. He is currently on high-dose prednisone and biologic injections and was hospitalized for three weeks. The clinician emphasizes the need for aggressive early treatment for better chances of remission and suggests sharing information about SoJIA symptoms with the patient's mother for better understanding.
