In [None]:
#!/usr/bin/env python3

In [8]:
import json
import os
import random
import argparse
from typing import Dict, List, Any, Tuple
from collections import defaultdict
import logging

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [14]:
# Handles loading, processing and splitting. Has the arguments for the dataset processor, radqa_ratio, eval_split, seed.
class DatasetProcessor:

    def __init__(self, radqa_ratio: float = 0.75, eval_split: float = 0.1, seed: int = 42):
        self.radqa_ratio = radqa_ratio
        self.orca_ratio = 1.0 - radqa_ratio
        self.eval_split = eval_split
        self.seed = seed
        random.seed(seed)

        # Schema for unified output format.
        self.unified_schema = {
            "instruction": "",
            "input": "",
            "output": ""
        }

    # To find the files in Drive.
    def find_files_in_drive(self, file_patterns: List[str]) -> Dict[str, str]:
        found_files = {}

        search_paths = [
            '/content/drive/MyDrive',
            '/content/drive/My Drive',
            '/content/drive/MyDrive/Colab Notebooks',
            '/content/drive/My Drive/Colab Notebooks',
            '/content/drive/MyDrive/data',
            '/content/drive/My Drive/data',
        ]

        logger.info("Searching for files in Google Drive.")

        for search_path in search_paths:
            if not os.path.exists(search_path):
                continue

            logger.info(f"Searching in: {search_path}")

            for root, dirs, files in os.walk(search_path):
                for file in files:
                    file_path = os.path.join(root, file)

                    for pattern_name, pattern in file_patterns:
                        if (pattern.lower() == file.lower() or
                            pattern.lower() in file.lower()):
                            found_files[pattern_name] = file_path
                            logger.info(f"Found {pattern_name}: {file_path}")
                            break
        return found_files

    # Loading the MedMCQA dataset.
    def load_medmcqa_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        logger.info(f"Loading MedMCQA dataset from {file_path}")

        try:
            import pandas as pd

            df = pd.read_csv(file_path, encoding='utf-8')

            logger.info(f"CSV shape: {df.shape}")
            logger.info(f"CSV columns: {list(df.columns)}")

            # Convertung the list of dictionaries.
            data = df.to_dict('records')

            # Cleaning NaN values.
            for sample in data:
                for key, value in sample.items():
                    if pd.isna(value):
                        sample[key] = ""

            logger.info(f"Loaded {len(data)} MedMCQA samples.")
            return data

        except Exception as e:
            logger.error(f"Error loading MedMCQA dataset: {e}")
            raise

    # Loading the Orca dataset.
    def load_orca_dataset(self, file_path: str) -> List[Dict[str, Any]]:
        logger.info(f"Loading Orca dataset from {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            if isinstance(data, dict):
                for key in ['data', 'samples', 'examples', 'train']:
                    if key in data:
                        data = data[key]
                        break
                else:
                    data = [data]

            logger.info(f"Loaded {len(data)} Orca samples.")
            return data

        except Exception as e:
            logger.error(f"Error loading Orca dataset: {e}")
            raise

    # Cleaning and normalizing the data.
    def clean_text(self, text: str) -> str:
        if not isinstance(text, str):
            text = str(text)

        # Removing the excessive whitespaces.
        text = ' '.join(text.split())

        # Removing special characters.
        text = text.replace('\r', ' ').replace('\n', ' ')

        text = text.strip()

        return text

    # Convert the MedMCQA samples to a unified schema.
    def standardize_medmcqa_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        # Extracting the fields with fallbacks.
        question = sample.get('question', '')
        explanation = sample.get('exp', '')

        # Extract options.
        option_a = sample.get('opa', '')
        option_b = sample.get('opb', '')
        option_c = sample.get('opc', '')
        option_d = sample.get('opd', '')

        # Getting the correct answer.
        correct_answer = sample.get('correct_answer', '')

        # Cleaning the text.
        question = self.clean_text(question)
        explanation = self.clean_text(explanation)
        option_a = self.clean_text(option_a)
        option_b = self.clean_text(option_b)
        option_c = self.clean_text(option_c)
        option_d = self.clean_text(option_d)
        correct_answer = self.clean_text(correct_answer)

        # Creating formatted multiple choice questions.
        options_text = f"A) {option_a}\nB) {option_b}\nC) {option_c}\nD) {option_d}"

        # Creating instruction by combining the question and options.
        instruction = f"Question: {question}\n\nOptions:\n{options_text}\n\nPlease select the correct answer and provide a brief explanation."

        # Create output with the correct answer and explanation.
        answer_mapping = {
            option_a: "A",
            option_b: "B",
            option_c: "C",
            option_d: "D"
        }

        # Find the letter corresponding to correct answer.
        correct_letter = "A" # Default to A if not found
        for option_text, letter in answer_mapping.items():
            if option_text.lower().strip() == correct_answer.lower().strip():
                correct_letter = letter
                break

        # Format the answer.
        if explanation:
            output = f"The correct answer is {correct_letter}) {correct_answer}.\n\nExplanation: {explanation}"
        else:
            output = f"The correct answer is {correct_letter}) {correct_answer}."

        return {
            "instruction": instruction,
            "input": "",
            "output": output
        }

    # Convert the Orca samples to a unified schema.
    def standardize_orca_sample(self, sample: Dict[str, Any]) -> Dict[str, str]:
        if 'instruction' in sample:
            # Alpaca format.
            instruction = sample.get('instruction', '')
            input_text = sample.get('input', '')
            output_text = sample.get('output', sample.get('response', ''))
        elif 'question' in sample:
            # OpenOrca format.
            instruction = sample.get('question', '')
            input_text = sample.get('input', '')
            output_text = sample.get('response', sample.get('answer', ''))
        else:
            # Infer from available keys.
            instruction = sample.get('prompt', sample.get('text', ''))
            input_text = ''
            output_text = sample.get('completion', sample.get('target', ''))

        # Clean text.
        instruction = self.clean_text(instruction)
        input_text = self.clean_text(input_text)
        output_text = self.clean_text(output_text)

        return {
            "instruction": instruction,
            "input": input_text,
            "output": output_text
        }

    # Filtering out invalid samples, with the missing required fields.
    def filter_valid_samples(self, samples: List[Dict[str, str]], dataset_name: str) -> List[Dict[str, str]]:
        valid_samples = []

        for sample in samples:
            if (sample.get('instruction', '').strip() and
                sample.get('output', '').strip()):
                valid_samples.append(sample)

        filtered_count = len(samples) - len(valid_samples)
        if filtered_count > 0:
            logger.warning(f"Filtered {filtered_count} invalid samples from {dataset_name}")

        logger.info(f"Valid {dataset_name} samples: {len(valid_samples)}")
        return valid_samples

    # Splitting the samples into train and eval sets.
    def stratified_split(self, samples: List[Dict[str, str]], eval_ratio: float) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:

        # Shuffle samples.
        shuffled_samples = samples.copy()
        random.shuffle(shuffled_samples)

        # Simple random split.
        eval_size = int(len(shuffled_samples) * eval_ratio)
        eval_samples = shuffled_samples[:eval_size]
        train_samples = shuffled_samples[eval_size:]

        return train_samples, eval_samples

    # Processing both datasets and creating train/eval splits.
    def process_datasets(self, medmcqa_path: str, orca_path: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
        logger.info("Starting dataset processing.")

        # Loading the datasets.
        medmcqa_raw = self.load_medmcqa_dataset(medmcqa_path)

        # Handling the Orca dataset.
        orca_raw = []
        if orca_path and os.path.exists(orca_path):
            try:
                orca_raw = self.load_orca_dataset(orca_path)
            except Exception as e:
                logger.warning(f"Could not load Orca dataset: {e}")
                orca_raw = []

        if not orca_raw:
            logger.warning("The Orca dataset is empty, using 100% MedMCQA.")
            self.radqa_ratio = 1.0
            self.orca_ratio = 0.0

        # Standardizing formats.
        logger.info("Standardizing MedMCQA samples.")
        if not medmcqa_raw:
            logger.error("No MedMCQA samples loaded.")
            raise ValueError("MedMCQA dataset is empty.")

        medmcqa_standardized = []
        for i, sample in enumerate(medmcqa_raw):
            try:
                standardized = self.standardize_medmcqa_sample(sample)
                medmcqa_standardized.append(standardized)
            except Exception as e:
                logger.warning(f"Error standardizing MedMCQA sample {i}: {e}")
                logger.warning(f"Sample data: {sample}")

        if not medmcqa_standardized:
            logger.error("No MedMCQA samples could be standardized.")
            raise ValueError("All MedMCQA samples failed standardization.")

        medmcqa_valid = self.filter_valid_samples(medmcqa_standardized, "MedMCQA")

        if orca_raw:
            logger.info("Standardizing Orca samples.")
            orca_standardized = [self.standardize_orca_sample(sample) for sample in orca_raw]
            orca_valid = self.filter_valid_samples(orca_standardized, "Orca")
        else:
            logger.info("Skipping the Orca processing.")
            orca_valid = []

        # Split each dataset into train/eval.
        medmcqa_train, medmcqa_eval = self.stratified_split(medmcqa_valid, self.eval_split)

        if orca_valid:
            orca_train, orca_eval = self.stratified_split(orca_valid, self.eval_split)
        else:
            orca_train, orca_eval = [], []

        logger.info(f"MedMCQA splits. Train: {len(medmcqa_train)}, Eval: {len(medmcqa_eval)}")
        logger.info(f"Orca splits. Train: {len(orca_train)}, Eval: {len(orca_eval)}")

        # Adjusting the training set sizes based on the specified ratios.
        if orca_train:
            total_train_target = len(medmcqa_train) + len(orca_train)
            medmcqa_train_target = int(total_train_target * self.radqa_ratio)
            orca_train_target = int(total_train_target * self.orca_ratio)

            # Limit the number of samples to the target sizes.
            if len(medmcqa_train) > medmcqa_train_target:
                medmcqa_train = random.sample(medmcqa_train, medmcqa_train_target)
            if len(orca_train) > orca_train_target:
                orca_train = random.sample(orca_train, orca_train_target)

        # Mix training data.
        mixed_train = medmcqa_train + orca_train
        random.shuffle(mixed_train)

        # Mix evaluation data.
        mixed_eval = medmcqa_eval + orca_eval
        random.shuffle(mixed_eval)

        logger.info(f"Final mixed dataset. Train: {len(mixed_train)}, Eval: {len(mixed_eval)}")
        if orca_train:
            logger.info(f"Training set composition. MedMCQA: {len(medmcqa_train)} ({len(medmcqa_train)/len(mixed_train)*100:.1f}%), "
                       f"Orca: {len(orca_train)} ({len(orca_train)/len(mixed_train)*100:.1f}%)")
        else:
            logger.info(f"Training set composition. MedMCQA: 100%")

        return mixed_train, mixed_eval

    def save_dataset(self, samples: List[Dict[str, str]], output_path: str):
        logger.info(f"Saving {len(samples)} samples to {output_path}")

        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(samples, f, indent=2, ensure_ascii=False)

        logger.info(f"Dataset saved successfully to {output_path}")

    def generate_dataset_stats(self, train_samples: List[Dict[str, str]], eval_samples: List[Dict[str, str]]) -> Dict[str, Any]:
        all_samples = train_samples + eval_samples
        total_samples = len(all_samples)

        # Handle edge case where no samples exist.
        if total_samples == 0:
            logger.warning("No samples found in datasets.")
            return {
                "total_samples": 0,
                "train_samples": 0,
                "eval_samples": 0,
                "train_eval_ratio": "0/0",
                "avg_instruction_length": 0.0,
                "avg_output_length": 0.0,
                "samples_with_input": 0,
                "processing_config": {
                    "radqa_ratio": self.radqa_ratio,
                    "orca_ratio": self.orca_ratio,
                    "eval_split": self.eval_split,
                    "seed": self.seed
                },
                "error": "No valid samples found"
            }

        # Averages.
        avg_instruction_length = sum(len(s['instruction'].split()) for s in all_samples) / total_samples
        avg_output_length = sum(len(s['output'].split()) for s in all_samples) / total_samples
        samples_with_input = sum(1 for s in all_samples if s['input'].strip())

        stats = {
            "total_samples": total_samples,
            "train_samples": len(train_samples),
            "eval_samples": len(eval_samples),
            "train_eval_ratio": f"{len(train_samples)}/{len(eval_samples)}",
            "avg_instruction_length": avg_instruction_length,
            "avg_output_length": avg_output_length,
            "samples_with_input": samples_with_input,
            "processing_config": {
                "radqa_ratio": self.radqa_ratio,
                "orca_ratio": self.orca_ratio,
                "eval_split": self.eval_split,
                "seed": self.seed
            }
        }

        return stats


def setup_colab_environment():
    try:
        import google.colab
        IN_COLAB = True
        print("Running in Google Colab environment.")
    except ImportError:
        IN_COLAB = False
        print("Running in local environment.")

    if IN_COLAB:
        print("Installing the required packages.")
        os.system("pip install -q pandas scikit-learn")

        # Google Drive.
        try:
            from google.colab import drive
            print("Mounting Google Drive.")
            drive.mount('/content/drive')
            print("Google Drive successful.")
        except Exception as e:
            print(f"Could not mount Google Drive: {e}")
            return False

    return IN_COLAB


def run_with_drive_files():
    print("Medical Chatbot Dataset.")
    print("=" * 70)

    # Setting up the environment.
    IN_COLAB = setup_colab_environment()

    if not IN_COLAB:
        print("This version requires Google Colab with Google Drive.")
        return

    temp_processor = DatasetProcessor()

    file_patterns = [
        ("medmcqa", "extracted_medmcqa 2.csv"),
        ("medmcqa", "extracted_medmcqa"),
        ("orca", "orca_subset.json"),
        ("orca", "orca_subset"),
    ]

    found_files = temp_processor.find_files_in_drive(file_patterns)

    medmcqa_path = None
    orca_path = None

    if "medmcqa" in found_files:
        medmcqa_path = found_files["medmcqa"]
        print(f"Found MedMCQA file: {medmcqa_path}")
    else:
        print("MedMCQA file not found.")
        print("Provide the path to the MedMCQA file:")
        print("Example: /content/drive/MyDrive/your_folder/medmcqa.csv")
        medmcqa_path = input("MedMCQA CSV path: ").strip()

    if "orca" in found_files:
        orca_path = found_files["orca"]
        print(f"✓ Found Orca file: {orca_path}")
    else:
        print("Orca file not found automatically.")
        print("Provide the path to the Orca file or press Enter to skip:")
        print("Example: /content/drive/MyDrive/your_folder/orca.json")
        orca_path = input("Orca JSON path (optional): ").strip()
        if not orca_path:
            orca_path = None

    # Verify that the files exist.
    if not os.path.exists(medmcqa_path):
        print(f"MedMCQA file not found: {medmcqa_path}")
        print("Please check the path.")
        return

    if orca_path and not os.path.exists(orca_path):
        print(f"Orca file not found: {orca_path}")
        print("Continuing with MedMCQA only.")
        orca_path = None

    # Preview files.
    print("\n" + "="*50)
    print("File Preview")
    print("="*50)

    try:
        import pandas as pd
        df_preview = pd.read_csv(medmcqa_path, nrows=2)
        print(f"MedMCQA file preview:")
        print(f"Shape: {df_preview.shape}")
        print(f"Columns: {list(df_preview.columns)}")
        print(f"Sample row: {df_preview.iloc[0].to_dict()}")
    except Exception as e:
        print(f"Could not preview MedMCQA file: {e}")

    if orca_path:
        try:
            with open(orca_path, 'r') as f:
                orca_preview = json.load(f)
            if isinstance(orca_preview, list):
                print(f"Orca file preview:")
                print(f"Type: List with {len(orca_preview)} items")
                if orca_preview:
                    print(f"Sample item keys: {list(orca_preview[0].keys())}")
        except Exception as e:
            print(f"Could not preview Orca file: {e}")

    # Configuration.
    print("\n" + "="*50)
    print("Configuration")
    print("="*50)

    medmcqa_ratio = float(input("MedMCQA ratio 0.75 for 75% MedMCQA, 25% Orca: ").strip() or "0.75")
    eval_split = float(input("Evaluation split 0.1 for 10%: ").strip() or "0.1")
    seed = int(input("Random seed 42: ").strip() or "42")
    output_dir = input("Output directory (./data): ").strip() or "./data"

    # Validate inputs.
    if not 0 < medmcqa_ratio <= 1:
        print(f"MedMCQA ratio must be between 0 and 1, got {medmcqa_ratio}")
        return

    if not 0 < eval_split < 1:
        print(f"Eval split must be between 0 and 1, got {eval_split}")
        return

    # Processor.
    print(f"\n Initializing dataset processor.")
    processor = DatasetProcessor(
        radqa_ratio=medmcqa_ratio,
        eval_split=eval_split,
        seed=seed
    )

    try:
        print("Processing datasets.")
        train_samples, eval_samples = processor.process_datasets(medmcqa_path, orca_path)

        train_path = os.path.join(output_dir, 'train.json')
        eval_path = os.path.join(output_dir, 'eval.json')

        print("Saving datasets.")
        processor.save_dataset(train_samples, train_path)
        processor.save_dataset(eval_samples, eval_path)

        # Statistics.
        stats = processor.generate_dataset_stats(train_samples, eval_samples)
        stats_path = os.path.join(output_dir, 'dataset_stats.json')

        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        # Summary.
        print("   Results")
        print(f"Total samples: {stats['total_samples']:,}")
        print(f"Training samples: {stats['train_samples']:,}")
        print(f"Evaluation samples: {stats['eval_samples']:,}")
        print(f"Average instruction length: {stats['avg_instruction_length']:.1f} words")
        print(f"Average output length: {stats['avg_output_length']:.1f} words")
        print(f"Files saved to: {output_dir}")
        print(f"   - train.json: {stats['train_samples']:,} samples")
        print(f"   - eval.json: {stats['eval_samples']:,} samples")
        print(f"   - dataset_stats.json: Processing statistics")

        # Download the files.
        print("\nDownloading.")
        try:
            from google.colab import files
            files.download(train_path)
            files.download(eval_path)
            files.download(stats_path)
            print("Files downloaded.")
        except Exception as e:
            print(f"Could not download files automatically: {e}")
            print(f"Files are available at: {output_dir}")

        print("\n Ready for next step: Fine-tuning your model.")
        return train_path, eval_path, stats_path

    except Exception as e:
        print(f"Error during dataset preparation: {e}")
        import traceback
        traceback.print_exc()
        raise


def main():
    try:
        return run_with_drive_files()
    except Exception as e:
        logger.error(f"Error during dataset preparation: {e}")
        raise


if __name__ == "__main__":
    main()

Medical Chatbot Dataset Preparation - Google Drive Integration
Running in Google Colab environment
Installing required packages...
Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!
✓ Found MedMCQA file: /content/drive/My Drive/extracted_medmcqa 2.csv
✓ Found Orca file: /content/drive/My Drive/orca_subset.json

FILE PREVIEW
MedMCQA file preview:
Shape: (2, 8)
Columns: ['question', 'exp', 'cop', 'opa', 'opb', 'opc', 'opd', 'correct_answer']
Sample row: {'question': 'Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma', 'exp': 'Chronic urethral obstruction because of urinary calculi, prostatic hyperophy, tumors, normal pregnancy, tumors, uterine prolapse or functional disorders cause hydronephrosis which by definition is used to describe dilatation of renal pelvis and calculus associate

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Files downloaded successfully!

🚀 Ready for next step: Fine-tuning your model!


In [16]:
import json

with open('data/train.json', 'r') as f:
    df1 = json.load(f)

display(df1[:5])

[{'instruction': 'Question: Mode of action of Fluoxetine is -\n\nOptions:\nA) GABA inhibition\nB) Adrenergic neuron blocking agent\nC) Inhibition axonal uptake of 5HT\nD) Alpha adrenergic stimulation\n\nPlease select the correct answer and provide a brief explanation.',
  'input': '',
  'output': "The correct answer is C) Inhibition axonal uptake of 5HT.\n\nExplanation: Ans. is 'c' Inhibit axonal uptake of 5HT Fluoxetine is a tricyclic antidepressant (like imipramine, Amitriptyline)While typically TCA inhibit uptake of both NA and 5HT by neurons.Fluoxetine is selective serotonin reuptake inhibitor (SSRI)Therefore it is devoid of following side effectsAnticholinergic *Sedation*Hypotension*Other SSRIsFluvoxamine*Paroxetine*Also remember mechanism of action of these atypical TCAsTianeptine - It increase rather than inhibit 5-HT uptake *Mianserin - does not inhibit either NA or 5-HT uptake, it blocks presynaptic alpha-2 receptors, increases release and turnover of NA in brain *"},
 {'instr

In [17]:
import pandas as pd

# Display the first 5 items in a pandas DataFrame for better organization
display(pd.DataFrame(df1[:5]))

Unnamed: 0,instruction,input,output
0,Question: Mode of action of Fluoxetine is -\n\...,,The correct answer is C) Inhibition axonal upt...
1,Question: In sipple syndrome [MEN II ) all are...,,The correct answer is B) Pituitary hyperplasia...
2,Question: Poor man's Iron source is:-\n\nOptio...,,The correct answer is D) Jaggery.\n\nExplanati...
3,Question: An i.v. bolus dose of thiopentone le...,,The correct answer is D) Redistributed from br...
4,Question: The Transorbital view is carried out...,,The correct answer is D) Internal auditory mea...
