In [1]:
#!/usr/bin/env python3

Medical Chatbot Dataset Preparation Script
=========================================

This script prepares and mixes MedMCQA and Orca datasets for fine-tuning a medical chatbot.
It handles loading, cleaning, standardizing, and mixing datasets according to specified ratios.

Requirements:
- MedMCQA dataset in CSV format with columns: question, exp, cop, opa, opb, opc, opd, correct_answer
- Orca dataset in JSON/Alpaca format
- Target: 75% MedMCQA + 25% Orca mix
- 10% evaluation split from each dataset (stratified)


In [2]:
import json
import os
import random
import argparse
from typing import Dict, List, Any, Tuple
from collections import defaultdict
import logging

In [15]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class DatasetProcessor:
    """Main class for processing and mixing datasets."""

    def __init__(self, radqa_ratio: float = 0.75, eval_split: float = 0.1, seed: int = 42):
        """
        Initialize the dataset processor.

        Args:
            radqa_ratio: Proportion of RadQA samples in final dataset (0.75 = 75%)
            eval_split: Proportion of data to use for evaluation (0.1 = 10%)
            seed: Random seed for reproducibility
        """
        self.radqa_ratio = radqa_ratio
        self.orca_ratio = 1.0 - radqa_ratio
        self.eval_split = eval_split
        self.seed = seed
        random.seed(seed)

        # Unified schema for all datasets
        self.unified_schema = {
            "instruction": "",
            "input": "",
            "output": ""
        }

    def load_medmcqa_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        """
        Load and parse MedMCQA dataset from CSV file.

        Expected format: question, exp, cop, opa, opb, opc, opd, correct_answer

        Args:
            file_path: Path to MedMCQA CSV file

        Returns:
            List of parsed MedMCQA samples
        """
        logger.info(f"Loading MedMCQA dataset from {file_path}")

        try:
            import pandas as pd

            # Load CSV with proper handling
            df = pd.read_csv(file_path, encoding='utf-8')

            # Convert to list of dictionaries
            data = df.to_dict('records')

            # Clean any NaN values
            for sample in data:
                for key, value in sample.items():
                    if pd.isna(value):
                        sample[key] = ""

            logger.info(f"Loaded {len(data)} MedMCQA samples")
            return data

        except Exception as e:
            logger.error(f"Error loading MedMCQA dataset: {e}")
            raise

    def load_orca_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        """
        Load and parse Orca dataset from JSON file.

        Expected formats:
        - Alpaca format: {"instruction": ..., "input": ..., "output": ...}
        - OpenOrca format: {"question": ..., "response": ...}

        Args:
            file_path: Path to Orca JSON file

        Returns:
            List of parsed Orca samples
        """
        logger.info(f"Loading Orca dataset from {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            # Handle different possible JSON structures
            if isinstance(data, dict):
                # If it's a dict, look for common keys that might contain the data
                for key in ['data', 'samples', 'examples', 'train']:
                    if key in data:
                        data = data[key]
                        break
                else:
                    # If no common key found, assume it's a single sample
                    data = [data]

            logger.info(f"Loaded {len(data)} Orca samples")
            return data

        except Exception as e:
            logger.error(f"Error loading Orca dataset: {e}")
            raise

    def clean_text(self, text: str) -> str:
        """
        Clean and normalize text data.

        Args:
            text: Raw text to clean

        Returns:
            Cleaned text
        """
        if not isinstance(text, str):
            text = str(text)

        # Remove excessive whitespace
        text = ' '.join(text.split())

        # Remove special characters that might cause issues
        text = text.replace('\r', ' ').replace('\n', ' ')

        # Strip leading/trailing whitespace
        text = text.strip()

        return text

    def standardize_medmcqa_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        """
        Convert MedMCQA sample to unified schema.

        Args:
            sample: Raw MedMCQA sample with columns: question, exp, cop, opa, opb, opc, opd, correct_answer

        Returns:
            Standardized sample in unified schema
        """
        # Extract fields with fallbacks
        question = sample.get('question', '')
        explanation = sample.get('exp', '')

        # Extract options
        option_a = sample.get('opa', '')
        option_b = sample.get('opb', '')
        option_c = sample.get('opc', '')
        option_d = sample.get('opd', '')

        # Get correct answer
        correct_answer = sample.get('correct_answer', '')

        # Clean text
        question = self.clean_text(question)
        explanation = self.clean_text(explanation)
        option_a = self.clean_text(option_a)
        option_b = self.clean_text(option_b)
        option_c = self.clean_text(option_c)
        option_d = self.clean_text(option_d)
        correct_answer = self.clean_text(correct_answer)

        # Create formatted multiple choice question
        options_text = f"A) {option_a}\nB) {option_b}\nC) {option_c}\nD) {option_d}"

        # Create instruction combining question and options
        instruction = f"Question: {question}\n\nOptions:\n{options_text}\n\nPlease select the correct answer and provide a brief explanation."

        # Create output with correct answer and explanation
        # Map correct answer to letter format
        answer_mapping = {
            option_a: "A",
            option_b: "B",
            option_c: "C",
            option_d: "D"
        }

        # Find the letter corresponding to correct answer
        correct_letter = "A"  # default
        for option_text, letter in answer_mapping.items():
            if option_text.lower().strip() == correct_answer.lower().strip():
                correct_letter = letter
                break

        # Format the complete answer
        if explanation:
            output = f"The correct answer is {correct_letter}) {correct_answer}.\n\nExplanation: {explanation}"
        else:
            output = f"The correct answer is {correct_letter}) {correct_answer}."

        return {
            "instruction": instruction,
            "input": "",
            "output": output
        }

    def standardize_orca_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        """
        Convert Orca sample to unified schema.

        Args:
            sample: Raw Orca sample

        Returns:
            Standardized sample in unified schema
        """
        # Handle different Orca formats
        if 'instruction' in sample:
            # Alpaca format
            instruction = sample.get('instruction', '')
            input_text = sample.get('input', '')
            output_text = sample.get('output', sample.get('response', ''))
        elif 'question' in sample:
            # OpenOrca format
            instruction = sample.get('question', '')
            input_text = sample.get('input', '')
            output_text = sample.get('response', sample.get('answer', ''))
        else:
            # Try to infer from available keys
            instruction = sample.get('prompt', sample.get('text', ''))
            input_text = ''
            output_text = sample.get('completion', sample.get('target', ''))

        # Clean text
        instruction = self.clean_text(instruction)
        input_text = self.clean_text(input_text)
        output_text = self.clean_text(output_text)

        return {
            "instruction": instruction,
            "input": input_text,
            "output": output_text
        }

    def filter_valid_samples(self, samples: List[Dict[str, str]], dataset_name: str) -> List[Dict[str, str]]:
        """
        Filter out invalid samples (missing required fields).

        Args:
            samples: List of standardized samples
            dataset_name: Name of dataset for logging

        Returns:
            List of valid samples
        """
        valid_samples = []

        for sample in samples:
            # Check if required fields are present and non-empty
            if (sample.get('instruction', '').strip() and
                sample.get('output', '').strip()):
                valid_samples.append(sample)

        filtered_count = len(samples) - len(valid_samples)
        if filtered_count > 0:
            logger.warning(f"Filtered {filtered_count} invalid samples from {dataset_name}")

        logger.info(f"Valid {dataset_name} samples: {len(valid_samples)}")
        return valid_samples

    def stratified_split(self, samples: List[Dict[str, str]], eval_ratio: float) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
        """
        Split samples into train/eval with stratification (if possible).

        Args:
            samples: List of samples to split
            eval_ratio: Ratio of samples to use for evaluation

        Returns:
            Tuple of (train_samples, eval_samples)
        """
        # Shuffle samples
        shuffled_samples = samples.copy()
        random.shuffle(shuffled_samples)

        # Simple random split (stratification would require more complex logic)
        eval_size = int(len(shuffled_samples) * eval_ratio)
        eval_samples = shuffled_samples[:eval_size]
        train_samples = shuffled_samples[eval_size:]

        return train_samples, eval_samples

    def process_datasets(self, medmcqa_path: str, orca_path: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
        """
        Process both datasets and create train/eval splits.

        Args:
            medmcqa_path: Path to MedMCQA dataset
            orca_path: Path to Orca dataset

        Returns:
            Tuple of (train_samples, eval_samples)
        """
        logger.info("Starting dataset processing...")

        # Load datasets
        medmcqa_raw = self.load_medmcqa_dataset(medmcqa_path)
        orca_raw = self.load_orca_dataset(orca_path)

        # Check if Orca dataset is empty
        if not orca_raw:
            logger.warning("Orca dataset is empty - using 100% MedMCQA")
            self.radqa_ratio = 1.0
            self.orca_ratio = 0.0

        # Standardize formats
        logger.info("Standardizing MedMCQA samples...")
        medmcqa_standardized = [self.standardize_medmcqa_sample(sample) for sample in medmcqa_raw]
        medmcqa_valid = self.filter_valid_samples(medmcqa_standardized, "MedMCQA")

        if orca_raw:
            logger.info("Standardizing Orca samples...")
            orca_standardized = [self.standardize_orca_sample(sample) for sample in orca_raw]
            orca_valid = self.filter_valid_samples(orca_standardized, "Orca")
        else:
            logger.info("Skipping Orca processing (empty dataset)")
            orca_valid = []

        # Split each dataset into train/eval
        medmcqa_train, medmcqa_eval = self.stratified_split(medmcqa_valid, self.eval_split)

        if orca_valid:
            orca_train, orca_eval = self.stratified_split(orca_valid, self.eval_split)
        else:
            orca_train, orca_eval = [], []

        logger.info(f"MedMCQA splits - Train: {len(medmcqa_train)}, Eval: {len(medmcqa_eval)}")
        logger.info(f"Orca splits - Train: {len(orca_train)}, Eval: {len(orca_eval)}")

        # Calculate target sizes for mixing
        if orca_train:
            total_train_target = len(medmcqa_train) + len(orca_train)
            medmcqa_train_target = int(total_train_target * self.radqa_ratio)
            orca_train_target = int(total_train_target * self.orca_ratio)

            # Sample according to target ratios
            if len(medmcqa_train) > medmcqa_train_target:
                medmcqa_train = random.sample(medmcqa_train, medmcqa_train_target)
            if len(orca_train) > orca_train_target:
                orca_train = random.sample(orca_train, orca_train_target)

        # Mix training data
        mixed_train = medmcqa_train + orca_train
        random.shuffle(mixed_train)

        # Mix evaluation data
        mixed_eval = medmcqa_eval + orca_eval
        random.shuffle(mixed_eval)

        logger.info(f"Final mixed dataset - Train: {len(mixed_train)}, Eval: {len(mixed_eval)}")
        if orca_train:
            logger.info(f"Training set composition - MedMCQA: {len(medmcqa_train)} ({len(medmcqa_train)/len(mixed_train)*100:.1f}%), "
                       f"Orca: {len(orca_train)} ({len(orca_train)/len(mixed_train)*100:.1f}%)")
        else:
            logger.info(f"Training set composition - MedMCQA: 100%")

        return mixed_train, mixed_eval

    def save_dataset(self, samples: List[Dict[str, str]], output_path: str):
        """
        Save dataset to JSON file.

        Args:
            samples: List of samples to save
            output_path: Path to output JSON file
        """
        logger.info(f"Saving {len(samples)} samples to {output_path}")

        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(samples, f, indent=2, ensure_ascii=False)

        logger.info(f"Dataset saved successfully to {output_path}")

    def generate_dataset_stats(self, train_samples: List[Dict[str, str]], eval_samples: List[Dict[str, str]]) -> Dict[str, Any]:
        """
        Generate statistics about the processed dataset.

        Args:
            train_samples: Training samples
            eval_samples: Evaluation samples

        Returns:
            Dictionary containing dataset statistics
        """
        stats = {
            "total_samples": len(train_samples) + len(eval_samples),
            "train_samples": len(train_samples),
            "eval_samples": len(eval_samples),
            "train_eval_ratio": f"{len(train_samples)}/{len(eval_samples)}",
            "avg_instruction_length": sum(len(s['instruction'].split()) for s in train_samples + eval_samples) / (len(train_samples) + len(eval_samples)),
            "avg_output_length": sum(len(s['output'].split()) for s in train_samples + eval_samples) / (len(train_samples) + len(eval_samples)),
            "samples_with_input": sum(1 for s in train_samples + eval_samples if s['input'].strip()),
            "processing_config": {
                "radqa_ratio": self.radqa_ratio,
                "orca_ratio": self.orca_ratio,
                "eval_split": self.eval_split,
                "seed": self.seed
            }
        }

        return stats

In [17]:
# Google Colab Setup Functions
def setup_colab_environment():
    try:
        import google.colab
        IN_COLAB = True
        print("Running in Google Colab environment")
    except ImportError:
        IN_COLAB = False
        print("Running in local environment")

    if IN_COLAB:
        print("Installing required packages...")
        os.system("pip install -q pandas scikit-learn datasets huggingface_hub")

        # Mount Google Drive.
        try:
            from google.colab import drive
            print("Mounting Google Drive...")
            drive.mount('/content/drive')
            print("Google Drive mounted successfully!")
        except Exception as e:
            print(f"Could not mount Google Drive: {e}")

    return IN_COLAB

# Upload files in Google Colab environment.
def upload_files_colab():
    try:
        from google.colab import files

        print("Please upload your dataset files:")
        print("1. Upload your MedMCQA CSV file (extracted_medmcqa.csv)")
        print("2. Upload your Orca JSON file")

        uploaded = files.upload()

        uploaded_files = list(uploaded.keys())
        print(f"Uploaded files: {uploaded_files}")

        # Try to identify file types
        medmcqa_path = None
        orca_path = None

        for filename in uploaded_files:
            if filename.endswith('.csv'):
                medmcqa_path = filename
                print(f"Detected MedMCQA CSV: {filename}")
            elif filename.endswith('.json'):
                orca_path = filename
                print(f"Detected Orca JSON: {filename}")

        return medmcqa_path, orca_path, uploaded_files

    except ImportError:
        print("Not running in Google Colab - file upload not available")
        return None, None, []

# Interactive function for Google Colab execution.
def run_colab_interactive():
    print("Medical Chatbot Dataset Preparation - Google Colab Version")
    print("=" * 60)

    # Setup environment
    IN_COLAB = setup_colab_environment()

    medmcqa_path = None
    orca_path = None

    if IN_COLAB:
        print("MedMCQA: Loading from Google Drive")
        print("OpenOrca: Loading from Hugging Face")

        # Load OpenOrca from Hugging Face
        try:
            from datasets import load_dataset
            print("Loading OpenOrca from Hugging Face.")
            print("Note: This may take a few minutes for the first download")

            # Use the exact method you specified
            print("Using your specified loading method.")
            ds = load_dataset("Open-Orca/OpenOrca")

            # Get the train split
            orca_dataset = ds['train']
            print(f"Successfully loaded OpenOrca dataset with {len(orca_dataset)} samples")

            # Take a reasonable subset for processing (50K samples)
            print("Taking subset of 50,000 samples for efficient processing.")
            subset_size = min(50000, len(orca_dataset))
            orca_dataset = orca_dataset.select(range(subset_size))

            # Convert to list
            orca_raw = orca_dataset.to_list()

            # Save to temporary file for processing
            orca_path = '/tmp/orca_hf.json'
            print("Saving OpenOrca subset to temporary file...")
            with open(orca_path, 'w', encoding='utf-8') as f:
                json.dump(orca_raw, f, indent=2, ensure_ascii=False)
            print(f"Saved {len(orca_raw)} OpenOrca samples to: {orca_path}")

        except Exception as e:
            print(f"Error loading OpenOrca from Hugging Face: {e}")
            print("Falling back to manual options.")
            print("You can:")
            print("1. Upload an Orca JSON file manually")
            print("2. Use a different Orca dataset (e.g., smaller one)")
            print("3. Skip Orca and use 100% MedMCQA")

            fallback_option = input("Choose option (1/2/3): ").strip()

            if fallback_option == "1":
                orca_path = input("Enter path to Orca JSON file: ").strip()
                if not os.path.exists(orca_path):
                    print(f"File not found: {orca_path}")
                    return
            elif fallback_option == "2":
                print("Available smaller Orca-like datasets:")
                print("- teknium/OpenHermes-2.5 (smaller, high quality)")
                print("- microsoft/orca-math-word-problems-200k")
                print("- Or provide your own dataset name")

                alt_dataset = input("Enter alternative dataset name: ").strip()
                if alt_dataset:
                    try:
                        print(f"Loading {alt_dataset}...")
                        alt_ds = load_dataset(alt_dataset)
                        if 'train' in alt_ds:
                            alt_orca_dataset = alt_ds['train']
                        else:
                            split_name = list(alt_ds.keys())[0]
                            alt_orca_dataset = alt_ds[split_name]
                            print(f"Using split: {split_name}")

                        # Subset.
                        subset_size = min(20000, len(alt_orca_dataset))
                        alt_orca_dataset = alt_orca_dataset.select(range(subset_size))
                        orca_raw = alt_orca_dataset.to_list()

                        orca_path = '/tmp/alt_orca.json'
                        with open(orca_path, 'w', encoding='utf-8') as f:
                            json.dump(orca_raw, f, indent=2, ensure_ascii=False)
                        print(f"Loaded {len(orca_raw)} samples from {alt_dataset}")

                    except Exception as e2:
                        print(f"Error loading alternative dataset: {e2}")
                        print("Falling back to 100% MedMCQA")
                        orca_path = '/tmp/empty_orca.json'
                        with open(orca_path, 'w', encoding='utf-8') as f:
                            json.dump([], f)
                else:
                    print("No dataset specified, using 100% MedMCQA")
                    orca_path = '/tmp/empty_orca.json'
                    with open(orca_path, 'w', encoding='utf-8') as f:
                        json.dump([], f)

            elif fallback_option == "3":
                orca_path = '/tmp/empty_orca.json'
                with open(orca_path, 'w', encoding='utf-8') as f:
                    json.dump([], f)
                print("Using 100% MedMCQA dataset")
            else:
                print("Invalid option selected. Using 100% MedMCQA as fallback.")
                orca_path = '/tmp/empty_orca.json'
                with open(orca_path, 'w', encoding='utf-8') as f:
                    json.dump([], f)

        # Load MedMCQA from Google Drive
        print("Loading MedMCQA from Google Drive.")

        #Google Drive paths
        possible_paths = [
            '/content/drive/MyDrive/extracted_medmcqa.csv',
            '/content/drive/My Drive/extracted_medmcqa.csv',
            '/content/drive/MyDrive/medmcqa/extracted_medmcqa.csv',
            '/content/drive/My Drive/medmcqa/extracted_medmcqa.csv',
            '/content/drive/MyDrive/data/extracted_medmcqa.csv',
            '/content/drive/My Drive/data/extracted_medmcqa.csv'
        ]

        found_file = False
        for path in possible_paths:
            if os.path.exists(path):
                medmcqa_path = path
                found_file = True
                print(f"Found MedMCQA at: {path}")
                break

        if not found_file:
            print("MedMCQA file not found in common locations.")
            print("Please provide the full path to your MedMCQA CSV file in Google Drive:")
            print("   Example: /content/drive/MyDrive/your_folder/extracted_medmcqa.csv")
            medmcqa_path = input("MedMCQA CSV file path: ").strip()

            if not os.path.exists(medmcqa_path):
                print(f"File not found: {medmcqa_path}")
                print("Tips:")
                print("   1. Make sure Google Drive is mounted")
                print("   2. Check the exact file path and spelling")
                print("   3. Ensure the file is in your Google Drive")
                return

        # MedMCQA file
        try:
            import pandas as pd
            test_df = pd.read_csv(medmcqa_path, nrows=1)
            print(f"MedMCQA file verified: {medmcqa_path}")
            print(f"Columns found: {list(test_df.columns)}")
        except Exception as e:
            print(f"Error reading MedMCQA file: {e}")
            return

    else:
        # Local environment.
        print("Enter dataset file paths:")
        medmcqa_path = input("MedMCQA CSV file path: ").strip()
        orca_path = input("Orca JSON file path: ").strip()

    # Configuration parameters.
    print("Configuration Parameters:")
    medmcqa_ratio = float(input("MedMCQA ratio (0.75): ").strip() or "0.75")
    eval_split = float(input("Evaluation split (0.1): ").strip() or "0.1")
    seed = int(input("Random seed (42): ").strip() or "42")
    output_dir = input("Output directory (./data): ").strip() or "./data"

    # Validate the inputs.
    if not os.path.exists(medmcqa_path):
        print(f"MedMCQA file not found: {medmcqa_path}")
        return

    if not os.path.exists(orca_path):
        print(f"Orca file not found: {orca_path}")
        return

    if not 0 < medmcqa_ratio < 1:
        print(f"MedMCQA ratio must be between 0 and 1, got {medmcqa_ratio}")
        return

    if not 0 < eval_split < 1:
        print(f"Eval split must be between 0 and 1, got {eval_split}")
        return

    # Initialize processor
    print(f"Initializing dataset processor.")
    processor = DatasetProcessor(
        radqa_ratio=medmcqa_ratio,
        eval_split=eval_split,
        seed=seed
    )

    try:
        print("Processing datasets.")
        train_samples, eval_samples = processor.process_datasets(medmcqa_path, orca_path)

        train_path = os.path.join(output_dir, 'train.json')
        eval_path = os.path.join(output_dir, 'eval.json')

        print("Saving datasets.")
        processor.save_dataset(train_samples, train_path)
        processor.save_dataset(eval_samples, eval_path)

        # Statistics.
        stats = processor.generate_dataset_stats(train_samples, eval_samples)
        stats_path = os.path.join(output_dir, 'dataset_stats.json')

        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        # Summary.
        print("\n" + "="*60)
        print("DATASET PREPARATION COMPLETED SUCCESSFULLY!")
        print("="*60)
        print(f"Total samples: {stats['total_samples']}")
        print(f"Training samples: {stats['train_samples']}")
        print(f"Evaluation samples: {stats['eval_samples']}")
        print(f"Average instruction length: {stats['avg_instruction_length']:.1f} words")
        print(f"Average output length: {stats['avg_output_length']:.1f} words")
        print(f"Samples with input field: {stats['samples_with_input']}")
        print(f"Files saved to: {output_dir}")
        print(f"   - train.json: {stats['train_samples']} samples")
        print(f"   - eval.json: {stats['eval_samples']} samples")
        print(f"   - dataset_stats.json: Processing statistics")

        # Download the files in Colab.
        if IN_COLAB:
            print("Downloading processed files.")
            try:
                from google.colab import files
                files.download(train_path)
                files.download(eval_path)
                files.download(stats_path)
                print("Files downloaded successfully!")
            except Exception as e:
                print(f"Could not download files: {e}")
                print(f"Files are available at: {output_dir}")

        print("Ready for next step: Chain-of-Thought generation.")

        return train_path, eval_path, stats_path

    except Exception as e:
        print(f"Error during dataset preparation: {e}")
        import traceback
        traceback.print_exc()
        raise

# Handles both CLI and Colab.
def main():
    try:
        # Try to detect if we're in Colab or have arguments.
        import sys

        try:
            import google.colab
            IN_COLAB = True
        except ImportError:
            IN_COLAB = False

        if IN_COLAB or len(sys.argv) == 1:
            return run_colab_interactive()

        parser = argparse.ArgumentParser(description='Prepare and mix MedMCQA and Orca datasets for medical chatbot fine-tuning')

        parser.add_argument('--medmcqa_path', type=str, required=True,
                           help='Path to MedMCQA dataset CSV file')
        parser.add_argument('--orca_path', type=str, required=True,
                           help='Path to Orca dataset JSON file')
        parser.add_argument('--output_dir', type=str, default='./data',
                           help='Directory to save processed datasets (default: ./data)')
        parser.add_argument('--medmcqa_ratio', type=float, default=0.75,
                           help='Proportion of MedMCQA samples in final dataset (default: 0.75)')
        parser.add_argument('--eval_split', type=float, default=0.1,
                           help='Proportion of data to use for evaluation (default: 0.1)')
        parser.add_argument('--seed', type=int, default=42,
                           help='Random seed for reproducibility (default: 42)')

        args = parser.parse_args()

        if not os.path.exists(args.medmcqa_path):
            logger.error(f"MedMCQA dataset not found: {args.medmcqa_path}")
            return

        if not os.path.exists(args.orca_path):
            logger.error(f"Orca dataset not found: {args.orca_path}")
            return

        if not 0 < args.medmcqa_ratio < 1:
            logger.error(f"MedMCQA ratio must be between 0 and 1, got {args.medmcqa_ratio}")
            return

        if not 0 < args.eval_split < 1:
            logger.error(f"Eval split must be between 0 and 1, got {args.eval_split}")
            return

        # Processor.
        processor = DatasetProcessor(
            radqa_ratio=args.medmcqa_ratio,
            eval_split=args.eval_split,
            seed=args.seed
        )

        train_samples, eval_samples = processor.process_datasets(args.medmcqa_path, args.orca_path)

        # Processed datasets.
        train_path = os.path.join(args.output_dir, 'train.json')
        eval_path = os.path.join(args.output_dir, 'eval.json')

        processor.save_dataset(train_samples, train_path)
        processor.save_dataset(eval_samples, eval_path)

        # Statistics.
        stats = processor.generate_dataset_stats(train_samples, eval_samples)
        stats_path = os.path.join(args.output_dir, 'dataset_stats.json')

        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        logger.info("Dataset preparation completed successfully!")
        logger.info(f"Statistics saved to {stats_path}")

        # Summary.
        print("\n" + "="*50)
        print("DATASET PREPARATION SUMMARY")
        print("="*50)
        print(f"Total samples: {stats['total_samples']}")
        print(f"Training samples: {stats['train_samples']}")
        print(f"Evaluation samples: {stats['eval_samples']}")
        print(f"Average instruction length: {stats['avg_instruction_length']:.1f} words")
        print(f"Average output length: {stats['avg_output_length']:.1f} words")
        print(f"Samples with input field: {stats['samples_with_input']}")
        print(f"\nFiles saved to: {args.output_dir}")
        print(f"- train.json: {stats['train_samples']} samples")
        print(f"- eval.json: {stats['eval_samples']} samples")
        print(f"- dataset_stats.json: Processing statistics")

    except Exception as e:
        logger.error(f"Error during dataset preparation: {e}")
        raise


if __name__ == "__main__":
    main()

Medical Chatbot Dataset Preparation - Google Colab Version
Running in Google Colab environment
Installing required packages...
Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!
MedMCQA: Loading from Google Drive
OpenOrca: Loading from Hugging Face
Loading OpenOrca from Hugging Face.
Note: This may take a few minutes for the first download
Using your specified loading method.
Error loading OpenOrca from Hugging Face: Invalid pattern: '**' can only be an entire path component
Falling back to manual options.
You can:
1. Upload an Orca JSON file manually
2. Use a different Orca dataset (e.g., smaller one)
3. Skip Orca and use 100% MedMCQA
Choose option (1/2/3): 3
Using 100% MedMCQA dataset
Loading MedMCQA from Google Drive.
Found MedMCQA at: /content/drive/MyDrive/extracted_medmcqa.csv
MedMCQA file verified: /content/drive/MyDrive/extracted_medmcqa.csv




Saving datasets.

DATASET PREPARATION COMPLETED SUCCESSFULLY!
Total samples: 182822
Training samples: 164540
Evaluation samples: 18282
Average instruction length: 39.1 words
Average output length: 76.0 words
Samples with input field: 0
Files saved to: ./data
   - train.json: 164540 samples
   - eval.json: 18282 samples
   - dataset_stats.json: Processing statistics
Downloading processed files.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Files downloaded successfully!
Ready for next step: Chain-of-Thought generation.
