In [6]:
from pathlib import Path
import pandas as pd
from openai import OpenAI
import matplotlib.pyplot as plt

# --- Config ---
VLLM_BASE_URL = "http://localhost:8000/v1"
MODEL_NAME = "qwen"
CSV_PATH = Path("data/testdata_MIE.csv")

if not CSV_PATH.exists():
    raise FileNotFoundError(f"CSV nicht gefunden: {CSV_PATH.resolve()}")

df = pd.read_csv(CSV_PATH)
client = OpenAI(base_url=VLLM_BASE_URL, api_key="local-not-needed")


def query_template(df: pd.DataFrame, query: str):
    schema_hint = f"df.columns = {list(df.columns)}"

    return [
        {
            "role": "system",
            "content": (
                "You are a Python code generator for pandas data analysis.\n"
                "A pandas DataFrame named df already exists in memory.\n\n"

                "STRICT RULES (MANDATORY):\n"
                "- Output ONLY executable Python code.\n"
                "- NO explanations, NO comments, NO markdown.\n"
                "- NEVER load files (no read_csv, no file paths).\n"
                "- NEVER overwrite or reassign df.\n"
                "- NEVER assign values to dataframe columns (FORBIDDEN: df['X'] = ...).\n"
                "- Use ONLY existing dataframe columns.\n"
                "- The FINAL answer MUST be assigned to a variable named result.\n"
                "- Do NOT use print().\n"
                "- ABSOLUTE RULE: The code MUST NOT contain the word 'import'.\n\n"

                "BOOLEAN INDEXING (MANDATORY):\n"
                "- When combining pandas conditions, use EXACTLY:\n"
                "  (condition_1) & (condition_2)\n"
                "- BOTH conditions MUST be fully enclosed in parentheses.\n"
                "- NEVER use 'and' or 'or' with pandas Series.\n\n"

                "STRING SAFETY RULE (MANDATORY):\n"
                "- BEFORE using .str.* you MUST convert the column using:\n"
                "  df['Column'].astype(str)\n\n"

                "COLUMN VALIDATION (MANDATORY PATTERN):\n"
                "required = {...}\n"
                "missing = required.difference(df.columns)\n"
                "if missing:\n"
                "    raise ValueError(f'Missing required columns: {missing}')\n\n"

                "DATASET SEMANTICS (IMPORTANT):\n"
                "- This dataset has NO concept of secondary diagnoses.\n"
                "- Diagnoses are ONLY rows where Finding == 'Diagnose'.\n"
                "- Diagnosis codes are stored in the Value column.\n"
                "- Patients are identified by MPINumber.\n"
                "- Patient counts MUST use df.loc[mask,'MPINumber'].nunique().\n\n"

                "PLOTTING RULES (MANDATORY):\n"
                "- matplotlib.pyplot is already available as plt. DO NOT import anything.\n"
                "- If a plot is requested, ALWAYS end with:\n"
                "  plt.tight_layout()\n"
                "  plt.show()\n"
                "- You MUST still set result to the data used for the plot (Series/DataFrame/scalar).\n\n"

                "BOX PLOT RULE (MANDATORY):\n"
                "- A box plot MUST visualize a distribution of multiple values.\n"
                "- For 'diagnoses per patient', compute per-patient counts using:\n"
                "  df.loc[mask].groupby('MPINumber').size()\n\n"

                "ABSOLUTE SEMANTIC RULE:\n"
                "- Do NOT invent column names.\n"
                "- Do NOT guess meanings like 'secondary'.\n"
                "- If the user asks for 'secondary diagnoses', treat it as 'Diagnose'.\n"
            )
        },

        {"role": "user", "content": schema_hint},

        # Few-shot 1: counting diagnoses
        {"role": "user", "content": "How many patients got diagnosed with N18?"},
        {
            "role": "assistant",
            "content": (
                "required = {'Finding', 'Value', 'MPINumber'}\n"
                "missing = required.difference(df.columns)\n"
                "if missing:\n"
                "    raise ValueError(f'Missing required columns: {missing}')\n"
                "mask = (df['Finding'] == 'Diagnose') & (df['Value'].astype(str).str.contains('N18', na=False))\n"
                "result = df.loc[mask, 'MPINumber'].nunique()"
            )
        },

        # Few-shot 2: plot diagnoses
        {"role": "user", "content": "Plot the five most common diagnoses."},
        {
            "role": "assistant",
            "content": (
                "required = {'Finding', 'Value'}\n"
                "missing = required.difference(df.columns)\n"
                "if missing:\n"
                "    raise ValueError(f'Missing required columns: {missing}')\n"
                "mask = (df['Finding'] == 'Diagnose')\n"
                "top5 = df.loc[mask, 'Value'].astype(str).value_counts().head(5)\n"
                "top5.plot(kind='bar')\n"
                "plt.tight_layout()\n"
                "plt.show()\n"
                "result = top5"
            )
        },

        # Real user query
        {"role": "user", "content": query},
    ]


