Skip to content

Commit

Permalink
added options for sparse and jax diagonalization
Browse files Browse the repository at this point in the history
  • Loading branch information
OMalenfantThuot committed May 10, 2024
1 parent 3782fa7 commit 9ad15f7
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 32 deletions.
71 changes: 60 additions & 11 deletions mlcalcdriver/calculators/schnetpack_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import numpy as np
import scipy
import torch
import warnings
from schnetpack import AtomsLoader
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
units=eVA,
md=False,
subgrid=None,
sparse=False,
):
super().__init__(
model_dir=model_dir,
Expand All @@ -55,6 +57,7 @@ def __init__(
)
self.n_interaction = len(self.model.representation.interactions)
self.subgrid = subgrid
self.sparse = sparse
self._convert_model()

@property
Expand Down Expand Up @@ -99,6 +102,11 @@ def run(
predictions : :class:`numpy.ndarray`
Corresponding prediction by the model.
"""
import psutil
import os

pid = os.getpid()
proc = psutil.Process(pid)

# Initial setup
assert (
Expand Down Expand Up @@ -187,9 +195,7 @@ def run(
results.append({out_name: cpu_patch_deriv1})
for key, value in patch_forward_results.items():
del value
del patch_forward_results
del patch_deriv1
print(torch.cuda.max_memory_allocated())
del patch_forward_results, patch_deriv1

if abs(derivative) == 2:
raise NotImplementedError()
Expand All @@ -212,7 +218,22 @@ def run(
predictions["forces"] = forces

elif property == "hessian":
hessian = np.zeros((3 * len(atoms), 3 * len(atoms)), dtype=np.float32)
hess_shape = (3 * len(atoms), 3 * len(atoms))
if self.sparse:
data_lims = [
9 * s.size * cs.size
for (s, cs) in zip(subcells_main_idx, complete_subcell_copy_idx)
]
data_lims.insert(0, 0)
data_lims = np.cumsum(data_lims)
num_data = data_lims[-1]

data = np.zeros(num_data, dtype=np.float32)
row, col = np.zeros(num_data, dtype=np.intc), np.zeros(
num_data, dtype=np.intc
)
else:
hessian = np.zeros(hess_shape, dtype=np.float32)

for i in range(len(results)):
(
Expand All @@ -230,12 +251,40 @@ def run(
np.arange(0, len(complete_subcell_copy_idx[i])),
)

hessian[
hessian_original_cell_idx_0, hessian_original_cell_idx_1
] = results[i]["hessian"].squeeze()[
hessian_subcells_main_idx_0, hessian_subcells_main_idx_1
]
predictions["hessian"] = np.expand_dims(hessian, 0)
if self.sparse:
row[
data_lims[i] : data_lims[i + 1]
] = hessian_original_cell_idx_0.flatten()
col[
data_lims[i] : data_lims[i + 1]
] = hessian_original_cell_idx_1.flatten()
data[data_lims[i] : data_lims[i + 1]] = (
results[i]["hessian"]
.copy()
.squeeze()[
hessian_subcells_main_idx_0, hessian_subcells_main_idx_1
]
.flatten()
)
else:
hessian[
hessian_original_cell_idx_0, hessian_original_cell_idx_1
] = results[i]["hessian"].squeeze()[
hessian_subcells_main_idx_0, hessian_subcells_main_idx_1
]
del hessian_subcells_main_idx_0, hessian_subcells_main_idx_1
del hessian_original_cell_idx_0, hessian_original_cell_idx_1

if self.sparse:
hessian = scipy.sparse.coo_array(
(data, (row, col)), shape=hess_shape, dtype=np.float32
)
hessian.eliminate_zeros()
hessian = hessian.tocsr()
else:
hessian = np.expand_dims(hessian, 0)

predictions["hessian"] = hessian

else:
raise NotImplementedError()
Expand Down Expand Up @@ -278,7 +327,7 @@ def prepare_hessian_indices(input_idx_0, input_idx_1):
hessian_idx_0 = np.repeat(3 * input_idx_0, 3) + bias_0
hessian_idx_1 = np.repeat(3 * input_idx_1, 3) + bias_1
idx_0, idx_1 = np.meshgrid(hessian_idx_0, hessian_idx_1, indexing="ij")
return idx_0, idx_1
return idx_0.astype(np.intc), idx_1.astype(np.intc)


def collect_results(patch_results):
Expand Down
106 changes: 86 additions & 20 deletions mlcalcdriver/workflows/phonon.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ def _create_displacements(self):
)
return structs

def _post_proc(self, job):
def _post_proc(self, job, use_jax=False):
r"""
Calculates the energies and normal modes from the results
obtained from the model.
"""
self.dyn_mat = self._compute_dyn_mat(job)
self.energies, self.normal_modes = self._solve_dyn_mat()
self.energies, self.normal_modes = self._solve_dyn_mat(use_jax=use_jax)
self.energies *= HA_TO_CMM1

def _compute_dyn_mat(self, job):
Expand Down Expand Up @@ -308,7 +308,7 @@ def _compute_hessian(self, job):
n_at = len(self.posinp)
if "hessian" in job.results.keys():
h = job.results["hessian"].reshape(3 * n_at, 3 * n_at)
return ((h + h.T) / 2.0).astype(np.float64)
return (h + h.T) / 2.0
else:
warnings.warn(
"The hessian matrix is approximated by a numerical derivative."
Expand All @@ -323,12 +323,26 @@ def _compute_hessian(self, job):
) / (12 * self.translation_amplitudes)
return -(hessian + hessian.T) / 2.0

def _solve_dyn_mat(self):
def _solve_dyn_mat(self, use_jax=False):
r"""
Obtains the eigenvalues and eigenvectors from
the dynamical matrix
"""
eigs, vecs = scipy.linalg.eigh(self.dyn_mat)
if use_jax:
try:
import jax.scipy

jax.config.update("jax_traceback_filtering", "off")
use_jax = True
except ModuleNotFoundError:
print("Jax not installed, defaults to basic scipy library.")
use_jax = False

if use_jax:
eigs, vecs = jax.scipy.linalg.eigh(self.dyn_mat)
else:
eigs, vecs = scipy.linalg.eigh(self.dyn_mat)

eigs *= EV_TO_HA * B_TO_ANG**2 / AMU_TO_EMU
eigs = np.sign(eigs) * np.sqrt(np.where(eigs < 0, -eigs, eigs))
return eigs, vecs
Expand All @@ -341,7 +355,7 @@ class PhononFromHessian(Phonon):
the time to compute the hessian matrix each time
"""

def __init__(self, posinp, hessian):
def __init__(self, posinp, hessian, sparse=False, sparse_kwargs=None):
r"""
Parameters
----------
Expand All @@ -358,6 +372,9 @@ def __init__(self, posinp, hessian):
finite_difference=False,
low_memory=True,
)
self.sparse = sparse
if self.sparse:
self.sparse_kwargs = sparse_kwargs
self.hessian = hessian

@property
Expand All @@ -366,19 +383,68 @@ def hessian(self):

@hessian.setter
def hessian(self, hessian):
if isinstance(hessian, str):
hessian = np.load(hessian)

if isinstance(hessian, np.ndarray):
assert hessian[0].shape == (
3 * len(self.posinp),
3 * len(self.posinp),
), f"The hessian shape {hessian.shape} does not match the number of atoms {len(self.posinp)}"
self._hessian = hessian
if self.sparse:
assert isinstance(hessian, scipy.sparse._compressed._cs_matrix)
self._hessian = scipy.sparse.csr_array(hessian)
else:
raise TypeError("The hessian matrix should be a numpy array.")
if isinstance(hessian, str):
hessian = np.load(hessian)

if isinstance(hessian, np.ndarray):
assert hessian[0].shape == (
3 * len(self.posinp),
3 * len(self.posinp),
), f"The hessian shape {hessian.shape} does not match the number of atoms {len(self.posinp)}"
self._hessian = hessian
else:
raise TypeError("The hessian matrix should be a numpy array.")

def run(self):
job = Job(posinp=self.posinp, calculator=self.calculator)
job.results["hessian"] = self.hessian
self._post_proc(job)
@property
def sparse(self):
return self._sparse

@sparse.setter
def sparse(self, sparse):
self._sparse = bool(sparse)

def run(self, use_jax=False, sparse_kwargs={}):
if not self.sparse:
job = Job(posinp=self.posinp, calculator=self.calculator)
job.results["hessian"] = self.hessian
self._post_proc(job, use_jax=use_jax)
else:
self._solve_sparse_hessian(kwargs=sparse_kwargs)

def _solve_sparse_hessian(self, kwargs):
self.dyn_mat = self._compute_sparse_dyn_mat()
self.energies, self.normal_modes = self._solve_sparse_dyn_mat(kwargs=kwargs)
self.energies *= HA_TO_CMM1

def _compute_sparse_dyn_mat(self):
self.hessian = (self.hessian + self.hessian.T) / 2
masses = np.array(
[atom.mass for atom in self.posinp for _ in range(3)],
)
for i in range(self.hessian.shape[0]):
self.hessian.data[
self.hessian.indptr[i] : self.hessian.indptr[i + 1]
] = self.hessian.data[
self.hessian.indptr[i] : self.hessian.indptr[i + 1]
] / np.sqrt(
masses[i]
* masses[
self.hessian.indices[
self.hessian.indptr[i] : self.hessian.indptr[i + 1]
]
]
)

def _solve_sparse_dyn_mat(self, kwargs={}):
eigs, vecs = scipy.sparse.linalg.eigsh(
self.hessian,
which="LM",
**kwargs,
)
eigs *= EV_TO_HA * B_TO_ANG**2 / AMU_TO_EMU
eigs = np.sign(eigs) * np.sqrt(np.abs(eigs))
return eigs, vecs
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
readme = readme_file.read()

requirements = [
"numpy>=1.20,<1.24",
"numpy>=1.20",
"torch>=2.0.1",
"schnetpack==1.0.1",
"ase>=3.22.0",
Expand Down

0 comments on commit 9ad15f7

Please sign in to comment.