In [1]:
import functools
import importlib
import time
from typing import Optional, Sequence, Tuple, Union

from absl import logging
import chex
from ferminet import checkpoint
from ferminet import constants
from ferminet import curvature_tags_and_blocks
from ferminet import envelopes
from ferminet import hamiltonian
from ferminet import loss as qmc_loss_functions
from ferminet import mcmc
from ferminet import networks
from ferminet import pretrain
from ferminet.utils import multi_host
from ferminet.utils import statistics
from ferminet.utils import system
from ferminet.utils import writers
import jax
import jax.numpy as jnp
import kfac_jax
import ml_collections
import numpy as np
import optax
from typing_extensions import Protocol
from ferminet.configs import atom

In [2]:
def init_electrons(
        key,
        molecule: Sequence[system.Atom],
        electrons: Sequence[int],
        batch_size: int,
        init_width: float,
) -> jnp.ndarray:
    """Initializes electron positions around each atom.

  Args:
    key: JAX RNG state.
    molecule: system.Atom objects making up the molecule.
    electrons: tuple of number of alpha and beta electrons.
    batch_size: total number of MCMC configurations to generate across all
      devices.
    init_width: width of (atom-centred) Gaussian used to generate initial
      electron configurations.

  Returns:
    array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron
    positions in the initial MCMC configurations and ndim is the dimensionality
    of the space (i.e. typically 3).
  """
    if sum(atom.charge for atom in molecule) != sum(electrons):
        if len(molecule) == 1:
            atomic_spin_configs = [electrons]
        else:
            raise NotImplementedError('No initialization policy yet '
                                      'exists for charged molecules.')
    else:
        atomic_spin_configs = [
            (atom.element.nalpha, atom.element.nbeta) for atom in molecule
        ]
        assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons)
        while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons:
            i = np.random.randint(len(atomic_spin_configs))
            nalpha, nbeta = atomic_spin_configs[i]
            atomic_spin_configs[i] = nbeta, nalpha

    # Assign each electron to an atom initially.
    electron_positions = []
    for i in range(2):
        for j in range(len(molecule)):
            atom_position = jnp.asarray(molecule[j].coords)
            electron_positions.append(
                jnp.tile(atom_position, atomic_spin_configs[j][i]))
    electron_positions = jnp.concatenate(electron_positions)
    # Create a batch of configurations with a Gaussian distribution about each
    # atom.
    key, subkey = jax.random.split(key)
    return (
            electron_positions +
            init_width *
            jax.random.normal(subkey, shape=(batch_size, electron_positions.size)))

In [3]:
OptimizerState = Union[optax.OptState]

In [4]:
OptUpdateResults = Tuple[networks.ParamTree, networks.ParamTree,  # param_psi, param_phi
                         Optional[OptimizerState], Optional[OptimizerState],  # OptState
                         jnp.ndarray,  # loss
                         Optional[qmc_loss_functions.AuxiliaryLossData]]

In [5]:
class OptUpdate(Protocol):

    def __call__(self,
                 params_psi: networks.ParamTree, data_psi: jnp.ndarray,
                 params_phi: networks.ParamTree, data_phi: jnp.ndarray,
                 params_previous: networks.ParamTree,
                 opt_state_psi: optax.OptState, opt_state_phi: optax.OptState,
                 key: chex.PRNGKey) -> OptUpdateResults:
        """Evaluates the loss and gradients and updates the parameters accordingly.

    Args:
      params_psi: network parameters.
      params_phi: network parameters.
      params_previous: network parameters.
      data_psi: electron positions.
      data_phi: electron positions.
      opt_state: optimizer internal state.
      key: RNG state.

    Returns:
      Tuple of (params_psi, params_phi, opt_state, loss, aux_data), where params and opt_state
      are the updated parameters and optimizer state, loss is the evaluated loss
      and aux_data auxiliary data (see AuxiliaryLossData docstring).
    """

In [6]:
StepResults = Tuple[jnp.ndarray, jnp.ndarray,  # data_psi 和 data_phi
                    networks.ParamTree, networks.ParamTree,  # params_psi 和 params_phi
                    Optional[optax.OptState],  Optional[optax.OptState],  # OptState
                    jnp.ndarray, qmc_loss_functions.AuxiliaryLossData,  # loss, aux_data
                    jnp.ndarray, jnp.ndarray]  # pmove_psi, pmove_phi

