Skip to content

Commit

Permalink
Merge pull request #243 from Nota-NetsPresso/241-add-model-name-field
Browse files Browse the repository at this point in the history
Add model name field
  • Loading branch information
Hyoung-Kyu Song authored Nov 24, 2023
2 parents 549beb7 + 6721e04 commit e004900
Show file tree
Hide file tree
Showing 31 changed files with 370 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Add a gpu option in `train_with_config` (only single-GPU supported) by `@deepkyu` in [PR 219](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/219)
- Support augmentation for classification task: cutmix, mixup by `@illian01` in [PR 221](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/221)
- Add model: MixNet by `@illian01` in [PR 229](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/229)
- Add `model.name` to get the exact nickname of the model by `@deepkyu` in [PR 243](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/243/)
- Add transforms: RandomErasing and TrivialAugmentationWide by `@illian01` in [PR 246](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/246)

## Bug Fixes:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: efficientformer_l1
checkpoint: ./weights/efficientformer/efficientformer_l1_1000d.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: detection
name: efficientformer_l1
checkpoint: ./weights/efficientformer/efficientformer_l1_1000d.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: efficientformer_l1
checkpoint: ./weights/efficientformer/efficientformer_l1_1000d.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-l-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: mixnet_l
checkpoint: ./weights/mixnet/mixnet_l.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-l-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: mixnet_l
checkpoint: ./weights/mixnet/mixnet_l.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-m-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: mixnet_m
checkpoint: ./weights/mixnet/mixnet_m.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-m-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: mixnet_m
checkpoint: ./weights/mixnet/mixnet_m.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-s-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: mixnet_s
checkpoint: ./weights/mixnet/mixnet_s.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/mixnet/mixnet-s-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: mixnet_s
checkpoint: ./weights/mixnet/mixnet_s.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
model:
task: classification
name: mobilenet_v3_small
checkpoint: ./weights/mobilenetv3/mobilenet_v3_small.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
freeze_backbone: False
architecture:
full: ~ # auto
backbone:
name: mobilenetv3_small
name: mobilenetv3
params: ~
stage_params:
-
Expand Down
3 changes: 2 additions & 1 deletion config/model/mobilenetv3/mobilenetv3-small-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
model:
task: segmentation
name: mobilenet_v3_small
checkpoint: ./weights/mobilenetv3/mobilenet_v3_small.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
freeze_backbone: False
architecture:
full: ~ # auto
backbone:
name: mobilenetv3_small
name: mobilenetv3
params: ~
stage_params:
-
Expand Down
1 change: 1 addition & 0 deletions config/model/mobilevit/mobilevit-s-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: mobilevit_s
checkpoint: ./weights/mobilevit/mobilevit_s.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/pidnet/pidnet-s-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: pidnet_s
checkpoint: ./weights/pidnet/pidnet_s.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
3 changes: 2 additions & 1 deletion config/model/resnet/resnet50-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
model:
task: classification
name: resnet50
checkpoint: ./weights/resnet/resnet50.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
freeze_backbone: False
architecture:
full: ~ # auto
backbone:
name: resnet50
name: resnet
params:
block: bottleneck
norm_layer: batch_norm
Expand Down
3 changes: 2 additions & 1 deletion config/model/resnet/resnet50-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: resnet50
checkpoint: ./weights/resnet/resnet50.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand All @@ -8,7 +9,7 @@ model:
full:
name: ~ # auto
backbone:
name: resnet50
name: resnet
params:
block: bottleneck
norm_layer: batch_norm
Expand Down
1 change: 1 addition & 0 deletions config/model/segformer/segformer-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: segformer
checkpoint: ./weights/segformer/segformer.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/segformer/segformer-segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: segmentation
name: segformer
checkpoint: ./weights/segformer/segformer.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/vit/vit-classification.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: classification
name: vit_tiny
checkpoint: ./weights/vit/vit-tiny.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
1 change: 1 addition & 0 deletions config/model/yolox/yolox-detection.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model:
task: detection
name: yolox_s
checkpoint: ./weights/yolox/yolox_s.pth
fx_model_checkpoint: ~
resume_optimizer_checkpoint: ~
Expand Down
2 changes: 1 addition & 1 deletion demo/gradio_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def launch_gradio(args):
task_choices = gr.Radio(label="Task: ", value='classification', choices=SUPPORTING_TASK_LIST)
with gr.Column(scale=1):
phase_choices = gr.Radio(label="Phase: ", value='train', choices=['train', 'valid'])
model_choices = gr.Radio(label="Model: ", value='resnet50', choices=SUPPORTING_MODEL_LIST)
model_choices = gr.Radio(label="Model: ", value='resnet', choices=SUPPORTING_MODEL_LIST)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
config_input = gr.Code(label="Augmentation configuration", value=args.config.read_text(), language='yaml', lines=30)
Expand Down
8 changes: 8 additions & 0 deletions src/netspresso_trainer/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,22 @@
from .logging import LoggingConfig
from .model import (
ClassificationEfficientFormerModelConfig,
ClassificationMixNetLargeModelConfig,
ClassificationMixNetMediumModelConfig,
ClassificationMixNetSmallModelConfig,
ClassificationMobileNetV3ModelConfig,
ClassificationMobileViTModelConfig,
ClassificationResNetModelConfig,
ClassificationSegFormerModelConfig,
ClassificationViTModelConfig,
DetectionEfficientFormerModelConfig,
DetectionYoloXModelConfig,
ModelConfig,
PIDNetModelConfig,
SegmentationEfficientFormerModelConfig,
SegmentationMixNetLargeModelConfig,
SegmentationMixNetMediumModelConfig,
SegmentationMixNetSmallModelConfig,
SegmentationMobileNetV3ModelConfig,
SegmentationResNetModelConfig,
SegmentationSegFormerModelConfig,
Expand All @@ -65,6 +72,7 @@
'detection': DetectionScheduleConfig
}


@dataclass
class TrainerConfig:
task: str = field(default=MISSING, metadata={"omegaconf_ignore": True})
Expand Down
Loading

0 comments on commit e004900

Please sign in to comment.