# OmniMedVQA Data Cleaning for Disease Diagnosis

In [2]:
import pandas as pd

import os, re

from src.data import load_omnimed_dataset
from src.config import OPTION_COLS

  from .autonotebook import tqdm as notebook_tqdm


### Mapping Ground Truth Answers to Options (`gt_label`)

In [3]:
# Recombine all splits
train_df, val_df, test_df = load_omnimed_dataset()
full_df = pd.concat([train_df, val_df, test_df]).reset_index(drop=True)

# Example usage:
row = full_df.iloc[1]
print("Correct option key:", row["gt_label"])
print("Correct answer text:", row[row["gt_label"]])

# Columns relevant for cleaning
cols_to_clean = ['question_id', 'question', 'gt_label', 'option_A', 'option_B', 'option_C', 'option_D']
full_df = full_df[cols_to_clean].copy()

print(full_df.head())
print(f"After mapping: {len(full_df)} samples remain")

Correct option key: option_D
Correct answer text: No Finding
  question_id                                           question  gt_label  \
0  JSIEC_0046  What abnormality is present in this fundus image?  option_C   
1  JSIEC_0047  What abnormality is present in this fundus image?  option_D   
2  JSIEC_0048  Is there any abnormality present in this fundu...  option_A   
3  JSIEC_0049          What is the finding in this fundus image?  option_D   
4  JSIEC_0050  What abnormality is present in this fundus image?  option_B   

                       option_A                    option_B  \
0  Choroidal neovascularization  Central serous retinopathy   
1          Macular degeneration        Diabetic retinopathy   
2                    No Finding          Retinal detachment   
3  Choroidal neovascularization              Optic neuritis   
4                  Macular hole                  No Finding   

                       option_C              option_D  
0                    No Finding  Di

In [4]:
# Function to check for duplicate options in a row
def has_duplicate_options(row, cols):
    opts = [row[col] for col in cols]
    # Drop None/NaN values
    opts = [o for o in opts if pd.notna(o)]
    return len(opts) != len(set(opts))

# Apply to full_df
full_df["duplicate_options"] = full_df.apply(lambda r: has_duplicate_options(r, OPTION_COLS), axis=1)

# How many rows have duplicates?
num_duplicates = full_df["duplicate_options"].sum()
print(f"Number of rows with duplicate options: {num_duplicates}")

Number of rows with duplicate options: 0


### Step 0: Export Answer and Option Distributions

Rerun this cell to see progress. i.e., for part 1 below run all cells and then rerun this one and you should see most labels have been simplified into "No Finding" or "Abnormal (unspecified)" or "Inconclusive".

This should make it easier to keep track of what labels remain to be simplified/modified in the .csv files.

Note: Inconclusive labels only appear in non-ground truth options (so you will see them in all_unique_options.csv only)

In [5]:
# Folder for exploratory outputs
export_dir = os.path.join("data", "exploration_outputs")
os.makedirs(export_dir, exist_ok=True)

# Count of unique gt_answers
answer_counts = full_df.apply(lambda r: r[r["gt_label"]], axis=1).value_counts()
answer_counts_file = os.path.join(export_dir, "gt_answer_counts.csv")
answer_counts.to_csv(answer_counts_file, header=True)
print(f"Exported gt_answer counts to {answer_counts_file}")

# Export all unique options
all_options = pd.concat([full_df[col] for col in OPTION_COLS])

# Count unique options
option_counts = all_options.value_counts()

# Export to CSV
option_counts_file = os.path.join(export_dir, "all_unique_options.csv")
option_counts.to_csv(option_counts_file, header=["count"])
print(f"Exported all unique options to {option_counts_file}")

Exported gt_answer counts to data/exploration_outputs/gt_answer_counts.csv
Exported all unique options to data/exploration_outputs/all_unique_options.csv


In [6]:
print(answer_counts.head(50))  # Top 50 most frequent
print(len(answer_counts))      # Total number of unique answers

Abnormal (unspecified)                                                                                  16083
No Finding                                                                                              12261
Airspace opacity                                                                                         1470
Soft tissue fluid                                                                                        1436
Bone inflammation.                                                                                       1382
Nodule                                                                                                   1128
Soft tissue fluid collection                                                                              993
Liver lesion                                                                                              773
Soft tissue edema                                                                                         750
Renal lesi

In [7]:
# This is to check what kinds of questions are being asked for a given answer
target_answer = None
mask = (
    (full_df['option_A'] == target_answer) | 
    (full_df['option_B'] == target_answer) | 
    (full_df['option_C'] == target_answer) | 
    (full_df['option_D'] == target_answer) 
)

