In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile
import os
import shutil
from tqdm import tqdm

zip_path = '/content/drive/MyDrive/OmniMedVQA/Images.zip'
extract_dir = '/content/Images'

os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zipf:
    file_list = zipf.namelist()
    for member in tqdm(file_list, desc="Unzipping images"):
        # Skip directories
        if member.endswith('/'):
            continue
        # Remove leading "Images/" from the path
        if member.startswith('Images/'):
            target = member[len('Images/'):]
        else:
            target = member
        target_path = os.path.join(extract_dir, target)
        os.makedirs(os.path.dirname(target_path), exist_ok=True)
        with zipf.open(member) as source, open(target_path, "wb") as dest:
            shutil.copyfileobj(source, dest)

Unzipping images: 100%|██████████| 6675/6675 [00:08<00:00, 813.21it/s] 


In [None]:
import pandas as pd

train_df = pd.read_csv('/content/drive/MyDrive/OmniMedVQA/CSV_Files/train_sampled_with_cot.csv')
val_df = pd.read_csv('/content/drive/MyDrive/OmniMedVQA/CSV_Files/val_sampled_with_cot.csv')
merged_df = pd.concat([train_df, val_df], ignore_index=True)

In [None]:
merged_df.head()

Unnamed: 0,dataset,question_id,question_type,question,gt_answer,image_path,option_A,option_B,option_C,option_D,modality_type,modality,abs_image_path,is_clear,medvlm_r1_cot
0,RadImageNet,RadImageNet_45353,Disease Diagnosis,Is there any presence of deviation in this image?,"No, It's normal.",Images/RadImageNet/normal/ankle-normal002887.png,It's difficult to determine if there is an abn...,"There is a possibility of abnormality, but fur...","No, It's normal.","Yes, it shows a significant abnormality.",MR (Mag-netic Resonance Imaging),,/content/Images/RadImageNet/normal/ankle-norma...,True,<think>\nTo determine if there is a deviation ...
1,RadImageNet,RadImageNet_26408,Modality Recognition,Which imaging technique was utilized to captur...,MRI,Images/RadImageNet/normal/mri-abd-normal057351...,Ultrasound,Bone scan,MRI,X-ray,MR (Mag-netic Resonance Imaging),,/content/Images/RadImageNet/normal/mri-abd-nor...,True,<think>\nTo determine the correct imaging tech...
2,RadImageNet,RadImageNet_30245,Disease Diagnosis,What is the observed pathology in this image?,Atfl pathology.,Images/RadImageNet/atfl_pathology/ankle081048.png,Joint dislocation,Cartilage degeneration,Tendon inflammation,Atfl pathology.,MR (Mag-netic Resonance Imaging),,/content/Images/RadImageNet/atfl_pathology/ank...,True,<think>\nThe image appears to be a magnetic re...
3,RadImageNet,RadImageNet_47904,Disease Diagnosis,Is there any irregularity or anomalous finding...,"No, It's normal.",Images/RadImageNet/normal/ankle-normal002815.png,"There is a possibility of abnormality, but fur...",It's difficult to determine if there is an abn...,"Yes, it shows a significant abnormality.","No, It's normal.",MR (Mag-netic Resonance Imaging),,/content/Images/RadImageNet/normal/ankle-norma...,True,<think>\nThe MRI image shows a metallic block ...
4,RadImageNet,RadImageNet_29269,Disease Diagnosis,What can be observed in this image?,ACL pathology.,Images/RadImageNet/acl_pathology/knee086534.png,ACL pathology.,Rheumatoid arthritis,Plantar fasciitis,Shin splints,MR (Mag-netic Resonance Imaging),,/content/Images/RadImageNet/acl_pathology/knee...,True,<think>\nThe image is a magnetic resonance ima...


In [None]:
merged_df.to_csv('/content/drive/MyDrive/OmniMedVQA/CSV_Files/train_1200.csv', index=False)

In [None]:
import pandas as pd
import json
import os
import re
from tqdm import tqdm

# Configuration
PROJECT_DIR = "/content/drive/MyDrive/OmniMedVQA"
CSV_DIR = f"{PROJECT_DIR}/CSV_Files"
SFT_OUTPUT_DIR = f"{PROJECT_DIR}/llamafactory_sft_clean"
DPO_OUTPUT_DIR = f"{PROJECT_DIR}/llamafactory_dpo_clean"
IMAGE_BASE_DIR = "/home/vishnu/data"

