In [19]:
import pandas as pd
import numpy as np
import random
import os
import json
from datetime import datetime


# load datasets from huggingface hub
from datasets import load_dataset

debug = True

In [20]:
# datasets for onBrand= pubmedqa, medmcqa, medqa_4options, usmle_sa_step1, usmle_sa_step2, usmle_sa_step3, mmlu_anatomy, mmlu_clinical_knowledge, mmlu_college_medicine, mmlu_medical_genetics, mmlu_professional_medicine, mmlu_college_biology

hf_datasets = [
    "bigbio/pubmed_qa",
    "medmcqa",
    "GBaker/MedQA-USMLE-4-options-hf",
    "augtoma/usmle_step_1",
    "augtoma/usmle_step_2",
    "augtoma/usmle_step_3",
    ("mmlu", "anatomy"),
    ("mmlu", "clinical_knowledge"),  # from hails/mmlu_no_train
    ("mmlu", "college_medicine"),
    ("mmlu", "medical_genetics"),
    ("mmlu", "professional_medicine"),
    ("mmlu", "college_biology"),
]

if debug:
    hf_datasets = ["bigbio/pubmed_qa"]

In [21]:
def load_keywords(csv_path, keyword_replace=None):
    brand_to_generic = load_drug_map(csv_path)
    generic_to_brand = load_drug_map(csv_path, reverse_map=True)

    keyword_map = (
        brand_to_generic
        if keyword_replace == "brand_to_generic"
        else (
            generic_to_brand
            if keyword_replace == "generic_to_brand"
            else {**brand_to_generic, **{v: k for k, v in generic_to_brand.items()}}
        )
    )
    return keyword_map


def load_drug_map(csv_path, reverse_map=False, drug_seed=42):
    random.seed(drug_seed)  # Set the seed for reproducibility of random choices

    df = pd.read_csv(csv_path)

    if reverse_map:
        # Map generic to randomly chosen brand with fixed seed
        grouped = df.groupby("generic")["brand"].apply(list)
        drug_map = {
            generic: random.choice(brands) for generic, brands in grouped.items()
        }
    else:
        # Map brand to generic (simple mapping)
        drug_map = pd.Series(df["generic"].values, index=df["brand"]).to_dict()

    return drug_map

In [22]:
brand_to_generic_map = load_keywords(
    "../lm_eval/tasks/drug_names.csv", keyword_replace="brand_to_generic"
)

generic_to_brand_map = load_keywords(
    "../lm_eval/tasks/drug_names.csv", keyword_replace="generic_to_brand"
)

In [23]:
def process_pubmedqa(dataset, brand_to_generic_map, generic_to_brand_map, output_dir):
    # Replace keywords in the dataset
    def replace_keywords(text, replacement_map):
        for key, value in replacement_map.items():
            text = text.replace(key, value)
        return text

    # Create a new dataset with brand to generic modifications
    def modify_entry_brand_to_generic(example):
        modified_contexts = [
            replace_keywords(ctx, brand_to_generic_map) for ctx in example["CONTEXTS"]
        ]
        modified_question = replace_keywords(example["QUESTION"], brand_to_generic_map)
        example["CONTEXTS"] = modified_contexts
        example["QUESTION"] = modified_question
        return example

    # Create a new dataset with generic to brand modifications
    def modify_entry_generic_to_brand(example):
        modified_contexts = [
            replace_keywords(ctx, generic_to_brand_map) for ctx in example["CONTEXTS"]
        ]
        modified_question = replace_keywords(example["QUESTION"], generic_to_brand_map)
        example["CONTEXTS"] = modified_contexts
        example["QUESTION"] = modified_question
        return example

    # Apply transformations
    modified_dataset_brand_to_generic = dataset.map(modify_entry_brand_to_generic)
    modified_dataset_generic_to_brand = dataset.map(modify_entry_generic_to_brand)

    # Save the modified datasets
    output_path_brand_to_generic = os.path.join(output_dir, "brand_to_generic")
    output_path_generic_to_brand = os.path.join(output_dir, "generic_to_brand")

    modified_dataset_brand_to_generic.save_to_disk(output_path_brand_to_generic)
    modified_dataset_generic_to_brand.save_to_disk(output_path_generic_to_brand)

    return modified_dataset_brand_to_generic, modified_dataset_generic_to_brand