-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
106 lines (85 loc) · 3.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from models.module import SitsScdModel
import hydra
import wandb
from os.path import isfile, join
from shutil import copyfile
from omegaconf import OmegaConf
from hydra.utils import instantiate
from pytorch_lightning.callbacks import LearningRateMonitor
from lightning_fabric.utilities.rank_zero import _get_rank
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)
def wandb_init(cfg):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "wandb_id.txt")):
with open(join(directory, "wandb_id.txt"), "r") as f:
wandb_id = f.readline()
else:
rank = _get_rank()
wandb_id = wandb.util.generate_id()
print(f"Generated wandb id: {wandb_id}")
if rank == 0 or rank is None:
with open(join(directory, "wandb_id.txt"), "w") as f:
f.write(str(wandb_id))
return wandb_id
def load_model(cfg, dict_config, wandb_id, callbacks):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "last.ckpt")):
checkpoint_path = join(directory, "last.ckpt")
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
model = SitsScdModel.load_from_checkpoint(checkpoint_path, cfg=cfg.model)
ckpt_path = join(directory, "last.ckpt")
print(f"Loading form checkpoint ... {ckpt_path}")
else:
ckpt_path = None
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]}
logger._wandb_init.update({"config": log_dict})
model = SitsScdModel(cfg.model)
trainer, strategy = cfg.trainer, cfg.trainer.strategy
trainer = instantiate(
trainer, strategy=strategy, logger=logger, callbacks=callbacks,
)
return trainer, model, ckpt_path
def project_init(cfg):
print("Working directory set to {}".format(os.getcwd()))
directory = cfg.checkpoints.dirpath
os.makedirs(directory, exist_ok=True)
copyfile(".hydra/config.yaml", join(directory, "config.yaml"))
def callback_init(cfg):
monitor = cfg.checkpoints["monitor"]
filename = cfg.checkpoints["filename"]
cfg.checkpoints["monitor"] = monitor + "_out"
cfg.checkpoints["filename"] = filename + "_out"
checkpoint_callback_out = instantiate(cfg.checkpoints)
cfg.checkpoints["monitor"] = monitor + "_in"
cfg.checkpoints["filename"] = filename + "_in"
checkpoint_callback_in = instantiate(cfg.checkpoints)
progress_bar = instantiate(cfg.progress_bar)
lr_monitor = LearningRateMonitor()
callbacks = [checkpoint_callback_out, checkpoint_callback_in, progress_bar, lr_monitor]
return callbacks
def init_datamodule(cfg):
datamodule = instantiate(cfg.datamodule)
return datamodule
def hydra_boilerplate(cfg):
dict_config = OmegaConf.to_container(cfg, resolve=True)
callbacks = callback_init(cfg)
datamodule = init_datamodule(cfg)
project_init(cfg)
wandb_id = wandb_init(cfg)
trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks)
return trainer, model, datamodule, ckpt_path
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg)
model.datamodule = datamodule
if cfg.mode == "train":
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
elif cfg.mode == "eval":
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
main()