os.makedirs(SFT_OUTPUT_DIR, exist_ok=True)
os.makedirs(DPO_OUTPUT_DIR, exist_ok=True)

FILES = {
    "train": f"{CSV_DIR}/train_1200.csv",
    "test_mri": f"{CSV_DIR}/test_mri_sampled_with_cot.csv",
    "test_ct": f"{CSV_DIR}/test_ct_sampled_with_cot.csv",
    "test_xray": f"{CSV_DIR}/test_xray_sampled_with_cot.csv"
}

FINAL_SAMPLE_SIZES = {
    "train": 1100,
    "test_mri": 150,
    "test_ct": 150,
    "test_xray": 150
}

# Helper Functions

def get_ground_truth_letter(row):
    for idx, opt in enumerate(['option_A', 'option_B', 'option_C', 'option_D']):
        if str(row[opt]).strip().lower() == str(row['gt_answer']).strip().lower():
            return chr(65 + idx)
    return None

def extract_answer_from_cot(text):
    if not isinstance(text, str):
        return None
    match = re.search(r"<answer>\s*([A-D])\s*</answer>", text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    return None

def create_sft_prompt(row):
    return (
        f"{row['question']}\n"
        f"A) {row['option_A']}\nB) {row['option_B']}\nC) {row['option_C']}\nD) {row['option_D']}\n"
        "Your task:\n"
        "1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.\n"
        "2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.\n"
        "3. No extra information or text outside of these tags."
    )

def create_dpo_prompt(row):
    return (
        "<image>\n"
        f"{row['question']}\n"
        f"A) {row['option_A']}\nB) {row['option_B']}\nC) {row['option_C']}\nD) {row['option_D']}\n"
        "Your task:\n"
        "1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.\n"
        "2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.\n"
        "3. No extra information or text outside of these tags."
    )

def create_generic_reasoning_output(row):
    answer_letter = get_ground_truth_letter(row)
    return (
        f"<think>Based on the image and options, the correct answer is {answer_letter}.</think>\n"
        f"<answer>{answer_letter}</answer>"
    )

# Main Processing Logic

print("Starting dataset processing...")

for split, csv_path in FILES.items():
    print(f"\n--- Processing split: {split} ---")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} total samples from {csv_path}")

    # Filter for samples where MedVLM-R1's answer is correct
    df['gt_letter'] = df.apply(get_ground_truth_letter, axis=1)
    df['pred_letter'] = df['medvlm_r1_cot'].apply(extract_answer_from_cot)
    correct_df = df[df['gt_letter'] == df['pred_letter']].copy()
    print(f"Found {len(correct_df)} samples where MedVLM-R1 was correct.")

    n_samples = FINAL_SAMPLE_SIZES.get(split)
    if len(correct_df) < n_samples:
        print(f"Warning: Not enough correct samples for '{split}'. Found {len(correct_df)}, needed {n_samples}. Using all available.")
        n_samples = len(correct_df)
    sampled_df = correct_df.sample(n=n_samples, random_state=42).reset_index(drop=True)
    print(f"Sampled {len(sampled_df)} correct items for the final dataset.")

    sampled_df['image_path'] = sampled_df['image_path'].apply(lambda x: os.path.join(IMAGE_BASE_DIR, x) if not os.path.isabs(x) else x)

    # SFT: Save generic, correct outputs
    sft_jsonl_path = os.path.join(SFT_OUTPUT_DIR, f"{split}.jsonl")
    with open(sft_jsonl_path, 'w') as f:
        for _, row in tqdm(sampled_df.iterrows(), total=len(sampled_df), desc=f"Writing SFT {split}"):
            entry = {
                "instruction": "You are a medical assistant. Answer based on the image and question.",
                "input": create_sft_prompt(row),
                "output": create_generic_reasoning_output(row),  # <-- Generic, correct output
                "images": [row["image_path"]]
            }
            f.write(json.dumps(entry) + "\n")
    print(f"Saved SFT data to {sft_jsonl_path}")

    # DPO: Save detailed CoT as chosen, generic as rejected
    dpo_jsonl_path = os.path.join(DPO_OUTPUT_DIR, f"{split}_dpo.jsonl")
    with open(dpo_jsonl_path, 'w') as f:
        for _, row in tqdm(sampled_df.iterrows(), total=len(sampled_df), desc=f"Writing DPO {split}"):
            entry = {
                "prompt": create_dpo_prompt(row),
                "image_path": row["image_path"],
                "chosen": row["medvlm_r1_cot"],
                "rejected": create_generic_reasoning_output(row)
            }
            f.write(json.dumps(entry) + "\n")
    print(f"Saved DPO data to {dpo_jsonl_path}")

