In [1]:
import pandas as pd
import numpy as np
import random
import os
import json
from datetime import datetime
import copy

# load datasets from huggingface hub
from datasets import load_dataset
from datasets import Dataset, DatasetDict
from datasets import Features, Value, ClassLabel, Sequence

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

#  set max display width too view full text
pd.set_option("display.max_colwidth", None)

debug = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load orig_filtered and g2b versions of GBaker/MedQA-USMLE-4-options-hf
orig_filtered_medqa = load_dataset(
    "AIM-Harvard/gbaker_medqa_usmle_4_options_hf_original", split="test"
)
g2b_medqa = load_dataset(
    "AIM-Harvard/gbaker_medqa_usmle_4_options_hf_generic_to_brand", split="test"
)

orig_filtered_medmcqa = load_dataset("AIM-Harvard/medmcqa_original", split="test")
g2b_medmcqa = load_dataset("AIM-Harvard/medmcqa_generic_to_brand", split="test")

# convert to pandas
orig_filtered_medqa = orig_filtered_medqa.to_pandas()
g2b_medqa = g2b_medqa.to_pandas()

orig_filtered_medmcqa = orig_filtered_medmcqa.to_pandas()
g2b_medmcqa = g2b_medmcqa.to_pandas()

# print shape
print("orig_filtered_medqa shape: ", orig_filtered_medqa.shape)
print("g2b_medqa shape: ", g2b_medqa.shape)

print("orig_filtered_medmcqa shape: ", orig_filtered_medmcqa.shape)
print("g2b_medmcqa shape: ", g2b_medmcqa.shape)

Downloading readme: 100%|██████████| 811/811 [00:00<00:00, 1.14MB/s]
Downloading data: 100%|██████████| 1.81M/1.81M [00:00<00:00, 5.95MB/s]
Downloading data: 100%|██████████| 227k/227k [00:00<00:00, 1.34MB/s]
Downloading data: 100%|██████████| 215k/215k [00:00<00:00, 920kB/s]
Generating train split: 100%|██████████| 3310/3310 [00:00<00:00, 70435.64 examples/s]
Generating validation split: 100%|██████████| 399/399 [00:00<00:00, 124416.57 examples/s]
Generating test split: 100%|██████████| 378/378 [00:00<00:00, 141204.75 examples/s]
Downloading readme: 100%|██████████| 811/811 [00:00<00:00, 2.18MB/s]
Downloading data: 100%|██████████| 1.80M/1.80M [00:00<00:00, 8.84MB/s]
Downloading data: 100%|██████████| 226k/226k [00:00<00:00, 1.25MB/s]
Downloading data: 100%|██████████| 214k/214k [00:00<00:00, 1.13MB/s]
Generating train split: 100%|██████████| 3310/3310 [00:00<00:00, 198025.14 examples/s]
Generating validation split: 100%|██████████| 399/399 [00:00<00:00, 109674.77 examples/s]
Generati

orig_filtered_medqa shape:  (378, 10)
g2b_medqa shape:  (378, 10)
orig_filtered_medmcqa shape:  (457, 13)
g2b_medmcqa shape:  (457, 13)





# Load in Annotated Data

In [14]:
# pre_filter_datasets/eval_csvs/annotated_medmcqa_new.csv
annotated_medmcqa_new = pd.read_csv(
    "../pre_filter_datasets/eval_csvs/annotated_medmcqa_new.csv"
)

annotated_medqa_new = pd.read_csv(
    "../pre_filter_datasets/eval_csvs/annotated_medqa_new.csv"
)

annotated_medmcqa_new.head(2)