matching_questions = full_df.loc[mask, 'question']
print(matching_questions.value_counts().head(50))

Series([], Name: count, dtype: int64)


In [8]:
# This is to check what the original answers were for a given question
target_question_id = "ACRIMA_0088"

# Filter the row by question_id
row = full_df[full_df['question_id'] == target_question_id].iloc[0]

# Show the original options
print("Original options:")
for col in ["option_A", "option_B", "option_C", "option_D"]:
    print(f"{col}: {row[col]}")

# Show which option was the ground truth
print("\nGround truth option key:", row["gt_label"])
print("Ground truth answer text:", row[row["gt_label"]])


Original options:
option_A: Glaucoma positive.
option_B: No Finding
option_C: Conjunctivitis positive.
option_D: None

Ground truth option key: option_A
Ground truth answer text: Glaucoma positive.


In [9]:
rare_answers = answer_counts[answer_counts <= 50]
print(rare_answers)

Soft tissue collection          50
Nevus                           43
Squamous Cell Carcinoma.        39
Malignant dermal.               39
Bil dil.                        38
                                ..
Fibrosis                         2
Preretinal hemorrhage            2
VKH disease                      2
Bietti crystalline dystrophy     1
Fundus neoplasm                  1
Name: count, Length: 63, dtype: int64


### Part 1: Converting "No" answers to No Finding and "Yes" answers to specific labels

In [10]:
# Helper function: Strip quotes, lowercase, remove trailing periods/spaces
def normalize_text(s):
    return str(s).strip().lower().strip('"').rstrip(".")

# Helper function: Apply a single-label mapping
def apply_single_label_mapping(df, option_cols, mapping_csv_path, unified_label):
    mapping_df = pd.read_csv(mapping_csv_path)
    mapping_dict = {normalize_text(raw): unified_label for raw in mapping_df["raw_label"]}

    for col in option_cols:
        df[col] = df[col].apply(lambda x: mapping_dict.get(normalize_text(x), x))
    return df

#### Converting Negative Answers to No Finding

In [11]:
# # Apply No Finding mapping
# full_df = apply_single_label_mapping(
#     full_df,
#     option_cols,
#     "data/label_mappings/no_finding_map.csv",
#     "No Finding"
# )

# Check replacements
for col in OPTION_COLS:
    no_finding_count = (full_df[col] == "No Finding").sum()
    print(f"{col}: {no_finding_count} entries replaced with 'No Finding'")

option_A: 5170 entries replaced with 'No Finding'
option_B: 6282 entries replaced with 'No Finding'
option_C: 1925 entries replaced with 'No Finding'
option_D: 892 entries replaced with 'No Finding'


#### Removing Inconclusive/Uncertain Labels

Only if the ground truth answer is inconclusive/uncertain.

In [12]:
# # Apply Inconclusive mapping
# full_df = apply_single_label_mapping(
#     full_df,
#     OPTION_COLS,
#     "data/label_mappings/inconclusive_map.csv",
#     "Inconclusive"
# )

# Check replacements
for col in OPTION_COLS:
    inconclusive_count = (full_df[col] == "Inconclusive").sum()
    print(f"{col}: {inconclusive_count} entries replaced with 'Inconclusive'")

option_A: 2939 entries replaced with 'Inconclusive'
option_B: 1143 entries replaced with 'Inconclusive'
option_C: 418 entries replaced with 'Inconclusive'
option_D: 52 entries replaced with 'Inconclusive'


#### Converting Generic Yes/Abnormal Answers to Abnormal (unspecified)

In [13]:
# # Apply Abnormal (unspecified) mapping
# full_df = apply_single_label_mapping(
#     full_df,
#     OPTION_COLS,
#     "data/label_mappings/yes_finding_map.csv",
#     "Abnormal (unspecified)"
# )

# Check replacements
for col in OPTION_COLS:
    abnormal_count = (full_df[col] == "Abnormal (unspecified)").sum()
    print(f"{col}: {abnormal_count} entries replaced with 'Abnormal (unspecified)'")

option_A: 7149 entries replaced with 'Abnormal (unspecified)'
option_B: 7090 entries replaced with 'Abnormal (unspecified)'
option_C: 4833 entries replaced with 'Abnormal (unspecified)'
option_D: 3839 entries replaced with 'Abnormal (unspecified)'


#### Checking for Duplicates in Answer Choices

In [14]:
# Apply to full_df
full_df["duplicate_options"] = full_df.apply(lambda r: has_duplicate_options(r, OPTION_COLS), axis=1)

# How many rows have duplicates?
num_duplicates = full_df["duplicate_options"].sum()
print(f"Number of rows with duplicate options: {num_duplicates}")


