In [12]:
import toml 
import json
import re
import csv
import json

from typing import List, Dict, Callable
from tqdm import tqdm


def load_openai_key(beaker_conf_path="../.beaker.conf"):
    """
    Load OpenAI API key from a Beaker configuration file.

    Args:
        beaker_conf_path (str): Path to the .beaker.conf file.

    Returns:
        str: OpenAI API key.
    """
    config = toml.load(beaker_conf_path)
    print("Beaker config loaded successfully.")

    openai_api_key = config['providers']['openai']['api_key']
    #print("Loaded API key (first 10 chars):", openai_api_key[:10] + "...")
    
    return openai_api_key

def prompt_multiple_questions_template(report_text: str, n: int) -> str:
    return f"""
    You are a biomedical researcher. Generate {n} multiple-choice questions from this gene set report:
    {report_text}
    
    Format the output as a JSON list of objects with:
    - "question": string,
    - "choices": list of 4 strings,
    - "correct": string (the correct answer)
    """

In [13]:

class OpenAILLM:
    def __init__(self, client, model="gpt-4o-mini"):
        self.client = client
        self.model = model

    def run(self, prompt: str, json_output: bool = True):
        """
        Sends a prompt to the model and returns the response.
        
        Parameters:
        - prompt: str, the instruction or task for the model.
        - json_output: bool, if True, parse output as JSON, else return raw text.
        
        Returns:
        - dict or str: Parsed JSON if json_output=True, else raw text.
        """
        resp = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}]
        )

        content = resp.choices[0].message.content.strip()
        if json_output:
            try:
                return json.loads(content)
            except Exception:
                content_clean = re.sub(r"^```json\s*|\s*```$", "", content, flags=re.DOTALL).strip()
                return json.loads(content_clean)

def get_llm(backend: str, **kwargs):

    if backend == "openai":
        from openai import OpenAI
        return OpenAILLM(OpenAI(api_key=kwargs.get("api_key")), model=kwargs.get("model", "gpt-4o-mini"))
    elif backend == "anthropic":
        # TODO: Add Anthropic
        pass
    elif backend == "hf":
        # TODO: Add HuggingFace 
        pass
    else:
        raise ValueError(f"Unknown LLM backend: {backend}")


In [17]:

