Skip to content

A general, feasible, and extensible framework for classification tasks.

License

Notifications You must be signed in to change notification settings

YijinHuang/pytorch-classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

82 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pytorch Classification

  • A general, feasible and extensible framework for 2D image classification.

Features

  • Easy to configure (model, hyperparameters)
  • Training progress monitoring and visualization
  • Weighted sampling / weighted loss / kappa loss / focal loss for imbalance dataset
  • Multiple metrics for evaluating model
  • Different learning rate schedulers and warmup support
  • Data augmentation
  • Multiple GPUs support (DDP mode)
  • ViTs support
  • Hyperparameter tuning support (nni)

Installation

Requirements:

  • pytorch
  • torchvision
  • torcheval
  • timm
  • tqdm
  • munch
  • packaging
  • tensorboard
  • nni (optional)

To install the dependencies, run:

$ git clone https://github.com/YijinHuang/pytorch-classification.git
$ cd pytorch-classification
$ conda create -n pycls python=3.8
$ conda activate pycls
$ pip install -r requirements.txt

How to use

Training

1. Use one of the following two methods to build your dataset:

  • Folder-form dataset:

Organize your images as follows:

├── your_data_dir
    ├── train
        ├── class1
            ├── image1.jpg
            ├── image2.jpg
            ├── ...
        ├── class2
            ├── image3.jpg
            ├── image4.jpg
            ├── ...
        ├── class3
        ├── ...
    ├── val
    ├── test

Here, val and test directory have the same structure of train. Then replace the value of base.data_path in configs/default.yaml with path to your_data_dir and keep base.data_index as null.

  • Dict-form dataset:

Define a dict as follows:

your_data_dict = {
    'train': [
        ('path/to/image1', 0), # use int. to represent the class of images (start from 0)
        ('path/to/image2', 0),
        ('path/to/image3', 1),
        ('path/to/image4', 2),
        ...
    ],
    'test': [
        ('path/to/image5', 0),
        ...
    ],
    'val': [
        ('path/to/image6', 0),
        ...
    ]
}

Then use pickle to save it:

import pickle
pickle.dump(your_data_dict, open('path/to/pickle/file', 'wb'))

Finally, replace the value of base.data_index in configs/default.yaml with 'path/to/pickle/file' and set base.data_path as null.

2. Update your training configurations and hyperparameters in configs/default.yaml.

3. Run to train:

$ CUDA_VISIBLE_DEVICES=x python main.py

Optional arguments:

-c yaml_file      Specify the config file (default: configs/default.yaml)
-p                Print configs before training

4. Monitor your training progress in website 127.0.0.1:6006 by running:

$ tensorborad --logdir=/path/to/your/log --port=6006

Tips to use tensorboard on a remote server

Hyperparameter tuning (optional)

We commend you learn about nni before using hyperparameter tuning.

Install nni by:

$ pip install nni

Set base.HPO in config file to be True, and update nni settings in configs/hp_tuning.yaml (See nni document for more details). Then run:

nnictl create --config ./configs/hp_tuning.yaml --port 8080

Monitor hyperparameter tuning progress via the web portal URL generated by nni.

External Tools

Diabetic Retinopathy Detection

This repository contains codes for the following papers. Instructions can be found here.

Huang, Y., Lin, L., Cheng, P., Lyu, J., Tam, R. and Tang, X., 2023. Identifying the key components in ResNet-50 for diabetic retinopathy grading from fundus images: a systematic investigation. Diagnostics, 13(10), p.1664. [link]

Huang, Y., Lin, L., Cheng, P., Lyu, J. and Tang, X., 2021. Lesion-based contrastive learning for diabetic retinopathy grading from fundus images. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part II 24 (pp. 113-123). Springer International Publishing. [link]

Citation

If you find this repository useful, please cite the paper:

@article{huang2023identifying,
  title={Identifying the key components in ResNet-50 for diabetic retinopathy grading from fundus images: a systematic investigation},
  author={Huang, Yijin and Lin, Li and Cheng, Pujin and Lyu, Junyan and Tam, Roger and Tang, Xiaoying},
  journal={Diagnostics},
  volume={13},
  number={10},
  pages={1664},
  year={2023},
  publisher={MDPI}
}

About

A general, feasible, and extensible framework for classification tasks.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages