In [32]:
!pip install datasets openai

30188.34s - pydevd: Sending message related to process being replaced timed-out after 5 seconds




In [None]:
import os
import json
import random
import pandas as pd
from datasets import load_dataset
import openai
import time
from tqdm import tqdm
import logging
from typing import List, Dict, Any, Optional, Tuple

In [34]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("generation.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [None]:
# Cell 3: Configuration

from dotenv import load_dotenv
load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

SEED = 42
EXPLANATION_MODEL = "gpt-4o-mini"  
DATASET_NAME = "GBaker/MedQA-USMLE-4-options"
SPLITS_TO_PROCESS = ["train"] 
OUTPUT_DIR = "synthetic_medqa_data"
MAX_SAMPLES_PER_SPLIT = 2 # Set to a number (e.g., 100) for testing, None to process all

# --- Output Format Configuration ---
PROMPT_FORMAT = """Question: {question}

Options:
{options_formatted}

Choose the best answer and provide a step-by-step explanation for your choice."""

# How to format the 'chosen' and 'rejected' responses
RESPONSE_FORMAT = """{answer_label}. {answer_text}
Explanation: {explanation}"""

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

logger.info(f"Configuration:")
logger.info(f"  Explanation Model: {EXPLANATION_MODEL}")
logger.info(f"  Dataset: {DATASET_NAME}")
logger.info(f"  Splits: {SPLITS_TO_PROCESS}")
logger.info(f"  Output Directory: {OUTPUT_DIR}")
logger.info(f"  Max Samples Per Split: {MAX_SAMPLES_PER_SPLIT}")

2025-04-12 19:24:52,110 - INFO - Configuration:
2025-04-12 19:24:52,110 - INFO -   Explanation Model: gpt-4o-mini
2025-04-12 19:24:52,111 - INFO -   Judge Model: gpt-4o-mini
2025-04-12 19:24:52,111 - INFO -   Dataset: GBaker/MedQA-USMLE-4-options
2025-04-12 19:24:52,112 - INFO -   Splits: ['train']
2025-04-12 19:24:52,112 - INFO -   Output Directory: synthetic_medqa_data
2025-04-12 19:24:52,112 - INFO -   Max Samples Per Split: 2


In [74]:
class PreferencePairGenerator:
    def __init__(self, api_key: str, seed: int,
                 explanation_model: str,
                 prompt_format: str, response_format: str):
        self.client = openai.OpenAI(api_key=api_key)
        random.seed(seed)
        self.explanation_model = explanation_model
        self.prompt_format = prompt_format
        self.response_format = response_format
        self.label_to_index = {"A": 0, "B": 1, "C": 2, "D": 3}
        self.index_to_label = {v: k for k, v in self.label_to_index.items()}

    def call_openai_api(self, messages: List[Dict[str, str]], model: str,
                        temperature: float = 0.5, max_tokens: int = 500,
                        retry_count: int = 3) -> Optional[Any]:
        for attempt in range(retry_count):
            try:
                return self.client.chat.completions.create(
                    model=model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
            except Exception as e:
                logger.warning(f"Attempt {attempt + 1} failed: {e}")
                time.sleep(2 ** attempt)
        logger.error("All retry attempts to call OpenAI API have failed.")
        return None

    def get_explanation(self, question: str, options_formatted: str, answer: str, is_correct: bool) -> str:
        if is_correct:
            prompt_detail = (
                "Generate a concise explanation (~20-30 words) highlighting the key clinical reasoning "
                "and evidence from the vignette that justifies why this answer is the most correct choice."
            )
        else:
            prompt_detail = (
                "Pretend this option is correct. Generate a concise explanation (~20-30 words) that appears "
                "plausible but subtly contains flawed reasoning. Focus on relevant clinical details while "
                "avoiding any direct mention that this choice might be incorrect."
            )

        prompt = (
            f"Medical Question Context:\n{question}\n\n"
            f"Options:\n{options_formatted}\n\n"
            f"Answer Choice to Explain: {answer}\n\n"
            f"{prompt_detail}"
        )

        messages = [
            {
                "role": "system",
                "content": (
                    "You are a highly knowledgeable medical expert specializing in clinical reasoning "
                    "explanations for USMLE-style questions. Be clear, concise, and follow instructions carefully."
                )
            },
            {"role": "user", "content": prompt}
        ]

        try:
            response = self.call_openai_api(
                messages=messages,
                model=self.explanation_model,
                temperature=0.6,
                max_tokens=150
            )

            if response:
                explanation = response.choices[0].message.content.strip()
                if not explanation or any(phrase in explanation.lower() for phrase in ["cannot provide", "explanation could not be generated"]):
                    logger.warning("Received potentially invalid explanation. Falling back.")
                    return "[Fallback] Explanation generation failed or returned invalid content."
                return explanation
            else:
                return "[API Error] Explanation could not be generated due to API failure."

        except Exception as e:
            logger.error(f"Error generating explanation: {e}")
            return "[Error] An explanation could not be generated."

    def select_alternative_answer(self, options: List[str], correct_index: int) -> int:
        incorrect_indices = [i for i in range(len(options)) if i != correct_index]
        return random.choice(incorrect_indices) if incorrect_indices else (correct_index + 1) % len(options)

    def process_sample(self, sample: Dict[str, Any], idx: int) -> Optional[Dict[str, Any]]:
        try:
            question = sample.get("question")
            options = sample.get("options")

            if not question or not options:
                logger.warning(f"Sample {idx}: Missing question or options. Skipping.")
                return None

            correct_label = sample.get("answer_idx").upper()
            correct_index = self.label_to_index[correct_label]
            correct_option_text = options[correct_label]

            options_formatted = "\n".join(
                [f"{self.index_to_label[i]}. {opt}" for i, opt in enumerate(options.values())]
            )

            logger.info(f"Sample {idx}: Generating explanation for CORRECT answer ({correct_label})")
            correct_explanation = self.get_explanation(question, options_formatted, correct_option_text, is_correct=True)
            if "[Error]" in correct_explanation or "[API Error]" in correct_explanation:
                logger.error(f"Sample {idx}: Failed to generate explanation for correct answer. Skipping.")
                return None

            alt_index = self.select_alternative_answer(list(options.values()), correct_index)
            alt_label = self.index_to_label[alt_index]
            alt_option_text = list(options.values())[alt_index]

            logger.info(f"Sample {idx}: Generating explanation for ALTERNATIVE answer ({alt_label})")
            alt_explanation = self.get_explanation(question, options_formatted, alt_option_text, is_correct=False)

            prompt_str = self.prompt_format.format(question=question, options_formatted=options_formatted)
            chosen_str = self.response_format.format(answer_label=correct_label, answer_text=correct_option_text, explanation=correct_explanation)
            rejected_str = self.response_format.format(answer_label=alt_label, answer_text=alt_option_text, explanation=alt_explanation)

            return {
                "prompt": prompt_str,
                "chosen": chosen_str,
                "rejected": rejected_str,
                "metadata": {
                    "original_question": question,
                    "options": options,
                    "correct_index": correct_index,
                    "alternative_index": alt_index,
                    "correct_label": correct_label,
                    "alternative_label": alt_label,
                    "correct_explanation_raw": correct_explanation,
                    "alternative_explanation_raw": alt_explanation,
                }
            }
        except Exception as e:
            logger.exception(f"Critical error processing sample {idx}: {e}")
            return None


    def process_dataset(self, dataset_name: str, splits: List[str],
                            output_dir: str, max_samples: Optional[int] = None
        ) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
            """
            Load dataset splits, generate SFT and DPO data, and save results.
            Results are also returned in a dict for optional in-code use.

            Returns:
                A dict mapping each split to its SFT and DPO data.
            """
            all_results = {}

            for split in splits:
                logger.info(f"--- Processing {split} split ---")

                try:
                    dataset = load_dataset(dataset_name, split=split)
                    logger.info(f"Loaded dataset for {split} split.")

                    # Exclude datapoints over 1024 characters
                    dataset = dataset.filter(lambda x: len(x["question"]) <= 1024)

                    total_samples = len(dataset)
                    logger.info(f"Total samples in {split}: {total_samples}")

                except Exception as e:
                    logger.error(f"Failed to load dataset {dataset_name} for split {split}: {e}")
                    continue

                # Prepare lists and file paths
                sft_data = []
                dpo_data = []
                sft_jsonl_path = os.path.join(output_dir, f"sft_data_{split}.jsonl")
                dpo_jsonl_path = os.path.join(output_dir, f"dpo_data_{split}.jsonl")

                processed_count = 0
                num_to_process = total_samples if max_samples is None else min(max_samples, total_samples)

                pbar_desc = f"Generating {split} data" + (f" (max {max_samples})" if max_samples else "")
                pbar_total = None if num_to_process is None else num_to_process

                # Open files once, then write results line by line inside the loop
                try:
                    with open(sft_jsonl_path, "w", encoding="utf-8") as f_sft, \
                        open(dpo_jsonl_path, "w", encoding="utf-8") as f_dpo:

                        for idx, sample in enumerate(tqdm(dataset, total=pbar_total, desc=pbar_desc)):
                            if max_samples is not None and processed_count >= max_samples:
                                logger.info(f"Reached max_samples limit ({max_samples}) for {split} split.")
                                break

                            processed_result = self.process_sample(sample, idx)
                            if processed_result:
                                # Create the SFT item
                                sft_item = {
                                    "prompt": processed_result["prompt"],
                                    "response": processed_result["chosen"]
                                }
                                # Create the DPO item
                                dpo_item = {
                                    "prompt": processed_result["prompt"],
                                    "response": processed_result["chosen"],
                                    "rejected": processed_result["rejected"],
                                    "metadata": processed_result["metadata"]
                                }

                                # Immediately write SFT item
                                f_sft.write(json.dumps(sft_item, ensure_ascii=False) + "\n")
                                f_sft.flush()
                                # Immediately write DPO item
                                f_dpo.write(json.dumps(dpo_item, ensure_ascii=False) + "\n")
                                f_dpo.flush()

                                # Also store them in memory
                                sft_data.append(sft_item)
                                dpo_data.append(dpo_item)

                                processed_count += 1
                            else:
                                logger.warning(f"Sample {idx} skipped due to processing errors.")

                    logger.info(f"Finished processing {processed_count} samples for {split} split.")
                    logger.info(f"Saved {len(sft_data)} SFT entries to {sft_jsonl_path}")
                    logger.info(f"Saved {len(dpo_data)} DPO entries to {dpo_jsonl_path}")

                    # Optionally show a sample entry for sanity check
                    if dpo_data:
                        logger.info(f"--- Sample {split} DPO entry ---")
                        logger.info(json.dumps(dpo_data[0], indent=2, ensure_ascii=False))
                        logger.info(f"--- Sample {split} SFT entry ---")
                        logger.info(json.dumps(sft_data[0], indent=2, ensure_ascii=False))

                except IOError as e:
                    logger.error(f"Failed to write data for {split} split: {e}")

                # Store the results for this split in the master dictionary
                all_results[split] = {
                    "sft": sft_data,
                    "dpo": dpo_data
                }

            return all_results

In [None]:
# Cell 5: Instantiation and Execution
generator = PreferencePairGenerator(
    api_key=OPENAI_API_KEY,
    seed=SEED,
    explanation_model=EXPLANATION_MODEL,
    prompt_format=PROMPT_FORMAT,
    response_format=RESPONSE_FORMAT
)

# Run the dataset processing
results = generator.process_dataset(
    dataset_name=DATASET_NAME,
    splits=SPLITS_TO_PROCESS,
    output_dir=OUTPUT_DIR,
    max_samples=MAX_SAMPLES_PER_SPLIT
)

# Summary
total_sft_pairs = sum(len(data["sft"]) for data in results.values())
total_dpo_pairs = sum(len(data["dpo"]) for data in results.values())
logger.info(f"\n--- Generation Complete ---")
logger.info(f"Processed splits: {list(results.keys())}")
logger.info(f"Total SFT pairs generated: {total_sft_pairs}")
logger.info(f"Total DPO pairs generated: {total_dpo_pairs}")
logger.info(f"Data saved in directory: {OUTPUT_DIR}")