print("\nAll splits processed successfully!")

Starting dataset processing...

--- Processing split: train ---
Loaded 1200 total samples from /content/drive/MyDrive/OmniMedVQA/CSV_Files/train_1200.csv
Found 1128 samples where MedVLM-R1 was correct.
Sampled 1100 correct items for the final dataset.


Writing SFT train: 100%|██████████| 1100/1100 [00:00<00:00, 14116.58it/s]


Saved SFT data to /content/drive/MyDrive/OmniMedVQA/llamafactory_sft_clean/train.jsonl


Writing DPO train: 100%|██████████| 1100/1100 [00:00<00:00, 12075.15it/s]


Saved DPO data to /content/drive/MyDrive/OmniMedVQA/llamafactory_dpo_clean/train_dpo.jsonl

--- Processing split: test_mri ---
Loaded 200 total samples from /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_mri_sampled_with_cot.csv
Found 191 samples where MedVLM-R1 was correct.
Sampled 150 correct items for the final dataset.


Writing SFT test_mri: 100%|██████████| 150/150 [00:00<00:00, 8785.72it/s]


Saved SFT data to /content/drive/MyDrive/OmniMedVQA/llamafactory_sft_clean/test_mri.jsonl


Writing DPO test_mri: 100%|██████████| 150/150 [00:00<00:00, 11006.75it/s]


Saved DPO data to /content/drive/MyDrive/OmniMedVQA/llamafactory_dpo_clean/test_mri_dpo.jsonl

--- Processing split: test_ct ---
Loaded 200 total samples from /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_ct_sampled_with_cot.csv
Found 157 samples where MedVLM-R1 was correct.
Sampled 150 correct items for the final dataset.


Writing SFT test_ct: 100%|██████████| 150/150 [00:00<00:00, 9761.31it/s]


Saved SFT data to /content/drive/MyDrive/OmniMedVQA/llamafactory_sft_clean/test_ct.jsonl


Writing DPO test_ct: 100%|██████████| 150/150 [00:00<00:00, 12209.78it/s]


Saved DPO data to /content/drive/MyDrive/OmniMedVQA/llamafactory_dpo_clean/test_ct_dpo.jsonl

--- Processing split: test_xray ---
Loaded 200 total samples from /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_xray_sampled_with_cot.csv
Found 157 samples where MedVLM-R1 was correct.
Sampled 150 correct items for the final dataset.


Writing SFT test_xray: 100%|██████████| 150/150 [00:00<00:00, 10563.05it/s]


Saved SFT data to /content/drive/MyDrive/OmniMedVQA/llamafactory_sft_clean/test_xray.jsonl


Writing DPO test_xray: 100%|██████████| 150/150 [00:00<00:00, 10406.67it/s]

Saved DPO data to /content/drive/MyDrive/OmniMedVQA/llamafactory_dpo_clean/test_xray_dpo.jsonl

All splits processed successfully!





In [None]:
import pandas as pd
import json
import os

def get_gt_letter(row):
    for idx, opt in enumerate(['option_A', 'option_B', 'option_C', 'option_D']):
        if str(row[opt]).strip().lower() == str(row['gt_answer']).strip().lower():
            return chr(65 + idx)
    return None

def create_prompt(row):
    return (
        f"{row['question']}\n"
        f"A) {row['option_A']}\nB) {row['option_B']}\nC) {row['option_C']}\nD) {row['option_D']}\n"
        "Your task:\n"
        "1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.\n"
        "2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.\n"
        "3. No extra information or text outside of these tags."
    )

def create_generic_output(gt_letter):
    return (
        f"<think>Based on the image and options, the correct answer is {gt_letter}.</think>\n"
        f"<answer>{gt_letter}</answer>"
    )