def load_reports(csv_path: str, report_column: str = "Ground Truth", id_column: str = "ID", source: str = "gt"):
    reports = []
    with open(csv_path, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for idx, row in enumerate(reader):
            report_id = row.get(id_column)
            report_text = row.get(report_column)
            if report_text and report_text.strip():
                report_key = f"{source}_{idx}"
                reports.append((report_key, report_id, report_text))
    return reports



def generate_questions(
    reports: List[tuple],
    prompt_template: Callable[[str, int], str],
    llm,
    num_questions: int = 20,
    output_path: str = None
) -> List[Dict]:
    """
    Generate multiple choice questions from pre-loaded reports using a given prompt template and LLM.
    
    Args:
        reports: List of tuples (report_id, report_text).
        prompt_template: Function to generate prompt from text and question count.
        llm: LLM object with a .run(prompt, json_output=True) method.
        num_questions: Number of MCQs per report.
        output_path: Optional path to save the result as JSON.

    Returns:
        List of MCQ dictionaries.
    """
    multiple_choice_questions = []

    for report_key, _ , report_text in tqdm(reports):
        try:
            prompt = prompt_template(report_text, num_questions)
            mcqs = llm.run(prompt, json_output=True)

            if isinstance(mcqs, list):
                for q in mcqs:
                    multiple_choice_questions.append({
                        "report_id": report_key,
                        "question": q.get("question", ""),
                        "choices": q.get("choices", []),
                        "correct": q.get("correct", "")
                    })
            else:
                print(f"[Warning] Unexpected output format for report {report_id}")
        except Exception as e:
            print(f"[Error] Report {report_key}: {e}")
            continue

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(multiple_choice_questions, f, indent=2, ensure_ascii=False)

    return multiple_choice_questions

In [38]:
src="gt"
reports = load_reports("../data/benchmark.csv", report_column="Ground Truth", source=src)
llm = get_llm("openai", api_key=load_openai_key())


Beaker config loaded successfully.


In [19]:
mcqs = generate_questions(
    reports=reports,
    prompt_template=prompt_multiple_questions_template,
    llm=llm,
    num_questions=20,
    output_path=f"../data/qs_{src}.json"
)

100%|███████████████████████████████████████████| 14/14 [04:37<00:00, 19.82s/it]


In [21]:
def answer_prompt_template(question: str, choices: List[str], report: str = None) -> str:
    if report:
        prompt = f"""
        Given the following gene set analysis report:
        {report}
        Answer the following question: {question}
        Choices: {choices}

        Answer with the best choice and a confidence score (0-1) in JSON:
        {{
            "answer": "<chosen_option>",
            "confidence": <float>
        }}
        """
    else:
        prompt = f"""
        Answer the following question: {question}
        Choices: {choices}
        Answer with the best choice and a confidence score (0-1) in JSON:
        {{
            "answer": "<chosen_option>",
            "confidence": <float>
        }}
        """
    return prompt.strip()

In [34]:
report_dict = {r[0]: r[2] for r in reports}  # {report_key: report_text}

In [27]:
import json
src='gt'
file_path = f"../data/qs_{src}.json"

with open(file_path, 'r', encoding='utf-8') as f:
    questions = json.load(f)

len(questions)

266

In [50]:


def evaluate_questions(
    questions: List[Dict],
    llm,
    reports: Dict[str, str] = None,
    with_report: bool = True,
    output_path: str = None,
    source: str = "gt"
) -> List[Dict]:
    """
    Evaluate MCQs by asking an LLM to answer them, optionally using the associated report text.

    Args:
        questions: List of dicts with keys: question, choices, correct, report_id.
        llm: LLM instance with .run(prompt, json_output=True)
        reports: Dict mapping report_id -> report_text
        with_report: Whether to pass the report to the prompt
        output_path: Optional path to save the output
        source: The source label to use in the output metadata

    Returns:
        List of dicts with answer, confidence, question, correct answer, etc.
    """
    results = []
    mode = "with_report" if with_report else "without_report"

    for q in questions:
        report_text = reports.get(q["report_id"], "") if with_report and reports else None
        print(report_text)
        prompt = answer_prompt_template(q["question"], q["choices"], report=report_text)
        
        try:
            response = llm.run(prompt, json_output=True)
            results.append({
                "report_id": q["report_id"],
                "question": q["question"],
                "choices": q["choices"],
                "correct": q["correct"],
                "answer": response.get("answer", None),
                "confidence": response.get("confidence", None),
                "mode": mode,
                "source": source
            })
        except Exception as e:
            print(f"[Error] report {q['report_id']} - {e}")
            results.append({
                "report_id": q["report_id"],
                "question": q["question"],
                "choices": q["choices"],
                "correct": q["correct"],
                "answer": None,
                "confidence": None,
                "mode": mode,
                "source": source,
                "error": str(e)
            })

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)

    return results


In [52]:
# Convert reports list to a dictionary: {report_id: report_text}
report_dict = {r[0]: r[2] for r in reports}
wr=False
qs_with_report = evaluate_questions(
    questions=questions[:50],
    llm=llm,
    reports=report_dict,
    with_report=wr,
    output_path=f"../data/answers_{src}_with_report{wr}.json",
    source=src
)

None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None


In [54]:
qs_with_report


[{'report_id': 'gt_0',
  'question': 'What is the main effect of CTNNB1 hotspot mutations on Wnt-β-catenin signaling?',
  'choices': ['Downregulation of Wnt-β-catenin signaling',
   'Upregulation of Wnt-β-catenin signaling',
   'No effect on Wnt-β-catenin signaling',
   'Inhibition of Wnt-β-catenin signaling'],
  'correct': 'Upregulation of Wnt-β-catenin signaling',
  'answer': 'Upregulation of Wnt-β-catenin signaling',
  'confidence': 0.95,
  'mode': 'without_report',
  'source': 'gt'},
 {'report_id': 'gt_0',
  'question': 'Which protein is significantly upregulated in tumors with CTNNB1 hotspot mutations?',
  'choices': ['DKK4', 'LEF1', 'GSK3β', 'APC'],
  'correct': 'LEF1',
  'answer': 'LEF1',
  'confidence': 0.85,
  'mode': 'without_report',
  'source': 'gt'},
 {'report_id': 'gt_0',
  'question': 'What is the role of DKK4 in relation to the Wnt pathway?',
  'choices': ['It activates the Wnt pathway',
   'It inhibits the Wnt pathway',
   'It has no effect on the Wnt pathway',
   'It 