Skip to content

Commit

Permalink
Changes for trainer.py to support the Gradio webui (nerfstudio-projec…
Browse files Browse the repository at this point in the history
…t#3046)

* changes for trainer to support webui

* Update trainer to support webui

* format

* add a seperated shutdown() function to stop training

* typo fix

* get rid of _stop_viewer_server()

* Update trainer.py

* organize import

---------

Co-authored-by: Brent Yi <yibrenth@gmail.com>
  • Loading branch information
KevinXu02 and brentyi committed Apr 19, 2024
1 parent 45d8bb7 commit eddf2d2
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import DefaultDict, Dict, List, Literal, Optional, Tuple, Type, cast

import torch
import viser
from rich import box, style
from rich.panel import Panel
from rich.table import Table
Expand Down Expand Up @@ -137,6 +138,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =

self.viewer_state = None

# used to keep track of the current step
self.step = 0

def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
"""Setup the Trainer by calling other setup functions.
Expand Down Expand Up @@ -233,8 +237,15 @@ def train(self) -> None:
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.max_num_iterations
step = 0
self.stop_training = False
for step in range(self._start_step, self._start_step + num_iterations):
self.step = step
if self.stop_training:
break
while self.training_state == "paused":
if self.stop_training:
self._after_train()
return
time.sleep(0.01)
with self.train_lock:
with TimeWriter(writer, EventName.ITER_TRAIN_TIME, step=step) as train_t:
Expand Down Expand Up @@ -291,12 +302,26 @@ def train(self) -> None:

writer.write_out_storage()

# save checkpoint at the end of training, and write out any remaining events
self._after_train()

def shutdown(self) -> None:
"""Stop the trainer and stop all associated threads/processes (such as the viewer)."""
self.stop_training = True # tell the training loop to stop
if self.viewer_state is not None:
# stop the viewer
# this condition excludes the case where `viser_server` is either `None` or an
# instance of `viewer_legacy`'s `ViserServer` instead of the upstream one.
if isinstance(self.viewer_state.viser_server, viser.ViserServer):
self.viewer_state.viser_server.stop()

def _after_train(self) -> None:
"""Function to run after training is complete"""
self.training_state = "completed" # used to update the webui state
# save checkpoint at the end of training
self.save_checkpoint(step)

self.save_checkpoint(self.step)
# write out any remaining events (e.g., total train time)
writer.write_out_storage()

table = Table(
title=None,
show_header=False,
Expand All @@ -309,7 +334,7 @@ def train(self) -> None:

# after train end callbacks
for callback in self.callbacks:
callback.run_callback_at_location(step=step, location=TrainingCallbackLocation.AFTER_TRAIN)
callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN)

if not self.config.viewer.quit_on_train_completion:
self._train_complete_viewer()
Expand Down

0 comments on commit eddf2d2

Please sign in to comment.