In [1]:
import json
import copy
import random
from cns_llava.utils import load_variable_json
from cns_llava.utils import save_variable_json

In [2]:
differential_prompts = [
    "Provide a differential:",
    "Provide a differential",
    "DDx:",
    "DDx",
    "Give a differential diagnosis",
    "diagnose",
    "what's the differential?",
    "List possible diagnoses",
    "differential diagnosis",
    "What could it be",
    "suggest possible conditions",
    "Provide a list of potential diagnoses.",
    "what are the diagnostic possibilities",
    "Create a differential list",
    "what should be considered in the differential?",
    "generate a differential diagnosis",
    "What's your differential for this case",
    "Differentials to consider",
    "what diagnoses should we rule out",
    "Provide a diagnostic workup.",
    "what's on your differential list?",
    "Give me your top differential diagnoses",
    "possible diagnoses?",
    "whats the ddx",
    "Differential dx",
    "list diagnoses to consider",
    "What are the possibilities",
    "give me a differential",
    "diagnostic considerations?",
    "What's your ddx",
    "potential diagnoses",
    "differential diagnosis list",
    "what could be causing this",
    "provide differential diagnoses",
    "diagnosis possibilities",
    "what's in your differential",
    "give possible explanations",
    "differential dx list",
    "what diagnoses fit",
    "list potential conditions",
    "differential considerations",
    "what's your diagnostic thinking",
    "possible conditions?",
    "provide a ddx",
    "what should I consider",
    "differential possibilities",
    "diagnostic hypotheses",
    "what are the differential diagnoses",
    "list diagnoses",
    "whats your differential",
    "give diagnostic options",
    "differential ideas?"
    "",
    "",
    "",
    ""
]

# Function to insert a random prompt into the question field
def insert_random_prompt(entry):
    question_dict = json.loads(entry["question"])
    question_dict["prompt"] = random.choice(differential_prompts)
    entry["question"] = json.dumps(question_dict)
    return entry

In [3]:
for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/dataset_claude_ddx.json")
    modified_data = [insert_random_prompt(entry) for entry in data]
    # save_variable_json(modified_data, f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/dataset_claude_ddx.json")

In [4]:
def add_conversations(items, mode):
    modified_items = copy.deepcopy(items)
    for item in modified_items:
        item["mode"] = mode
        item["conversations"] = []
        
        if mode == "ift":
            # Handle IFT (multiple questions) format
            for i, qa in enumerate(json.loads(item["question"].replace('\n', ''))):
                item["conversations"].append({
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": qa["question"]
                        }
                    ] + ([{"type": "image"}] if i == 0 else [])
                })
                item["conversations"].append({
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text",
                            "text": qa["answer"]
                        }
                    ]
                })
        elif mode == "mc":
            # Handle MC (single multiple-choice question) format
            qa = json.loads(item["question"].replace('\n', ''))
            # Construct the question text with answer options
            try:
                question_text = qa["question_stem"] + "\n\n"
            except:
                print(item)
                raise
            for option in qa["answer_choices"]:
                question_text += f"{option}\n"
            
            item["conversations"].append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": question_text.strip()
                    },
                    {
                        "type": "image"
                    }
                ]
            })
            item["conversations"].append({
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": f"{qa['discussion']}\nAnswer: {qa['correct_answer']}"
                    }
                ]
            })
        elif mode == "ddx":
            # Handle DDX format
            qa = json.loads(item["question"].replace('\n', ''))
            
            # Randomly decide the order of prompt and one-liner
            prompt_first = random.choice([True, False])
            newline_count = random.choice([1, 2])
            newlines = "\n" * newline_count
            
            if prompt_first:
                user_text = f"{qa['prompt']}{newlines}{qa['one-liner']}"
            else:
                user_text = f"{qa['one-liner']}{newlines}{qa['prompt']}"
            
            item["conversations"].append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_text
                    },
                    {
                        "type": "image"
                    }
                ]
            })
            
            # Create the numbered list of differentials
            differentials = [f"{i+1}. {diff}" for i, diff in enumerate(qa['ddx'])]
            assistant_text = "\n".join(differentials)
            
            item["conversations"].append({
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": assistant_text
                    }
                ]
            })
    
    return modified_items

In [5]:
for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    full_journal_dataset = []
    for mode in ["ift", "mc", "ddx"]:
        data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/dataset_claude_{mode}.json")
        modified_data = add_conversations(data, mode=mode)
        full_journal_dataset.extend(modified_data)
    save_variable_json(full_journal_dataset, 
                           f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude.json")

In [6]:
import pandas as pd
for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    train = pd.read_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_train.json")["paper"].values
    val = pd.read_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_val.json")["paper"].values
    test = pd.read_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_test.json")["paper"].values
    
    data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude.json")
    train_claude = []
    val_claude = []
    test_claude = []
    
    for i in data:
        if i["paper"] in train:
            train_claude.append(i)
        elif i["paper"] in val:
            val_claude.append(i)
        elif i["paper"] in test:
            test_claude.append(i)  
        else:
            raise RuntimeError
            
    save_variable_json(train_claude, f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude_train.json")
    save_variable_json(val_claude, f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude_val.json")
    save_variable_json(test_claude, f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude_test.json")

In [7]:
def sort_by_mode_and_custom_id(data):
    return sorted(data, key=lambda x: (x.get('mode', 0), int(x.get('custom_id', '').split('_')[-1])))

In [8]:
for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    for typ in ["", "_train", "_val", "_test", "_claude", "_claude_train", "_claude_val", "_claude_test"]:
        file = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
        save_variable_json(sort_by_mode_and_custom_id(file),
                           f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")

In [9]:
def add_source(data, source_label):
    for entry in data:
        entry['source'] = source_label
    return data

for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    for typ in ["", "_train", "_val", "_test"]:
        gpt_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
        other_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude{typ}.json")
        
        gpt_data = add_source(gpt_data, "gpt")
        other_data = add_source(other_data, "claude")
        
        new_data = gpt_data + other_data
        
        save_variable_json(new_data, f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")


In [13]:
typ="_train"
gpt_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
other_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude{typ}.json")
total_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")
len(gpt_data), len(other_data), len(total_data)

(102604, 101183, 203787)

In [14]:
typ="_val"
gpt_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
other_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude{typ}.json")
total_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")
len(gpt_data), len(other_data), len(total_data)

(2792, 2763, 5555)

In [15]:
typ="_test"
gpt_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
other_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude{typ}.json")
total_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")
len(gpt_data), len(other_data), len(total_data)

(3094, 3053, 6147)

In [16]:
typ=""
gpt_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset{typ}.json")
other_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_claude{typ}.json")
total_data = load_variable_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")
len(gpt_data), len(other_data), len(total_data)

(108490, 106999, 215489)