In [1]:
%pip install datasets openai backoff

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import json
import random
import pandas as pd
from datasets import load_dataset
import openai
import time
from tqdm import tqdm
import logging
from typing import List, Dict, Any, Optional, Tuple
import backoff

In [3]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("generation.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [4]:
# Cell 3: Configuration

from dotenv import load_dotenv
load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

SEED = 42
EXPLANATION_MODEL = "gpt-4o-mini"  
DATASET_NAME = "GBaker/MedQA-USMLE-4-options"
SPLITS_TO_PROCESS = ["train"] 
OUTPUT_DIR = "synthetic_medqa_data2"
MAX_SAMPLES_PER_SPLIT = None # Set to a number (e.g., 100) for testing, None to process all

CHECKPOINT_INTERVAL = 10  # Save checkpoint every 10 samples 
BACKOFF_MAX_TRIES = 8    # Maximum number of retries for rate limits
BACKOFF_FACTOR = 2       # Exponential backoff factor

# --- Output Format Configuration ---
PROMPT_FORMAT = """Question: {question}

Options:
{options_formatted}

Choose the best answer and provide a step-by-step explanation for your choice."""

# How to format the 'chosen' and 'rejected' responses
RESPONSE_FORMAT = """{answer_label}. {answer_text}
Explanation: {explanation}"""

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

logger.info(f"Configuration:")
logger.info(f"  Explanation Model: {EXPLANATION_MODEL}")
logger.info(f"  Dataset: {DATASET_NAME}")
logger.info(f"  Splits: {SPLITS_TO_PROCESS}")
logger.info(f"  Output Directory: {OUTPUT_DIR}")
logger.info(f"  Max Samples Per Split: {MAX_SAMPLES_PER_SPLIT}")

2025-04-28 09:59:30,468 - INFO - Configuration:
2025-04-28 09:59:30,469 - INFO -   Explanation Model: gpt-4o-mini
2025-04-28 09:59:30,470 - INFO -   Dataset: GBaker/MedQA-USMLE-4-options
2025-04-28 09:59:30,471 - INFO -   Splits: ['train']
2025-04-28 09:59:30,472 - INFO -   Output Directory: synthetic_medqa_data2
2025-04-28 09:59:30,473 - INFO -   Max Samples Per Split: None


In [6]:
class PreferencePairGenerator:
    def __init__(self, api_key: str, seed: int,
                 explanation_model: str,
                 prompt_format: str, response_format: str):
        self.client = openai.OpenAI(api_key=api_key)
        random.seed(seed)
        self.explanation_model = explanation_model
        self.prompt_format = prompt_format
        self.response_format = response_format
        self.label_to_index = {"A": 0, "B": 1, "C": 2, "D": 3}
        self.index_to_label = {v: k for k, v in self.label_to_index.items()}

    @backoff.on_exception(
        backoff.expo, 
        (openai.RateLimitError, openai.APIConnectionError),
        max_tries=BACKOFF_MAX_TRIES,  # Maximum number of retries
        factor=BACKOFF_FACTOR,     # Exponential backoff factor
        jitter=None   # Add randomness to backoff delays
    )
    def call_openai_api(self, messages: List[Dict[str, str]], model: str,
                        temperature: float = 0.5, max_tokens: int = 500, timeout=60) -> Any:
        """Call OpenAI API with automatic backoff for rate limits"""
        try:
            return self.client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                timeout=timeout
            )
        except Exception as e:
            # Handle non-rate-limit errors that backoff doesn't catch
            if not isinstance(e, (openai.RateLimitError, openai.APIConnectionError)):
                logger.error(f"API Error (non-rate-limit): {e}")
            raise  # Re-raise to let backoff handle it if appropriate

    def get_explanation(self, question: str, options_formatted: str, answer: str, is_correct: bool) -> str:
        if is_correct:
            prompt_detail = (
                "Generate a concise explanation (~20-30 words) highlighting the key clinical reasoning "
                "and evidence from the vignette that justifies why this answer is the most correct choice."
            )
        else:
            prompt_detail = (
                "Pretend this option is correct. Generate a concise explanation (~20-30 words) that appears "
                "plausible but subtly contains flawed reasoning. Focus on relevant clinical details while "
                "avoiding any direct mention that this choice might be incorrect."
            )

        prompt = (
            f"Medical Question Context:\n{question}\n\n"
            f"Options:\n{options_formatted}\n\n"
            f"Answer Choice to Explain: {answer}\n\n"
            f"{prompt_detail}"
        )

        messages = [
            {
                "role": "system",
                "content": (
                    "You are a highly knowledgeable medical expert specializing in clinical reasoning "
                    "explanations for USMLE-style questions. Be clear, concise, and follow instructions carefully."
                )
            },
            {"role": "user", "content": prompt}
        ]

        try:
            response = self.call_openai_api(
                messages=messages,
                model=self.explanation_model,
                temperature=0.6,
                max_tokens=150
            )

            if response:
                explanation = response.choices[0].message.content.strip()
                if not explanation or any(phrase in explanation.lower() for phrase in ["cannot provide", "explanation could not be generated"]):
                    logger.warning("Received potentially invalid explanation. Falling back.")
                    return "[Fallback] Explanation generation failed or returned invalid content."
                return explanation
            else:
                return "[API Error] Explanation could not be generated due to API failure."

        except Exception as e:
            logger.error(f"Error generating explanation: {e}")
            return "[Error] An explanation could not be generated."

    def select_alternative_answer(self, options: List[str], correct_index: int) -> int:
        incorrect_indices = [i for i in range(len(options)) if i != correct_index]
        return random.choice(incorrect_indices) if incorrect_indices else (correct_index + 1) % len(options)

    def process_sample(self, sample: Dict[str, Any], idx: int) -> Optional[Dict[str, Any]]:
        try:
            question = sample.get("question")
            options = sample.get("options")

            if not question or not options:
                logger.warning(f"Sample {idx}: Missing question or options. Skipping.")
                return None

            correct_label = sample.get("answer_idx").upper()
            correct_index = self.label_to_index[correct_label]
            correct_option_text = options[correct_label]

            options_formatted = "\n".join(
                [f"{self.index_to_label[i]}. {opt}" for i, opt in enumerate(options.values())]
            )

            logger.info(f"Sample {idx}: Generating explanation for CORRECT answer ({correct_label})")
            correct_explanation = self.get_explanation(question, options_formatted, correct_option_text, is_correct=True)
            if "[Error]" in correct_explanation or "[API Error]" in correct_explanation:
                logger.error(f"Sample {idx}: Failed to generate explanation for correct answer. Skipping.")
                return None

            alt_index = self.select_alternative_answer(list(options.values()), correct_index)
            alt_label = self.index_to_label[alt_index]
            alt_option_text = list(options.values())[alt_index]

            logger.info(f"Sample {idx}: Generating explanation for ALTERNATIVE answer ({alt_label})")
            alt_explanation = self.get_explanation(question, options_formatted, alt_option_text, is_correct=False)

            prompt_str = self.prompt_format.format(question=question, options_formatted=options_formatted)
            chosen_str = self.response_format.format(answer_label=correct_label, answer_text=correct_option_text, explanation=correct_explanation)
            rejected_str = self.response_format.format(answer_label=alt_label, answer_text=alt_option_text, explanation=alt_explanation)

            return {
                "prompt": prompt_str,
                "chosen": chosen_str,
                "rejected": rejected_str,
                "metadata": {
                    "original_question": question,
                    "options": options,
                    "correct_index": correct_index,
                    "alternative_index": alt_index,
                    "correct_label": correct_label,
                    "alternative_label": alt_label,
                    "correct_explanation_raw": correct_explanation,
                    "alternative_explanation_raw": alt_explanation,
                }
            }
        except Exception as e:
            logger.exception(f"Critical error processing sample {idx}: {e}")
            return None


    def process_dataset(self, dataset_name: str, splits: List[str],
                    output_dir: str, max_samples: Optional[int] = None,
                    checkpoint_interval: int = 5) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
        """
        Load dataset splits, generate SFT and DPO data, and save results.
        Results are also returned in a dict for optional in-code use.
        
        Args:
            dataset_name: The HuggingFace dataset name
            splits: List of dataset splits to process
            output_dir: Directory to save output files
            max_samples: Maximum number of samples to process per split
            checkpoint_interval: How often to save checkpoint state
            
        Returns:
            A dict mapping each split to its SFT and DPO data.
        """
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Create checkpoint file for recovery
        checkpoint_file = os.path.join(output_dir, "checkpoint.json")
        
        # Try to load checkpoint
        checkpoint_data = {}
        if os.path.exists(checkpoint_file):
            try:
                with open(checkpoint_file, 'r') as f:
                    checkpoint_data = json.load(f)
                    logger.info(f"Loaded checkpoint: {checkpoint_data}")
            except Exception as e:
                logger.warning(f"Failed to load checkpoint: {e}")
        
        all_results = {}
        
        for split in splits:
            logger.info(f"--- Processing {split} split ---")
            
            # Skip if this split is already completed according to checkpoint
            if checkpoint_data.get(f"{split}_completed", False):
                logger.info(f"Split {split} already completed according to checkpoint. Skipping.")
                
                # Try to load the previously generated data for this split
                try:
                    sft_jsonl_path = os.path.join(output_dir, f"sft_data_{split}.jsonl")
                    dpo_jsonl_path = os.path.join(output_dir, f"dpo_data_{split}.jsonl")
                    
                    sft_data = []
                    with open(sft_jsonl_path, "r", encoding="utf-8") as f:
                        for line in f:
                            sft_data.append(json.loads(line))
                    
                    dpo_data = []
                    with open(dpo_jsonl_path, "r", encoding="utf-8") as f:
                        for line in f:
                            dpo_data.append(json.loads(line))
                    
                    all_results[split] = {
                        "sft": sft_data,
                        "dpo": dpo_data
                    }
                    
                    logger.info(f"Loaded {len(sft_data)} SFT and {len(dpo_data)} DPO entries for {split} from files.")
                except Exception as e:
                    logger.warning(f"Failed to load previous data for completed split {split}: {e}")
                
                continue
                
            # Get the last processed index for this split
            last_idx = checkpoint_data.get(f"{split}_last_idx", -1)
            
            try:
                # Load dataset
                dataset = load_dataset(dataset_name, split=split)
                logger.info(f"Loaded dataset for {split} split.")
                
                # Exclude datapoints over 1024 characters
                dataset = dataset.filter(lambda x: len(x["question"]) <= 1024)
                
                total_samples = len(dataset)
                logger.info(f"Total samples in {split}: {total_samples}")
                
            except Exception as e:
                logger.error(f"Failed to load dataset {dataset_name} for split {split}: {e}")
                continue
            
            # Prepare lists and file paths
            sft_data = []
            dpo_data = []
            sft_jsonl_path = os.path.join(output_dir, f"sft_data_{split}.jsonl")
            dpo_jsonl_path = os.path.join(output_dir, f"dpo_data_{split}.jsonl")
            
            # Handle resuming from checkpoint
            start_idx = last_idx + 1
            processed_count = 0
            num_to_process = total_samples if max_samples is None else min(max_samples, total_samples)
            
            # Adjust pbar to reflect remaining work
            pbar_desc = f"Generating {split} data" + (f" (max {max_samples})" if max_samples else "")
            remaining_count = num_to_process - start_idx if start_idx > 0 else num_to_process
            pbar_total = None if num_to_process is None else remaining_count
            
            # Check if files already exist and need to be appended to
            append_mode = "a" if start_idx > 0 else "w"
            
            # Load existing data if continuing
            if append_mode == "a":
                try:
                    if os.path.exists(sft_jsonl_path):
                        with open(sft_jsonl_path, "r", encoding="utf-8") as f:
                            for line in f:
                                sft_data.append(json.loads(line))
                    
                    if os.path.exists(dpo_jsonl_path):
                        with open(dpo_jsonl_path, "r", encoding="utf-8") as f:
                            for line in f:
                                dpo_data.append(json.loads(line))
                    
                    logger.info(f"Loaded {len(sft_data)} existing SFT and {len(dpo_data)} DPO entries.")
                    processed_count = len(sft_data)
                except Exception as e:
                    logger.warning(f"Error loading existing data files: {e}")
            
            # Open files once, then write results line by line inside the loop
            try:
                with open(sft_jsonl_path, append_mode, encoding="utf-8") as f_sft, \
                    open(dpo_jsonl_path, append_mode, encoding="utf-8") as f_dpo:
                    
                    # Create a mini-batch processing approach to avoid exhausting resources
                    batch_size = 10  # Process 10 samples at a time
                    
                    # Only iterate over samples we haven't processed yet
                    dataset_slice = dataset.select(range(start_idx, len(dataset)))
                    
                    # Create progress bar for remaining items
                    pbar = tqdm(enumerate(dataset_slice, start=start_idx), 
                            total=pbar_total, 
                            desc=pbar_desc)
                    
                    for idx, sample in pbar:
                        # Check if we've reached the limit
                        if max_samples is not None and processed_count >= max_samples:
                            logger.info(f"Reached max_samples limit ({max_samples}) for {split} split.")
                            break
                        
                        # Actual processing with error handling
                        try:
                            processed_result = self.process_sample(sample, idx)
                            if processed_result:
                                # Create the SFT item
                                sft_item = {
                                    "prompt": processed_result["prompt"],
                                    "response": processed_result["chosen"]
                                }
                                # Create the DPO item
                                dpo_item = {
                                    "prompt": processed_result["prompt"],
                                    "chosen": processed_result["chosen"],
                                    "rejected": processed_result["rejected"],
                                    "metadata": processed_result["metadata"]
                                }
                                
                                # Immediately write SFT item
                                f_sft.write(json.dumps(sft_item, ensure_ascii=False) + "\n")
                                f_sft.flush()
                                # Immediately write DPO item
                                f_dpo.write(json.dumps(dpo_item, ensure_ascii=False) + "\n")
                                f_dpo.flush()
                                
                                # Also store them in memory
                                sft_data.append(sft_item)
                                dpo_data.append(dpo_item)
                                
                                processed_count += 1
                                
                                # Update progress bar description
                                pbar.set_description(f"Processing {split} ({processed_count}/{num_to_process})")
                            else:
                                logger.warning(f"Sample {idx} skipped due to processing errors.")
                                
                            # Save checkpoint at regular intervals
                            if idx % checkpoint_interval == 0:
                                checkpoint_data[f"{split}_last_idx"] = idx
                                with open(checkpoint_file, 'w') as f:
                                    json.dump(checkpoint_data, f)
                                logger.info(f"Saved checkpoint at index {idx}")
                                
                        except Exception as e:
                            logger.exception(f"Error processing sample {idx}: {e}")
                            # Save checkpoint on error to enable recovery
                            checkpoint_data[f"{split}_last_idx"] = idx - 1  # Mark the last successful one
                            with open(checkpoint_file, 'w') as f:
                                json.dump(checkpoint_data, f)
                            # Don't break the loop - continue with next sample
                    
                    # Mark split as completed
                    checkpoint_data[f"{split}_completed"] = True
                    checkpoint_data[f"{split}_last_idx"] = idx
                    with open(checkpoint_file, 'w') as f:
                        json.dump(checkpoint_data, f)
                    logger.info(f"Marked split {split} as completed in checkpoint")
                    
                    logger.info(f"Finished processing {processed_count} samples for {split} split.")
                    logger.info(f"Saved {len(sft_data)} SFT entries to {sft_jsonl_path}")
                    logger.info(f"Saved {len(dpo_data)} DPO entries to {dpo_jsonl_path}")
                    
                    # Optionally show a sample entry for sanity check
                    if dpo_data:
                        logger.info(f"--- Sample {split} DPO entry ---")
                        logger.info(json.dumps(dpo_data[0], indent=2, ensure_ascii=False))
                        logger.info(f"--- Sample {split} SFT entry ---")
                        logger.info(json.dumps(sft_data[0], indent=2, ensure_ascii=False))
                        
            except IOError as e:
                logger.error(f"Failed to write data for {split} split: {e}")
            
            # Store the results for this split in the master dictionary
            all_results[split] = {
                "sft": sft_data,
                "dpo": dpo_data
            }
        
        return all_results

