In [1]:
import functools
import importlib
import time
from typing import Optional, Sequence, Tuple, Union

from absl import logging
import chex
from ferminet import checkpoint
from ferminet import constants
from ferminet import curvature_tags_and_blocks
from ferminet import envelopes
from ferminet import hamiltonian
from ferminet import loss as qmc_loss_functions
from ferminet import mcmc
from ferminet import networks
from ferminet import pretrain
from ferminet.utils import multi_host
from ferminet.utils import statistics
from ferminet.utils import system
from ferminet.utils import writers
import jax
import jax.numpy as jnp
import kfac_jax
import ml_collections
import numpy as np
import optax
from typing_extensions import Protocol
from ferminet.configs import atom

In [7]:
def evaluate(p1, p2, d1, d2):
    result = p1@d1 + p2@d2
    return result

In [8]:
ba1 = 10 
ba2 = 7

In [9]:
param1 = np.random.randn(1,ba1)
param2 = np.random.randn(1,ba2)
data1 = np.random.randn(ba1,1)
data2 = np.random.randn(ba2,1)

In [11]:
evaluate(param1, param2, data1, data2).shape

(1, 1)

In [13]:
cfg = atom.get_config()
cfg.system.atom = 'H'
cfg.system.spin_polarisation = None
cfg = atom._adjust_nuclear_charge(cfg)
cfg.batch_size = 128
cfg.pretrain.iterations = 0

In [15]:
evaluate_psi = lambda param, data: evaluate(param, param2, data, data2)
val_and_grad_psi = jax.value_and_grad(evaluate, argnums=0, has_aux=None)
optimizer_psi = kfac_jax.Optimizer(
    val_and_grad_psi,
    l2_reg=cfg.optim.kfac.l2_reg,
    norm_constraint=cfg.optim.kfac.norm_constraint,
    value_func_has_aux=True,
    value_func_has_rng=True,
    #learning_rate_schedule=learning_rate_schedule,
    curvature_ema=cfg.optim.kfac.cov_ema_decay,
    inverse_update_period=cfg.optim.kfac.invert_every,
    min_damping=cfg.optim.kfac.min_damping,
    num_burnin_steps=0,
    register_only_generic=cfg.optim.kfac.register_only_generic,
    estimation_mode='fisher_exact',
    multi_device=True,
    pmap_axis_name=constants.PMAP_AXIS_NAME,
    auto_register_kwargs=dict(
        graph_patterns=curvature_tags_and_blocks.GRAPH_PATTERNS,
    ),
)

In [16]:
evaluate_phi = lambda param, data: evaluate(param1, param, data1, data)
val_and_grad_phi = jax.value_and_grad(evaluate, argnums=0, has_aux=None)
optimizer_phi = kfac_jax.Optimizer(
    val_and_grad_phi,
    l2_reg=cfg.optim.kfac.l2_reg,
    norm_constraint=cfg.optim.kfac.norm_constraint,
    value_func_has_aux=True,
    value_func_has_rng=True,
    #learning_rate_schedule=learning_rate_schedule,
    curvature_ema=cfg.optim.kfac.cov_ema_decay,
    inverse_update_period=cfg.optim.kfac.invert_every,
    min_damping=cfg.optim.kfac.min_damping,
    num_burnin_steps=0,
    register_only_generic=cfg.optim.kfac.register_only_generic,
    estimation_mode='fisher_exact',
    multi_device=True,
    pmap_axis_name=constants.PMAP_AXIS_NAME,
    auto_register_kwargs=dict(
        graph_patterns=curvature_tags_and_blocks.GRAPH_PATTERNS,
    ),
)

In [17]:
param1

array([[-1.23114835,  0.7966739 ,  0.64556191, -0.68944566, -2.23406556,
         0.68504187,  1.01346511, -0.92029596, -0.53956469, -2.01436797]])

In [19]:
evaluate_phi(param2, data2)

array([[-3.68795359]])

In [20]:
param1 = 10*param1

In [21]:
evaluate_phi(param2, data2)

array([[-11.81567361]])

In [22]:
param1 = 0.1*param1

In [23]:
evaluate_phi(param2, data2)

array([[-3.68795359]])