In [47]:
import os
import jax

DEFAULT_PLATFORM = 'cpu'
DEFAULT_CORE_NUM = os.cpu_count()

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', DEFAULT_PLATFORM)
os.environ['XLA_FLAGS'] = '--{:s}={:d}'.format(
    "xla_force_host_platform_device_count",
    DEFAULT_CORE_NUM
)


## pbc2pbc


In [48]:
from jax import numpy as jnp


def _pbc2pbc(pbc):
    result = jnp.zeros(3, dtype=bool)
    return result.at[:].set(pbc)


pbc2pbc = jax.jit(_pbc2pbc)


%timeit  _pbc2pbc([False, True, False])
%timeit  pbc2pbc([False, True, False])


1.45 ms ± 6.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
20.3 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## fractional <---> cartesian


In [49]:
import jax
import numpy as np


In [58]:
def _frc2car(f, cell):
    """Calculate Cartesian positions from scaled positions."""
    return jnp.asarray(f) @ jnp.asarray(cell)


frc2car = jax.jit(_frc2car)


a = jnp.asarray(np.random.rand(3))
b = jnp.asarray(np.random.rand(2, 3))
cell = jnp.asarray(np.identity(3) * 5)


%timeit  _frc2car(a, cell), _frc2car(b, cell)
%timeit  frc2car(a, cell), frc2car(b, cell)


264 µs ± 7.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
31 µs ± 18.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [59]:
def _car2frc(v, cell):
    """Calculate scaled positions from Cartesian positions."""
    return jnp.linalg.solve(
        jnp.transpose(cell),
        jnp.transpose(v)
    ).transpose()


car2frc = jax.jit(_car2frc)


a = jnp.asarray(np.random.rand(3))
b = jnp.asarray(np.random.rand(2, 3))
cell = jnp.asarray(np.identity(3) * 5)
a_ = a @ cell
b_ = b @ cell


%timeit _car2frc(a, cell), _car2frc(b, cell)
%timeit  car2frc(a, cell), car2frc(b, cell)


599 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
49.5 µs ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [60]:
a = jnp.asarray(np.random.rand(3))
b = jnp.asarray(np.random.rand(2, 3))
cell = jnp.asarray(np.identity(3) * 5)
a_ = a @ cell
b_ = b @ cell

print(car2frc(a_, cell) - a < 1e-15)
print(car2frc(b_, cell) - b < 1e-15)


[ True  True  True]
[[ True  True  True]
 [ True  True  True]]


## cell & pbc class

根据不同的 cell 和 pbc 情况返回不同的 jit 函数


In [170]:
import itertools
from functools import partial
from typing import Callable

from ase.geometry import minkowski_reduce
from ase.geometry.cell import complete_cell
from ase.cell import Cell


@jax.jit
def fn_mic_nopbc(v, cell):
    return v


@jax.jit
def fn_mic_naive(v, cell):
    """Finds the minimum-image representation of vector(s) v.

    Safe to use for (pbc.all() and (norm(v_mic) < 0.5 * min(cell.lengths()))).
    Can otherwise fail for non-orthorhombic cells.
    Described in:
    W. Smith, "The Minimum Image Convention in Non-Cubic MD Cells", 1989,
    http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.57.1696.
    """
    frac = car2frc(v, cell)
    frac = frac - jnp.floor(frac + 0.5)
    return frac @ cell


@jax.jit
def wrap_positions(v, cell, pbc=True):
    frac = car2frc(jnp.asarray(v), jnp.asarray(cell))
    frac_wrapped = jnp.where(pbc2pbc(pbc), frac % 1.0, frac)
    return frac_wrapped @ jnp.asarray(cell)


@jax.jit
def fn_mic_general(v, cell, pbc, hkls, r_op):
    v = jnp.asarray(v, dtype=float)
    cell = jnp.asarray(cell, dtype=float)
    r_op = jnp.asarray(r_op, dtype=float)
    hkls = jnp.asarray(hkls, dtype=float)

    rcell = r_op @ cell
    vrvecs = hkls @ rcell
    v = wrap_positions(v, rcell, pbc=pbc)

    if v.ndim == 1:
        x = v + vrvecs
        lengths = jnp.linalg.norm(x, axis=1)
        indices = jnp.argmin(lengths, axis=0)
        return x[indices, :]
    elif v.ndim == 2:
        x = v + vrvecs[:, None]
        lengths = jnp.linalg.norm(x, axis=2)
        indices = jnp.argmin(lengths, axis=0)
        return x[indices, jnp.arange(len(v)), :]
    else:
        raise KeyError("v.ndim must <= 2.")


class CellWithPBC:

    def __init__(self, cell=None, pbc=None,
                 cell_as_param: bool = False):
        self.pbc = pbc2pbc(pbc if pbc else False)
        self._cell_as_param = cell_as_param
        self._cell = Cell.new(cell)

    def __repr__(self):
        if self._cell.orthorhombic:
            numbers = self._cell.lengths().tolist()
        else:
            numbers = self._cell.tolist()
        return 'Cell({})'.format(numbers)

    @property
    def fn_mic(self) -> Callable:
        cell = jnp.asarray(self._cell.array)
        dim = jnp.sum(self._cell.any(1) & self.pbc)
        if dim == 0:
            func = fn_mic_nopbc
        elif dim == 3 and self._cell.orthorhombic:
            func = fn_mic_naive
        else:
            pbc = self.pbc.tolist()
            cell = complete_cell(self._cell)
            r_op = minkowski_reduce(cell, pbc)[1]
            r_op = jnp.asarray(r_op, dtype=float)
            ranges = [np.arange(-1 * p, p + 1) for p in pbc]
            hkls = [(0, 0, 0)] + list(itertools.product(*ranges))
            hkls = jnp.asarray(hkls, dtype=float)
            func = partial(fn_mic_general, pbc=self.pbc,
                           hkls=hkls, r_op=r_op)
        if not self._cell_as_param:
            cell = jnp.asarray(cell, dtype=float)
            return partial(func, cell=cell)
        else:
            return func

In [175]:
from ase.geometry import find_mic


for pbc in (True, [True, False, True], False):
    cell = np.random.rand(3, 3)
    v = np.random.rand(5, 3)

    mic_1 = find_mic(v, cell, pbc)[0]
    mic_2 = CellWithPBC(cell, pbc, False).fn_mic(v)
    assert np.allclose(mic_1, mic_2)

    f = CellWithPBC(cell, pbc, True).fn_mic
    assert np.allclose(mic_1, f(v, cell))

    def ff(v, cell):
        return jnp.sum(f(v, cell))
    ff = jax.jit(ff)
    jax.jacfwd(ff, argnums=1)(v, cell)