Unnamed: 0,id,question_orig,opa_orig,opb_orig,opc_orig,opd_orig,cop_orig,choice_type_orig,exp_orig,subject_name_orig,topic_name_orig,found_keywords_orig,local_id_orig,Unnamed: 13,question_g2b,opa_g2b,opb_g2b,opc_g2b,opd_g2b,cop_g2b,choice_type_g2b,exp_g2b,subject_name_g2b,topic_name_g2b,found_keywords_g2b,local_id_g2b,Unnamed: 26,keep/drop,comments
0,006acfff-dc8f-4bb5-97b2-e26144c56483,PGE1 analogue is ?,Carboprost,Alprostadil,Epoprostenol,Dinoprostone,-1,single,,Pharmacology,,['carboprost' 'dinoprostone' 'alprostadil' 'epoprostenol'],4101,,PGE1 analogue is ?,hemabate,caverject,flolan,cervidil,-1,single,,Pharmacology,,['carboprost' 'dinoprostone' 'alprostadil' 'epoprostenol'],4101,,keep,
1,024f96d1-8881-4b52-a7f9-58e5b194a0fa,Which of the following cephalosporin is active against Pseudomonas aeruginosa:,Ceftriaxone,Cephalothin,Ceftazidime,Cefotaxime,-1,single,,Unknown,,['cefotaxime' 'ceftazidime' 'cephalothin' 'ceftriaxone'],1162,,Which of the following cephalosporin is active against Pseudomonas aeruginosa:,rocephin,keflin,fortaz,claforan,-1,single,,Unknown,,['cefotaxime' 'ceftazidime' 'cephalothin' 'ceftriaxone'],1162,,keep,


In [15]:
# get list of ids to filter (where penultimate column is not "keep")
# make the col is string
annotated_medmcqa_new.iloc[:, -2] = annotated_medmcqa_new.iloc[:, -2].astype(str)

rows_to_filter = annotated_medmcqa_new[
    annotated_medmcqa_new.iloc[:, -2] != "keep"
].id.tolist()

# same for medqa
annotated_medqa_new.iloc[:, -2] = annotated_medqa_new.iloc[:, -2].astype(str)

rows_to_filter_medqa = annotated_medqa_new[
    annotated_medqa_new.iloc[:, -2] != "keep"
].id.tolist()

print(f"Number of rows to filter in medmcqa: {len(rows_to_filter)}")
print(f"Number of rows to filter in medqa: {len(rows_to_filter_medqa)}")

Number of rows to filter in medmcqa: 82
Number of rows to filter in medqa: 63


In [18]:
# get the ids of the rows to filter in annotated_medmcqa_new
medmcqa_rows_to_filter = annotated_medmcqa_new[
    annotated_medmcqa_new.iloc[:, -2] != "keep"
].id.tolist()


# get the ids of the rows to filter in annotated_medqa_new
medqa_rows_to_filter = annotated_medqa_new[
    annotated_medqa_new.iloc[:, -2] != "keep"
].id.tolist()

# filter out the rows from the pandas hf datasets in orig and g2b
filtered_orig_filtered_medmcqa = orig_filtered_medmcqa[
    ~orig_filtered_medmcqa.id.isin(medmcqa_rows_to_filter)
]
filtered_g2b_medmcqa = g2b_medmcqa[~g2b_medmcqa.id.isin(medmcqa_rows_to_filter)]

filtered_orig_filtered_medqa = orig_filtered_medqa[
    ~orig_filtered_medqa.id.isin(medqa_rows_to_filter)
]
filtered_g2b_medqa = g2b_medqa[~g2b_medqa.id.isin(medqa_rows_to_filter)]

# check rows and difference
print(
    f"Number of rows in filtered_orig_filtered_medmcqa: {len(filtered_orig_filtered_medmcqa)}"
)
print(f"Number of rows in filtered_g2b_medmcqa: {len(filtered_g2b_medmcqa)}")
print(
    f"Number of rows in filtered_orig_filtered_medqa: {len(filtered_orig_filtered_medqa)}"
)
print(f"Number of rows in filtered_g2b_medqa: {len(filtered_g2b_medqa)}")
print(
    f"Difference in rows in filtered_orig_filtered_medmcqa: {len(orig_filtered_medmcqa) - len(filtered_orig_filtered_medmcqa)}"
)
print(
    f"Difference in rows in filtered_g2b_medmcqa: {len(g2b_medmcqa) - len(filtered_g2b_medmcqa)}"
)
print(
    f"Difference in rows in filtered_orig_filtered_medqa: {len(orig_filtered_medqa) - len(filtered_orig_filtered_medqa)}"
)
print(
    f"Difference in rows in filtered_g2b_medqa: {len(g2b_medqa) - len(filtered_g2b_medqa)}"
)