def ask_llm(df: pd.DataFrame, query: str) -> str:
    resp = client.chat.completions.create(
        model=MODEL_NAME,
        messages=query_template(df, query),
        temperature=0.0,
        max_tokens=700,
    )
    return (resp.choices[0].message.content or "").strip()


class ExecWithCodeError(RuntimeError):
    def __init__(self, message: str, code: str):
        super().__init__(message)
        self.code = code


def execute_query(query: str):
    code = ask_llm(df, query)

    # remove markdown fences if any
    if code.startswith("```"):
        code = "\n".join(
            l for l in code.splitlines()
            if not l.strip().startswith("```")
        ).strip()

    # hard block imports
    if "import " in code:
        raise ExecWithCodeError("Import is not allowed (plt is already provided).", code)

    safe_globals = {
        "__builtins__": {
            "len": len, "str": str, "int": int, "float": float,
            "min": min, "max": max, "sum": sum,
            "set": set, "sorted": sorted, "range": range
        },
        "plt": plt,
        "pd": pd,
    }
    safe_locals = {"df": df}

    try:
        exec(code, safe_globals, safe_locals)
    except Exception as e:
        raise ExecWithCodeError(str(e), code) from e

    if "result" not in safe_locals:
        raise ExecWithCodeError("LLM code did not set variable 'result'.", code)

    return safe_locals["result"], code


In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output

query_input = widgets.Textarea(
    value="How many patients got diagnosed with H36? Consider that each patient can have multiple findings.",
    placeholder="Enter your query...",
    description="Query:",
    layout=widgets.Layout(width="95%", height="80px")
)

run_button = widgets.Button(description="Run", button_style="primary")
status = widgets.HTML(value="")
output = widgets.Output(layout=widgets.Layout(border="1px solid #ddd", width="95%"))

def on_run_clicked(_):
    status.value = "<b>Running…</b>"
    q = query_input.value.strip()

    with output:
        clear_output(wait=True)

        if not q:
            status.value = "<span style='color:#b00'><b>Please enter a query.</b></span>"
            return

        try:
            res, code = execute_query(q)
            display(res)
            status.value = "<span style='color:#0a0'><b>Done.</b></span>"
        except ExecWithCodeError as e:
            status.value = "<span style='color:#b00'><b>Error.</b></span>"
            print("Fehler bei der Ausführung:")
            print(e)
            print("\n--- Generated code (debug) ---")
            print(e.code)

run_button.on_click(on_run_clicked)

display(widgets.VBox([
    widgets.HBox([query_input, run_button]),
    status,
    output
]))


VBox(children=(HBox(children=(Textarea(value='How many patients got diagnosed with H36? Consider that each pat…

In [8]:
import os, sys
print("cwd:", os.getcwd())
print("python:", sys.executable)


cwd: J:\dev\LLMchatbot_MIE2024
python: C:\Python314\python.exe