In [7]:
class Step(Protocol):

    def __call__(self,
                 data_psi: jnp.ndarray,
                 data_phi: jnp.ndarray,
                 params_psi: networks.ParamTree,
                 params_phi: networks.ParamTree,
                 params_previous: networks.ParamTree,
                 state_psi: OptimizerState,
                 state_phi: OptimizerState,
                 key: chex.PRNGKey,
                 mcmc_width: jnp.ndarray) -> StepResults:
        """Performs one set of MCMC moves and an optimization step.

    Args:
      data: batch of MCMC configurations.
      params: network parameters.
      state: optimizer internal state.
      key: JAX RNG state.
      mcmc_width: width of MCMC move proposal. See mcmc.make_mcmc_step.

    Returns:
      Tuple of (data, params, state, loss, aux_data, pmove).
        data: Updated MCMC configurations drawn from the network given the
          *input* network parameters.
        params: updated network parameters after the gradient update.
        state: updated optimization state.
        loss: energy of system based on input network parameters averaged over
          the entire set of MCMC configurations.
        aux_data: AuxiliaryLossData object also returned from evaluating the
          loss of the system.
        pmove: probability that a proposed MCMC move was accepted.
    """

In [8]:
def null_update(params_psi: networks.ParamTree, data_psi: jnp.ndarray,
                params_phi: networks.ParamTree, data_phi: jnp.ndarray,
                params_previous: networks.ParamTree,
                opt_state_psi: Optional[optax.OptState],
                opt_state_phi: Optional[optax.OptState],
                key: chex.PRNGKey) -> OptUpdateResults:
    """Performs an identity operation with an OptUpdate interface."""
    del data_psi, data_phi, key
    return params_psi, params_phi, opt_state_psi, opt_state_phi, jnp.zeros(1), None

In [9]:
def make_opt_update_step(evaluate_loss: qmc_loss_functions.LossFn,
                         optimizer_psi: optax.GradientTransformation,
                         optimizer_phi: optax.GradientTransformation,
                         iteration_psi: int, iteration_phi: int) -> OptUpdate:
    """Returns an OptUpdate function for performing a parameter update."""

    # Differentiate wrt parameters (argument 0)
    # loss_and_grad_psi = jax.value_and_grad(evaluate_loss, argnums=0, has_aux=True)
    # loss_and_grad_phi = jax.value_and_grad(evaluate_loss, argnums=1, has_aux=True)

    def opt_update(params_psi: networks.ParamTree, data_psi: jnp.ndarray,
                   params_phi: networks.ParamTree, data_phi: jnp.ndarray,
                   params_previous: networks.ParamTree,
                   opt_state_psi: Optional[optax.OptState],
                   opt_state_phi: Optional[optax.OptState],
                   key: chex.PRNGKey) -> OptUpdateResults:
        """Evaluates the loss and gradients and updates the parameters using optax."""
        # 对loss进行closure操作
        # 先psi下降一步
        evaluate_loss_psi = lambda params, keys, data: \
            evaluate_loss(params, params_phi, params_previous, keys, data, data_phi)
        loss_and_grad_psi = jax.value_and_grad(evaluate_loss_psi, argnums=0, has_aux=True)
        
        for k in range(iteration_psi):
            (loss, aux_data), grad_psi = loss_and_grad_psi(params_psi, key, data_psi)
            grad_psi = constants.pmean(grad_psi)
            updates_psi, opt_state_psi = optimizer_psi.update(grad_psi, opt_state_psi, params_psi)
            params_psi = optax.apply_updates(params_psi, updates_psi)

        # 再对phi上升一步 这里有个负号
        evaluate_loss_phi = lambda params, keys, data: \
            evaluate_loss(params_psi, params, params_previous, keys, data_psi, data)
        loss_and_grad_phi = jax.value_and_grad(evaluate_loss_phi, argnums=0, has_aux=True)
        
        for k in range(iteration_phi):
            (loss, aux_data), grad_phi = loss_and_grad_phi(params_phi, key, data_phi)
            grad_phi = constants.pmean(grad_phi)
            updates_phi, opt_state_phi = optimizer_phi.update(grad_phi, opt_state_phi, params_phi)
            params_phi = optax.apply_updates(params_phi, updates_phi)
            
        #原先使用同一个优化器，所以要对updates乘以-1.updates为dict{list,list,dict,list,list}.
        #现在改成两个优化器，在optax里面scale调整正负1就可以了
        #for key in updates_phi:
            #if type(key) == dict:
                #for subkey in updates_phi[key]:
                    #updates_phi[key][subkey] = -1 * updates_phi[key][subkey]
            #if type(key) == list:
                #updates_phi[key] = -1 * updates_phi[key]
                
        return params_psi, params_phi, opt_state_psi, opt_state_phi, loss, aux_data

    return opt_update

