Skip to content

Control-xl/Medical-Vision-Langauge-Transformer

Repository files navigation

Medical Vision-Language Transformer (MVLT)

Requirement

torch >= 1.11.0

torchvision >= 0.12.0 (Vision Transformer support from PyTorch)

transformers >= 4.16.0

Prepare for pre-training

To pretrain MVLT wtih Swin Transformer as the visual backbone, prepare 3 datasets, RGC, ROCO and MedICaT. Download train/test split of RGC from openI and put them in ./dataset/RGC/ so that they are organized like ./datset/RGC/RGC_dataset.json.

RGC

The RGC dataset is hosted on the MedPix website. Due to copyright issues, we cannot directly release the dataset. If you need help in materializing the dataset, please send an email to li-control.xu@connect.polyu.hk .

Run:

python preprocess_rgc.py

This will pre-process RGC dataset and save the data in .pkl format.

ROCO

Download ROCO following ROCO and put the files in ./dataset/ROCO/.

MedICaT

Download MedICaT following MedICat and put the files in ./dataset/medicat/.

pre-prepraing Swin Transformer

Download Swin-S chekcpoint (swin_small_patch4_window7_224.pth) from Swin Transformer and put it in ./checkpoints/.

Pre-train the model:

python run_pretrain_rgc_roco_medicat.py --conv swintransformer --batch 32 --max_length 80 --save_model_name swin-rgc-roco-medicat --epochs 150 --save_freq 50

The pre-trained model will be saved in ./checkpoints/swin-rgc-roco-medicat/

The pre-trained model (Swin-S + [RGC+ROCO+MedICaT]) can be found in Google Drive

You can also use Resnet101 as visual backbone by using the argument --conv resnet101

Fine-tuning

Medical Visual Question Answering (Med-VQA)

SLAKE

Download SLAKE from Google Drive.

put the files in ./dataset/SLAKE/ and preprocess SLAKE:

python preprocess_VQA.py --dataset SLAKE

Fine-tuning on SLAKE

python run_vqa.py --batch 64 --conv swintransformer --pretrained --pretrained_path ./checkpoints/swin-rgc-roco-medicat --dataset SLAKE --epochs 100 --total_round 10 --lr 2e-5

It will run 10 times with different torch seeds. We can also reduce the repeat time by using --total_round 1

VQA-RAD

Download VQA-RAD from Google Drive

put the files in ./dataset/VQA-RAD/ and preprocess SLAKE:

python preprocess_VQA.py --dataset VQA-RAD

Fine-tuning on VQA-RAD

python run_vqa.py --batch 64 --conv swintransformer --pretrained --pretrained_path ./checkpoints/swin-rgc-roco-medicat --dataset VQA-RAD --epochs 100 --total_round 10 --lr 2e-5

Note that the learning rate is not the optimal for different platform. In Win11 with PyTorch 2.0, the learning rate can be set to 3e-5

Report Generation on MIMIC-CXR and IU X-Ray

Download MIMIC-CXR and IU X-Ray from G2GEN

Put MIMIC-CXR in ./dataset/mimic_cxr/, Put IU X-Ray in ./dataset/iu_xray/

python run_report_generation_cxr.py --batch 32 --conv swintransformer --pretrained --pretrained_path ./checkpoints/swin-rgc-roco-medicat --dataset mimic --test_freq 5

Report Generation on RGC

python run_report_generation.py --batch 32 --conv swintransformer --pretrained --pretrained_path ./checkpoints/swin-rgc-roco-medicat --test_freq 5 --beam_search

Image-Text Retrieval on RGC

python run_retrieval.py --batch 32 --conv swintransformer --pretrained --pretrained_path ./checkpoints/swin-rgc-roco-medicat  --do_train --do_test --do_rank --epochs 100 --lr 1e-6

You can also directly train the models on the downstream tasks from scratch by removing the argument --pretrained

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages