Skip to content

Commit

Permalink
Merge pull request #432 from ReactionMechanismGenerator/scan_fixes
Browse files Browse the repository at this point in the history
Misc. scan fixes
  • Loading branch information
alongd committed Oct 21, 2020
2 parents 4577962 + 2c82231 commit 20e68d2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
29 changes: 29 additions & 0 deletions arc/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,8 @@ def plot_2d_rotor_scan(results: dict,
if len(results['scans']) != 2:
raise InputError(f'results must represent a 2D rotor, got {len(results["scans"])}D')

results['directed_scan'] = clean_scan_results(results['directed_scan'])

# phis0 and phis1 correspond to columns and rows in energies, respectively
phis0 = np.array(sorted(list(set([float(key[0]) for key in results['directed_scan'].keys()]))), np.float64)
phis1 = np.array(sorted(list(set([float(key[1]) for key in results['directed_scan'].keys()]))), np.float64)
Expand Down Expand Up @@ -1258,3 +1260,30 @@ def save_nd_rotor_yaml(results, path):
elif key == 'xyz' and not isinstance(val, str):
modified_results['directed_scan'][dihedral_tuple][key] = xyz_to_str(val)
save_yaml_file(path=path, content=modified_results)


def clean_scan_results(results: dict) -> dict:
"""
Filter noise of high energy points if the value distribution is such that removing the top 10% points
results in values which are significantly lower. Useful for scanning methods which occasionally give
extremely high energies by mistake.
Args:
results (dict): The directed snan results dictionary. Keys are dihedral tuples, values are energies.
Returns:
dict: A filtered results dictionary.
"""
results_ = results.copy()
for val in results_.values():
val['energy'] = float(val['energy'])
min_val = min([val['energy'] for val in results_.values()])
for val in results_.values():
val['energy'] = val['energy'] - min_val
max_val = max([val['energy'] for val in results_.values()])
cut_down_values = [val['energy'] for val in results_.values() if val['energy'] < 0.9 * max_val]
max_cut_down_values = max(cut_down_values)
if max_cut_down_values < 0.5 * max_val:
# filter high values
results_ = {key: val for key, val in results_.items() if val['energy'] < 0.5 * max_val}
return results_
22 changes: 22 additions & 0 deletions arc/plotterTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,28 @@ def test_log_bde_report(self):
"""
self.assertEqual(content, expected_content)

def test_clean_scan_results(self):
"""Test the clean_scan_results function"""
correct_results = {(1, 1): {'energy': 0},
(1, 2): {'energy': 7},
(1, 3): {'energy': 4.5},
(1, 4): {'energy': 5}}

results_1 = {(1, 1): {'energy': -2},
(1, 2): {'energy': '5'},
(1, 3): {'energy': 2.5},
(1, 4): {'energy': 3}}
filtered_results_1 = plotter.clean_scan_results(results_1)
self.assertEqual(filtered_results_1, correct_results)

results_2 = {(1, 1): {'energy': '-2'},
(1, 2): {'energy': 5},
(1, 3): {'energy': 2.5},
(1, 4): {'energy': 3},
(1, 5): {'energy': 1100}}
filtered_results_2 = plotter.clean_scan_results(results_2)
self.assertEqual(filtered_results_2, correct_results)

@classmethod
def tearDownClass(cls):
"""A function that is run ONCE after all unit tests in this class."""
Expand Down

0 comments on commit 20e68d2

Please sign in to comment.