In [10]:
def make_loss_step(evaluate_loss: qmc_loss_functions.LossFn) -> OptUpdate:
    """Returns an OptUpdate function for evaluating the loss."""

    def loss_eval(params_psi: networks.ParamTree, data_psi: jnp.ndarray,
                  params_phi: networks.ParamTree, data_phi: jnp.ndarray,
                  params_previous: networks.ParamTree,
                  opt_state_psi: Optional[optax.OptState],
                  opt_state_phi: Optional[optax.OptState],
                  key: chex.PRNGKey) -> OptUpdateResults:
        """Evaluates just the loss and gradients with an OptUpdate interface."""
        loss, aux_data = evaluate_loss(params_psi, params_phi, params_previous, key, data_psi, data_phi)

        return params_psi, params_phi, opt_state_psi, opt_state_phi, loss, aux_data

    return loss_eval

In [11]:
def make_training_step(
        mcmc_step,
        optimizer_step: OptUpdate,
) -> Step:
    """Factory to create traning step for non-KFAC optimizers.

  Args:
    mcmc_step: Callable which performs the set of MCMC steps. See make_mcmc_step
      for creating the callable.
    optimizer_step: OptUpdate callable which evaluates the forward and backward
      passes and updates the parameters and optimizer state, as required.

  Returns:
    step, a callable which performs a set of MCMC steps and then an optimization
    update. See the Step protocol for details.
  """

    # 这个修饰是干啥的
    #@functools.partial(constants.pmap, donate_argnums=(0, 1, 2, 3, 4, 5, 6))
    @constants.pmap
    def step(data_psi: jnp.ndarray, data_phi: jnp.ndarray,
             params_psi: networks.ParamTree, params_phi: networks.ParamTree, params_previous: networks.ParamTree,
             state_psi: Optional[optax.OptState],
             state_phi: Optional[optax.OptState],
             key: chex.PRNGKey, mcmc_width: jnp.ndarray) -> StepResults:
        """A full update iteration (except for KFAC): MCMC steps + optimization."""
        # MCMC loop for psi
        mcmc_key, loss_key = jax.random.split(key, num=2)
        data_psi, pmove_psi = mcmc_step(params_psi, data_psi, mcmc_key, mcmc_width)
        # MCMC loop for phi
        mcmc_key, loss_key = jax.random.split(key, num=2)
        data_phi, pmove_phi = mcmc_step(params_phi, data_phi, mcmc_key, mcmc_width)

        # Optimization step for psi&phi 需要提前在optimizer里调整好内循环的次数
        new_params_psi, new_params_phi, state_psi, state_phi, loss, aux_data = \
            optimizer_step(params_psi, data_psi, params_phi, data_phi, params_previous, state_psi, state_phi, loss_key)

        result: Tuple
        result = [data_psi, data_phi,
                  new_params_psi, new_params_phi,
                  state_psi, state_phi,
                  loss, aux_data, pmove_psi, pmove_phi]
        return result

    return step

In [12]:
writer_manager=None

In [13]:
cfg = atom.get_config()
cfg.system.atom = 'H'
cfg.system.spin_polarisation = None
cfg = atom._adjust_nuclear_charge(cfg)
cfg.batch_size = 128
cfg.pretrain.iterations = 0

In [14]:
num_devices = jax.local_device_count()
num_hosts = jax.device_count() // num_devices
logging.info('Starting QMC with %i XLA devices per host '
                 'across %i hosts.', num_devices, num_hosts)
