-
Notifications
You must be signed in to change notification settings - Fork 0
/
grover.py
93 lines (74 loc) · 2.63 KB
/
grover.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
# Author: Yinghao Li
# Modified: April 10th, 2024
# ---------------------------------------
# Description: Run the uncertainty quantification experiments
with GROVER backbone model.
"""
import os
import sys
import logging
from rdkit import RDLogger
from datetime import datetime
from transformers import set_seed, HfArgumentParser
from muben.utils.io import set_logging, set_log_path
from muben.model import GROVER
from muben.dataset import DatasetGrover, CollatorGrover
from muben.args import ArgumentsGrover as ArgumentsGrover, ConfigGrover as ConfigGrover
from muben.train import TrainerGrover
logger = logging.getLogger(__name__)
def main(args: ArgumentsGrover):
# --- construct and validate configuration ---
config = ConfigGrover().from_args(args).get_meta().validate().log()
# --- prepare dataset ---
training_dataset = DatasetGrover().prepare(
config=config,
partition="train",
subset_ids_file_name=config.training_subset_ids_file_name,
)
valid_dataset = DatasetGrover().prepare(
config=config,
partition="valid",
subset_ids_file_name=config.valid_subset_ids_file_name,
)
test_dataset = DatasetGrover().prepare(
config=config,
partition="test",
subset_ids_file_name=config.test_subset_ids_file_name,
)
# --- initialize trainer ---
trainer = TrainerGrover(
config=config,
model_class=GROVER,
training_dataset=training_dataset,
valid_dataset=valid_dataset,
test_dataset=test_dataset,
collate_fn=CollatorGrover(config),
).initialize(config=config)
# --- run training and testing ---
trainer.run()
return None
if __name__ == "__main__":
_time = datetime.now().strftime("%m.%d.%y-%H.%M")
# --- set up arguments ---
parser = HfArgumentParser(ArgumentsGrover)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
(arguments,) = parser.parse_json_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith((".yaml", ".yml")):
(arguments,) = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
else:
(arguments,) = parser.parse_args_into_dataclasses()
if not getattr(arguments, "log_path", None):
arguments.log_path = set_log_path(arguments, _time)
set_logging(log_path=arguments.log_path)
# supress rdkit logger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
set_seed(arguments.seed)
if arguments.deploy:
try:
main(args=arguments)
except Exception as e:
logger.exception(e)
else:
main(args=arguments)