Skip to content

Nirmaz/CLARE

Repository files navigation

LVLM-Aware Multimodal Retrieval for RAG-Based Medical Diagnosis with General-Purpose Models

Project Overview

Clinical decision-making often involves interpreting medical images (e.g., radiology) to make diagnoses. Retrieving relevant visual information from medical literature and hospital records can enhance diagnostic accuracy.

CLARE (Clincal LVLM-Aware Retrival) is a lightweight LVLM-aware multimodal retriever, such that the retriever learns to return images and texts that guide the LVLM toward correct predictions.

CLARE Architecture Training Pipeline

Environment Setup

Main Environment

git clone <Add here>
cd CLARE
export PYTHONPATH=./
python3.9 -m venv clare_env
source clare_env/bin/activate
pip install -r requirements.txt

Qwen2-VL Environment

For running the qwen2_vl model, a different environment is needed:

git clone <Add here>
cd CLARE
export PYTHONPATH=./
python3.9 -m venv clare_qwen_env
source clare_qwen_env/bin/activate
pip install -r requirements_qwen.txt

Index Initialization

Convert Datasets to JSONL Format

Download the required datasets:

Organize the datasets into JSONL format:

MIMIC-CXR

python preporcess/mimic_cxr_processor.py --path_folder_location /path/to/mimic/dataset --path_save /path/to/save/output

PMC-OA

python preporcess/prepare_pmc_oa_for_embeddings.py --home_pmc_oa_project /path/to/project/PYCHARMPROJECTS/PMC_OA/ --home_pmc_oa_images /path/to/images/ --path_save /path/to/save/output

ROCO

python preporcess/organize_roco_for_embedding.py --home_pmc_oa_project /path/to/project/PYCHARMPROJECTS/PMC_OA/ --home_pmc_oa_images /path/to/images/ --path_save /path/to/save/output

Merge Datasets

Merge the three JSONL files into a final one:

python preporcess/merge_jsonl.py --mimic_file path/to/mimic-cxr.jsonl --roco_file path/to/roco.jsonl --pmc_file path/to/pmc-oa.jsonl --output_file merged_datasets.jsonl

Create Embeddings Index

The scripts vqarad_embeddings_pmc_encode_img.sh and vqarad_embeddings_pmc_encode_text.sh are used to encode the image and text embeddings. In each script, the arguments '--passages' and '--save_index_path' need to be updated. Then run the index creation:

Image Embeddings

source run_scripts/create_embeddings/vqarad_embeddings_pmc_encode_img.sh

Text Embeddings

source run_scripts/create_embeddings/vqarad_embeddings_pmc_encode_text.sh

Training the Reader

The reader is trained on the LlamaFactory platform. We supply scripts for preparing the data and training in lamafactory_scripts/README.md

Training the retriever

Training scripts for each benchmark are available in run_scripts/training_scripts/.

Dataset Downloads

MedMNIST Datasets

For breast, retina, and derma datasets:

PhysioNet Datasets (Approval Required)

VQA Datasets (HuggingFace)

After downloading the data, follow the instructions in data/README.md (Coming Soon).

Training Steps

Train the multimodal retriever in two stages:

  1. Text retriever head
  2. Image retriever head

For each script, update:

  • Reader checkpoint path
  • Index paths (text and image)
  • Checkpoint directory

Text Retriever Head Training

source run_scripts/training_retriever_text/<name_benchmark>/<chosen_script>

Image Retriever Head Training

source run_scripts/training_retriever_image/<name_benchmark>/<chosen_script>

Evaluation

Calculate Metrics

Use calculate_metrics.py to compute evaluation metrics from prediction files:

python calculate_metrics.py \
  --prediction_file /path/to/predictions.jsonl \
  --classes "0" "1" "2" "3" "4"

Arguments

  • --prediction_file: Path to JSONL file containing model predictions (required)
  • --classes: List of class labels for classification tasks (required)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published