Skip to content

Commit

Permalink
Minor update aiida-phonopy
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Aug 30, 2020
1 parent 8898b76 commit b69bc71
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
61 changes: 33 additions & 28 deletions aiida_phonopy/workflows/iter_ha.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def collect_dataset(number_of_steps_for_fitting,

d, f, energies = _extract_dataset_from_db(forces_in_db, ph_info_in_db)

displacements, forces, included = _create_dataset(
displacements, forces, included = create_dataset(
d, f, energies,
number_of_steps_for_fitting.value,
include_ratio.value,
linear_decay.value)
max_items=number_of_steps_for_fitting.value,
ratio=include_ratio.value,
linear_decay=linear_decay.value)
dataset = ArrayData()
dataset.set_array('forces', forces)
dataset.set_array('displacements', displacements)
Expand All @@ -134,6 +134,26 @@ def collect_dataset(number_of_steps_for_fitting,
return {'dataset': dataset, 'supercell_energies': supercell_energies}


def create_dataset(displacements, forces, energies,
max_items=None, ratio=None, linear_decay=False):
included = _choose_snapshots_by_linear_decay(
displacements, forces, max_items=max_items, linear_decay=linear_decay)

# Remove snapshots that have high energies when include_ratio is given.
if energies is not None and ratio is not None:
if 0 < ratio and ratio < 1:
included = _remove_high_energy_snapshots(energies, included, ratio)

_displacements, _forces, _energies = _include_snapshots(
displacements, forces, energies, included)

# Concatenate the data
d = np.concatenate(_displacements, axis=0)
f = np.concatenate(_forces, axis=0)

return d, f, included


def _extract_dataset_from_db(forces_in_db, ph_info_in_db):
nitems = len(forces_in_db)
displacements = []
Expand All @@ -154,28 +174,9 @@ def _extract_dataset_from_db(forces_in_db, ph_info_in_db):
return displacements, forces, energies


def _create_dataset(displacements, forces, energies,
max_items, ratio, linear_decay):
included = _choose_snapshots_by_linear_decay(
displacements, forces, max_items, linear_decay=linear_decay)

# Remove snapshots that have high energies when include_ratio is given.
if energies is not None and ratio is not None:
if 0 < ratio and ratio < 1:
included = _remove_high_energy_snapshots(energies, included, ratio)

_displacements, _forces, _energies = _include_snapshots(
displacements, forces, energies, included)

# Concatenate the data
d = np.concatenate(_displacements, axis=0)
f = np.concatenate(_forces, axis=0)

return d, f, included


def _choose_snapshots_by_linear_decay(displacements, forces, max_items,
linear_decay=True):
def _choose_snapshots_by_linear_decay(displacements, forces,
max_items=None,
linear_decay=False):
"""Choose snapshots by linear_decay
With linear_decay=True, numbers of snapshots to be taken
Expand All @@ -194,11 +195,15 @@ def _choose_snapshots_by_linear_decay(displacements, forces, max_items,
assert len(forces) == len(displacements)

nitems = len(forces)
if max_items is None:
_max_items = nitems
else:
_max_items = max_items

if linear_decay:
ratios = (np.arange(max_items, dtype=float) + 1) / max_items
ratios = (np.arange(_max_items, dtype=float) + 1) / _max_items
else:
ratios = np.ones(max_items, dtype=int)
ratios = np.ones(_max_items, dtype=int)
ratios = ratios[-nitems:]
included = []

Expand Down
7 changes: 6 additions & 1 deletion aiida_phonopy/workflows/phonopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def initialize(self):
self.ctx.supercells[label] = return_vals['supercell']

def run_force_and_nac_calculations(self):
self.report('run force calculations')
self._run_force_calculations()
self._run_nac_calculation()

def _run_force_calculations(self):
# Forces
self.report('run force calculations')
for key in self.ctx.supercells:
builder = get_calcjob_builder(self.ctx.supercells[key],
self.inputs.calculator_settings,
Expand All @@ -218,7 +221,9 @@ def run_force_and_nac_calculations(self):
self.report('{} pk = {}'.format(label, future.pk))
self.to_context(**{label: future})

def _run_nac_calculation(self):
# Born charges and dielectric constant
self.report('run nac calculation')
if self.is_nac():
self.report('calculate born charges and dielectric constant')
builder = get_calcjob_builder(self.ctx.primitive,
Expand Down

0 comments on commit b69bc71

Please sign in to comment.