Number of rows in filtered_orig_filtered_medmcqa: 457
Number of rows in filtered_g2b_medmcqa: 457
Number of rows in filtered_orig_filtered_medqa: 378
Number of rows in filtered_g2b_medqa: 378
Difference in rows in filtered_orig_filtered_medmcqa: 82
Difference in rows in filtered_g2b_medmcqa: 82
Difference in rows in filtered_orig_filtered_medqa: 63
Difference in rows in filtered_g2b_medqa: 63


## n keywords


In [6]:
# count number of brand keywords and generic keywords
import pandas as pd
from collections import Counter


def load_brand_generic_maps(brand_to_generic_path, generic_to_brand_path):
    brand_to_generic_df = pd.read_csv(brand_to_generic_path)
    generic_to_brand_df = pd.read_csv(generic_to_brand_path)

    brand_keywords = set(brand_to_generic_df["brand"])
    generic_keywords = set(generic_to_brand_df["generic"])

    return brand_keywords, generic_keywords


def count_keywords(merged_datasets, split, brand_keywords, generic_keywords):
    results = []
    for dataset_name, df in merged_datasets.items():
        if "found_keywords" in df.columns:
            # Flatten the list of keywords
            all_keywords = [
                keyword
                for sublist in df["found_keywords"].dropna()
                for keyword in sublist
            ]

            # Count total number of words (keywords)
            total_keyword_length = len(all_keywords)

            # Count brand and generic keywords
            brand_keyword_count = sum(
                1 for keyword in all_keywords if keyword in brand_keywords
            )
            generic_keyword_count = sum(
                1 for keyword in all_keywords if keyword in generic_keywords
            )

            # divide by 3 to get the number of questions
            total_keyword_length = total_keyword_length // 3
            brand_keyword_count = brand_keyword_count // 3
            generic_keyword_count = generic_keyword_count // 3

            # get n questions
            n = len(df) // 3

            # Collect the results
            results.append(
                {
                    "dataset": dataset_name,
                    "split": split,
                    "n_questions": n,
                    "keywords": all_keywords,
                    "total_keyword_length": total_keyword_length,
                    "brand_keywords_count": brand_keyword_count,
                    "generic_keywords_count": generic_keyword_count,
                }
            )
        else:
            print(f"No 'found_keywords' column in {dataset_name}")

    return pd.DataFrame(results)


# Load brand and generic keyword maps
brand_to_generic_path = "../RxNorm_eval/filtered_keywords.csv"
generic_to_brand_path = "../RxNorm_eval/filtered_keywords.csv"
brand_keywords, generic_keywords = load_brand_generic_maps(
    brand_to_generic_path, generic_to_brand_path
)

# Assuming merge_all_datasets function is already defined
merged_datasets = merge_all_datasets(pre_filtered_df_dir, split)

# Count keywords and get the result as a DataFrame
keywords_df = count_keywords(merged_datasets, split, brand_keywords, generic_keywords)

# save the keywords_df
keywords_df.to_csv(f"../RxNorm_eval/keywords_count_{split}.csv", index=False)

# Display the result
keywords_df


Merged Dataset: medmcqa
Split: test
Contains 1617 rows.
539 transformations.

Merged Dataset: GBaker/MedQA-USMLE-4-options-hf
Split: test
Contains 1323 rows.
441 transformations.


