-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_gnn_scannet.py
45 lines (37 loc) · 1.4 KB
/
train_gnn_scannet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os, sys
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from utils.config import load_config
from unet3d.lightning_model_scannet import Unet3DGNNPartnetLightning
def main(args):
config = load_config(args)
tb_logger = TensorBoardLogger(config.checkpoint_dir,
name=config.model,
version=config.version)
CHECKPOINTS = os.path.join(config.checkpoint_dir, config.model, config.version, 'checkpoints')
checkpoint_callback = ModelCheckpoint(
filepath=CHECKPOINTS,
save_top_k=100
)
os.makedirs(CHECKPOINTS, exist_ok=True)
if config.model == 'Unet3DGNNPartnet':
model = Unet3DGNNPartnetLightning(config)
trainer = Trainer(
checkpoint_callback=checkpoint_callback,
logger=tb_logger,
early_stop_callback=False,
gpus=config.gpus,
distributed_backend=config.distributed_backend,
num_nodes=1,
max_epochs=config.max_epochs,
val_check_interval=config.val_check_interval,
amp_level=config.amp_level,
log_save_interval=10,
fast_dev_run=False,
# resume_from_checkpoint=config.resume_from_checkpoint,
accumulate_grad_batches=4
)
trainer.fit(model)
if __name__ == '__main__':
main(sys.argv[1:])