# Multimodal Medical Classification: OmniMedVQA Data Exploration

This notebook performs **data exploration** on the OmniMedVQA dataset: loading the data, ensuring schema consistency, and analyzing the distribution of question types and modalities. The notebook is structured to serve as both a **working analysis** and a **report**.

## Environment Setup

### Import Libraries

In [None]:
import os, json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset, Dataset
import torch

### Detect Device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

## Loading the Dataset

### 

Use Hugging Face `datasets` library with `load_dataset` and point it at your local JSON files.

In [None]:
qa_dir = "./data/OmniMedVQA/QA_information/Open-access"
json_files = [os.path.join(qa_dir, f) for f in os.listdir(qa_dir) if f.endswith(".json")]

### Visualize the Schema

(`dataset`, `question_id`, `question_type`, `question`, `gt_answer`, `image_path`, `option_*`, `modality_type`)

In [None]:
qa_dir = "./data/OmniMedVQA/QA_information/Open-access"
json_files = [os.path.join(qa_dir, f) for f in os.listdir(qa_dir) if f.endswith(".json")]

schema_dict = {}

for f in json_files:
    try:
        df = pd.read_json(f)   # load normally (array of dicts)
        schema_dict[os.path.basename(f)] = set(df.columns)
    except Exception as e:
        print(f"Error reading {f}: {e}")

# Show per-file columns
for fname, cols in schema_dict.items():
    print(f"\n{fname} ({len(cols)} columns):")
    print(sorted(cols))

# Compare schemas against the "reference" (most common set of columns)
from collections import Counter

# Count frequency of each schema
schemas = [tuple(sorted(cols)) for cols in schema_dict.values()]
most_common_schema, _ = Counter(schemas).most_common(1)[0]

print("\nReference schema:", most_common_schema)

for fname, cols in schema_dict.items():
    extra = cols - set(most_common_schema)
    missing = set(most_common_schema) - cols
    if extra or missing:
        print(f"\n⚠️ {fname}")
        if extra:
            print("  Extra columns:", extra)
        if missing:
            print("  Missing columns:", missing)


### Note on Schema Inconsistency

While inspecting the JSON files, we found that **`Chest CT Scan.json` contains a single entry using the key `"modality"` instead of `"modality_type"`**.  

This would prevent merging all JSON files into a single dataset because Hugging Face requires consistent column names.  

Below, we automatically correct this entry so that `"modality"` is renamed to `"modality_type"` for consistency.

In [None]:
# Fix "modality" -> "modality_type" in all JSON files if it exists
import json

for f in json_files:
    with open(f, "r", encoding="utf-8") as file:
        data = json.load(file)
    
    modified = False
    for entry in data:
        if "modality" in entry:
            entry["modality_type"] = entry.pop("modality")
            modified = True
    
    if modified:
        print(f"Fixed schema in {os.path.basename(f)}")
        # Optionally overwrite the original file
        with open(f, "w", encoding="utf-8") as file:
            json.dump(data, file, indent=2)


### Create a Unified Dataset

(merge across multiple JSONs in `QA_information/Open-access` at minimum)

In [None]:
dataset = load_dataset("json", data_files=json_files, split="train")

## Analyze Modalities

### Count Samples

- Total number of QA items (length of dataset)
- Number of unique images
- Number of datasets represented

### Class Distribution

- Distribution of `question_type` (e.g., Anatomy Identification, Modality Classification, etc.)
- Distribution of `gt_answer` (top classes, long-tail?)
- Dataset-level distribution (are some datasets overrepresented?)

### Modalities

- Count of unique `modality_type` (CT, MRI, X-ray, Ultrasound, etc.)
- Visualize counts with bar plots or pie charts
- Comment on coverage across modalities (balanced vs skewed)