Unnamed: 0,dataset,split,n_questions,keywords,total_keyword_length,brand_keywords_count,generic_keywords_count
0,medmcqa,test,539,"[danazol, levofloxacin, moxifloxacin, ciprofloxacin, ofloxacin, pegvisomant, fulvestrant, vigabatrin, cabergoline, ziprasidone, aripiprazole, clozapine, quetiapine, pentazocine, methadone, tetracycline, dapsone, methotrexate, amikacin, ampicillin, phenytoin, ipratropium, oxytocin, glucagon, epinephrine, glutaraldehyde, nystatin, griseofulvin, trimipramine, doxepin, amitriptyline, desipramine, glucagon, ritonavir, indinavir, ciclesonide, digoxin, acetazolamide, flutamide, mifepristone, metyrapone, ketamine, propofol, glucagon, paba, docetaxel, paclitaxel, papaverine, atropine, nevirapine, methotrexate, sulfasalazine, chloroquine, ambenonium, edrophonium, eptifibatide, abciximab, clopidogrel, tirofiban, chlorpropamide, tolbutamide, nefazodone, furosemide, warfarin, oxytocin, metyrapone, oxytocin, theophylline, ciprofloxacin, procarbazine, topotecan, repaglinide, glucagon, dibucaine, piroxicam, sulindac, chlorhexidine, octreotide, bisoprolol, acebutolol, esmolol, pindolol, erythromycin, tetracycline, procainamide, metoclopramide, dapsone, cycloserine, hydralazine, penicillamine, doripenem, carbidopa, bromocriptine, levodopa, oseltamivir, cisplatin, epinephrine, glucagon, nevirapine, saquinavir, ...]",1002,0,1000
1,GBaker/MedQA-USMLE-4-options-hf,test,441,"[lamivudine, prednisone, cisplatin, metformin, lamivudine, dolutegravir, ritonavir, emtricitabine, efavirenz, metformin, fenofibrate, lisinopril, simvastatin, hydrochlorothiazide, hydrochlorothiazide, ibuprofen, ramipril, donepezil, metformin, acetazolamide, metformin, metoprolol, verapamil, enalapril, chloroquine, salbutamol, diphenhydramine, metformin, ibuprofen, naltrexone, fomepizole, metformin, oseltamivir, lisinopril, amlodipine, hydrochlorothiazide, oxybutynin, bethanechol, metoclopramide, galantamine, metronidazole, morphine, atropine, scopolamine, thyroxine, phentolamine, prazosin, octreotide, lisinopril, atenolol, morphine, cefotaxime, lisinopril, ceftriaxone, atorvastatin, metformin, hydrochlorothiazide, ibuprofen, lisinopril, ibuprofen, naltrexone, ipratropium, fexofenadine, itraconazole, doxycycline, griseofulvin, oxymetazoline, morphine, sildenafil, metoprolol, cilostazol, diphenhydramine, vasopressin, pantoprazole, bupropion, simvastatin, citalopram, levofloxacin, salmeterol, prednisone, metformin, oxycodone, spironolactone, hydrochlorothiazide, metformin, ibuprofen, hydrochlorothiazide, propranolol, azathioprine, cyclosporine, prednisone, enalapril, phenoxybenzamine, atenolol, propranolol, epinephrine, thyroxine, indomethacin, ampicillin, mepolizumab, ...]",950,0,947


## List of keywords that appeared


In [7]:
import pandas as pd


def load_brand_generic_maps(brand_to_generic_path, generic_to_brand_path):
    brand_to_generic_df = pd.read_csv(brand_to_generic_path)
    generic_to_brand_df = pd.read_csv(generic_to_brand_path)

    return brand_to_generic_df, generic_to_brand_df


