In [None]:
import os
import json
import pandas as pd
from PIL import Image, ImageStat
import shutil
import random

random.seed(42)

# Load all JSONs into DataFrame
qa_dir = 'OmniMedVQA/QA_information/Open-access'
all_samples = []
for fname in os.listdir(qa_dir):
    if fname.endswith('.json'):
        with open(os.path.join(qa_dir, fname), 'r') as f:
            all_samples.extend(json.load(f))
full_df = pd.DataFrame(all_samples)

# Filter for clear images
def is_clear_image(img_path, min_size=128, min_var=10):
    try:
        img = Image.open(img_path)
        if min(img.size) < min_size:
            return False
        stat = ImageStat.Stat(img)
        if min(stat.var) < min_var:
            return False
        return True
    except Exception:
        return False

dataset_root = 'OmniMedVQA'
full_df['abs_image_path'] = full_df['image_path'].apply(lambda x: os.path.join(dataset_root, x))
full_df['is_clear'] = full_df['abs_image_path'].apply(is_clear_image)
full_df = full_df[full_df['is_clear']].copy()

# Filter for meaningful VQA pairs
def is_good_vqa(row, min_q=10, min_a=1):
    q = str(row['question']).strip()
    a = str(row['gt_answer']).strip()
    return len(q) >= min_q and len(a) >= min_a and a.lower() not in ['unclear', 'unknown', 'no finding', 'none']

full_df = full_df[full_df.apply(is_good_vqa, axis=1)].copy()

# Filter by modality (MRI, CT, X-ray only; exclude OCT)
oct_mask = full_df['modality_type'].str.contains('OCT', case=False)
full_df = full_df[~oct_mask].copy()
mri_df = full_df[(full_df['modality_type'].str.contains('MR', case=False)) & (full_df['dataset'] == 'RadImageNet')]
ct_df = full_df[(full_df['modality_type'].str.contains('CT', case=False)) & (~full_df['modality_type'].str.contains('OCT', case=False))]
xray_df = full_df[(full_df['modality_type'].str.contains('X-Ray', case=False))]

# Sample splits according to your plan
train_mri = mri_df.sample(n=min(2000, len(mri_df)), random_state=42)
remaining_mri = mri_df[~mri_df['image_path'].isin(train_mri['image_path'])]
val_mri = remaining_mri.sample(n=min(300, len(remaining_mri)), random_state=42)
remaining_mri = remaining_mri[~remaining_mri['image_path'].isin(val_mri['image_path'])]
test_mri = remaining_mri.sample(n=min(300, len(remaining_mri)), random_state=42)
test_ct = ct_df.sample(n=min(300, len(ct_df)), random_state=42)
test_xray = xray_df.sample(n=min(300, len(xray_df)), random_state=42)

# Save splits as CSVs
train_mri.to_csv('train.csv', index=False)
val_mri.to_csv('val.csv', index=False)
test_mri.to_csv('test_mri.csv', index=False)
test_ct.to_csv('test_ct.csv', index=False)
test_xray.to_csv('test_xray.csv', index=False)

# Copy needed images to new folder
output_root = 'OmniMedVQA_Subset_MedVLMR1'
needed_images = set()
for df in [train_mri, val_mri, test_mri, test_ct, test_xray]:
    needed_images.update(df['image_path'])

for rel_path in needed_images:
    src = os.path.join(dataset_root, rel_path)
    dst = os.path.join(output_root, rel_path)
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    if os.path.exists(src):
        shutil.copy2(src, dst)
    else:
        print(f"Missing: {src}")

# Verify all images exist
for split, df in zip(
    ['train.csv', 'val.csv', 'test_mri.csv', 'test_ct.csv', 'test_xray.csv'],
    [train_mri, val_mri, test_mri, test_ct, test_xray]
):
    missing = []
    for rel_path in df['image_path']:
        full_path = os.path.join(output_root, rel_path)
        if not os.path.exists(full_path):
            missing.append(rel_path)
    if missing:
        print(f"{split}: {len(missing)} missing images")
        print("Examples:", missing[:5])
    else:
        print(f"{split}: All images present!")

train.csv: All images present!
val.csv: All images present!
test_mri.csv: All images present!
test_ct.csv: All images present!
test_xray.csv: All images present!


In [9]:
train_mri['modality_type'].value_counts()

modality_type
MR (Mag-netic Resonance Imaging)    2000
Name: count, dtype: int64

In [10]:
val_mri['modality_type'].value_counts()

modality_type
MR (Mag-netic Resonance Imaging)    300
Name: count, dtype: int64

In [11]:
test_mri['modality_type'].value_counts()

modality_type
MR (Mag-netic Resonance Imaging)    300
Name: count, dtype: int64

In [12]:
test_ct['modality_type'].value_counts()

modality_type
CT(Computed Tomography)    300
Name: count, dtype: int64

In [13]:
test_xray['modality_type'].value_counts()

modality_type
X-Ray    300
Name: count, dtype: int64