Template for training Semantic Segmentation models. Main frameworks:
Install uv
curl -LsSf https://astral.sh/uv/install.sh | shInstall dependencies
# clone project
git clone https://github.com/IvanMatoshchuk/semantic-segmentation-template
cd semantic-segmentation-template
uv syncAdd your label-classes into data/label_classes.json.
Specify number of classes in the selected model config.
Train model with default configuration:
# default
uv run run.py
# train on CPU
uv run run.py trainer.gpus=0
# train on GPU
uv run run.py trainer.gpus=1
# train on multiple GPUs
uv run run.py trainer.gpus=[0,1,2,3]You can override any parameter from the command line:
uv run run.py trainer.max_epochs=20 datamodule.dataset_args.train.crop_size=416 model=unetYou can run hyperparameter search from the command line:
# this will run 6 experiments one after the other,
# each with different combination of batch_size and learning rate
uv run run.py -m datamodule.dataloader_args.train.batch_size=32,64,128 optimizer.lr=0.001,0.0005By design, every run is initialized by run.py file. All PyTorch Lightning modules are dynamically instantiated from module paths specified in config. Example model config (unet.yaml):
_target_: src.model.segmentation_model.HoneyBeeModel
_recursive_: False
model_cfg:
_target_: segmentation_models_pytorch.Unet
encoder_name: efficientnet-b0 # efficientnet-b0 timm-mobilenetv3_small_100
encoder_weights: imagenet
encoder_depth: 5
classes: 9
in_channels: 1Using this config we can instantiate the object with the following line:
model = hydra.utils.instantiate(config.model)This allows you to easily iterate over new models!
Every time you create a new one, just specify its module path and parameters in appriopriate config file.
The whole pipeline managing the instantiation logic is placed in src/train.py.
Location: configs/config.yaml
Main project config contains default training configuration.
It determines how config is composed when simply executing command python run.py.
It also specifies everything that shouldn't be managed by experiment configurations.
Show main project configuration
# specify here default training configuration
defaults:
- _self_
- logger: wandb
- callbacks: wandb
- datamodule: batch_datamodule
- model: unet
- trainer: default_trainer
- optimizer: adam
- scheduler: cosinewarm
- loss: dice_with_ce
# enable color logging
- override hydra/hydra_logging: colorlog
- override hydra/job_logging: colorlog
general:
name: test # name of the run, accessed by loggers
seed: 123
work_dir: ${hydra:runtime.cwd}
# print config at the start
print_config: True
# disable python warnings if they annoy you
ignore_warnings: False
# check performance on test set, using the best model achieved during training
# lightning chooses best model based on metric specified in checkpoint callback
test_after_training: FalseInspirations
This template was inspired by:
Useful repositories
- pytorch/hydra-torch - resources for configuring PyTorch classes with Hydra,
- qubvel/segmentation_models.pytorch - pytorch-based models for semantic segmentation.
This project is licensed under the MIT License.
MIT License
Copyright (c) 2021 ashleve
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.