In [2]:
from datasets import load_dataset
import os

In [3]:
os.makedirs("data/raw", exist_ok=True)
os.makedirs("data/processed", exist_ok=True)

In [4]:
#MCQ
mmlu_pro_medicine = load_dataset("openlifescienceai/mmlu_professional_medicine")
mmlu_anatomy = load_dataset("openlifescienceai/mmlu_anatomy")
mmlu_clinical_knowledge = load_dataset("openlifescienceai/mmlu_clinical_knowledge")
pubmedqa = load_dataset("openlifescienceai/pubmedqa")

#Text Summarization
cord19_summarization = load_dataset("medalpaca/medical_meadow_cord19")

#QA
wikidoc_qa = load_dataset("medalpaca/medical_meadow_wikidoc_patient_information")


# Dataset Inspection

Let's take a look at the structure of one of our datasets.

In [5]:
# Display basic information about mmlu_pro_medicine dataset
print("Dataset structure:")
print(mmlu_pro_medicine)

# View a sample from the dataset
print("\nSample from 'dev' split:")
sample = mmlu_pro_medicine['dev'][0]
print(sample)

Dataset structure:
DatasetDict({
    test: Dataset({
        features: ['subject_name', 'data', 'id'],
        num_rows: 272
    })
    validation: Dataset({
        features: ['subject_name', 'data', 'id'],
        num_rows: 31
    })
    dev: Dataset({
        features: ['subject_name', 'data', 'id'],
        num_rows: 5
    })
})