if cfg.batch_size % (num_devices * num_hosts) != 0:
    raise ValueError('Batch size must be divisible by number of devices, '
                         f'got batch size {cfg.batch_size} for '
                         f'{num_devices * num_hosts} devices.')
host_batch_size = cfg.batch_size // num_hosts  # batch size per host
device_batch_size = host_batch_size // num_devices  # batch size per device
data_shape = (num_devices, device_batch_size)

 # Check if mol is a pyscf molecule and convert to internal representation
if cfg.system.pyscf_mol:
    cfg.update(
        system.pyscf_mol_to_internal_representation(cfg.system.pyscf_mol))

In [15]:
atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule])
charges = jnp.array([atom.charge for atom in cfg.system.molecule])
nspins = cfg.system.electrons

In [16]:
if cfg.debug.deterministic:
    seed = 23
else:
    seed = 1e6 * time.time()
    seed = int(multi_host.broadcast_to_hosts(seed))
key = jax.random.PRNGKey(seed)

In [17]:
if cfg.pretrain.method == 'direct_init' or (
        cfg.pretrain.method == 'hf' and cfg.pretrain.iterations > 0):
    hartree_fock = pretrain.get_hf(
        pyscf_mol=cfg.system.get('pyscf_mol'),
        molecule=cfg.system.molecule,
        nspins=nspins,
        restricted=False,
        basis=cfg.pretrain.basis)
    # broadcast the result of PySCF from host 0 to all other hosts
    hartree_fock.mean_field.mo_coeff = tuple([
        multi_host.broadcast_to_hosts(x)
        for x in hartree_fock.mean_field.mo_coeff
    ])

In [18]:
hf_solution = hartree_fock if cfg.pretrain.method == 'direct_init' else None

In [19]:
if cfg.network.make_feature_layer_fn:
    feature_layer_module, feature_layer_fn = (
        cfg.network.make_feature_layer_fn.rsplit('.', maxsplit=1))
    feature_layer_module = importlib.import_module(feature_layer_module)
    make_feature_layer = getattr(feature_layer_module, feature_layer_fn)
    feature_layer = make_feature_layer(
        charges,
        cfg.system.electrons,
        cfg.system.ndim,
        **cfg.network.make_feature_layer_kwargs)  # type: networks.FeatureLayer
else:
    feature_layer = networks.make_ferminet_features(
        charges,
        cfg.system.electrons,
        cfg.system.ndim,
    )

In [20]:
if cfg.network.make_envelope_fn:
    envelope_module, envelope_fn = (
        cfg.network.make_envelope_fn.rsplit('.', maxsplit=1))
    envelope_module = importlib.import_module(envelope_module)
    make_envelope = getattr(envelope_module, envelope_fn)
    envelope = make_envelope(**cfg.network.make_envelope_kwargs)  # type: envelopes.Envelope
else:
    envelope = envelopes.make_isotropic_envelope()

In [21]:
network_init, signed_network, network_options = networks.make_fermi_net(
    atoms,
    nspins,
    charges,
    envelope=envelope,
    feature_layer=feature_layer,
    bias_orbitals=cfg.network.bias_orbitals,
    use_last_layer=cfg.network.use_last_layer,
    hf_solution=hf_solution,
    full_det=cfg.network.full_det,
    ndim=cfg.system.ndim,
    **cfg.network.detnet)
key, subkey = jax.random.split(key)
params_psi = network_init(subkey)
params_psi = kfac_jax.utils.replicate_all_local_devices(params_psi)
key, subkey = jax.random.split(key)
params_phi = network_init(subkey)
params_phi = kfac_jax.utils.replicate_all_local_devices(params_phi)
# Often just need log|psi(x)|.
network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1]  # type: networks.LogFermiNetLike
batch_network = jax.vmap(
    network, in_axes=(None, 0), out_axes=0)

In [22]:
# 这里改成对psi和phi都设置存储和读取路径
ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path)
ckpt_restore_path = checkpoint.get_restore_path(cfg.log.restore_path)
ckpt_restore_filename = (
        checkpoint.find_last_checkpoint(ckpt_save_path) or
        checkpoint.find_last_checkpoint(ckpt_restore_path))

