Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable continuous GeM computation. #309

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
142 changes: 102 additions & 40 deletions pysages/colvars/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from jaxopt import GradientDescent as minimize

from pysages.colvars.core import CollectiveVariable
from pysages.utils import gaussian, quaternion_from_euler, quaternion_matrix
from pysages.utils import (
gaussian,
identity,
quaternion_from_euler,
quaternion_matrix,
row_sum,
)


def rotate_pattern_with_quaternions(rot_q, pattern):
Expand All @@ -25,7 +31,7 @@ def func_to_optimise(Q, modified_pattern, local_pattern):
# Main class implementing the GeM CV
class Pattern:
"""
For determining nearest neighbors,
For determining the nearest neighbors,
[JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html)
neighborlist library is utilized. This requires the user
to define the indices of all the atoms in the system and a JAX MD
Expand All @@ -34,6 +40,7 @@ class Pattern:

def __init__(
self,
positions,
simulation_box,
fractional_coords,
reference,
Expand All @@ -42,32 +49,47 @@ def __init__(
centre_j_id,
standard_deviation,
mesh_size,
number_of_added_sites=0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the continuous version of the LoM is to be used, the additional atoms should already be added to the reference (coordinates of the reference structure). In other words, reference should have $M$ rows corresponding to the original reference and additional $M_b$ rows representing the coordinates of extra atoms. With number_of_added_atoms we specify how many sites (rows) were included. The $M_b$ atoms are assumed to be outside the 1st shell, so their distances from the central site are larger than for the other $M$ sites.

width_of_switch_func=None,
scale_for_radial_distance=None,
):

self.characteristic_distance = characteristic_distance
self.reference = reference
self.neighborlist = neighborlist
self.simulation_box = simulation_box
self.centre_j_id = centre_j_id
# This is added to handle neighborlists with fractional coordinates
# (needed for NPT simulations)
if fractional_coords:
self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box)
else:
self.positions = self.neighborlist.reference_position
# if fractional_coords:
# self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box)
# else:
# self.positions = self.neighborlist.reference_position
self.positions = positions
self.centre_j_coords = self.positions[self.centre_j_id]
self.standard_deviation = standard_deviation
self.mesh_size = mesh_size
# These settings are needed if continuous LoM is to be used
self.number_of_added_sites = number_of_added_sites
if self.number_of_added_sites > 0:
if width_of_switch_func is None:
self.width_of_switch_func = self.standard_deviation / 2
else:
self.width_of_switch_func = width_of_switch_func

if scale_for_radial_distance is None:
self.scale_for_radial_distance = 0.9
else:
self.scale_for_radial_distance = scale_for_radial_distance

self._neighborhood = []

def comp_pair_distance_squared(self, pos1):
displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box))
displacement_fn, _ = space.periodic(np.diag(self.simulation_box))
mic_vector = displacement_fn(self.centre_j_coords, pos1)
mic_norm = linalg.norm(mic_vector)
return mic_norm, mic_vector

def _generate_neighborhood(self):
self._neighborhood = []

positions_of_all_nbrs = self.positions[self.neighborlist.idx[self.centre_j_id]]
distances, mic_vectors = vmap(self.comp_pair_distance_squared)(positions_of_all_nbrs)
# remove the same atom from the neighborhood
Expand All @@ -79,6 +101,14 @@ def _generate_neighborhood(self):

ids_of_neighbors = np.argsort(distances)[: len(self.reference)]

n_added_sites = self.number_of_added_sites
if n_added_sites > 0:
ids_of_neighbors_2nd_shell = ids_of_neighbors[-n_added_sites:]
self.shell_distance = self.scale_for_radial_distance * np.mean(
distances[ids_of_neighbors_2nd_shell]
)
self._neighborhood_distances = distances[ids_of_neighbors]

coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords
# Step 1: Translate to origin;
coordinates = coordinates.at[:].set(coordinates - np.mean(coordinates, axis=0))
Expand All @@ -95,9 +125,21 @@ def _generate_neighborhood(self):
self._neighbor_coords = np.array([n["coordinates"] for n in self._neighborhood])
self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors]

def _switching_function(self, distance, width):
result = 0.5 * lax.erfc((distance - self.shell_distance) / width)
return result

def compute_score(self, optim_reference):
r = self._neighbor_coords - optim_reference
return np.prod(gaussian(1, self.standard_deviation, r))
std = self.standard_deviation

if self.number_of_added_sites != 0:
width = self.width_of_switch_func
squared_dist = row_sum(r**2)
x = self._switching_function(self._neighborhood_distances, width)
return np.exp(-np.sum(x * squared_dist) / (2 * (std**2) * np.sum(x)))

return np.prod(gaussian(1, std * np.sqrt(len(self.reference)), r))

def rotate_reference(self, random_euler_point):
# Perform rotation of the reference pattern;
Expand Down Expand Up @@ -147,22 +189,21 @@ def return_close(_, n):
_, indices = lax.scan(
lambda _, sites: (
None,
lax.cond(np.sum(sites) == 1, lambda s: s, lambda s: np.zeros_like(s), sites),
lax.cond(np.sum(sites) == 1, identity, np.zeros_like, sites),
),
None,
close_sites,
)
# Return the locations of settled nighbours in the neighborhood;
# Settlled site should have a unique neighbor
# Settled site should have a unique neighbor
settled_neighbor_indices = np.where(np.sum(indices, axis=0) >= 1, 1, 0)
return settled_neighbor_indices

def driver_match(self, number_of_rotations, number_of_opt_steps, num):

self._generate_neighborhood()

"""Step2: Scale the reference so that the spread matches
with the current local pattern"""
# STEP 2:
# Scale the reference so that the spread matches with the current local pattern.
local_distance = 0.0
reference_distance = 0.0
for n_index, neighbor in enumerate(self._neighborhood):
Expand All @@ -171,17 +212,18 @@ def driver_match(self, number_of_rotations, number_of_opt_steps, num):

self.reference *= np.sqrt(local_distance / reference_distance)

"""Step3: mesh-loop -> Define angles in reduced Euler domain,
and for each rotate, resort and score the pattern

