In [1]:
import functools
import importlib
import time
from typing import Optional, Sequence, Tuple, Union

from absl import logging
import chex
from ferminet import checkpoint
from ferminet import constants
from ferminet import curvature_tags_and_blocks
from ferminet import envelopes
from ferminet import hamiltonian
from ferminet import loss as qmc_loss_functions
from ferminet import mcmc
from ferminet import networks
from ferminet import pretrain
from ferminet.utils import multi_host
from ferminet.utils import statistics
from ferminet.utils import system
from ferminet.utils import writers
import jax
import jax.numpy as jnp
import kfac_jax
import ml_collections
import numpy as np
import optax
from typing_extensions import Protocol

In [2]:
def init_electrons(
        key,
        molecule: Sequence[system.Atom],
        electrons: Sequence[int],
        batch_size: int,
        init_width: float,
) -> jnp.ndarray:
    """Initializes electron positions around each atom.

  Args:
    key: JAX RNG state.
    molecule: system.Atom objects making up the molecule.
    electrons: tuple of number of alpha and beta electrons.
    batch_size: total number of MCMC configurations to generate across all
      devices.
    init_width: width of (atom-centred) Gaussian used to generate initial
      electron configurations.

  Returns:
    array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron
    positions in the initial MCMC configurations and ndim is the dimensionality
    of the space (i.e. typically 3).
  """
    if sum(atom.charge for atom in molecule) != sum(electrons):
        if len(molecule) == 1:
            atomic_spin_configs = [electrons]
        else:
            raise NotImplementedError('No initialization policy yet '
                                      'exists for charged molecules.')
    else:
        atomic_spin_configs = [
            (atom.element.nalpha, atom.element.nbeta) for atom in molecule
        ]
        assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons)
        while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons:
            i = np.random.randint(len(atomic_spin_configs))
            nalpha, nbeta = atomic_spin_configs[i]
            atomic_spin_configs[i] = nbeta, nalpha

    # Assign each electron to an atom initially.
    electron_positions = []
    for i in range(2):
        for j in range(len(molecule)):
            atom_position = jnp.asarray(molecule[j].coords)
            electron_positions.append(
                jnp.tile(atom_position, atomic_spin_configs[j][i]))
    electron_positions = jnp.concatenate(electron_positions)
    # Create a batch of configurations with a Gaussian distribution about each
    # atom.
    key, subkey = jax.random.split(key)
    return (
            electron_positions +
            init_width *
            jax.random.normal(subkey, shape=(batch_size, electron_positions.size)))

In [3]:
writer_manager=None

In [4]:
from ferminet.configs import atom

In [5]:
cfg = atom.get_config()
cfg.system.atom = 'H'
cfg.system.spin_polarisation = None
cfg = atom._adjust_nuclear_charge(cfg)
cfg.batch_size = 128
cfg.pretrain.iterations = 0

In [6]:
num_devices = jax.local_device_count()
num_hosts = jax.device_count() // num_devices
logging.info('Starting QMC with %i XLA devices per host '
                 'across %i hosts.', num_devices, num_hosts)
if cfg.batch_size % (num_devices * num_hosts) != 0:
    raise ValueError('Batch size must be divisible by number of devices, '
                         f'got batch size {cfg.batch_size} for '
                         f'{num_devices * num_hosts} devices.')
host_batch_size = cfg.batch_size // num_hosts  # batch size per host
device_batch_size = host_batch_size // num_devices  # batch size per device
data_shape = (num_devices, device_batch_size)

 # Check if mol is a pyscf molecule and convert to internal representation
if cfg.system.pyscf_mol:
    cfg.update(
        system.pyscf_mol_to_internal_representation(cfg.system.pyscf_mol))

In [7]:
atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule])
charges = jnp.array([atom.charge for atom in cfg.system.molecule])
nspins = cfg.system.electrons

In [8]:
if cfg.debug.deterministic:
    seed = 23
else:
    seed = 1e6 * time.time()
    seed = int(multi_host.broadcast_to_hosts(seed))
key = jax.random.PRNGKey(seed)

In [9]:
if cfg.pretrain.method == 'direct_init' or (
        cfg.pretrain.method == 'hf' and cfg.pretrain.iterations > 0):
    hartree_fock = pretrain.get_hf(
        pyscf_mol=cfg.system.get('pyscf_mol'),
        molecule=cfg.system.molecule,
        nspins=nspins,
        restricted=False,
        basis=cfg.pretrain.basis)
    # broadcast the result of PySCF from host 0 to all other hosts
    hartree_fock.mean_field.mo_coeff = tuple([
        multi_host.broadcast_to_hosts(x)
        for x in hartree_fock.mean_field.mo_coeff
    ])

In [10]:
hf_solution = hartree_fock if cfg.pretrain.method == 'direct_init' else None

In [11]:
if cfg.network.make_feature_layer_fn:
    feature_layer_module, feature_layer_fn = (
        cfg.network.make_feature_layer_fn.rsplit('.', maxsplit=1))
    feature_layer_module = importlib.import_module(feature_layer_module)
    make_feature_layer = getattr(feature_layer_module, feature_layer_fn)
    feature_layer = make_feature_layer(
        charges,
        cfg.system.electrons,
        cfg.system.ndim,
        **cfg.network.make_feature_layer_kwargs)  # type: networks.FeatureLayer
else:
    feature_layer = networks.make_ferminet_features(
        charges,
        cfg.system.electrons,
        cfg.system.ndim,
    )