def count_keywords(merged_datasets, split, brand_keywords, generic_keywords):
    results = []
    unique_brand_keywords_per_dataset = {}
    unique_generic_keywords_per_dataset = {}

    for dataset_name, df in merged_datasets.items():
        if "found_keywords" in df.columns:
            # Flatten the list of keywords
            all_keywords = [
                keyword
                for sublist in df["found_keywords"].dropna()
                for keyword in sublist
            ]
            unique_keywords = set(all_keywords)

            # Separate brand and generic keywords
            unique_brand_keywords = unique_keywords.intersection(brand_keywords)
            unique_generic_keywords = unique_keywords.intersection(generic_keywords)

            # Store unique keywords for later use
            unique_brand_keywords_per_dataset[dataset_name] = unique_brand_keywords
            unique_generic_keywords_per_dataset[dataset_name] = unique_generic_keywords

            # Count total number of words (keywords)
            total_keyword_length = len(all_keywords)

            # Count brand and generic keywords
            brand_keyword_count = sum(
                1 for keyword in all_keywords if keyword in brand_keywords
            )
            generic_keyword_count = sum(
                1 for keyword in all_keywords if keyword in generic_keywords
            )

            # divide by 3 to get the number of questions
            total_keyword_length = total_keyword_length // 3
            brand_keyword_count = brand_keyword_count // 3
            generic_keyword_count = generic_keyword_count // 3

            # get n questions
            n = len(df) // 3

            # Collect the results
            results.append(
                {
                    "dataset": dataset_name,
                    "split": split,
                    "n_questions": n,
                    "unique_brand_keywords": list(unique_brand_keywords),
                    "unique_generic_keywords": list(unique_generic_keywords),
                    "total_keyword_length": total_keyword_length,
                    "brand_keywords_count": brand_keyword_count,
                    "generic_keywords_count": generic_keyword_count,
                }
            )
        else:
            print(f"No 'found_keywords' column in {dataset_name}")

    return (
        pd.DataFrame(results),
        unique_brand_keywords_per_dataset,
        unique_generic_keywords_per_dataset,
    )


def filter_and_save_mappings(
    dataset_name,
    unique_brand_keywords,
    unique_generic_keywords,
    brand_to_generic_df,
    generic_to_brand_df,
):
    filtered_brand_to_generic = brand_to_generic_df[
        brand_to_generic_df["brand"].isin(unique_brand_keywords)
    ]
    filtered_generic_to_brand = generic_to_brand_df[
        generic_to_brand_df["generic"].isin(unique_generic_keywords)
    ]

    # clean dataset name
    dataset_name = dataset_name.replace("/", "_")

    filtered_brand_to_generic.to_csv(
        f"../RxNorm_eval/filtered_brand_to_generic_{dataset_name}.csv", index=False
    )
    filtered_generic_to_brand.to_csv(
        f"../RxNorm_eval/filtered_generic_to_brand_{dataset_name}.csv", index=False
    )


# Load brand and generic keyword maps
brand_to_generic_path = "../RxNorm_eval/filtered_keywords.csv"
generic_to_brand_path = "../RxNorm_eval/filtered_keywords.csv"
brand_to_generic_df, generic_to_brand_df = load_brand_generic_maps(
    brand_to_generic_path, generic_to_brand_path
)

# Assuming merge_all_datasets function is already defined
merged_datasets = merge_all_datasets(pre_filtered_df_dir, split)

# Load brand and generic keywords
brand_keywords = set(brand_to_generic_df["brand"])
generic_keywords = set(generic_to_brand_df["generic"])

# Count keywords and get the result as a DataFrame
keywords_df, unique_brand_keywords_per_dataset, unique_generic_keywords_per_dataset = (
    count_keywords(merged_datasets, split, brand_keywords, generic_keywords)
)

# Save the keywords_df
keywords_df.to_csv(f"../RxNorm_eval/keywords_count_{split}.csv", index=False)

# Filter and save brand_to_generic and generic_to_brand mappings for each dataset
for dataset_name in unique_brand_keywords_per_dataset:
    unique_brand_keywords = unique_brand_keywords_per_dataset[dataset_name]
    unique_generic_keywords = unique_generic_keywords_per_dataset[dataset_name]
    filter_and_save_mappings(
        dataset_name,
        unique_brand_keywords,
        unique_generic_keywords,
        brand_to_generic_df,
        generic_to_brand_df,
    )

# Display the result
keywords_df


Merged Dataset: medmcqa
Split: test
Contains 1617 rows.
539 transformations.

Merged Dataset: GBaker/MedQA-USMLE-4-options-hf
Split: test
Contains 1323 rows.
441 transformations.


