Skip to content

Commit

Permalink
Implement for convnext trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdawkins committed Sep 8, 2023
1 parent bbd5694 commit 090a103
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion plugins/pytorch/cutler/cutler_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ConvNextCascadeRCNNTrainer( TrainDetector ):

_Option('_cutler_config_file', 'cutler_config_file', '', str, ''),

_Option('_seed_weights', 'seed_weights', '', str, ''),

_Option('_output_directory', 'output_directory', 'category_models', str, ''),
_Option('_pipeline_template', 'pipeline_template', '', str, '')
]
Expand Down Expand Up @@ -245,7 +247,9 @@ def set_configuration( self, cfg_in ):
os.environ[ "MASTER_PORT" ] = "12345"
init_dist( self._launcher )

if "checkpoint_override" in self.config and self.config["checkpoint_override"]:
if self._seed_weights:
self.original_chkpt_file = self._seed_weights
elif "checkpoint_override" in self.config and self.config["checkpoint_override"]:
self.original_chkpt_file = self.config["checkpoint_override"]
else:
self.original_chkpt_file = self.config["model_checkpoint_file"]
Expand Down
4 changes: 2 additions & 2 deletions tools/viame_train_detector.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1297,8 +1297,8 @@ main( int argc, char* argv[] )
std::map< std::string, std::vector< std::string > > weight_ext =
{
{ ".zip", { "seed_model" } },
{ ".pth", { "backbone" } },
{ ".pt", { "backbone" } },
{ ".pth", { "backbone", "seed_weights" } },
{ ".pt", { "backbone", "seed_weights" } },
{ ".py", { "config" } },
{ ".weights", { "seed_weights" } },
{ ".wt", { "seed_weights" } }
Expand Down

0 comments on commit 090a103

Please sign in to comment.