/
process_mlrs.py
172 lines (145 loc) · 6.06 KB
/
process_mlrs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
This script processes ML relaxations and sets it up for the next step.
- Reads final energy and structure for each relaxation
- Filters out anomalies
- Groups together all configurations for one adsorbate-surface system
- Sorts configs by lowest energy first
The following files are saved out:
- cache_sorted_byE.pkl: dict going from the system ID (bulk, surface, adsorbate)
to a list of configs and their relaxed structures, sorted by lowest energy first.
This is later used by write_top_k_vasp.py.
- anomalies_by_sid.pkl: dict going from integer sid to boolean representing
whether it was an anomaly. Anomalies are already excluded from cache_sorted_byE.pkl
and this file is only used for extra analyses.
- errors_by_sid.pkl: any errors that occurred
"""
import argparse
import multiprocessing as mp
import os
import pickle
from collections import defaultdict
import numpy as np
from ase.io import read
from ocdata.test.flag_anomaly import DetectTrajAnomaly
from tqdm import tqdm
SURFACE_CHANGE_CUTOFF_MULTIPLIER = 1.5
DESORPTION_CUTOFF_MULTIPLIER = 1.5
def parse_args():
parser = argparse.ArgumentParser(
description="Process ml relaxations and group them by adsorbate-surface system"
)
parser.add_argument(
"--ml-trajs-path",
type=str,
required=True,
help="ML relaxation trajectories folder path",
)
parser.add_argument(
"--outdir", type=str, default="cache", help="Output directory path"
)
parser.add_argument(
"--workers", type=int, default=80, help="Number of workers for multiprocessing"
)
parser.add_argument("--fmax", type=float, default=0.02)
parser.add_argument(
"--metadata", type=str, help="Path to mapping of sid to metadata"
)
parser.add_argument("--surface-dir", type=str, help="Path to surface DFT outputs")
args = parser.parse_args()
return args
def min_diff(atoms_init, atoms_final):
# used to compare atom positions, taking PBC into account
positions = atoms_final.positions - atoms_init.positions
fractional = np.linalg.solve(atoms_init.get_cell(complete=True).T, positions.T).T
for i, periodic in enumerate(atoms_init.pbc):
if periodic:
# Yes, we need to do it twice.
# See the scaled_positions.py test.
fractional[:, i] %= 1.0
fractional[:, i] %= 1.0
fractional[fractional > 0.5] -= 1
return np.matmul(fractional, atoms_init.get_cell(complete=True))
def process_mlrs(arg):
# for each ML trajectory, run anomaly detection and get relaxed energy
sid, metadata = arg
system_id = metadata["system_id"]
adslab_idx = metadata["config_id"]
try:
traj = read(f"{args.ml_trajs_path}/{sid}.traj", ":")
init_atoms, final_atoms = traj[0], traj[-1]
if fmax:
for atoms in traj:
forces = atoms.get_forces()
_fmax = max(np.sqrt((forces**2).sum(axis=1)))
if _fmax <= fmax:
final_atoms = atoms
break
final_energy = final_atoms.get_potential_energy()
except:
error_msg = f"Error parsing traj: {sid}.traj"
return [sid, system_id, adslab_idx, None, None, True, error_msg]
surface_id = system_id + "_surface.traj"
dft_slab_path = os.path.join(SURFACE_DIR, system_id, surface_id)
if not os.path.isfile(dft_slab_path):
error_msg = f"Surface {surface_id} unavailable."
return [sid, system_id, adslab_idx, None, None, True, error_msg]
slab_traj = read(dft_slab_path, ":")
tags = init_atoms.get_tags()
assert sum(tags) > 0 # make sure tag info exists
# Verify adslab and slab are ordered consistently before anomaly detection
# This checks that the positions of the initial adslab and clean surface
# are approximately equivalent.
diff = (min_diff(init_atoms[tags != 2], slab_traj[0])).sum()
# ML trajectories are saved out after 1 optimization step, so some movement
# is expected. A cushion of 0.5A is used based off the maximum difference
# previously measured for sample trajectories.
assert abs(diff) < 0.5
detector = DetectTrajAnomaly(
init_atoms,
final_atoms,
atoms_tag=tags,
final_slab_atoms=slab_traj[-1],
surface_change_cutoff_multiplier=SURFACE_CHANGE_CUTOFF_MULTIPLIER,
desorption_cutoff_multiplier=DESORPTION_CUTOFF_MULTIPLIER,
)
anom = (
detector.is_adsorbate_dissociated()
or detector.is_adsorbate_desorbed()
or detector.has_surface_changed()
)
return [sid, system_id, adslab_idx, final_energy, final_atoms, anom, None]
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.outdir, exist_ok=True)
global fmax, METADATA, SURFACE_DIR
fmax = args.fmax
METADATA = args.metadata
SURFACE_DIR = args.surface_dir
metadata_by_sid = pickle.load(open(METADATA, "rb"))
mp_args = list(metadata_by_sid.items())
pool = mp.Pool(args.workers)
print("Processing ML trajectories...")
results = list(tqdm(pool.imap(process_mlrs, mp_args), total=len(mp_args)))
# process each individual trajectory
grouped_configs = defaultdict(list)
anomalies = {}
errored_sysids = {}
for result in tqdm(results):
sid, system, adslab_idx, predE, mlrs, anomaly, error_msg = result
if predE is None or mlrs is None:
errored_sysids[sid] = (system, adslab_idx, error_msg)
continue
anomalies[sid] = anomaly
if not anomaly:
grouped_configs[system].append(tuple([adslab_idx, predE, mlrs]))
# group configs by system and sort
sorted_grouped_configs = {}
for system, lst in grouped_configs.items():
sorted_lst = sorted(lst, key=lambda x: x[1])
sorted_grouped_configs[system] = [(x[0], x[2]) for x in sorted_lst]
pickle.dump(
sorted_grouped_configs,
open(f"{args.outdir}/cache_sorted_byE.pkl", "wb"),
)
pickle.dump(anomalies, open(f"{args.outdir}/anomalies_by_sid.pkl", "wb"))
pickle.dump(errored_sysids, open(f"{args.outdir}/errors_by_sid.pkl", "wb"))