The implementation below follows the article Martelli et al. 2018


(a) Randomly with uniform probability pick a point in the Euler domain,
(b) Rotate the reference
(c) Resort the local pattern and assign the closest reference sites,
(d) Perform the optimisation step (conjugate gradient),
and (e) store the score with (f) the final settled status"""
# STEP 3:
#
# mesh-loop -> Define angles in reduced Euler domain, and for each rotate,
# resort and score the pattern.
#
# The implementation below follows the article Martelli et al. 2018
#
# (a) Randomly with uniform probability pick a point in the Euler domain,
# (b) Rotate the reference
# (c) Resort the local pattern and assign the closest reference sites,
# (d) Perform the optimisation step (conjugate gradient), and
# (e) store the score with (f) the final settled status

def get_all_scores(newkey, euler_point):
# b. Rotate the reference pattern
Expand All @@ -190,8 +232,7 @@ def get_all_scores(newkey, euler_point):
# and assign ids to the closest reference sites
newkey, newsubkey = random.split(random.PRNGKey(newkey))
reshuffled_reference, random_indices = self.resort(rotated_reference, newsubkey)
# d. Find the best rotation that aligns the settled sites
# in both patterns;
# d. Find the best rotation that aligns the settled sites in both patterns.
# Here, ‘optimal’ or ‘best’ is in terms of least squares errors
solver = minimize(fun=func_to_optimise, maxiter=number_of_opt_steps)
# We are fixing the initial guess for the quaternions;
Expand All @@ -217,7 +258,7 @@ def get_all_scores(newkey, euler_point):

# a. Randomly pick a point in the Euler domain

key, subkey = random.split(random.PRNGKey(num))
_, subkey = random.split(random.PRNGKey(num))
mesh_size = self.mesh_size
grid_dimension = np.pi / mesh_size
euler_angles = np.arange(
Expand Down Expand Up @@ -256,14 +297,13 @@ def get_all_scores(newkey, euler_point):


def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params):

if params.fractional_coords:
update_neighborlist = neighborlist.update(np.divide(all_positions, np.diag(simulation_box)))
else:
update_neighborlist = neighborlist.update(all_positions)

"""Step1: Move the reference and
local patterns so that their centers coincide with the origin"""
# STEP 1:
# Move the reference and local patterns so that their centers coincide with the origin.

reference_positions = params.reference_positions.at[:].set(
params.reference_positions - np.mean(params.reference_positions, axis=0)
Expand All @@ -273,6 +313,7 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params)
seed = np.int64(time.process_time() * 1e5)
optimal_results = vmap(
lambda i: Pattern(
all_positions,
params.box,
params.fractional_coords,
reference_positions,
Expand All @@ -281,6 +322,9 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params)
i,
params.standard_deviation,
params.mesh_size,
params.number_of_added_sites,
params.width_of_switch_func,
params.scale_for_radial_distance,
).driver_match(
params.number_of_rotations,
params.number_of_opt_it,
Expand All @@ -298,14 +342,14 @@ class GeM(CollectiveVariable):
an atomic or a molecular site is described in
[Martelli2018](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.97.064105).

Given a pattern, the algorithm is returning an average score (from 0 to 1),
Given a pattern, the algorithm returns an average score (from 0 to 1),
denoting how closely the atomic neighbors resemble the reference.

For determining nearest neighbors,
For determining the nearest neighbors,
[JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html)
neighborlist library is utilized. This requires the user
to define the indices of all the atoms in the system and a JAX MD
neighbor list callable for updating the state.
neighbor list which is callable for updating the state.

Matching a neighborhood to the pattern is an optimization process.
Based on the number of initial rotations of the reference structure
Expand All @@ -326,11 +370,11 @@ class GeM(CollectiveVariable):
box: JaxArray
Definition of the simulation box.
number_of_rotations: integer
Number of initial rotated structures for the optimization study.
number_of_opt_it: iteger
Number of iterations for gradient descent.
A number of initial rotated structures for the optimization study.
number_of_opt_it: integer
A number of iterations for gradient descent.
standard_deviation: float
Parameter that controls the spread of the Gaussian function.
A parameter that controls the spread of the Gaussian function.
mesh_size: integer
Defines the size of the angular grid from which we draw
random Euler angles.
Expand All @@ -339,6 +383,17 @@ class GeM(CollectiveVariable):
fractional_coords: bool
Set to True if NPT simulation is considered and the box size
changes; use periodic_general for constructing the neighborlist.
number_of_added_sites: int
Specify the number of additional sites to the main reference for the continuous
calculation (skip if the continuous LoM is not needed). The additional atoms should
already be added to the reference (reference_positions).
In other words, the reference should have elements corresponding to the
original reference and additional coordinates representing the extra atoms.
width_of_switch_func: float
Width of the switching function for the continuous score function.
scale_for_radial_distance: float
Scaling factor for the mean radial distance of added sites
used in the continuous score function calculation.
Returns
-------
calculate_lom: float
Expand All @@ -357,6 +412,9 @@ def __init__(
mesh_size,
nbrs,
fractional_coords,
number_of_added_sites=0,
width_of_switch_func=None,
scale_for_radial_distance=None,
):
super().__init__(indices, group_length=None)

Expand All @@ -369,6 +427,10 @@ def __init__(
self.mesh_size = mesh_size
self.nbrs = nbrs
self.fractional_coords = fractional_coords
# The parameters below are only used in the continuous version
self.number_of_added_sites = number_of_added_sites
self.width_of_switch_func = width_of_switch_func
self.scale_for_radial_distance = scale_for_radial_distance

@property
def function(self):
Expand Down
2 changes: 1 addition & 1 deletion pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
solve_pos_def,
try_import,
)
from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity
from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, row_sum
from .transformations import quaternion_from_euler, quaternion_matrix