Sample from 'dev' split:
{'subject_name': 'professional_medicine', 'data': {'Correct Answer': 'Phenoxybenzamine', 'Correct Option': 'D', 'Options': {'A': 'Labetalol', 'B': 'A loading dose of potassium chloride', 'C': 'Nifedipine', 'D': 'Phenoxybenzamine'}, 'Question': 'A 42-year-old man comes to the office for preoperative evaluation prior to undergoing adrenalectomy scheduled in 2 weeks. One month ago, he received care in the emergency department for pain over his right flank following a motor vehicle collision. At that time, blood pressure was 160/100 mm Hg and CT scan of the abdomen showed an incidental 10-cm left adrenal mass. Results of laboratory stu

# Save Datasets to CSV (Limited to 500 samples)

We'll convert each dataset to CSV format and save a maximum of 500 samples to the raw data folder.

In [6]:
import pandas as pd
import random
import json

def save_limited_dataset_to_csv(dataset, dataset_name, max_samples=500):
    # Collect all samples from the dataset across splits
    all_samples = []
    for split_name, split_data in dataset.items():
        df = split_data.to_pandas()
        all_samples.append(df)
    
    # Combine all splits
    combined_df = pd.concat(all_samples, ignore_index=True)
    print(f"Total samples in {dataset_name}: {len(combined_df)}")
    
    # Randomly select max_samples or all if fewer
    if len(combined_df) > max_samples:
        sampled_df = combined_df.sample(max_samples, random_state=42)
    else:
        sampled_df = combined_df
    
    # Save to CSV in raw data folder
    output_path = f"data/raw/{dataset_name}.csv"
    sampled_df.to_csv(output_path, index=False)
    
    print(f"Saved {len(sampled_df)} samples from {dataset_name} to {output_path}")
    
    return sampled_df

In [7]:
# Save medical datasets (limited to 500 samples each)
mmlu_med_samples = save_limited_dataset_to_csv(mmlu_pro_medicine, "mmlu_prof_med")
mmlu_anat_samples = save_limited_dataset_to_csv(mmlu_anatomy, "mmlu_anatomy")
mmlu_clin_samples = save_limited_dataset_to_csv(mmlu_clinical_knowledge, "mmlu_clinical_knowledge")
pubmedqa_samples = save_limited_dataset_to_csv(pubmedqa, "pubmedqa")

Total samples in mmlu_prof_med: 308
Saved 308 samples from mmlu_prof_med to data/raw/mmlu_prof_med.csv
Total samples in mmlu_anatomy: 154
Saved 154 samples from mmlu_anatomy to data/raw/mmlu_anatomy.csv
Total samples in mmlu_clinical_knowledge: 299
Saved 299 samples from mmlu_clinical_knowledge to data/raw/mmlu_clinical_knowledge.csv
Total samples in pubmedqa: 1000
Saved 500 samples from pubmedqa to data/raw/pubmedqa.csv


In [8]:
# Save the summarization and QA datasets (limited to 500 samples each)
cord19_samples = save_limited_dataset_to_csv(cord19_summarization, "cord19_summarization")
wikidoc_samples = save_limited_dataset_to_csv(wikidoc_qa, "wikidoc_qa")

Total samples in cord19_summarization: 821007
Saved 500 samples from cord19_summarization to data/raw/cord19_summarization.csv
Total samples in wikidoc_qa: 5942
Saved 500 samples from wikidoc_qa to data/raw/wikidoc_qa.csv
Total samples in wikidoc_qa: 5942
Saved 500 samples from wikidoc_qa to data/raw/wikidoc_qa.csv


# Combine 600 Random Samples from Medical MCQ Datasets

Let's create a combined dataset with 600 randomly selected samples from the three medical MCQ datasets.

In [9]:
# Load the saved CSV files
med_df = pd.read_csv('data/raw/mmlu_prof_med.csv')
anat_df = pd.read_csv('data/raw/mmlu_anatomy.csv')
clin_df = pd.read_csv('data/raw/mmlu_clinical_knowledge.csv')

# Check the number of samples in each dataset
print(f"Professional Medicine dataset: {len(med_df)} samples")
print(f"Anatomy dataset: {len(anat_df)} samples")
print(f"Clinical Knowledge dataset: {len(clin_df)} samples")

# Calculate how many samples to take from each dataset
# We'll aim for an approximately equal distribution
samples_per_dataset = 600 // 3

# If there are fewer samples in any dataset, adjust accordingly
med_samples = min(samples_per_dataset, len(med_df))
anat_samples = min(samples_per_dataset, len(anat_df))
clin_samples = 600 - med_samples - anat_samples  # Make sure we get exactly 600 total

print(f"\nSampling: {med_samples} from Medicine, {anat_samples} from Anatomy, {clin_samples} from Clinical Knowledge")

# Randomly sample from each dataset
med_subset = med_df.sample(med_samples, random_state=42)
anat_subset = anat_df.sample(anat_samples, random_state=42)
clin_subset = clin_df.sample(clin_samples, random_state=42)

# Add a source column to track where each question came from
med_subset['source'] = 'Professional Medicine'
anat_subset['source'] = 'Anatomy'
clin_subset['source'] = 'Clinical Knowledge'

Professional Medicine dataset: 308 samples
Anatomy dataset: 154 samples
Clinical Knowledge dataset: 299 samples

Sampling: 200 from Medicine, 154 from Anatomy, 246 from Clinical Knowledge


# Format data and extract required columns

Now we'll transform the dataset by:
1. Removing the subject_name column
2. Creating input column (combining Question and Options)
3. Extracting the original 'Correct Answer' and 'Correct Option' fields
4. Keeping source column for reference but dropping id column

In [10]:
def format_mcq_data(df):
    formatted_rows = []
    
    for _, row in df.iterrows():
        data_dict = eval(row['data']) if isinstance(row['data'], str) else row['data']
        
        # Extract question and options
        question = data_dict.get('Question', '')
        options = data_dict.get('Options', {})
        
        # Format options as A. option, B. option, etc.
        formatted_options = '\n'.join([f"{k}. {v}" for k, v in options.items()])
        
        # Create input by combining question and options
        input_text = f"{question}\n\n{formatted_options}"
        
        # Get the correct answer and option
        correct_answer = data_dict.get('Correct Answer', '')
        correct_option = data_dict.get('Correct Option', '')
        
        formatted_rows.append({
            'input': input_text,
            'Correct Answer': correct_answer,
            'Correct Option': correct_option,
            'source': row['source']
        })
    
    return pd.DataFrame(formatted_rows)

# Format each subset
med_formatted = format_mcq_data(med_subset)
anat_formatted = format_mcq_data(anat_subset)
clin_formatted = format_mcq_data(clin_subset)

# Combine the formatted datasets
combined_df = pd.concat([med_formatted, anat_formatted, clin_formatted], ignore_index=True)

# Shuffle the combined dataset
combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Save the combined dataset
output_path = "data/processed/combined_medical_mcq_600.csv"
combined_df.to_csv(output_path, index=False)

print(f"\nCombined dataset with {len(combined_df)} samples saved to {output_path}")
print(f"Dataset distribution:")
print(combined_df['source'].value_counts())


Combined dataset with 600 samples saved to data/processed/combined_medical_mcq_600.csv
Dataset distribution:
source
Clinical Knowledge       246
Professional Medicine    200
Anatomy                  154
Name: count, dtype: int64


In [11]:
# Check the structure of the combined dataset
print("Combined dataset columns:")
print(combined_df.columns.tolist())
print("\nSample from combined dataset:")
print(combined_df.iloc[0])

Combined dataset columns:
['input', 'Correct Answer', 'Correct Option', 'source']

Sample from combined dataset:
input             A 50-year-old female presents to the office wi...
Correct Answer                                         fibromyalgia
Correct Option                                                    B
source                                        Professional Medicine
Name: 0, dtype: object


In [12]:
# Sample 600 records from wikidoc_qa dataset and save to processed folder
import pandas as pd
import os

# Check if we already have the wikidoc_samples dataframe in memory
if 'wikidoc_samples' not in globals() or len(wikidoc_samples) < 600:
    # Load the CSV file if we need to
    wikidoc_df = pd.read_csv('data/raw/wikidoc_qa.csv')
    print(f"Total records in wikidoc_qa dataset: {len(wikidoc_df)}")
    
    # Sample 600 records randomly if the dataset is large enough
    if len(wikidoc_df) >= 600:
        wikidoc_samples = wikidoc_df.sample(n=600, random_state=42)
    else:
        # If dataset is smaller than 600, use all available records
        wikidoc_samples = wikidoc_df
        print(f"Warning: Only {len(wikidoc_df)} records available in dataset, using all records.")

# Ensure we have complete input-output pairs
print(f"Number of samples with both input and output: {wikidoc_samples[['input', 'output']].notna().all(axis=1).sum()}")
    
# Remove any rows with missing input or output (if any)
wikidoc_complete = wikidoc_samples.dropna(subset=['input', 'output'])
print(f"Complete records after dropping rows with missing values: {len(wikidoc_complete)}")

# If we have more than 600 complete records, sample exactly 600
if len(wikidoc_complete) > 600:
    wikidoc_complete = wikidoc_complete.sample(n=600, random_state=42)
    
# Save the sampled data to the processed folder
processed_path = "data/processed/wikidoc_qa_600.csv"
wikidoc_complete.to_csv(processed_path, index=False)

print(f"Saved {len(wikidoc_complete)} input-output pairs to {processed_path}")
print(f"Sample of the data:")
wikidoc_complete[['input', 'output']].head(3)

Total records in wikidoc_qa dataset: 500
Number of samples with both input and output: 500
Complete records after dropping rows with missing values: 500
Saved 500 input-output pairs to data/processed/wikidoc_qa_600.csv
Sample of the data:


Unnamed: 0,input,output
0,Where to find Medical Care for Gonadoblastoma?,Medical care for gonadoblastoma can be found h...
1,What to expect if I have Craniosynostosis (Ou...,How well a person does depends on how many sut...
2,When to seek urgent medical care when I have S...,Call for an appointment with your health care ...


In [13]:
# Verify the wikidoc dataset structure and contents
print(f"Dataset shape: {wikidoc_complete.shape}")
print("\nColumn names:")
for col in wikidoc_complete.columns:
    print(f"- {col}")
    
# Count unique questions and answers
print(f"\nNumber of unique input questions: {wikidoc_complete['input'].nunique()}")
print(f"Number of unique output answers: {wikidoc_complete['output'].nunique()}")

# Display example input and output
print("\nSample input question:")
print(wikidoc_complete['input'].iloc[0])

print("\nSample output answer:")
print(wikidoc_complete['output'].iloc[0])

Dataset shape: (500, 3)

Column names:
- input
- output
- instruction

Number of unique input questions: 499
Number of unique output answers: 499

Sample input question:
Where to find Medical Care for Gonadoblastoma?

Sample output answer:
Medical care for gonadoblastoma can be found here.


In [14]:
# Sample 600 records from cord19_summarization dataset and save to processed folder\n
import pandas as pd


# Check if we already have the cord19_samples dataframe in memory\n
if 'cord19_samples' not in globals() or len(cord19_samples) < 600:
    # Load the CSV file if we need to\n
    cord19_df = pd.read_csv('data/raw/cord19_summarization.csv')
    # Sample 600 records randomly if the dataset is large enough\n
    if len(cord19_df) >= 600:
        cord19_samples = cord19_df.sample(n=600, random_state=42)
    else:
        # If dataset is smaller than 600, use all available records\n
        cord19_samples = cord19_df
        print(f"Warning: Only {len(cord19_df)} records available in dataset, using all records.")

# Save the sampled data to the processed folder\n
processed_path = "data/processed/cord19_summarization_600.csv"
cord19_samples.to_csv(processed_path, index=False)

print(f"Saved {len(cord19_samples)} records to {processed_path}")
print(f"Sample of the data:")
cord19_samples.head()

Saved 500 records to data/processed/cord19_summarization_600.csv
Sample of the data:

Saved 500 records to data/processed/cord19_summarization_600.csv
Sample of the data:


Unnamed: 0,output,input,instruction
0,183. Decrease in Invasive Pneumococcal Disease...,BACKGROUND: During the 2020 SARS-CoV-2 pandemi...,Please summerize the given abstract to a title
1,CD10 in the developing human kidney: immunorea...,CD10 was first identified in tumor cells of ac...,Please summerize the given abstract to a title
2,Long-term bone and lung consequences associate...,The most severe sequelae after rehabilitation ...,Please summerize the given abstract to a title
3,RNA structure interactions and ribonucleoprote...,"In one more years, we will ‘celebrate’ an exac...",Please summerize the given abstract to a title
4,Organization of remote rehabilitation of mosco...,The tense epidemic situation in the Russian Fe...,Please summerize the given abstract to a title


# Process PubMedQA Dataset

Extract information from the PubMedQA dataset, restructure it, and save to processed folder.

In [15]:
# Process PubMedQA dataset
import pandas as pd
import json
import ast

# Load the PubMedQA dataset from raw data if not already in memory
if 'pubmedqa_samples' not in globals():
    pubmed_df = pd.read_csv('data/raw/pubmedqa.csv')
else:
    pubmed_df = pd.DataFrame(pubmedqa_samples)

# Function to extract data from the 'data' column
def extract_pubmedqa_data(data_str):
    try:
        # Convert string to dictionary if it's a string
        if isinstance(data_str, str):
            data_dict = ast.literal_eval(data_str)
        else:
            data_dict = data_str
            
        # Extract the fields
        context = ' '.join(data_dict.get('Context', [])) if isinstance(data_dict.get('Context'), list) else str(data_dict.get('Context', ''))
        question = data_dict.get('Question', '')
        options = data_dict.get('Options', {})
        correct_option = data_dict.get('Correct Option', '')
        correct_answer = data_dict.get('Correct Answer', '')
        long_answer = data_dict.get('Long Answer', '')
        
        return context, question, options, correct_option, correct_answer, long_answer
    except Exception as e:
        print(f"Error processing data: {e}")
        return '', '', {}, '', '', ''

# Apply the extraction function to each row
results = pubmed_df['data'].apply(extract_pubmedqa_data)

# Create a new DataFrame with the extracted data
processed_df = pd.DataFrame({
    'context': [r[0] for r in results],
    'question': [r[1] for r in results],
    'options': [r[2] for r in results],
    'correct_option': [r[3] for r in results],
    'correct_answer': [r[4] for r in results],
    'long_answer': [r[5] for r in results]
})

# Display the first few rows to verify the extraction
print("Processed PubMedQA dataset:")
print(f"Shape: {processed_df.shape}")
processed_df.head(2)

Processed PubMedQA dataset:
Shape: (500, 6)


Unnamed: 0,context,question,options,correct_option,correct_answer,long_answer
0,['Mossy fibers are the sole excitatory project...,Do mossy fibers release GABA?,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A,yes,We have thus provided compelling evidence that...
1,"[""Research on stroke survivors' driving safety...",Department of Transportation vs self-reported ...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",B,no,"In our population of stroke survivors, self-re..."


In [16]:
# Save the processed dataset to the processed folder
processed_path = "data/processed/pubmedqa_processed.csv"
processed_df.to_csv(processed_path, index=False)
print(f"Processed PubMedQA dataset saved to {processed_path}")

# Sample 600 records for consistency with other datasets (if needed)
if len(processed_df) > 600:
    sample_df = processed_df.sample(n=600, random_state=42)
    sample_path = "data/processed/pubmedqa_600.csv"
    sample_df.to_csv(sample_path, index=False)
    print(f"Random sample of 600 records saved to {sample_path}")

Processed PubMedQA dataset saved to data/processed/pubmedqa_processed.csv



In [17]:
# Verify the dataset structure and contents\n
print(f"Dataset shape: {cord19_samples.shape}")
print("\nColumn names:")
for col in cord19_samples.columns:
    print(f"- {col}")
print("\nSample output title:")
print(cord19_samples['output'].iloc[0])

Dataset shape: (500, 3)

Column names:
- output
- input
- instruction

Sample output title:
183. Decrease in Invasive Pneumococcal Disease in 7 United States Children’s Hospitals during the COVID-19 Pandemic