Unnamed: 0,dataset,split,n_questions,unique_brand_keywords,unique_generic_keywords,total_keyword_length,brand_keywords_count,generic_keywords_count
0,medmcqa,test,539,[],"[methylprednisolone, flecainide, terazosin, carboplatin, etoposide, rifampin, bran, chlorhexidine, ibutilide, ibuprofen, tizanidine, griseofulvin, quinine, lenalidomide, bupivacaine, metoclopramide, etonogestrel, propofol, ketorolac, naratriptan, exenatide, celiprolol, cimetidine, clomipramine, fomepizole, amiodarone, paracetamol, linezolid, methoxyflurane, dapsone, ciprofloxacin, tacrolimus, selegiline, mebendazole, zileuton, procainamide, podophyllin, metoprolol, etomidate, oxybutynin, topotecan, dipivefrine, suxamethonium, chlorambucil, phenoxybenzamine, carbamazepine, dicloxacillin, ofloxacin, trihexyphenidyl, sufentanil, cycloserine, dinoprostone, chloroquine, capecitabine, nitisinone, desflurane, benzocaine, gatifloxacin, simvastatin, metyrapone, prasugrel, physostigmine, miltefosine, ciclesonide, abatacept, risedronate, omalizumab, mesalazine, digoxin, piroxicam, sertraline, glutaraldehyde, tropicamide, ticlopidine, cisplatin, cetuximab, erythromycin, levamisole, olanzapine, sumatriptan, zafirlukast, sulfasalazine, omeprazole, triprolidine, ampicillin, levodopa, primidone, benazepril, daptomycin, foscarnet, fondaparinux, econazole, flumazenil, tamoxifen, moxifloxacin, cephalothin, ticagrelor, prednisolone, triptorelin, zonisamide, ...]",1002,0,1000
1,GBaker/MedQA-USMLE-4-options-hf,test,441,[],"[methylprednisolone, alendronate, epinephrine, aldesleukin, moxifloxacin, hydromorphone, carboplatin, ribavirin, oxycodone, nifedipine, etoposide, lansoprazole, rifampin, prednisolone, furosemide, calamine, atenolol, paroxetine, gabapentin, chlorhexidine, metronidazole, ibuprofen, enoxaparin, isoflurane, tolvaptan, scopolamine, glucagon, dantrolene, griseofulvin, methenamine, diphenhydramine, oxytocin, anakinra, liraglutide, chlorthalidone, dihydroergotamine, pramlintide, modafinil, isotretinoin, metoclopramide, miglitol, deferoxamine, fluorometholone, hydrochlorothiazide, lovastatin, estriol, pantoprazole, propofol, azithromycin, dexamethasone, ramelteon, methylphenidate, doxycycline, dobutamine, tolterodine, diethylstilbestrol, ketorolac, amlodipine, edta, temazepam, rasburicase, exenatide, atorvastatin, flucytosine, rosuvastatin, dextroamphetamine, cimetidine, clomipramine, isoniazid, vincristine, demeclocycline, fomepizole, amiodarone, naltrexone, lorazepam, fludrocortisone, atropine, naloxone, maraviroc, albendazole, dapsone, infliximab, ciprofloxacin, selegiline, sevoflurane, mebendazole, etanercept, carvedilol, erythropoietin, thalidomide, rivastigmine, lamivudine, indomethacin, baclofen, tetrahydrobiopterin, nitroprusside, fluoxetine, zileuton, triamterene, labetalol, ...]",950,0,947


In [8]:
print(keywords_df["unique_generic_keywords"][1])