In [23]:
logging.info('No checkpoint found. Training new model.')
key, subkey = jax.random.split(key)
# make sure data on each host is initialized differently
subkey = jax.random.fold_in(subkey, jax.process_index())
data_psi = init_electrons(
    subkey,
    cfg.system.molecule,
    cfg.system.electrons,
    batch_size=host_batch_size,
    init_width=cfg.mcmc.init_width)
data_psi = jnp.reshape(data_psi, data_shape + data_psi.shape[1:])
data_psi = kfac_jax.utils.broadcast_all_local_devices(data_psi)
key, subkey = jax.random.split(key)
# make sure data on each host is initialized differently
subkey = jax.random.fold_in(subkey, jax.process_index())
data_phi = init_electrons(
    subkey,
    cfg.system.molecule,
    cfg.system.electrons,
    batch_size=host_batch_size,
    init_width=cfg.mcmc.init_width)
data_phi = jnp.reshape(data_phi, data_shape + data_phi.shape[1:])
data_phi = kfac_jax.utils.broadcast_all_local_devices(data_phi)
t_init = 0
opt_state_ckpt = None
mcmc_width_ckpt = None

In [24]:
train_schema = ['step', 'energy', 'ewmean', 'ewvar', 'pmove_psi', 'pmove_phi']

# Initialisation done. We now want to have different PRNG streams on each
# device. Shard the key over devices
sharded_key = kfac_jax.utils.make_different_rng_key_on_all_devices(key)

# Pretraining to match Hartree-Fock

if (t_init == 0 and cfg.pretrain.method == 'hf' and
        cfg.pretrain.iterations > 0):
    orbitals = functools.partial(
        networks.fermi_net_orbitals,
        atoms=atoms,
        nspins=cfg.system.electrons,
        options=network_options,
    )
    batch_orbitals = jax.vmap(
        lambda params, data: orbitals(params, data)[0],
        in_axes=(None, 0),
        out_axes=0)
    sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
    params_psi, data_psi = pretrain.pretrain_hartree_fock(
        params=params_psi,
        data=data_psi,
        batch_network=batch_network,
        batch_orbitals=batch_orbitals,
        network_options=network_options,
        sharded_key=subkeys,
        atoms=atoms,
        electrons=cfg.system.electrons,
        scf_approx=hartree_fock,
        iterations=cfg.pretrain.iterations)
    sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
    params_phi, data_phi = pretrain.pretrain_hartree_fock(
        params=params_phi,
        data=data_phi,
        batch_network=batch_network,
        batch_orbitals=batch_orbitals,
        network_options=network_options,
        sharded_key=subkeys,
        atoms=atoms,
        electrons=cfg.system.electrons,
        scf_approx=hartree_fock,
        iterations=cfg.pretrain.iterations)

In [25]:
atoms_to_mcmc = atoms if cfg.mcmc.scale_by_nuclear_distance else None
mcmc_step = mcmc.make_mcmc_step(
    batch_network,
    device_batch_size,
    steps=cfg.mcmc.steps,
    atoms=atoms_to_mcmc,
    one_electron_moves=cfg.mcmc.one_electron,
)

In [26]:
local_energy = hamiltonian.local_energy(
    f=signed_network,
    atoms=atoms,
    charges=charges,
    nspins=nspins,
    use_scan=False)

In [27]:
evaluate_loss = qmc_loss_functions.make_loss(
    network,
    local_energy,
    clip_local_energy=cfg.optim.clip_el)

In [28]:
def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
    return cfg.optim.lr.rate * jnp.power(
        (1.0 / (1.0 + (t_ / cfg.optim.lr.delay))), cfg.optim.lr.decay)

In [29]:
#对两个优化器可以定义不同的schedule
optimizer_psi = optax.chain(
    optax.scale_by_adam(**cfg.optim.adam),
    optax.scale_by_schedule(learning_rate_schedule),
    optax.scale(-1.))
optimizer_phi = optax.chain(
    optax.scale_by_adam(**cfg.optim.adam),
    optax.scale_by_schedule(learning_rate_schedule),
    optax.scale(1.))