Number of rows with duplicate options: 0


In [15]:
def reduce_duplicates(row, option_cols):
    # Save GT text
    gt_key = row["gt_label"]
    gt_text = row[gt_key]

    # Collect original options
    opts = [row[col] for col in option_cols]

    # Deduplicate while preserving order
    seen = {}
    for i, opt in enumerate(opts):
        if opt not in seen:
            seen[opt] = i
    unique_opts = list(seen.keys())

    # Ensure GT text is in the unique options
    if gt_text not in unique_opts:
        unique_opts.append(gt_text)

    # Pad/truncate back to fixed length
    while len(unique_opts) < len(option_cols):
        unique_opts.append(None)
    unique_opts = unique_opts[:len(option_cols)]

    # Assign back options
    for col, val in zip(option_cols, unique_opts):
        row[col] = val

    # Recompute gt_label so it matches the correct column
    for col in option_cols:
        if row[col] == gt_text:
            row["gt_label"] = col
            break

    return row

full_df = full_df.apply(reduce_duplicates, option_cols=OPTION_COLS, axis=1)

# Verify no invalid GT
def is_gt_valid(row):
    gt_col = row["gt_label"]
    return (row[gt_col] is not None)

num_invalid = (~full_df.apply(is_gt_valid, axis=1)).sum()
print("Number of rows with invalid ground truth:", num_invalid)

Number of rows with invalid ground truth: 0


### Part 2: Removing punctuation/grammar/useless words in answer choices

In [16]:
# --- Full grammar_trim with simplified features ---
def grammar_trim(s):
    if not isinstance(s, str):
        return s
    
    s = s.strip()

    # --- Special cases for abnormalities ---
    if re.search(r"significant abnormality", s, flags=re.I):
        return "Abnormality"
    if re.search(r"minor abnormality|insignificant", s, flags=re.I):
        return "Slight abnormality"
    if re.search(r"possible abnormality|potential abnormality", s, flags=re.I):
        return "Possible abnormality"
    if re.search(r"multiple abnormalities|abnormalities are seen", s, flags=re.I):
        return "Abnormalities present"
    # --- Preserve 'Inconclusive' as-is ---
    if re.search(r"inconclusive", s, flags=re.I):
        return "Inconclusive"
    if re.search(r"inconclusive|does not provide enough information|too low quality|too blurry|cannot determine", s, flags=re.I):
        return "Can't tell abnormality"

    # --- Remove redundant phrases ---
    redundant_phrases = [
        r"^no,? the image .*",           
        r"^yes,? the image .*",          
        r"lungs will be affected",       
        r"\bthe abnormality shown in this image is\b",
        r"\bthis image indicates\b",
        r"\bthis image shows\b",
        r"\bthis image displays\b",
        r"\bthe diagnosis is\b",
        r"\bthe findings in this image are\b",
        r"\bthere is a\b",
        r"\bthere are\b",
    ]
    for pat in redundant_phrases:
        s = re.sub(pat, "", s, flags=re.I).strip()

    # --- Remove trailing "in the provided image" ---
    s = re.sub(r"\s*,?\s*in the provided image\.?$", "", s, flags=re.I)

    # --- Head patterns ---
    head_patterns = [
        r"^the head appears (.+)",
        r"^the head is (.+)",
    ]
    for pat in head_patterns:
        match = re.match(pat, s, flags=re.I)
        if match:
            s = match.group(1).strip()
            break

    # --- Tail patterns ---
    tail_patterns = [
        r"^yes, the tail appears to be (.+)",
        r"^no, the tail appears to be (.+)",
    ]
    for pat in tail_patterns:
        match = re.match(pat, s, flags=re.I)
        if match:
            s = match.group(1).strip()
            break

    # --- Vacuole patterns ---
    vacuole_patterns = [
        r"^yes, the vacuole (is|appears to be) (.+)",
        r"^no, the vacuole (is|shows signs of|appears to be) (.+)"
    ]
    for pat in vacuole_patterns:
        match = re.match(pat, s, flags=re.I)
        if match:
            s = match.group(2).strip()
            break

    # --- Lungs patterns ---
    lungs_patterns = [
        r"^(.+?)\.?\s*lungs will be affected\.?",  
        r"^the lungs (in the image )?(are|appear|show|showing) (.+)", 
        r"^the image (shows|depicts|displays) lungs? (.+)"
    ]
    for pat in lungs_patterns:
        match = re.match(pat, s, flags=re.I)
        if match:
            desc = match.groups()[-1].strip()
            # Special cases for explicit lung disease
            if re.search(r"fibrotic|stiff", desc, flags=re.I):
                s = "Lung disease"
            else:
                s = re.sub(
                    r"^(signs of |showing |show |appears to be |appears |in the image )",
                    "",
                    desc,
                    flags=re.I
                ).strip()
            break

    # --- Retina / Fundus image filters ---
    retina_patterns = [
        r"^The presence of multiple abnormalities is evident in this image\.$",
        r"^In this image, there are no apparent abnormalities\.?.*$",
        r"^The abnormalities in this image are consistent with (.+)\.$",
        r"^There is a conspicuous abnormality in (.+?) in this image\.$",
        r"^This image displays (.+)$",
        r"^There is a clear indication of (.+?) in this image\.$",
        r"^The image shows (.+)$",
        r"^This image shows (.+)$",
        r"^The anomalies visible in this image suggest (.+)$",
        r"^The anomalies in this image indicate the presence of (.+)$",
        r"^The abnormality in this image is consistent with (.+)$",
    ]
    for pat in retina_patterns:
        match = re.match(pat, s, flags=re.I)
        if match:
            if "multiple abnormalities" in s.lower():
                s = "Abnormalities present"
            elif "no apparent abnormalities" in s.lower():
                s = "No Abnormalities"
            else:
                s = match.group(1).strip()
            break

    # --- Acronyms ---
    s = re.sub(
        r"\((?:proliferative diabetic retinopathy|age-related macular degeneration)\)",
        lambda m: m.group(0)[1:-1],
        s,
        flags=re.I
    )

    # --- Remove trailing generic image phrases ---
    s = re.sub(
        r"\s*(,?\s*(in|on)( the)? (provided )?image)\s*\.?$",
        "",
        s,
        flags=re.I
    )

    # --- General cleanup ---
    s = re.sub(r"\s+", " ", s).strip()
    s = s.rstrip(".")
    if s:
        s = s[0].upper() + s[1:]
    return s