In [6]:
# Cell 5: Instantiation and Execution
generator = PreferencePairGenerator(
        api_key=OPENAI_API_KEY,
        seed=SEED,
        explanation_model=EXPLANATION_MODEL,
        prompt_format=PROMPT_FORMAT,
        response_format=RESPONSE_FORMAT
)
    
# Run the dataset processing
results = generator.process_dataset(
    dataset_name=DATASET_NAME,
    splits=SPLITS_TO_PROCESS,
    output_dir=OUTPUT_DIR,
    max_samples=MAX_SAMPLES_PER_SPLIT,
    checkpoint_interval=CHECKPOINT_INTERVAL
)

# Summary
total_sft_pairs = sum(len(data["sft"]) for data in results.values())
total_dpo_pairs = sum(len(data["dpo"]) for data in results.values())
logger.info(f"\n--- Generation Complete ---")
logger.info(f"Processed splits: {list(results.keys())}")
logger.info(f"Total SFT pairs generated: {total_sft_pairs}")
logger.info(f"Total DPO pairs generated: {total_dpo_pairs}")
logger.info(f"Data saved in directory: {OUTPUT_DIR}")

2025-04-16 06:19:28,861 - INFO - Loaded checkpoint: {'train_last_idx': 7780}
2025-04-16 06:19:28,861 - INFO - --- Processing train split ---
2025-04-16 06:19:31,222 - INFO - Loaded dataset for train split.
2025-04-16 06:19:31,224 - INFO - Total samples in train: 8800
2025-04-16 06:19:31,332 - INFO - Loaded 7788 existing SFT and 7788 DPO entries.
Generating train data:   0%|          | 0/1019 [00:00<?, ?it/s]2025-04-16 06:19:31,338 - INFO - Sample 7781: Generating explanation for CORRECT answer (D)
2025-04-16 06:19:33,211 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-04-16 06:19:33,217 - INFO - Sample 7781: Generating explanation for ALTERNATIVE answer (C)
2025-04-16 06:19:34,144 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing train (7789/8800):   0%|          | 1/1019 [00:02<47:44,  2.81s/it]2025-04-16 06:19:34,150 - INFO - Sample 7782: Generating explanation for CORRECT answer (D)
2025-