def convert_csv_to_sft_jsonl(csv_path, jsonl_path):
    image_base_dir = '/home/vishnu/data'
    df = pd.read_csv(csv_path)
    with open(jsonl_path, 'w') as f:
        for _, row in df.iterrows():
            gt_letter = get_gt_letter(row)
            if gt_letter is None:
                continue  # skip if no match
            rel_path = row['image_path'].lstrip('/')
            image_path = os.path.join(image_base_dir, rel_path)
            entry = {
                "instruction": "You are a medical assistant. Answer based on the image and question.",
                "input": create_prompt(row),
                "output": create_generic_output(gt_letter),
                "images": [image_path]
            }
            f.write(json.dumps(entry) + "\n")

# Define input and output directories
input_dir = "/content/drive/MyDrive/OmniMedVQA/CSV_Files"
output_dir = "/content/drive/MyDrive/OmniMedVQA/SFT_JSONL_Files"
os.makedirs(output_dir, exist_ok=True)

# List of test files
test_files = ["test_mri.csv", "test_ct.csv", "test_xray.csv"]

# Convert each CSV to SFT-compatible JSONL
for fname in test_files:
    csv_path = os.path.join(input_dir, fname)
    jsonl_path = os.path.join(output_dir, fname.replace('.csv', '.jsonl'))
    convert_csv_to_sft_jsonl(csv_path, jsonl_path)
    print(f"Converted {csv_path} to {jsonl_path}")


Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_mri.csv to /content/drive/MyDrive/OmniMedVQA/SFT_JSONL_Files/test_mri.jsonl
Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_ct.csv to /content/drive/MyDrive/OmniMedVQA/SFT_JSONL_Files/test_ct.jsonl
Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_xray.csv to /content/drive/MyDrive/OmniMedVQA/SFT_JSONL_Files/test_xray.jsonl


In [None]:
import pandas as pd
import json
import os

def get_gt_letter(row):
    for idx, opt in enumerate(['option_A', 'option_B', 'option_C', 'option_D']):
        if str(row[opt]).strip().lower() == str(row['gt_answer']).strip().lower():
            return chr(65 + idx)
    return None

def create_prompt(row):
    return (
        "<image>\n"
        f"{row['question']}\n"
        f"A) {row['option_A']}\nB) {row['option_B']}\nC) {row['option_C']}\nD) {row['option_D']}\n"
        "Your task:\n"
        "1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.\n"
        "2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.\n"
        "3. No extra information or text outside of these tags."
    )

def convert_csv_to_jsonl(csv_path, jsonl_path):
    image_base_dir = '/home/vishnu/data'
    df = pd.read_csv(csv_path)
    with open(jsonl_path, 'w') as f:
        for _, row in df.iterrows():
            gt_letter = get_gt_letter(row)
            if gt_letter is None:
                continue  # skip if no match
            rel_path = row['image_path'].lstrip('/')
            image_path = os.path.join(image_base_dir, rel_path)
            prompt = create_prompt(row)
            entry = {
                "prompt": prompt,
                "image_path": image_path,
                "gt_letter": gt_letter
            }
            f.write(json.dumps(entry) + "\n")

# Define input and output directories
input_dir = "/content/drive/MyDrive/OmniMedVQA/CSV_Files"
output_dir = "/content/drive/MyDrive/OmniMedVQA/DPO_JSONL_Files"
os.makedirs(output_dir, exist_ok=True)

# List of test files
test_files = ["test_mri.csv", "test_ct.csv", "test_xray.csv"]

# Convert each CSV to JSONL
for fname in test_files:
    csv_path = os.path.join(input_dir, fname)
    jsonl_path = os.path.join(output_dir, fname.replace('.csv', '.jsonl'))
    convert_csv_to_jsonl(csv_path, jsonl_path)
    print(f"Converted {csv_path} to {jsonl_path}")


Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_mri.csv to /content/drive/MyDrive/OmniMedVQA/DPO_JSONL_Files/test_mri.jsonl
Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_ct.csv to /content/drive/MyDrive/OmniMedVQA/DPO_JSONL_Files/test_ct.jsonl
Converted /content/drive/MyDrive/OmniMedVQA/CSV_Files/test_xray.csv to /content/drive/MyDrive/OmniMedVQA/DPO_JSONL_Files/test_xray.jsonl
