In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax

import logging

# Configure logging to show debug messages
# logging.basicConfig(level=logging.DEBUG, force=True)
# logging.disable(logging.CRITICAL)

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32")


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1

  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


def model_fn(x):
  """A Haiku MLP model function - three hidden layer network with tanh."""
  return hk.nets.MLP(
    output_sizes=(50, 50, 50, NUM_CLASSES),
    with_bias=True,
    activation=jax.nn.tanh,
  )(x)


# The Haiku transformed model
hk_model = hk.without_apply_rng(hk.transform(model_fn))


def loss_fn(model_params, model_batch):
  """The loss function to optimize."""
  x, y = model_batch
  logits = hk_model.apply(model_params, x)
  loss = jnp.mean(softmax_cross_entropy(logits, y))

  # The optimizer assumes that the function you provide has already added
  # the L2 regularizer to its gradients.
  return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0


# Create the optimizer
optimizer = kfac_jax.Optimizer(
  value_and_grad_func=jax.value_and_grad(loss_fn),
  l2_reg=L2_REG,
  value_func_has_aux=False,
  value_func_has_state=False,
  value_func_has_rng=False,
  use_adaptive_learning_rate=True,
  use_adaptive_momentum=True,
  use_adaptive_damping=True,
  initial_damping=1.0,
  multi_device=False,
)

input_dataset = make_dataset_iterator(128)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
rng, key = jax.random.split(rng)
params = hk_model.init(key, dummy_images)
rng, key = jax.random.split(rng)
opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

0 {'batch_size': Array(128, dtype=int32), 'damping': Array(1., dtype=float32), 'data_seen': Array(128, dtype=int32), 'learning_rate': Array(1.7462709, dtype=float32), 'loss': Array(2.3646538, dtype=float32), 'momentum': Array(-0., dtype=float32), 'new_loss': Array(nan, dtype=float32, weak_type=True), 'quad_model_change': Array(-1.5149225, dtype=float32), 'rho': Array(nan, dtype=float32, weak_type=True), 'scaled_grad_norm_sq': None, 'step': Array(1, dtype=int32)}
1 {'batch_size': Array(128, dtype=int32), 'damping': Array(1., dtype=float32), 'data_seen': Array(256, dtype=int32), 'learning_rate': Array(0.45567128, dtype=float32), 'loss': Array(0.9914196, dtype=float32), 'momentum': Array(0.16053492, dtype=float32), 'new_loss': Array(nan, dtype=float32, weak_type=True), 'quad_model_change': Array(-0.6257111, dtype=float32), 'rho': Array(nan, dtype=float32, weak_type=True), 'scaled_grad_norm_sq': None, 'step': Array(2, dtype=int32)}
2 {'batch_size': Array(128, dtype=int32), 'damping': Array

In [11]:
rng = jax.random.PRNGKey(42)

In [13]:
rng

Array([ 0, 42], dtype=uint32)

In [14]:
jax.random.split(rng)

Array([[2465931498, 3679230171],
       [ 255383827,  267815257]], dtype=uint32)

# Going towards ferminet

In [14]:
2+2

4

In [15]:
from pyscf import gto
mol = gto.Mole()
mol.build(
    atom = 'H  0 0 1; H 0 0 -1',
    basis = 'sto-3g', unit='bohr')

<pyscf.gto.mole.Mole at 0x123fefd90>

In [50]:
? mol.build

[0;31mSignature:[0m
 [0mmol[0m[0;34m.[0m[0mbuild[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mdump_input[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparse_arg[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mverbose[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0moutput[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_memory[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0matom[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbasis[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0munit[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnucmod[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mecp[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpseudo[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcharge[0m[0

In [57]:
from pyscf import gto
mol = gto.Mole()
mol.build(
    atom = 'H  0 0 0',
    spin = 1,
    basis = 'sto-3g', unit='bohr')


from ferminet import base_2DEG_config
from ferminet import train

# Add H2 molecule
cfg = base_config.default()
cfg.system.pyscf_mol = mol

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

In [None]:
train.train(cfg)

In [54]:
cfg

batch_size: 256
config_module: ferminet.base_config
debug:
  check_nan: false
  deterministic: false
log:
  features: false
  local_energies: false
  restore_path: ''
  save_frequency: 10.0
  save_path: ''
  stats_frequency: 1
  walkers: false
mcmc:
  adapt_frequency: 100
  blocks: 1
  burn_in: 100
  init_means: !!python/tuple []
  init_width: 1.0
  move_width: 0.02
  num_leapfrog_steps: 10
  scale_by_nuclear_distance: false
  steps: 10
  use_hmc: false
network:
  bias_orbitals: false
  complex: false
  determinants: 16
  ferminet:
    electron_nuclear_aux_dims: !!python/tuple []
    hidden_dims: !!python/tuple
    - &id001 !!python/tuple
      - 256
      - 32
    - *id001
    - *id001
    - *id001
    nuclear_embedding_dim: 0
    schnet_electron_electron_convolutions: !!python/tuple []
    schnet_electron_nuclear_convolutions: !!python/tuple []
    separate_spin_channels: false
    use_last_layer: false
  full_det: true
  jastrow: default
  make_envelope_fn: ''
  make_envelope_kwargs

In [11]:
cfg.optim.iterations = 500

In [13]:
cfg.log.save_path = './ferminet-log'

''

In [19]:
type(cfg.system.molecule)

NoneType

In [55]:
cfg.system.pyscf_mol.atom_charges()

array([1], dtype=int32)

In [56]:
cfg.system.pyscf_mol.atom_coords()

array([[0., 0., 1.]])

In [46]:
cfg.system.pyscf_mol.

1

In [62]:
import jax.numpy as jnp

jnp.tanh(2)

Array(0.9640276, dtype=float32, weak_type=True)

In [68]:
arr = jnp.array([0, 0.5, 1., 5, 9.5, 10, 10.5, 11, 15])
(1. + jnp.tanh((arr - 10) / 0.1)) / 2

Array([0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       4.541874e-05, 5.000000e-01, 9.999546e-01, 1.000000e+00,
       1.000000e+00], dtype=float32)

In [82]:
import jax.random as random

key = random.PRNGKey(42)
key1, key2 = random.split(key)
arr1 = random.uniform(key1, (2,3))
arr2 = jnp.array([0, 1, -1])

In [81]:
arr1

Array([[0.87425196, 0.12079132, 0.5372118 ],
       [0.41176045, 0.6269895 , 0.5899111 ]], dtype=float32)

In [83]:
arr2 * arr1

Array([[ 0.        ,  0.12079132, -0.5372118 ],
       [ 0.        ,  0.6269895 , -0.5899111 ]], dtype=float32)

In [84]:
arr1 * arr2

Array([[ 0.        ,  0.12079132, -0.5372118 ],
       [ 0.        ,  0.6269895 , -0.5899111 ]], dtype=float32)