diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index a458cf9c9..33bb3becb 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -1312,7 +1312,7 @@ def _save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_co self.checkpoints_queue.put(("eval", model_path)) self.last_save = self.training_pipeline.total_steps - def run_pipeline(self): + def run_pipeline(self, valid_on_initial_weights: bool = False): cur_stage_training_settings = ( self.training_pipeline.current_stage.training_settings ) @@ -1336,6 +1336,16 @@ def run_pipeline(self): ) already_saved_checkpoint = False + if ( + valid_on_initial_weights + and should_save_checkpoints + and self.checkpoints_queue is not None + ): + if self.worker_id == self.first_local_worker_id: + model_path = self.checkpoint_save() + if self.checkpoints_queue is not None: + self.checkpoints_queue.put(("eval", model_path)) + while True: pipeline_stage_changed = self.training_pipeline.before_rollout( train_metrics=self._last_aggregated_train_task_metrics @@ -1572,7 +1582,10 @@ def run_pipeline(self): ) def train( - self, checkpoint_file_name: Optional[str] = None, restart_pipeline: bool = False + self, + checkpoint_file_name: Optional[str] = None, + restart_pipeline: bool = False, + valid_on_initial_weights: bool = False, ): assert ( self.mode == TRAIN_MODE_STR @@ -1584,7 +1597,7 @@ def train( if checkpoint_file_name is not None: self.checkpoint_load(checkpoint_file_name, restart_pipeline) - self.run_pipeline() + self.run_pipeline(valid_on_initial_weights=valid_on_initial_weights) training_completed_successfully = True except KeyboardInterrupt: diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index 872d09ebf..d2d0c24d8 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -316,6 +316,7 @@ def train_loop( id: int = 0, checkpoint: Optional[str] = None, restart_pipeline: bool = False, + valid_on_initial_weights: bool = False, *engine_args, **engine_kwargs, ): @@ -333,7 +334,9 @@ def train_loop( if trainer is not None: OnPolicyRunner.init_process("Train", id, to_close_on_termination=trainer) trainer.train( - checkpoint_file_name=checkpoint, restart_pipeline=restart_pipeline + checkpoint_file_name=checkpoint, + restart_pipeline=restart_pipeline, + valid_on_initial_weights=valid_on_initial_weights, ) @staticmethod @@ -407,6 +410,7 @@ def start_train( max_sampler_processes_per_worker: Optional[int] = None, save_ckpt_after_every_pipeline_stage: bool = True, collect_valid_results: bool = False, + valid_on_initial_weights: bool = False, ): self._initialize_start_train_or_start_test() @@ -457,6 +461,7 @@ def start_train( if model_hash is None else model_hash, first_local_worker_id=worker_ids[0], + valid_on_initial_weights=valid_on_initial_weights, ) train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, kwargs=training_kwargs, diff --git a/allenact/main.py b/allenact/main.py index a15a75253..54c6b86d9 100755 --- a/allenact/main.py +++ b/allenact/main.py @@ -220,6 +220,15 @@ def get_argument_parser(): ) parser.set_defaults(collect_valid_results=False) + parser.add_argument( + "--valid_on_initial_weights", + dest="valid_on_initial_weights", + action="store_true", + required=False, + help="enables running validation on the model with initial weights", + ) + parser.set_defaults(collect_valid_results=False) + parser.add_argument( "--test_expert", dest="test_expert", @@ -443,6 +452,7 @@ def main(): restart_pipeline=args.restart_pipeline, max_sampler_processes_per_worker=args.max_sampler_processes_per_worker, collect_valid_results=args.collect_valid_results, + valid_on_initial_weights=args.valid_on_initial_weights, ) else: OnPolicyRunner(