In [12]:
if cfg.network.make_envelope_fn:
    envelope_module, envelope_fn = (
        cfg.network.make_envelope_fn.rsplit('.', maxsplit=1))
    envelope_module = importlib.import_module(envelope_module)
    make_envelope = getattr(envelope_module, envelope_fn)
    envelope = make_envelope(**cfg.network.make_envelope_kwargs)  # type: envelopes.Envelope
else:
    envelope = envelopes.make_isotropic_envelope()

In [13]:
network_init, signed_network, network_options = networks.make_fermi_net(
    atoms,
    nspins,
    charges,
    envelope=envelope,
    feature_layer=feature_layer,
    bias_orbitals=cfg.network.bias_orbitals,
    use_last_layer=cfg.network.use_last_layer,
    hf_solution=hf_solution,
    full_det=cfg.network.full_det,
    ndim=cfg.system.ndim,
    **cfg.network.detnet)
key, subkey = jax.random.split(key)
params_psi = network_init(subkey)
params_psi = kfac_jax.utils.replicate_all_local_devices(params_psi)
key, subkey = jax.random.split(key)
params_phi = network_init(subkey)
params_phi = kfac_jax.utils.replicate_all_local_devices(params_phi)
# Often just need log|psi(x)|.
network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1]  # type: networks.LogFermiNetLike
batch_network = jax.vmap(
    network, in_axes=(None, 0), out_axes=0)

In [14]:
params_psi.keys()

dict_keys(['double', 'envelope', 'input', 'orbital', 'single'])

In [15]:
lenofp = []
for h in params_psi.keys():
    lenofp.append(h)
    lenofp.append(len(params_psi[h]))
lenofp

['double', 3, 'envelope', 1, 'input', 0, 'orbital', 1, 'single', 4]

In [16]:
logging.info('No checkpoint found. Training new model.')
key, subkey = jax.random.split(key)
# make sure data on each host is initialized differently
subkey = jax.random.fold_in(subkey, jax.process_index())
data_psi = init_electrons(
    subkey,
    cfg.system.molecule,
    cfg.system.electrons,
    batch_size=host_batch_size,
    init_width=cfg.mcmc.init_width)
data_psi = jnp.reshape(data_psi, data_shape + data_psi.shape[1:])
data_psi = kfac_jax.utils.broadcast_all_local_devices(data_psi)
key, subkey = jax.random.split(key)
# make sure data on each host is initialized differently
subkey = jax.random.fold_in(subkey, jax.process_index())
data_phi = init_electrons(
    subkey,
    cfg.system.molecule,
    cfg.system.electrons,
    batch_size=host_batch_size,
    init_width=cfg.mcmc.init_width)
data_phi = jnp.reshape(data_phi, data_shape + data_phi.shape[1:])
data_phi = kfac_jax.utils.broadcast_all_local_devices(data_phi)
t_init = 0
opt_state_ckpt = None
mcmc_width_ckpt = None

In [17]:
data_psi.shape #=(1, batch, ndim*ne)

(1, 128, 3)

In [18]:
params_previous = params_psi

In [19]:
local_energy = hamiltonian.local_energy(
    f=signed_network,
    atoms=atoms,
    charges=charges,
    nspins=nspins,
    use_scan=False)

In [20]:
evaluate_loss = qmc_loss_functions.make_loss(
    network,
    local_energy,
    clip_local_energy=cfg.optim.clip_el)

In [21]:
print(len(params_psi))
print(len(params_phi))
print(data_psi.shape)
print(data_phi.shape)

5
5
(1, 128, 3)
(1, 128, 3)


In [27]:
sharded_key = kfac_jax.utils.make_different_rng_key_on_all_devices(key)
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)

ptotal_energy = constants.pmap(evaluate_loss)
initial_energy, _ = ptotal_energy(params_psi, params_phi, params_previous, subkeys, data_psi, data_phi)

In [28]:
print(initial_energy)

[0.10496587]


In [23]:
def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
    return cfg.optim.lr.rate * jnp.power(
        (1.0 / (1.0 + (t_ / cfg.optim.lr.delay))), cfg.optim.lr.decay)

In [24]:
optimizer = optax.chain(
    optax.scale_by_adam(**cfg.optim.adam),
    optax.scale_by_schedule(learning_rate_schedule),
    optax.scale(-1.))

In [25]:
opt_state_psi = jax.pmap(optimizer.init)(params_psi)
opt_state_psi = opt_state_ckpt or opt_state_psi  # avoid overwriting ckpted state
opt_state_phi = jax.pmap(optimizer.init)(params_phi)
opt_state_phi = opt_state_ckpt or opt_state_phi  # avoid overwriting ckpted state

In [29]:
        evaluate_loss_psi = lambda params, keys, data: \
            ptotal_energy(params, params_phi, params_previous, keys, data, data_phi)
        loss_and_grad_psi = jax.value_and_grad(evaluate_loss_psi, argnums=0, has_aux=True)
        (loss, aux_data), grad_psi = loss_and_grad_psi(params_psi, key, data_psi)
        grad_psi = constants.pmean(grad_psi)
        updates_psi, opt_state_psi = optimizer.update(grad_psi, opt_state_psi, params_psi)
        params_psi = optax.apply_updates(params_psi, updates_psi)

ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (53 of them) had size 1, e.g. axis 0 of argument params_psi['double'][0]['b'] of type float32[1,32];
  * one axis had size 2: axis 0 of argument key of type uint32[2]