In [1]:
from typing import Callable, Sequence, Optional, NamedTuple
from glob import glob

import numpy as np
import h5py
import flax
from flax import linen as nn

import jax
from jax import numpy as jnp

import optax

from pyscf import gto, scf
from pyscf.dft import numint

import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'h5py'

# Define input objects

In [None]:
Array = jax.Array

In [None]:
class Grid(NamedTuple):
    coords : Array
    weights : Array

def integrate(grid: Grid, vals = Array, axis: int = 0)-> Array:
    return jnp.tensordot(grid.weights, vals, axes=(0, axis))

In [None]:
class FunctionalInputs(NamedTuple):
    grid: Grid
    rho: Array
    grad_rho: Optional[Array] = None

In [None]:
def build_input_array(inputs: FunctionalInputs):
    
    _, rho, grad_rho = inputs

    if grad_rho is None:
        feature_list = [rho]

    else:
        grad_rho_norm = jnp.sum(grad_rho**2, axis=-1)
        feature_list = [rho,grad_rho_norm]
    return jnp.stack(feature_list, axis = -1)

# Define the neural network

In [None]:
def lda(rho):
    kF = (3 * jnp.pi**2)**(1/3) * rho ** (1/3)
    return (-3 * kF)/(4 * jnp.pi)

class ForwardFeedNN(nn.Module):
    
    layer_widths: Sequence[int]
    out_features = 1
    activate : Callable[[Array],Array] = jax.nn.gelu
    squash_offset: float = 1e-4

    @nn.compact
    def __call__(self, inputs: FunctionalInputs):
        
        x = build_input_array(inputs)

        h = jnp.log(jnp.abs(x)+self.squash_offset)
        h = nn.Dense(features = self.layer_widths[0])(h)
        h = jnp.tanh(h)

        for features in self.layer_widths:
            h = nn.Dense(features)(h)
            h = self.activate(h)
            h = nn.LayerNorm()(h)
            
        h = nn.Dense(features = self.out_features)(h)
        return 2 * lda(inputs.rho) * jax.nn.sigmoid(h).squeeze()
        # We have redefined the output to follow the DeepMind paper - multuply with LDA value

### Data loading

In [None]:
def load_file(path):

    with h5py.File(path, 'r') as file:

        coords = np.array(file['coords'])
        weights = np.array(file['weights'])

        rho = np.array(file['rho'])

        grad_rho = np.array(file['grad_rho'])

        exc = np.array(file['exc_pbe'])

    grid = Grid(coords, weights)
    inputs = FunctionalInputs(grid, rho, grad_rho)

    return inputs, exc

def load_data(folder_path):
    paths = glob(folder_path + '/*.h5')
    return [load_file(path) for path in paths]

def separate_data(data):
    return zip(*data)

def divide_into_batches(data, batch_size):
    return [data[i:i+batch_size] for i in range(0, len(data), batch_size)]

In [None]:
data = load_data('/Users/corneliussalonis/SNSP_23/MLtut/cornelius_data')

In [None]:
len(data) # loaded 19 files

In [None]:
# Super important - train-test split
train_data = data[:16]
test_data = data[16:]

In [None]:
batch_size = 4
batches = divide_into_batches(train_data, batch_size)

In [None]:
len(batches)

In [None]:
test_batch, test_targets = separate_data(batches[0])

In [None]:
test_inputs, test_exc = test_batch[0], test_targets[0]

In [None]:
test_inputs

In [None]:
x = build_input_array(test_inputs)
x.shape

In [None]:
fxc = ForwardFeedNN(layer_widths=(128,128))
key = jax.random.PRNGKey(42)
params = fxc.init(key, test_inputs)

In [None]:
exc = fxc.apply(params, test_inputs)
exc.shape

In [None]:
test_exc.shape

## The cost function:

In [None]:
@jax.jit
def cost_one_input(params, inputs, target_exc):
    exc = fxc.apply(params, inputs)
    return integrate(inputs.grid, (exc - target_exc)**2)

@jax.value_and_grad
def cost_batch(params, inputs: Sequence[FunctionalInputs], target_excs: Sequence[Array]):

    batch_size = len(inputs)
    mean_cost = 0.0

    for input, target in zip(inputs, target_excs):
        mean_cost += cost_one_input(params, input, target) / batch_size

    return mean_cost


The cost function is integrated mean squared error between the predicted and true values of the output

$$ \texttt{cost} = \int d^3 x \; (\epsilon ^{NN} _{xc} - \epsilon ^{\text{target}} _{xc}) ^ 2$$

averaged over the batch

In [None]:
cost_one_input(params, test_inputs, test_exc)

In [None]:
value, grad = cost_batch(params, test_batch, test_targets)

In [None]:
value

# The optimization loop:

In [None]:
optim = optax.adam(learning_rate=1e-3)
opt_state = optim.init(params)

In [None]:
# This function performs a single update step for the parameters of the model
# following optax documentation

def update_params(params, batch, opt_state):

    inputs, targets = separate_data(batch)
    value, grads = cost_batch(params, inputs, targets)

    updates, opt_state = optim.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, value, opt_state

We now run the optimizer and save loss values for plotting later, into the `losses` array. After this loop is run, the variable `params` will hold optimal parameters and `fxc.apply(params, inputs)` should be identical to evaluating the functional.

In [None]:
losses = []
epoch_losses = np.zeros(len(batches))

for n in range(100):

    for i, batch in enumerate(batches):
        params, value, opt_state = update_params(params, batch, opt_state)
        epoch_losses[i] = value

    loss = np.mean(epoch_losses)
    losses.append(loss)

    print(f"Epoch: {n+1:3} | Loss: {loss:.4e}")

In [None]:
plt.semilogy(losses)

In [None]:
names  = [names.split('-')[0].split('/')()(load_data)]