In [30]:
opt_state_psi = jax.pmap(optimizer_psi.init)(params_psi)
opt_state_psi = opt_state_ckpt or opt_state_psi  # avoid overwriting ckpted state
opt_state_phi = jax.pmap(optimizer_phi.init)(params_phi)
opt_state_phi = opt_state_ckpt or opt_state_phi  # avoid overwriting ckpted state

# 定义内循环步数
k_psi = cfg.optim.iterations_psi
k_phi = cfg.optim.iterations_phi

# 这段注意一下
step = make_training_step(
    mcmc_step=mcmc_step,
    optimizer_step=make_opt_update_step(evaluate_loss, optimizer_psi, optimizer_phi, k_psi, k_phi))

In [31]:
if mcmc_width_ckpt is not None:
    mcmc_width = kfac_jax.utils.replicate_all_local_devices(mcmc_width_ckpt[0])
else:
    mcmc_width = kfac_jax.utils.replicate_all_local_devices(
        jnp.asarray(cfg.mcmc.move_width))
pmoves = np.zeros(cfg.mcmc.adapt_frequency)

In [32]:
if t_init == 0:
    logging.info('Burning in MCMC chain for %d steps', cfg.mcmc.burn_in)

    burn_in_step = make_training_step(
        mcmc_step=mcmc_step, optimizer_step=null_update)

In [33]:
cfg.optim.iterations = 10

In [34]:
params_previous = params_psi

In [35]:
cfg.mcmc.burn_in = 1

In [36]:
for t in range(cfg.mcmc.burn_in):
    sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
    data_psi, data_phi, params_psi, params_phi, *_ = burn_in_step(
        data_psi,
        data_phi,
        params_psi,
        params_phi,
        params_previous,
        state_psi=None,
        state_phi=None,
        key=subkeys,
        mcmc_width=mcmc_width)
logging.info('Completed burn-in MCMC steps')

sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
ptotal_energy = constants.pmap(evaluate_loss)
initial_energy, _ = ptotal_energy(params_psi, params_psi, params_previous, \
                                  subkeys, data_psi, data_phi)
logging.info('Initial energy: %03.4f E_h', initial_energy[0])

time_of_last_ckpt = time.time()
weighted_stats = None

In [37]:
initial_energy

Array([0.04710469], dtype=float32)

In [38]:
if cfg.optim.optimizer == 'none' and opt_state_ckpt is not None:
 # If opt_state_ckpt is None, then we're restarting from a previous inference
 # run (most likely due to preemption) and so should continue from the last
 # iteration in the checkpoint. Otherwise, starting an inference run from a
 # training run.
    logging.info('No optimizer provided. Assuming inference run.')
    logging.info('Setting initial iteration to 0.')
    t_init = 0

In [41]:
import time

In [42]:
# Main training loop
for t in range(t_init, cfg.optim.iterations):
    start = time.time()
    sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
    data_psi, data_phi, params_psi, params_phi, \
    opt_state_psi, opt_state_phi, loss, unused_aux_data, \
    pmove_psi, pmove_phi = step(
        data_psi, data_phi,
        params_psi, params_phi, params_previous,
        opt_state_psi, opt_state_phi,
        subkeys,
        mcmc_width)
    print(loss, pmove_psi, pmove_phi)
    end = time.time()
    print('run time is ' + str(end))

[6.72421] [0.98046875] [1.]
run time is 1686732724.3779762
[-0.11365134] [0.9828125] [0.99921876]
run time is 1686732724.56288
[74063.055] [0.97734374] [0.9976563]
run time is 1686732724.790601
[-44.3219] [0.94687504] [0.94453126]
run time is 1686732724.96032
[-18.900875] [0.93828124] [0.9632813]
run time is 1686732725.146911
[358.43237] [0.9640625] [0.9765625]
run time is 1686732725.3169038
[1871.159] [0.9484375] [0.975]
run time is 1686732725.513228
[256.26468] [0.9484375] [0.97265625]
run time is 1686732725.693111
[-93.36137] [0.94687504] [0.97578126]
run time is 1686732725.932626
[54.462776] [0.97812504] [0.9742188]
run time is 1686732726.190337


In [40]:
pmove_phi

Array([0.996875], dtype=float32)