['methylprednisolone', 'alendronate', 'epinephrine', 'aldesleukin', 'moxifloxacin', 'hydromorphone', 'carboplatin', 'ribavirin', 'oxycodone', 'nifedipine', 'etoposide', 'lansoprazole', 'rifampin', 'prednisolone', 'furosemide', 'calamine', 'atenolol', 'paroxetine', 'gabapentin', 'chlorhexidine', 'metronidazole', 'ibuprofen', 'enoxaparin', 'isoflurane', 'tolvaptan', 'scopolamine', 'glucagon', 'dantrolene', 'griseofulvin', 'methenamine', 'diphenhydramine', 'oxytocin', 'anakinra', 'liraglutide', 'chlorthalidone', 'dihydroergotamine', 'pramlintide', 'modafinil', 'isotretinoin', 'metoclopramide', 'miglitol', 'deferoxamine', 'fluorometholone', 'hydrochlorothiazide', 'lovastatin', 'estriol', 'pantoprazole', 'propofol', 'azithromycin', 'dexamethasone', 'ramelteon', 'methylphenidate', 'doxycycline', 'dobutamine', 'tolterodine', 'diethylstilbestrol', 'ketorolac', 'amlodipine', 'edta', 'temazepam', 'rasburicase', 'exenatide', 'atorvastatin', 'flucytosine', 'rosuvastatin', 'dextroamphetamine', 'cim

In [None]:
# load g2b medmcqa

In [None]:
# processing etc

In [None]:
# local_ids that we want to drop for unique datasets
## dictionary of dataset and local_ids to drop
drop_local_ids = {
    "augtoma/usmle_step_1": [],
    "augtoma/usmle_step_2": [],
    "augtoma/usmle_step_3": [],
    "bigbio/pubmed_qa": [],
    "GBaker/MedQA-USMLE-4-options-hf": [],
    "medmcqa": [],
    "hails/mmlu_no_train/anatomy": [],
    "hails/mmlu_no_train/clinical_knowledge": [],
    "hails/mmlu_no_train/college_biology": [],
    "hails/mmlu_no_train/college_medicine": [],
    "hails/mmlu_no_train/medical_genetics": [],
    "hails/mmlu_no_train/professional_medicine": [],
}

# drop local_ids
for dataset_name, local_ids in drop_local_ids.items():
    if dataset_name in merged_datasets:
        merged_datasets[dataset_name] = merged_datasets[dataset_name][
            ~merged_datasets[dataset_name]["local_id"].isin(local_ids)
        ]

In [None]:
# Now write out the datasets back to independent files per transformation
for dataset_name, dataset in merged_datasets.items():
    for transformation in dataset["transformation_type"].unique():
        dataset_filtered = dataset[dataset["transformation_type"] == transformation]
        if len(dataset_filtered) == 0:
            print(
                f"Skipping dataset {dataset_name}_{transformation} as it has no rows."
            )
            continue
        dataset_filtered.drop(columns=["transformation_type"], inplace=True)
        dataset_filtered.reset_index(drop=True, inplace=True)
        dataset_filtered_path = os.path.join(
            output_dir,
            f"{dataset_name.replace('/', '_')}",
            f"{split}",
            f"{dataset_name.replace('/', '_')}_{transformation}",
            f"{split}.parquet",
        )
        if not os.path.exists(os.path.dirname(dataset_filtered_path)):
            os.makedirs(os.path.dirname(dataset_filtered_path))
        dataset_filtered.to_parquet(
            dataset_filtered_path,
        )
        print(
            f"Writing out dataset: {dataset_name}_{transformation} to {output_dir}/{dataset_name.replace('/', '_')}/{split}"
        )

Writing out dataset: medmcqa_original_filtered to ../datasets/medmcqa/test
Writing out dataset: medmcqa_brand_to_generic_filtered to ../datasets/medmcqa/test
Writing out dataset: medmcqa_generic_to_brand_filtered to ../datasets/medmcqa/test
Writing out dataset: bigbio/pubmed_qa_original_filtered to ../datasets/bigbio_pubmed_qa/test
Writing out dataset: bigbio/pubmed_qa_brand_to_generic_filtered to ../datasets/bigbio_pubmed_qa/test
Writing out dataset: bigbio/pubmed_qa_generic_to_brand_filtered to ../datasets/bigbio_pubmed_qa/test
Writing out dataset: GBaker/MedQA-USMLE-4-options-hf_original_filtered to ../datasets/GBaker_MedQA-USMLE-4-options-hf/test
Writing out dataset: GBaker/MedQA-USMLE-4-options-hf_brand_to_generic_filtered to ../datasets/GBaker_MedQA-USMLE-4-options-hf/test
Writing out dataset: GBaker/MedQA-USMLE-4-options-hf_generic_to_brand_filtered to ../datasets/GBaker_MedQA-USMLE-4-options-hf/test
Writing out dataset: augtoma/usmle_step_1_original_filtered to ../datasets/augt

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_filtered.drop(columns=["transformation_type"], inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_filtered.drop(columns=["transformation_type"], inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_filtered.drop(columns=["transformation_type"], inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing