In [None]:
import toml 
import json
import re
import csv
import json

from typing import List, Dict, Callable
from tqdm import tqdm


def load_openai_key(beaker_conf_path="../.beaker.conf"):
    """
    Load OpenAI API key from a Beaker configuration file.

    Args:
        beaker_conf_path (str): Path to the .beaker.conf file.

    Returns:
        str: OpenAI API key.
    """
    config = toml.load(beaker_conf_path)
    print("Beaker config loaded successfully.")

    openai_api_key = config['providers']['openai']['api_key']
    #print("Loaded API key (first 10 chars):", openai_api_key[:10] + "...")
    
    return openai_api_key

def prompt_multiple_questions_template(report_text: str, n: int) -> str:
    return f"""
    You are a biomedical researcher. Generate {n} multiple-choice questions from this gene set report:
    {report_text}
    
    Format the output as a JSON list of objects with:
    - "question": string,
    - "choices": list of 4 strings,
    - "correct": string (the correct answer)
    """

In [None]:

class OpenAILLM:
    def __init__(self, client, model="gpt-4o-mini"):
        self.client = client
        self.model = model

    def run(self, prompt: str, json_output: bool = True):
        """
        Sends a prompt to the model and returns the response.
        
        Parameters:
        - prompt: str, the instruction or task for the model.
        - json_output: bool, if True, parse output as JSON, else return raw text.
        
        Returns:
        - dict or str: Parsed JSON if json_output=True, else raw text.
        """
        resp = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}]
        )

        content = resp.choices[0].message.content.strip()
        if json_output:
            try:
                return json.loads(content)
            except Exception:
                content_clean = re.sub(r"^```json\s*|\s*```$", "", content, flags=re.DOTALL).strip()
                return json.loads(content_clean)

def get_llm(backend: str, **kwargs):

    if backend == "openai":
        from openai import OpenAI
        return OpenAILLM(OpenAI(api_key=kwargs.get("api_key")), model=kwargs.get("model", "gpt-4o-mini"))
    elif backend == "anthropic":
        # TODO: Add Anthropic
        pass
    elif backend == "hf":
        # TODO: Add HuggingFace 
        pass
    else:
        raise ValueError(f"Unknown LLM backend: {backend}")


In [None]:

def load_reports(csv_path: str, report_column: str = "Ground Truth", id_column: str = "ID", source: str = "gt"):
    reports = []
    with open(csv_path, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for idx, row in enumerate(reader):
            report_id = row.get(id_column)
            report_text = row.get(report_column)
            if report_text and report_text.strip():
                report_key = f"{source}_{idx}"
                reports.append((report_key, report_id, report_text))
    return reports

In [None]:


def generate_questions(
    reports: List[tuple],
    prompt_template: Callable[[str, int], str],
    llm,
    num_questions: int = 20,
    output_path: str = None
) -> List[Dict]:
    """
    Generate MCQs from pre-loaded reports using a given prompt template and LLM.
    
    Args:
        reports: List of tuples (report_id, report_text).
        prompt_template: Function to generate prompt from text and question count.
        llm: LLM object with a .run(prompt, json_output=True) method.
        num_questions: Number of MCQs per report.
        output_path: Optional path to save the result as JSON.

    Returns:
        List of MCQ dictionaries.
    """
    multiple_choice_questions = []

    for report_key, _ , report_text in tqdm(reports):
        try:
            prompt = prompt_template(report_text, num_questions)
            mcqs = llm.run(prompt, json_output=True)

            if isinstance(mcqs, list):
                for q in mcqs:
                    multiple_choice_questions.append({
                        "report_id": report_key,
                        "question": q.get("question", ""),
                        "choices": q.get("choices", []),
                        "correct": q.get("correct", "")
                    })
            else:
                print(f"[Warning] Unexpected output format for report {report_id}")
        except Exception as e:
            print(f"[Error] Report {report_id}: {e}")
            continue

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(multiple_choice_questions, f, indent=2, ensure_ascii=False)

    return multiple_choice_questions

In [None]:
src="gt"
reports = load_reports("../data/benchmark.csv", report_column="Ground Truth", source=src)
llm = get_llm("openai", api_key=load_openai_key())

mcqs = generate_questions(
    reports=reports,
    prompt_template=prompt_multiple_questions_template,
    llm=llm,
    num_questions=20,
    output_path=f"../data/qs_{src}.json"
)

In [None]:
mcqs