From eddf2d21b5f568eb3370426b3d95e2501788752c Mon Sep 17 00:00:00 2001 From: Congrong Xu <50019703+KevinXu02@users.noreply.github.com> Date: Thu, 18 Apr 2024 17:37:50 -0700 Subject: [PATCH] Changes for trainer.py to support the Gradio webui (#3046) * changes for trainer to support webui * Update trainer to support webui * format * add a seperated shutdown() function to stop training * typo fix * get rid of _stop_viewer_server() * Update trainer.py * organize import --------- Co-authored-by: Brent Yi --- nerfstudio/engine/trainer.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/nerfstudio/engine/trainer.py b/nerfstudio/engine/trainer.py index 9b4d6d8a15..0634f62e53 100644 --- a/nerfstudio/engine/trainer.py +++ b/nerfstudio/engine/trainer.py @@ -28,6 +28,7 @@ from typing import DefaultDict, Dict, List, Literal, Optional, Tuple, Type, cast import torch +import viser from rich import box, style from rich.panel import Panel from rich.table import Table @@ -137,6 +138,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = self.viewer_state = None + # used to keep track of the current step + self.step = 0 + def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None: """Setup the Trainer by calling other setup functions. @@ -233,8 +237,15 @@ def train(self) -> None: with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME): num_iterations = self.config.max_num_iterations step = 0 + self.stop_training = False for step in range(self._start_step, self._start_step + num_iterations): + self.step = step + if self.stop_training: + break while self.training_state == "paused": + if self.stop_training: + self._after_train() + return time.sleep(0.01) with self.train_lock: with TimeWriter(writer, EventName.ITER_TRAIN_TIME, step=step) as train_t: @@ -291,12 +302,26 @@ def train(self) -> None: writer.write_out_storage() + # save checkpoint at the end of training, and write out any remaining events + self._after_train() + + def shutdown(self) -> None: + """Stop the trainer and stop all associated threads/processes (such as the viewer).""" + self.stop_training = True # tell the training loop to stop + if self.viewer_state is not None: + # stop the viewer + # this condition excludes the case where `viser_server` is either `None` or an + # instance of `viewer_legacy`'s `ViserServer` instead of the upstream one. + if isinstance(self.viewer_state.viser_server, viser.ViserServer): + self.viewer_state.viser_server.stop() + + def _after_train(self) -> None: + """Function to run after training is complete""" + self.training_state = "completed" # used to update the webui state # save checkpoint at the end of training - self.save_checkpoint(step) - + self.save_checkpoint(self.step) # write out any remaining events (e.g., total train time) writer.write_out_storage() - table = Table( title=None, show_header=False, @@ -309,7 +334,7 @@ def train(self) -> None: # after train end callbacks for callback in self.callbacks: - callback.run_callback_at_location(step=step, location=TrainingCallbackLocation.AFTER_TRAIN) + callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN) if not self.config.viewer.quit_on_train_completion: self._train_complete_viewer()