- A general, feasible and extensible framework for 2D image classification.
- Easy to configure (model, hyperparameters)
- Training progress monitoring and visualization
- Weighted sampling / weighted loss / kappa loss / focal loss for imbalance dataset
- Multiple metrics for evaluating model
- Different learning rate schedulers and warmup support
- Data augmentation
- Multiple GPUs support (DDP mode)
- ViTs support
- Hyperparameter tuning support (nni)
Requirements:
- pytorch
- torchvision
- torcheval
- timm
- tqdm
- munch
- packaging
- tensorboard
- nni (optional)
To install the dependencies, run:
$ git clone https://github.com/YijinHuang/pytorch-classification.git
$ cd pytorch-classification
$ conda create -n pycls python=3.8
$ conda activate pycls
$ pip install -r requirements.txt
1. Use one of the following two methods to build your dataset:
- Folder-form dataset:
Organize your images as follows:
├── your_data_dir
├── train
├── class1
├── image1.jpg
├── image2.jpg
├── ...
├── class2
├── image3.jpg
├── image4.jpg
├── ...
├── class3
├── ...
├── val
├── test
Here, val
and test
directory have the same structure of train
. Then replace the value of base.data_path
in configs/default.yaml
with path to your_data_dir and keep base.data_index
as null.
- Dict-form dataset:
Define a dict as follows:
your_data_dict = {
'train': [
('path/to/image1', 0), # use int. to represent the class of images (start from 0)
('path/to/image2', 0),
('path/to/image3', 1),
('path/to/image4', 2),
...
],
'test': [
('path/to/image5', 0),
...
],
'val': [
('path/to/image6', 0),
...
]
}
Then use pickle to save it:
import pickle
pickle.dump(your_data_dict, open('path/to/pickle/file', 'wb'))
Finally, replace the value of base.data_index
in configs/default.yaml
with 'path/to/pickle/file' and set base.data_path
as null.
2. Update your training configurations and hyperparameters in configs/default.yaml
.
3. Run to train:
$ CUDA_VISIBLE_DEVICES=x python main.py
Optional arguments:
-c yaml_file Specify the config file (default: configs/default.yaml)
-p Print configs before training
4. Monitor your training progress in website 127.0.0.1:6006 by running:
$ tensorborad --logdir=/path/to/your/log --port=6006
Tips to use tensorboard on a remote server
We commend you learn about nni before using hyperparameter tuning.
Install nni by:
$ pip install nni
Set base.HPO
in config file to be True, and update nni settings in configs/hp_tuning.yaml
(See nni document for more details). Then run:
nnictl create --config ./configs/hp_tuning.yaml --port 8080
Monitor hyperparameter tuning progress via the web portal URL generated by nni.
This repository contains codes for the following papers. Instructions can be found here.
Huang, Y., Lin, L., Cheng, P., Lyu, J., Tam, R. and Tang, X., 2023. Identifying the key components in ResNet-50 for diabetic retinopathy grading from fundus images: a systematic investigation. Diagnostics, 13(10), p.1664. [link]
Huang, Y., Lin, L., Cheng, P., Lyu, J. and Tang, X., 2021. Lesion-based contrastive learning for diabetic retinopathy grading from fundus images. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part II 24 (pp. 113-123). Springer International Publishing. [link]
If you find this repository useful, please cite the paper:
@article{huang2023identifying,
title={Identifying the key components in ResNet-50 for diabetic retinopathy grading from fundus images: a systematic investigation},
author={Huang, Yijin and Lin, Li and Cheng, Pujin and Lyu, Junyan and Tam, Roger and Tang, Xiaoying},
journal={Diagnostics},
volume={13},
number={10},
pages={1664},
year={2023},
publisher={MDPI}
}