Skip to content

NeaseZ/MARL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MARL

Introdunction

This is the official implementation of the paper ["Deep reinforcement learning framework for thoracic diseases classification via prior knowledge guidance"].

Abstract

In this paper, we focus on the thorax disease diagnostic problem and propose a novel deep reinforcement learning framework, which introduces prior knowledge to direct the learning of diagnostic agents and the model parameters can also be continuously updated as the data increases, like a person's learning process. Finally, our approach's performance was demonstrated using the well-known NIH ChestX-ray 14 and CheXpert datasets, and we achieved competitive results.

fig

Quick start

  1. Clone this repo:
git clone 
cd MARL
  1. Install cuda, PyTorch and torchvision.

Please make sure they are compatible. We test our models on:

cuda==11, torch==1.9.0, torchvision==0.10.0, python==3.7.3
  1. Data preparation.

Download NIH. Download CheXpert.

  1. Train the model.
python main.py \
--dataset_dir '/path/to/data/' \
--batch-size 64 --print-freq 100 \
--output "path/to/output" \
--world-size 1 --rank 0 --dist-url tcp://127.0.0.1:3717 \
--gamma_pos 0 --gamma_neg 2 --dtgfl \
--lr 1e-4 --optim AdamW --pretrained \
--num_class 14 --img_size 224 --weight-decay 1e-2 \

About

pytorch code

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages