In [9]:
from __future__ import annotations

import random
import socket
import traceback
from datetime import timedelta
from datetime import datetime
from typing import Any, Callable, Literal, Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import tyro
from pathlib import Path
import os
import yaml

import nerfstudio
from nerfstudio.configs.config_utils import convert_markup_to_ansi
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
from nerfstudio.configs.method_configs import all_methods
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils import comms, profiler
from nerfstudio.utils.rich_utils import CONSOLE

In [10]:
DEFAULT_TIMEOUT = timedelta(minutes=30)

# speedup for when input size to model doesn't change (much)
torch.backends.cudnn.benchmark = True  # type: ignore

In [11]:
def launch(
    main_func: Callable,
    num_devices_per_machine: int,
    num_machines: int = 1,
    machine_rank: int = 0,
    dist_url: str = "auto",
    config: Optional[TrainerConfig] = None,
    timeout: timedelta = DEFAULT_TIMEOUT,
    device_type: Literal["cpu", "cuda", "mps"] = "cuda",
) -> None:
    """Function that spawns multiple processes to call on main_func

    Args:
        main_func (Callable): function that will be called by the distributed workers
        num_devices_per_machine (int): number of GPUs per machine
        num_machines (int, optional): total number of machines
        machine_rank (int, optional): rank of this machine.
        dist_url (str, optional): url to connect to for distributed jobs.
        config (TrainerConfig, optional): config file specifying training regimen.
        timeout (timedelta, optional): timeout of the distributed workers.
        device_type: type of device to use for training.
    """
    assert config is not None
    world_size = num_machines * num_devices_per_machine
    if world_size == 0:
        raise ValueError("world_size cannot be 0")
    elif world_size == 1:
        # uses one process
        try:
            # print('uses one process') here
            main_func(local_rank=0, world_size=world_size, config=config)
        except KeyboardInterrupt:
            # print the stack trace
            CONSOLE.print(traceback.format_exc())
        finally:
            profiler.flush_profiler(config.logging)

In [12]:
def _set_random_seed(seed) -> None:
    """Set randomness seed in torch and numpy"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


In [13]:
def train_loop(local_rank: int, world_size: int, config: TrainerConfig, global_rank: int = 0):
    """Main training function that sets up and runs the trainer per process

    Args:
        local_rank: current rank of process
        world_size: total number of gpus available
        config: config file specifying training regimen
    """
    _set_random_seed(config.machine.seed + global_rank)
    trainer = config.setup(local_rank=local_rank, world_size=world_size)
    # print('trainer', trainer)
    # <nerfstudio.engine.trainer.Trainer object at 0x7f2cd34ddbe0>
    trainer.setup()
    trainer.train()

In [14]:
config = all_methods['nerfacto']
config.timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
config.data = Path('data/nerfstudio/poster')
config.pipeline.datamanager.data = Path('data/nerfstudio/poster')
launch(
        main_func=train_loop,
        num_devices_per_machine=config.machine.num_devices,
        device_type=config.machine.device_type,
        num_machines=config.machine.num_machines,
        machine_rank=config.machine.machine_rank,
        dist_url=config.machine.dist_url,
        config=config,
    )

Output()

Output()



Step (% Done)       Train Iter (time)    ETA (time)           
--------------------------------------------------------------
Step (% Done)       Train Iter (time)    ETA (time)                                                  [0m
--------------------------------------------------------------                                       [0m
0 (0.00%)           2 s, 386.210 ms      19 h, 53 m, 6 s                                             [0m
---------------------------------------------------------------------------------------------------- [0m
[6;30;42mViewer at: https://viewer.nerf.studio/versions/23-05-15-1/?websocket_url=ws://localhost:7007         [0m
Step (% Done)       Train Iter (time)    ETA (time)           Train Rays / Sec                       [0m
-----------------------------------------------------------------------------------                  [0m
0 (0.00%)           2 s, 386.210 ms      19 h, 53 m, 6 s                                             [0m
10 (0.03%)      