Skip to content

Commit

Permalink
hessian matrix now working for direct predictions with a patch calcul…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
OMalenfantThuot committed Apr 13, 2022
1 parent c61a629 commit af13aa8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
38 changes: 29 additions & 9 deletions mlcalcdriver/calculators/schnetpack_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def run(
), "Use the PatchSPCalculator for one configuration at a time."
atoms = posinp_to_ase_atoms(posinp[0])

if property == "hessian" and any(self.subgrid == 2):
raise warnings.warn(
"""
The hessian matrix can have some bad values with a grid of
size 2 because the same atom can be copied multiple times
in the buffers of the same subcell. Use a larger grid.
"""
)

init_property, out_name, derivative, wrt = get_derivative_names(
property, self.available_properties
)
Expand All @@ -104,9 +113,12 @@ def run(
at_to_patches = AtomsToPatches(
cutoff=self.cutoff, n_interaction=self.n_interaction, grid=self.subgrid
)
subcells, subcells_main_idx, original_cell_idx = at_to_patches.split_atoms(
atoms
)
(
subcells,
subcells_main_idx,
original_cell_idx,
complete_subcell_copy_idx,
) = at_to_patches.split_atoms(atoms)

# Pass each subcell independantly
results = []
Expand Down Expand Up @@ -176,12 +188,17 @@ def run(
(
hessian_original_cell_idx_0,
hessian_original_cell_idx_1,
) = prepare_hessian_indices(original_cell_idx[i])
) = prepare_hessian_indices(
original_cell_idx[i], complete_subcell_copy_idx[i]
)

(
hessian_subcells_main_idx_0,
hessian_subcells_main_idx_1,
) = prepare_hessian_indices(subcells_main_idx[i])
) = prepare_hessian_indices(
subcells_main_idx[i],
np.arange(0, len(complete_subcell_copy_idx[i])),
)

hessian[hessian_original_cell_idx_0, hessian_original_cell_idx_1] = (
results[i]["hessian"]
Expand Down Expand Up @@ -225,8 +242,11 @@ def _convert_model(self):
self.model = patches_model.to(self.device)


def prepare_hessian_indices(input_idx):
bias = np.tile(np.array([0, 1, 2]), len(input_idx))
hessian_idx = np.repeat(3 * input_idx, 3) + bias
idx_0, idx_1 = np.meshgrid(hessian_idx, hessian_idx, indexing="ij")
def prepare_hessian_indices(input_idx_0, input_idx_1):

bias_0 = np.tile(np.array([0, 1, 2]), len(input_idx_0))
bias_1 = np.tile(np.array([0, 1, 2]), len(input_idx_1))
hessian_idx_0 = np.repeat(3 * input_idx_0, 3) + bias_0
hessian_idx_1 = np.repeat(3 * input_idx_1, 3) + bias_1
idx_0, idx_1 = np.meshgrid(hessian_idx_0, hessian_idx_1, indexing="ij")
return idx_0, idx_1
23 changes: 20 additions & 3 deletions mlcalcdriver/interfaces/atoms_to_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def split_atoms(self, atoms):
raise ValueError("The grid is too fine to use with this buffer.")

# Add initial buffer around the supercell
buffered_atoms = add_initial_buffer(atoms, full_scaled_buffer_length, full_cell)
buffered_atoms, copy_idx = add_initial_buffer(
atoms, full_scaled_buffer_length, full_cell
)

# Define grid indexes
dim0, dim1, dim2 = (
Expand All @@ -108,6 +110,7 @@ def split_atoms(self, atoms):
subcell_as_atoms_list = []
main_subcell_idx_list = []
original_atoms_idx_list = []
complete_subcell_copy_idx_list = []

for i, subcell in enumerate(subcells_idx):
buffered_subcell_min = subcell - grid_scaled_buffer_length
Expand All @@ -123,6 +126,8 @@ def split_atoms(self, atoms):
)
)[0]

complete_subcell_copy_idx = copy_idx[buffered_subcell_atoms_idx]

main_subcell_idx = np.where(
np.all(
np.floor(
Expand All @@ -139,6 +144,7 @@ def split_atoms(self, atoms):
subcell_as_atoms_list.append(buffered_atoms[buffered_subcell_atoms_idx])
main_subcell_idx_list.append(main_subcell_idx)
original_atoms_idx_list.append(buffered_subcell_atoms_idx[main_subcell_idx])
complete_subcell_copy_idx_list.append(complete_subcell_copy_idx)

# Returns:
# 1) a list of atoms instances (subcells)
Expand All @@ -147,7 +153,12 @@ def split_atoms(self, atoms):
# 3) a list of the original index of the atoms
# to map back per atom predicted properties
# to the original configuration.
return subcell_as_atoms_list, main_subcell_idx_list, original_atoms_idx_list
return (
subcell_as_atoms_list,
main_subcell_idx_list,
original_atoms_idx_list,
complete_subcell_copy_idx_list,
)


def add_initial_buffer(atoms, scaled_buffer_length, full_cell):
Expand All @@ -159,6 +170,7 @@ def add_initial_buffer(atoms, scaled_buffer_length, full_cell):
in_buff = in_buff_low - in_buff_high

# Look at all possible permutations
copy_idx = [i for i in range(len(atoms))]
for i in range(init_scaled_positions.shape[0]):
non_zero_dimensions = np.sum(np.absolute(in_buff[i]))
x, y, z = in_buff[i]
Expand All @@ -177,26 +189,31 @@ def add_initial_buffer(atoms, scaled_buffer_length, full_cell):
atoms = copy_atom_with_translation(
atoms, i, (translation * dim).dot(full_cell)
)
copy_idx.append(i)
if non_zero_dimensions >= 2:
if x != 0:
if y != 0:
atoms = copy_atom_with_translation(
atoms, i, np.array([x, y, 0]).dot(full_cell)
)
copy_idx.append(i)
if z != 0:
atoms = copy_atom_with_translation(
atoms, i, np.array([x, 0, z]).dot(full_cell)
)
copy_idx.append(i)
else:
atoms = copy_atom_with_translation(
atoms, i, np.array([0, y, z]).dot(full_cell)
)
copy_idx.append(i)
if non_zero_dimensions == 3:
atoms = copy_atom_with_translation(
atoms, i, np.array([x, y, z]).dot(full_cell)
)
copy_idx.append(i)

return atoms
return atoms, np.array(copy_idx)


def copy_atom_with_translation(atoms, idx, translation):
Expand Down

0 comments on commit af13aa8

Please sign in to comment.