# --- Apply to a copy of full_df ---
option_cols = ["option_A", "option_B", "option_C", "option_D"]
trimmed_df = full_df.copy()  # <-- new variable

# Keep a list of mappings for export
mappings = []

for col in option_cols:
    def apply_trim(x):
        if isinstance(x, str):
            cleaned = grammar_trim(x)
            if cleaned != x:
                mappings.append((x, cleaned))
            return cleaned
        return x
    trimmed_df[col] = trimmed_df[col].apply(apply_trim)

# Export mapping CSV
mapping_df = pd.DataFrame(mappings, columns=["raw_label", "clean_label"])
mapping_df.to_csv("data/label_mappings/grammar_trim.csv", index=False)
print(f"Exported grammar_trim mapping for {len(mapping_df)} changed rows to data/label_mappings/grammar_trim.csv")

Exported grammar_trim mapping for 19698 changed rows to data/label_mappings/grammar_trim.csv


### Part 3 Remove rare or problematic labels / options, combine similar labels together, decide threshold for what labels show up too little once all other cleaning is done

In [19]:
trimmed_df["gt_answer"] = trimmed_df.apply(lambda r: r[r["gt_label"]], axis=1)
trimmed_df = trimmed_df.drop(columns=["duplicate_options"])
trimmed_df.head()

Unnamed: 0,question_id,question,gt_label,option_A,option_B,option_C,option_D,gt_answer
0,JSIEC_0046,What abnormality is present in this fundus image?,option_C,Choroidal neovascularization,Central serous retinopathy,No Finding,Diabetic retinopathy,No Finding
1,JSIEC_0047,What abnormality is present in this fundus image?,option_D,Macular degeneration,Diabetic retinopathy,Cataracts,No Finding,No Finding
2,JSIEC_0048,Is there any abnormality present in this fundu...,option_A,No Finding,Retinal detachment,Macular degeneration,Diabetic retinopathy,No Finding
3,JSIEC_0049,What is the finding in this fundus image?,option_D,Choroidal neovascularization,Optic neuritis,Glaucoma,No Finding,No Finding
4,JSIEC_0050,What abnormality is present in this fundus image?,option_B,Macular hole,No Finding,Choroidal neovascularization,Diabetic retinopathy,No Finding


Remove rows where the gt_answer occurrence count is less than 50

In [26]:
#Remove rows where the gt_answer value count is less than 50
value_counts = trimmed_df["gt_answer"].value_counts()
valid_answers = value_counts[value_counts >= 50].index
refined_df = trimmed_df[trimmed_df["gt_answer"].isin(valid_answers)].reset_index(drop=True)
print(f"After removing rare answers: {len(refined_df)} samples remain")

After removing rare answers: 54711 samples remain


In [28]:
final_df = refined_df.copy()