diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 0d71e48e..43f125ba 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -11,7 +11,13 @@ from jaxopt import GradientDescent as minimize from pysages.colvars.core import CollectiveVariable -from pysages.utils import gaussian, quaternion_from_euler, quaternion_matrix +from pysages.utils import ( + gaussian, + identity, + quaternion_from_euler, + quaternion_matrix, + row_sum, +) def rotate_pattern_with_quaternions(rot_q, pattern): @@ -25,7 +31,7 @@ def func_to_optimise(Q, modified_pattern, local_pattern): # Main class implementing the GeM CV class Pattern: """ - For determining nearest neighbors, + For determining the nearest neighbors, [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) neighborlist library is utilized. This requires the user to define the indices of all the atoms in the system and a JAX MD @@ -34,6 +40,7 @@ class Pattern: def __init__( self, + positions, simulation_box, fractional_coords, reference, @@ -42,8 +49,10 @@ def __init__( centre_j_id, standard_deviation, mesh_size, + number_of_added_sites=0, + width_of_switch_func=None, + scale_for_radial_distance=None, ): - self.characteristic_distance = characteristic_distance self.reference = reference self.neighborlist = neighborlist @@ -51,23 +60,36 @@ def __init__( self.centre_j_id = centre_j_id # This is added to handle neighborlists with fractional coordinates # (needed for NPT simulations) - if fractional_coords: - self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box) - else: - self.positions = self.neighborlist.reference_position + # if fractional_coords: + # self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box) + # else: + # self.positions = self.neighborlist.reference_position + self.positions = positions self.centre_j_coords = self.positions[self.centre_j_id] self.standard_deviation = standard_deviation self.mesh_size = mesh_size + # These settings are needed if continuous LoM is to be used + self.number_of_added_sites = number_of_added_sites + if self.number_of_added_sites > 0: + if width_of_switch_func is None: + self.width_of_switch_func = self.standard_deviation / 2 + else: + self.width_of_switch_func = width_of_switch_func + + if scale_for_radial_distance is None: + self.scale_for_radial_distance = 0.9 + else: + self.scale_for_radial_distance = scale_for_radial_distance + + self._neighborhood = [] def comp_pair_distance_squared(self, pos1): - displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box)) + displacement_fn, _ = space.periodic(np.diag(self.simulation_box)) mic_vector = displacement_fn(self.centre_j_coords, pos1) mic_norm = linalg.norm(mic_vector) return mic_norm, mic_vector def _generate_neighborhood(self): - self._neighborhood = [] - positions_of_all_nbrs = self.positions[self.neighborlist.idx[self.centre_j_id]] distances, mic_vectors = vmap(self.comp_pair_distance_squared)(positions_of_all_nbrs) # remove the same atom from the neighborhood @@ -79,6 +101,14 @@ def _generate_neighborhood(self): ids_of_neighbors = np.argsort(distances)[: len(self.reference)] + n_added_sites = self.number_of_added_sites + if n_added_sites > 0: + ids_of_neighbors_2nd_shell = ids_of_neighbors[-n_added_sites:] + self.shell_distance = self.scale_for_radial_distance * np.mean( + distances[ids_of_neighbors_2nd_shell] + ) + self._neighborhood_distances = distances[ids_of_neighbors] + coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords # Step 1: Translate to origin; coordinates = coordinates.at[:].set(coordinates - np.mean(coordinates, axis=0)) @@ -95,9 +125,21 @@ def _generate_neighborhood(self): self._neighbor_coords = np.array([n["coordinates"] for n in self._neighborhood]) self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors] + def _switching_function(self, distance, width): + result = 0.5 * lax.erfc((distance - self.shell_distance) / width) + return result + def compute_score(self, optim_reference): r = self._neighbor_coords - optim_reference - return np.prod(gaussian(1, self.standard_deviation, r)) + std = self.standard_deviation + + if self.number_of_added_sites != 0: + width = self.width_of_switch_func + squared_dist = row_sum(r**2) + x = self._switching_function(self._neighborhood_distances, width) + return np.exp(-np.sum(x * squared_dist) / (2 * (std**2) * np.sum(x))) + + return np.prod(gaussian(1, std * np.sqrt(len(self.reference)), r)) def rotate_reference(self, random_euler_point): # Perform rotation of the reference pattern; @@ -147,22 +189,21 @@ def return_close(_, n): _, indices = lax.scan( lambda _, sites: ( None, - lax.cond(np.sum(sites) == 1, lambda s: s, lambda s: np.zeros_like(s), sites), + lax.cond(np.sum(sites) == 1, identity, np.zeros_like, sites), ), None, close_sites, ) # Return the locations of settled nighbours in the neighborhood; - # Settlled site should have a unique neighbor + # Settled site should have a unique neighbor settled_neighbor_indices = np.where(np.sum(indices, axis=0) >= 1, 1, 0) return settled_neighbor_indices def driver_match(self, number_of_rotations, number_of_opt_steps, num): - self._generate_neighborhood() - """Step2: Scale the reference so that the spread matches - with the current local pattern""" + # STEP 2: + # Scale the reference so that the spread matches with the current local pattern. local_distance = 0.0 reference_distance = 0.0 for n_index, neighbor in enumerate(self._neighborhood): @@ -171,17 +212,18 @@ def driver_match(self, number_of_rotations, number_of_opt_steps, num): self.reference *= np.sqrt(local_distance / reference_distance) - """Step3: mesh-loop -> Define angles in reduced Euler domain, - and for each rotate, resort and score the pattern - - The implementation below follows the article Martelli et al. 2018 - - - (a) Randomly with uniform probability pick a point in the Euler domain, - (b) Rotate the reference - (c) Resort the local pattern and assign the closest reference sites, - (d) Perform the optimisation step (conjugate gradient), - and (e) store the score with (f) the final settled status""" + # STEP 3: + # + # mesh-loop -> Define angles in reduced Euler domain, and for each rotate, + # resort and score the pattern. + # + # The implementation below follows the article Martelli et al. 2018 + # + # (a) Randomly with uniform probability pick a point in the Euler domain, + # (b) Rotate the reference + # (c) Resort the local pattern and assign the closest reference sites, + # (d) Perform the optimisation step (conjugate gradient), and + # (e) store the score with (f) the final settled status def get_all_scores(newkey, euler_point): # b. Rotate the reference pattern @@ -190,8 +232,7 @@ def get_all_scores(newkey, euler_point): # and assign ids to the closest reference sites newkey, newsubkey = random.split(random.PRNGKey(newkey)) reshuffled_reference, random_indices = self.resort(rotated_reference, newsubkey) - # d. Find the best rotation that aligns the settled sites - # in both patterns; + # d. Find the best rotation that aligns the settled sites in both patterns. # Here, ‘optimal’ or ‘best’ is in terms of least squares errors solver = minimize(fun=func_to_optimise, maxiter=number_of_opt_steps) # We are fixing the initial guess for the quaternions; @@ -217,7 +258,7 @@ def get_all_scores(newkey, euler_point): # a. Randomly pick a point in the Euler domain - key, subkey = random.split(random.PRNGKey(num)) + _, subkey = random.split(random.PRNGKey(num)) mesh_size = self.mesh_size grid_dimension = np.pi / mesh_size euler_angles = np.arange( @@ -256,14 +297,13 @@ def get_all_scores(newkey, euler_point): def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params): - if params.fractional_coords: update_neighborlist = neighborlist.update(np.divide(all_positions, np.diag(simulation_box))) else: update_neighborlist = neighborlist.update(all_positions) - """Step1: Move the reference and - local patterns so that their centers coincide with the origin""" + # STEP 1: + # Move the reference and local patterns so that their centers coincide with the origin. reference_positions = params.reference_positions.at[:].set( params.reference_positions - np.mean(params.reference_positions, axis=0) @@ -273,6 +313,7 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) seed = np.int64(time.process_time() * 1e5) optimal_results = vmap( lambda i: Pattern( + all_positions, params.box, params.fractional_coords, reference_positions, @@ -281,6 +322,9 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) i, params.standard_deviation, params.mesh_size, + params.number_of_added_sites, + params.width_of_switch_func, + params.scale_for_radial_distance, ).driver_match( params.number_of_rotations, params.number_of_opt_it, @@ -298,14 +342,14 @@ class GeM(CollectiveVariable): an atomic or a molecular site is described in [Martelli2018](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.97.064105). - Given a pattern, the algorithm is returning an average score (from 0 to 1), + Given a pattern, the algorithm returns an average score (from 0 to 1), denoting how closely the atomic neighbors resemble the reference. - For determining nearest neighbors, + For determining the nearest neighbors, [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) neighborlist library is utilized. This requires the user to define the indices of all the atoms in the system and a JAX MD - neighbor list callable for updating the state. + neighbor list which is callable for updating the state. Matching a neighborhood to the pattern is an optimization process. Based on the number of initial rotations of the reference structure @@ -326,11 +370,11 @@ class GeM(CollectiveVariable): box: JaxArray Definition of the simulation box. number_of_rotations: integer - Number of initial rotated structures for the optimization study. - number_of_opt_it: iteger - Number of iterations for gradient descent. + A number of initial rotated structures for the optimization study. + number_of_opt_it: integer + A number of iterations for gradient descent. standard_deviation: float - Parameter that controls the spread of the Gaussian function. + A parameter that controls the spread of the Gaussian function. mesh_size: integer Defines the size of the angular grid from which we draw random Euler angles. @@ -339,6 +383,17 @@ class GeM(CollectiveVariable): fractional_coords: bool Set to True if NPT simulation is considered and the box size changes; use periodic_general for constructing the neighborlist. + number_of_added_sites: int + Specify the number of additional sites to the main reference for the continuous + calculation (skip if the continuous LoM is not needed). The additional atoms should + already be added to the reference (reference_positions). + In other words, the reference should have elements corresponding to the + original reference and additional coordinates representing the extra atoms. + width_of_switch_func: float + Width of the switching function for the continuous score function. + scale_for_radial_distance: float + Scaling factor for the mean radial distance of added sites + used in the continuous score function calculation. Returns ------- calculate_lom: float @@ -357,6 +412,9 @@ def __init__( mesh_size, nbrs, fractional_coords, + number_of_added_sites=0, + width_of_switch_func=None, + scale_for_radial_distance=None, ): super().__init__(indices, group_length=None) @@ -369,6 +427,10 @@ def __init__( self.mesh_size = mesh_size self.nbrs = nbrs self.fractional_coords = fractional_coords + # The parameters below are only used in the continuous version + self.number_of_added_sites = number_of_added_sites + self.width_of_switch_func = width_of_switch_func + self.scale_for_radial_distance = scale_for_radial_distance @property def function(self): diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 00279a4c..81b04b0d 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -17,5 +17,5 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity +from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, row_sum from .transformations import quaternion_from_euler, quaternion_matrix