In [8]:
import os
import json
from collections import Counter

# Check for duplicate questions in the produced files

def check_duplicates(output_dir: str, splits: List[str]) -> None:
    for split in splits:
        sft_file = os.path.join(output_dir, f"sft_data_{split}.jsonl")
        if not os.path.exists(sft_file):
            logger.warning(f"File {sft_file} does not exist. Skipping.")
            continue

        questions = []
        with open(sft_file, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                questions.append(data["prompt"])  # Assuming "prompt" contains the question

        # Count occurrences of each question
        question_counts = Counter(questions)
        duplicates = {q: count for q, count in question_counts.items() if count > 1}

        if duplicates:
            logger.info(f"Found {len(duplicates)} duplicate questions in {split} split:")
            for question, count in duplicates.items():
                logger.info(f"  - {question[:100]}... (repeated {count} times)")
        else:
            logger.info(f"No duplicate questions found in {split} split.")

# Run the duplicate check
check_duplicates(OUTPUT_DIR, SPLITS_TO_PROCESS)

2025-04-16 09:40:02,963 - INFO - No duplicate questions found in train split.


In [8]:
generator = PreferencePairGenerator(
        api_key=OPENAI_API_KEY,
        seed=SEED,
        explanation_model=EXPLANATION_MODEL,
        prompt_format=PROMPT_FORMAT,
        response_format=RESPONSE_FORMAT
)

# Process the test split
test_results = generator.process_dataset(
    dataset_name=DATASET_NAME,
    splits=["test"],  # Specify the test split
    output_dir=OUTPUT_DIR,
    max_samples=MAX_SAMPLES_PER_SPLIT,
    checkpoint_interval=CHECKPOINT_INTERVAL
)

# Summary for the test split
total_test_sft_pairs = len(test_results["test"]["sft"]) if "test" in test_results else 0
total_test_dpo_pairs = len(test_results["test"]["dpo"]) if "test" in test_results else 0
logger.info(f"\n--- Test Split Processing Complete ---")
logger.info(f"Total SFT pairs generated for test split: {total_test_sft_pairs}")
logger.info(f"Total DPO pairs generated for test split: {total_test_dpo_pairs}")
logger.info(f"Test data saved in directory: {OUTPUT_DIR}")

2025-04-19 19:03:43,858 - INFO - Loaded checkpoint: {'train_last_idx': 0}
2025-04-19 19:03:43,858 - INFO - --- Processing test split ---
2025-04-19 19:03:46,222 - INFO - Loaded dataset for test split.


Filter:   0%|          | 0/1273 [00:00<?, ? examples/s]

2025-04-19 19:03:46,314 - INFO - Total samples in test: 1078
Generating test data:   0%|          | 0/1078 [00:00<?, ?it/s]2025-04-19 19:03:46,324 - INFO - Sample 0: Generating explanation for CORRECT answer (B)
2025-04-19 19:03:48,214 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-04-19 19:03:48,223 - INFO - Sample 0: Generating explanation for ALTERNATIVE answer (D)
2025-04-19 19:03:48,882 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing test (1/1078):   0%|          | 0/1078 [00:02<?, ?it/s]2025-04-19 19:03:48,889 - INFO - Saved checkpoint at index 0
Processing test (1/1078):   0%|          | 1/1078 [00:02<46:03,  2.57s/it]2025-04-19 19:03:48,892 - INFO - Sample 1: Generating explanation for CORRECT answer (D)
2025-04-19 19:03:49,745 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-04-19 19:03:49,748 - INFO - Sample 1: Generating explanation 

In [13]:
def calculate_length_metrics_from_file(file_path: str) -> Dict[str, Any]:
    if not os.path.exists(file_path):
        logger.warning(f"File {file_path} does not exist.")
        return {}

    prompts = []
    responses = []

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            prompts.append(len(data["prompt"]))
            responses.append(len(data["response"]))

    metrics = {
        "prompt": {
            "min": np.min(prompts),
            "max": np.max(prompts),
            "mean": np.mean(prompts),
            "median": np.median(prompts)
        },
        "response": {
            "min": np.min(responses),
            "max": np.max(responses),
            "mean": np.mean(responses),
            "median": np.median(responses)
        }
    }
    return metrics

# Calculate metrics for the test split from file
train_sft_file = os.path.join(OUTPUT_DIR, "sft_data_train.jsonl")
test_sft_file = os.path.join(OUTPUT_DIR, "sft_data_test.jsonl")
metrics_from_file_train = calculate_length_metrics_from_file(test_sft_file)
metrics_from_file_test = calculate_length_metrics_from_file(train_sft_file)


logger.info(f"Prompt and Response Length Metrics for Train Split (from file):")
logger.info(f"Prompt - Min: {metrics_from_file_train['prompt']['min']}, Max: {metrics_from_file_train['prompt']['max']}, "
            f"Mean: {metrics_from_file_train['prompt']['mean']:.2f}, Median: {metrics_from_file_train['prompt']['median']}")
logger.info(f"Response - Min: {metrics_from_file_train['response']['min']}, Max: {metrics_from_file_train['response']['max']}, "
            f"Mean: {metrics_from_file_train['response']['mean']:.2f}, Median: {metrics_from_file_train['response']['median']}")

logger.info(f"Prompt and Response Length Metrics for Test Split (from file):")
logger.info(f"Prompt - Min: {metrics_from_file_test['prompt']['min']}, Max: {metrics_from_file_test['prompt']['max']}, "
            f"Mean: {metrics_from_file_test['prompt']['mean']:.2f}, Median: {metrics_from_file_test['prompt']['median']}")
logger.info(f"Response - Min: {metrics_from_file_test['response']['min']}, Max: {metrics_from_file_test['response']['max']}, "
            f"Mean: {metrics_from_file_test['response']['mean']:.2f}, Median: {metrics_from_file_test['response']['median']}")


2025-04-19 19:41:56,894 - INFO - Prompt and Response Length Metrics for Train Split (from file):
2025-04-19 19:41:56,895 - INFO - Prompt - Min: 310, Max: 1474, Mean: 879.71, Median: 887.5
2025-04-19 19:41:56,895 - INFO - Response - Min: 152, Max: 441, Mean: 245.88, Median: 242.0
2025-04-19 19:41:56,895 - INFO - Prompt and Response Length Metrics for Test Split (from file):
2025-04-19 19:41:56,896 - INFO - Prompt - Min: 263, Max: 1843, Mean: 871.39, Median: 871.0
2025-04-19 19:41:56,897 - INFO - Response - Min: 148, Max: 558, Mean: 245.37, Median: 242.0


In [15]:
import json
from datasets import load_dataset

# Filepath for the test dataset
test_data_file = "/Users/aravadikesh/Documents/GitHub/MedQA_DPO/data/synthetic_medqa_data/sft_data_test.jsonl"

# Load the test dataset
def load_test_data(filepath):
    test_data = []
    with open(filepath, "r", encoding="utf-8") as file:
        for line in file:
            data = json.loads(line.strip())
            test_data.append(data)
    return test_data

# Extract the answer label (e.g., A, B, C, D) from the response
def extract_answer_label(response):
    return response.strip()[0].upper()  # Extract the first character (e.g., A, B, C, D)

# Compare answers using substring matching for questions and ignoring explanations
def compare_answers(test_data, reference_data):
    mismatches = []
    for test_entry in test_data:
        test_prompt = test_entry["prompt"]
        test_response = extract_answer_label(test_entry["response"])

        # Find the corresponding question in the reference dataset using substring matching
        reference_entry = next((ref for ref in reference_data if ref["question"] in test_prompt), None)
        if reference_entry:
            reference_response = reference_entry["answer_idx"]  # Directly use the answer_idx from the reference
            if test_response != reference_response:
                mismatches.append({
                    "question": reference_entry["question"],
                    "test_prompt": test_prompt,
                    "test_response": test_response,
                    "reference_response": reference_response
                })
        else:
            mismatches.append({
                "test_prompt": test_prompt,
                "test_response": test_response,
                "reference_response": "Not Found in Reference Data"
            })
    return mismatches

# Main function
def main():
    # Load the test dataset
    test_data = load_test_data(test_data_file)

    # Load the reference dataset (GBaker test set)
    dataset = load_dataset('GBaker/MedQA-USMLE-4-options', split='test')

    # Prepare reference data
    reference_data = []
    for entry in dataset:
        question = entry["question"]
        answer_idx = entry["answer_idx"]  # This is already the correct answer label (e.g., 'A', 'B', etc.)
        reference_data.append({
            "question": question,
            "answer_idx": answer_idx
        })

    # Compare answers
    mismatches = compare_answers(test_data, reference_data)

    # Print results
    print(f"Total questions in test data: {len(test_data)}")
    print(f"Total mismatches: {len(mismatches)}")
    if mismatches:
        print("Mismatches:")
        for mismatch in mismatches[:10]:  # Show up to 10 mismatches
            print(f"Reference Question: {mismatch['question']}")
            print(f"Test Prompt: {mismatch['test_prompt']}")
            print(f"Test Response: {mismatch['test_response']}")
            print(f"Reference Response: {mismatch['reference_response']}")
            print()

if __name__ == "__main__":
    main()

Total questions in test data: 1078
Total mismatches: 0


In [2]:
import json

# Paths
input_file = "gemma3_data/gemma3_dpo_train_data.jsonl"
scored_file = "gemma3_data/gemma3_dpo_scored_data.jsonl"
skipped_output_file = "gemma3_data/skipped_entries.jsonl"

# Load original input
with open(input_file, 'r', encoding='utf-8') as f:
    original_lines = f.readlines()

# Load scored output
with open(scored_file, 'r', encoding='utf-8') as f:
    scored_data = [json.loads(line) for line in f]

# Build a set of prompts that were successfully scored
scored_prompts = set(entry["prompt"] for entry in scored_data)

# Now find skipped entries
skipped_lines = []
for line in original_lines:
    try:
        entry = json.loads(line)
        if entry.get("prompt") not in scored_prompts:
            skipped_lines.append(line)
    except json.JSONDecodeError:
        skipped_lines.append(line)  # Also treat invalid JSON as skipped

# Save the skipped entries
with open(skipped_output_file, 'w', encoding='utf-8') as f:
    for line in skipped_lines:
        f.write(line)

print(f"Done! Found and saved {len(skipped_lines)} skipped entries to {skipped_output_file}")


Done! Found and saved 59 skipped entries to gemma3_data/skipped_entries.jsonl
