-
Notifications
You must be signed in to change notification settings - Fork 10
/
test_roost_regression.py
142 lines (122 loc) · 3.86 KB
/
test_roost_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import torch
from sklearn.model_selection import train_test_split as split
from aviary.roost.data import CompositionData, collate_batch
from aviary.roost.model import Roost
from aviary.utils import get_metrics, results_multitask, train_ensemble
def test_roost_regression(df_matbench_phonons):
elem_embedding = "matscholar200"
target_name = "last phdos peak"
task = "regression"
losses = ["L1"]
robust = True
model_name = "roost-reg-test"
elem_fea_len = 64
n_graph = 3
ensemble = 2
run_id = 1
data_seed = 42
epochs = 25
log = False
sample = 1
test_size = 0.2
resume = False
fine_tune = None
transfer = None
optim = "AdamW"
learning_rate = 3e-4
momentum = 0.9
weight_decay = 1e-6
batch_size = 128
workers = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
task_dict = dict(zip([target_name], [task]))
loss_dict = dict(zip([target_name], losses))
dataset = CompositionData(
df=df_matbench_phonons, elem_embedding=elem_embedding, task_dict=task_dict
)
n_targets = dataset.n_targets
elem_emb_len = dataset.elem_emb_len
train_idx = list(range(len(dataset)))
print(f"using {test_size} of training set as test set")
train_idx, test_idx = split(train_idx, random_state=data_seed, test_size=test_size)
test_set = torch.utils.data.Subset(dataset, test_idx)
print("No validation set used, using test set for evaluation purposes")
# NOTE that when using this option care must be taken not to
# peak at the test-set. The only valid model to use is the one
# obtained after the final epoch where the epoch count is
# decided in advance of the experiment.
val_set = test_set
train_set = torch.utils.data.Subset(dataset, train_idx[0::sample])
data_params = {
"batch_size": batch_size,
"num_workers": workers,
"pin_memory": False,
"shuffle": True,
"collate_fn": collate_batch,
}
setup_params = {
"optim": optim,
"learning_rate": learning_rate,
"weight_decay": weight_decay,
"momentum": momentum,
"device": device,
}
restart_params = {
"resume": resume,
"fine_tune": fine_tune,
"transfer": transfer,
}
model_params = {
"task_dict": task_dict,
"robust": robust,
"n_targets": n_targets,
"elem_emb_len": elem_emb_len,
"elem_fea_len": elem_fea_len,
"n_graph": n_graph,
"elem_heads": 2,
"elem_gate": [256],
"elem_msg": [256],
"cry_heads": 2,
"cry_gate": [256],
"cry_msg": [256],
"trunk_hidden": [256, 256],
"out_hidden": [128, 64],
}
train_ensemble(
model_class=Roost,
model_name=model_name,
run_id=run_id,
ensemble_folds=ensemble,
epochs=epochs,
train_set=train_set,
val_set=val_set,
log=log,
data_params=data_params,
setup_params=setup_params,
restart_params=restart_params,
model_params=model_params,
loss_dict=loss_dict,
)
data_params["batch_size"] = 64 * batch_size # faster model inference
data_params["shuffle"] = False # need fixed data order due to ensembling
results_dict = results_multitask(
model_class=Roost,
model_name=model_name,
run_id=run_id,
ensemble_folds=ensemble,
test_set=test_set,
data_params=data_params,
robust=robust,
task_dict=task_dict,
device=device,
eval_type="checkpoint",
save_results=False,
)
preds = results_dict[target_name]["preds"]
targets = results_dict[target_name]["targets"]
y_ens = np.mean(preds, axis=0)
mae, rmse, r2 = get_metrics(targets, y_ens, task).values()
assert r2 > 0.7
assert mae < 150
assert rmse < 300