# Finding Kohn Sham Gradients 

This notebook contains some **rough work** to get the values for the gradients in the following tests.

- `test_kohn_sham_iteration_neural_xc_energy_loss_gradient`
- `test_kohn_sham_iteration_neural_xc_density_loss_gradient`
- `test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry`
- `test_kohn_sham_neural_xc_energy_loss_gradient`
- `test_kohn_sham_neural_xc_density_loss_gradient`

Please note:
- The only goal of this notebook is to show how the gradients may be found analytically
- The values obtained match very closely both with those found by `jax.grad` and those given in the the tests
- It is *very* rough, by no means efficient and does not represent best practices for obtaining gradients. 
- It is confined to the cases covered in the test which means:
  - It is hard-coded to handle exactly 2 electrons and we can deal with a single density vector and we don't keep track of whether or not the density vector is repeated as would be the case if you had an odd number of electrons.
  - Same behaviour is assumed whether `enforce_reflection_symmetry` is `True` or `False`

The formulas for finding gradients of eigenvectors and eigenvalues which is needed for `scf.solve_interacting_system` can be found [here](https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf) (see 67 and 68 in page 10, section 2.3). These formulas are valid for the return values of `jax.numpy.eigenh` since it returns normalised eigenvectors. 

In [22]:
import jax.numpy as jnp
import jax
import sys
import functools
%reload_ext autoreload
%autoreload 2

In [23]:
import sys
sys.path.append('../')
from jax_dft import scf
from jax.experimental import stax
from jax_dft import np_utils
from jax import tree_util
from jax.config import config
from jax_dft import neural_xc, utils
from jax import random

# Set the default dtype as float64
config.update('jax_enable_x64', True)

In [24]:
def KS_grad_step(state, old_state, weights, grads):
    w, b = weights
    dx = utils.get_dx(state.grids)
    
    n1 = old_state.density[:, None]
    n2 = state.density[:, None]
    
    (dn1_dw, dn1_db) = grads
    
    dvxc1_dw = 2 * n1 + 2 * w * dn1_dw
    dvxc1_db = 1 + 2 * w * dn1_db
    
    r1 = jnp.expand_dims(state.grids, axis=0)
    r2 = jnp.expand_dims(state.grids, axis=1)
    
    fr1r2 = utils.exponential_coulomb(r1 - r2) # constant
    
    dvH1_shared = (fr1r2 * dx)
    dvH1_dw = dvH1_shared @ dn1_dw
    dvH1_db = dvH1_shared @ dn1_db
    
    ks_potential = (state.hartree_potential + state.external_potential + state.xc_potential)[:, None]
    
    A = scf.get_kinetic_matrix(state.grids) + jnp.diag(ks_potential.squeeze())
    eigen_energies, wavefunctions_transpose = jnp.linalg.eigh(A)
    
    E0 = eigen_energies[0]
    Psi0 = wavefunctions_transpose[:, :1]
    
    
    dPsi0_shared = jnp.linalg.pinv(E0*jnp.eye(len(n1)) - A) @ jnp.diag(Psi0.squeeze())
    dPsi0_dw = dPsi0_shared @ (dvxc1_dw + dvH1_dw)
    dPsi0_db = dPsi0_shared @ (dvxc1_db + dvH1_db)
    
    Psi0_abs = jnp.sqrt(Psi0.T @ Psi0)
    
    a = Psi0 / (Psi0_abs * jnp.sqrt(dx))
    
    da_shared = (jnp.eye(len(n1)) / Psi0_abs - (Psi0 @ Psi0.T) / Psi0_abs**1.5) / jnp.sqrt(dx)

    
    da_dw = da_shared @ dPsi0_dw
    da_db = da_shared @ dPsi0_db
    
    dn2_dw = 4 * a * da_dw
    dn2_db = 4 * a * da_db
    
    dEeig_dw = (2 * Psi0 * Psi0).T @ (dvxc1_dw + dvH1_dw)
    dEeig_db = (2 * Psi0 * Psi0).T @ (dvxc1_db + dvH1_db)
    
    
    dPext_dw = ((dvxc1_dw + dvH1_dw).T @ n2 + (ks_potential).T @ dn2_dw) * dx
    dPext_db = ((dvxc1_db + dvH1_db).T @ n2 + (ks_potential).T @ dn2_db) * dx
    
    dvxc2_dw = 2 * n2 + 2 * w * dn2_dw
    dvxc2_db = 1 + 2 * w * dn2_db
    
    dvH2_shared =  (fr1r2 * dx)
    dvH2_dw = dvH2_shared @ dn2_dw 
    dvH2_db = dvH2_shared @ dn2_db
    
    vH2 = scf.get_hartree_potential(
            density=n2.squeeze(),
            interaction_fn=utils.exponential_coulomb,
            grids=state.grids,
        )[:, None]
    
    print(dvH2_dw.shape, n2.shape, vH2.shape, dn2_dw.shape)
    
    dEH2_shared = vH2 * dx # (fr1r2 @ n2) * dx**2
    dEH2_dw = dEH2_shared.T @ dn2_dw
    dEH2_db = dEH2_shared.T @ dn2_db
    
    dEext2_dw = (state.external_potential.T @ dn2_dw) * dx
    dEext2_db = (state.external_potential.T @ dn2_db) * dx
    
    dExc2_dw = (n2.T @ n2  + (2 * w * n2 + b).T @ dn2_dw) * dx
    dExc2_db = (n2.sum() + (2 * w * n2 + b).T @ dn2_db) * dx        
                
                
    dTE_dw = (dEeig_dw - dPext_dw) + dEH2_dw + dEext2_dw + dExc2_dw
    dTE_db = (dEeig_db - dPext_db) + dEH2_db + dEext2_db + dExc2_db
    
    return (dn2_dw, dn2_db), (dTE_dw, dTE_db)

In [25]:
def KS_grad(grids, num_electrons, enforce_reflection_symmetry, iterations, 
            alpha=0.5,
            alpha_decay=0.9,
            num_mixing_iterations=2,
           density_mse_converge_tolerance=-1.,
           solve_init_density=False):
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
    stax.serial(stax.Dense(1))
)
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = _create_testing_initial_state(
        grids, num_electrons, utils.exponential_coulomb)
    
    if solve_init_density:
        initial_state = initial_state._replace(density = scf.solve_interacting_system(
            external_potential=initial_state.external_potential,
            num_electrons=num_electrons,
            grids=grids
        )[0])
    
    
    
    target_density = (
        utils.gaussian(grids=grids, centre=-0.5, sigma=1.)
        + utils.gaussian(grids=grids, centre=0.5, sigma=1.))
    spec, flatten_init_params = np_utils.flatten(init_params)
     
    w, b = flatten_init_params
    
    dn_dw, dn_db = (jnp.zeros([len(initial_state.density), 1]) for _ in range(2))
    dLE_dw, dLE_db = (jnp.zeros([1, 1]) for _ in range(2))
    
    dx = utils.get_dx(grids)
    
    state = initial_state
    
    converged = False
    
    states = []
    differences = None
    ddiffs_w = None
    ddiffs_b = None
    
    state = initial_state
    
    for step in range(iterations):
        
        if converged:
            states.append(state)
            continue
            
        old_state = state
        state = scf.kohn_sham_iteration(
              state=old_state,
              num_electrons=num_electrons,
              xc_energy_density_fn=tree_util.Partial(
                  xc_energy_density_fn,
                  params=np_utils.unflatten(spec, flatten_init_params)),
              interaction_fn=utils.exponential_coulomb,
              enforce_reflection_symmetry=enforce_reflection_symmetry)
        
        
        
        # dL_dE = 2 * (state.total_energy - target_energy)
        
        dn_dw_before, dn_db_before = (dn_dw, dn_db)
        
        (dn_dw, dn_db), (dE_dw, dE_db) = KS_grad_step(state, old_state, (w, b), (dn_dw, dn_db))
 
        
        # dLE_dw += dL_dE * dE_dw_step
        # dLE_db += dL_dE * dE_db_step
        
        if iterations == 1:
            states.append(state)
            break
        
        
        density_difference = state.density - old_state.density
        dn_dw_diff = dn_dw - dn_dw_before
        dn_db_diff = dn_db - dn_db_before
        

        if differences is None:
            differences = jnp.array([density_difference])
            ddiffs_w = jnp.array([dn_dw_diff.squeeze()])
            ddiffs_b = jnp.array([dn_db_diff.squeeze()])
        else:
            differences = jnp.vstack([differences, density_difference])
            ddiffs_w = jnp.vstack([ddiffs_w, dn_dw_diff.squeeze()])
            ddiffs_b = jnp.vstack([ddiffs_b, dn_db_diff.squeeze()])

        if jnp.mean(jnp.square(differences)) < density_mse_converge_tolerance:
            print('Converged at iter', step)
            converged = True

        state = state._replace(converged=converged)
        state = state._replace(
            density=old_state.density + alpha * jnp.mean(differences[-num_mixing_iterations:], axis=0)
        )
        dn_dw = dn_dw_before + alpha * jnp.mean(ddiffs_w[-num_mixing_iterations:], axis=0)[:, None]
        dn_db = dn_db_before + alpha * jnp.mean(ddiffs_b[-num_mixing_iterations:], axis=0)[:, None]
        
        states.append(state)
        alpha *= alpha_decay
        
    
    dL_dn = (jnp.sign(states[-1].density - target_density) * dx)[:, None]
    dLn_dw = dL_dn.T @ dn_dw
    dLn_db = dL_dn.T @ dn_db
    
    dL_dE = 2 * (states[-1].total_energy - target_energy)
    dLE_dw = dL_dE * dE_dw
    dLE_db = dL_dE * dE_db
    return (dLn_dw.squeeze(), dLn_db.squeeze()), (dLE_dw.squeeze(), dLE_db.squeeze()), (w, b)

In [26]:
def _create_testing_initial_state(grids, num_electrons, interaction_fn):
    locations = jnp.array([-0.5, 0.5])
    nuclear_charges = jnp.array([1, 1])
    return scf.KohnShamState(
        density=num_electrons * utils.gaussian(
            grids=grids, centre=0., sigma=1.),
        # Set initial energy as inf, the actual value is not used in Kohn-Sham
        # calculation.
        total_energy=jnp.inf,
        locations=locations,
        nuclear_charges=nuclear_charges,
        external_potential=utils.get_atomic_chain_potential(
            grids=grids,
            locations=locations,
            nuclear_charges=nuclear_charges,
            interaction_fn=interaction_fn),
        grids=grids,
        num_electrons=num_electrons)

## Single iteration

In [63]:
target_energy = 2.

_grids = jnp.linspace(-5, 5, 101)
_num_electrons = 2

target_density = (
    utils.gaussian(grids=_grids, centre=-0.5, sigma=1.)
    + utils.gaussian(grids=_grids, centre=0.5, sigma=1.))

init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
    stax.serial(stax.Dense(1))
)
init_params = init_fn(rng=random.PRNGKey(0))

spec, flatten_init_params = np_utils.flatten(init_params)

initial_state = _create_testing_initial_state(
    _grids, _num_electrons, utils.exponential_coulomb)

def loss_n(flatten_params, initial_state, target_density):
    state = scf.kohn_sham_iteration(
      state=initial_state,
      num_electrons=_num_electrons,
      xc_energy_density_fn=tree_util.Partial(
          xc_energy_density_fn,
          params=np_utils.unflatten(spec, flatten_params)),
      interaction_fn=utils.exponential_coulomb,
      enforce_reflection_symmetry=False)
    return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
      _grids)

grad_fn_n = jax.grad(loss_n)

params_grad_n = grad_fn_n(
    flatten_init_params,
    initial_state=initial_state,
    target_density=target_density)



def loss_n_sym(flatten_params, initial_state, target_density):
    state = scf.kohn_sham_iteration(
      state=initial_state,
      num_electrons=_num_electrons,
      xc_energy_density_fn=tree_util.Partial(
          xc_energy_density_fn,
          params=np_utils.unflatten(spec, flatten_params)),
      interaction_fn=utils.exponential_coulomb,
      enforce_reflection_symmetry=True)
    return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
      _grids)
                     

grad_fn_n_sym = jax.grad(loss_n_sym)

params_grad_n_sym = grad_fn_n_sym(
    flatten_init_params,
    initial_state=initial_state,
    target_density=target_density)


def loss_E(flatten_params, initial_state, target_energy):
    _state = scf.kohn_sham_iteration(
        state=initial_state,
        num_electrons=_num_electrons,
        xc_energy_density_fn=tree_util.Partial(
            xc_energy_density_fn,
            params=np_utils.unflatten(spec, flatten_params)),
        interaction_fn=utils.exponential_coulomb,
        enforce_reflection_symmetry=True)
    return (_state.total_energy - target_energy) ** 2

grad_fn_E = jax.grad(loss_E)

params_grad_E = grad_fn_E(
    flatten_init_params,
    initial_state=initial_state,
    target_energy=target_energy)

In [64]:
x = KS_grad(grids=_grids, num_electrons=_num_electrons, enforce_reflection_symmetry=True, iterations=1)

(101, 1) (101, 1) (101, 1) (101, 1)


In [65]:
[[i.shape for i in y] for y in x]

[[(), ()], [(), ()], [(), ()]]

In [66]:
x[0], x[1]

((DeviceArray(-1.3413697, dtype=float64),
  DeviceArray(-1.51196672e-15, dtype=float64)),
 (DeviceArray(-8.54995173, dtype=float64),
  DeviceArray(-14.75419501, dtype=float64)))

In [67]:
flatten_init_params

array([-0.27235165,  0.01030468])

In [68]:
params_grad_n_sym

DeviceArray([-1.34136970e+00,  2.26901831e-15], dtype=float64)

In [69]:
assert jnp.allclose(params_grad_n_sym, jnp.array(x[0]))
jnp.allclose(params_grad_n_sym, jnp.array(x[0]))

DeviceArray(True, dtype=bool)

In [70]:
params_grad_E

DeviceArray([ -8.54995173, -14.75419501], dtype=float64)

In [71]:
assert jnp.allclose(params_grad_E, jnp.array(x[1]))
jnp.allclose(params_grad_E, jnp.array(x[1]))

DeviceArray(True, dtype=bool)

In [72]:
x = KS_grad(grids=_grids, num_electrons=_num_electrons, enforce_reflection_symmetry=False, iterations=1)

assert jnp.allclose(params_grad_n, jnp.array(x[0]))
jnp.allclose(params_grad_n, jnp.array(x[0]))

(101, 1) (101, 1) (101, 1) (101, 1)


DeviceArray(True, dtype=bool)

## Several iterations - Energy

In [73]:
init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
    stax.serial(stax.Dense(1)))
init_params = init_fn(rng=random.PRNGKey(0))
target_energy = 2.

spec, flatten_init_params = np_utils.flatten(init_params)
spec2, flatten_init_params2 = np_utils.flatten(init_params)
_locations = jnp.array([-0.5, 0.5])
_grids = jnp.linspace(-5, 5, 101)
_nuclear_charges = jnp.array([1, 1])
_num_electrons = 2

target_density = (
        utils.gaussian(grids=_grids, centre=-0.5, sigma=1.)
        + utils.gaussian(grids=_grids, centre=0.5, sigma=1.))



def loss_E(flatten_params, target_energy):
  state = scf.kohn_sham(
      locations=_locations,
      nuclear_charges=_nuclear_charges,
      num_electrons=_num_electrons,
      num_iterations=3,
      grids=_grids,
      xc_energy_density_fn=tree_util.Partial(
          xc_energy_density_fn,
          params=np_utils.unflatten(spec2, flatten_params)),
      interaction_fn=utils.exponential_coulomb)
  final_state = scf.get_final_state(state)
  return (final_state.total_energy - target_energy) ** 2

grad_fn_E = jax.grad(loss_E)

params_grad_E = grad_fn_E(flatten_init_params2, target_energy=target_energy)


def loss_n(flatten_params, target_density):
  state = scf.kohn_sham(
      locations=_locations,
      nuclear_charges=_nuclear_charges,
      num_electrons=_num_electrons,
      num_iterations=3,
      grids=_grids,
      xc_energy_density_fn=tree_util.Partial(
          xc_energy_density_fn,
          params=np_utils.unflatten(spec, flatten_params)),
      interaction_fn=utils.exponential_coulomb,
      density_mse_converge_tolerance=-1.)
  final_state = scf.get_final_state(state)
  return jnp.sum(
      jnp.abs(final_state.density - target_density)) * utils.get_dx(
          _grids)

grad_fn_n = jax.grad(loss_n)

params_grad_n = grad_fn_n(flatten_init_params, target_density=target_density)


In [74]:
x = KS_grad(grids=_grids, 
            num_electrons=_num_electrons, 
            enforce_reflection_symmetry=False, 
            iterations=3, 
            solve_init_density=True)

(101, 1) (101, 1) (101, 1) (101, 1)
(101, 1) (101, 1) (101, 1) (101, 1)
(101, 1) (101, 1) (101, 1) (101, 1)


In [75]:
x

((DeviceArray(-1.59671362, dtype=float64),
  DeviceArray(-1.07372625e-13, dtype=float64)),
 (DeviceArray(-8.57162696, dtype=float64),
  DeviceArray(-14.75474883, dtype=float64)),
 (-0.27235164784460814, 0.010304675877677812))

In [76]:
init_params, params_grad_n

([(DeviceArray([[-0.27235165]], dtype=float64),
   DeviceArray([0.01030468], dtype=float64))],
 DeviceArray([-1.59671362e+00, -7.32053307e-16], dtype=float64))

In [77]:
assert jnp.allclose(params_grad_n, jnp.stack(x[0]))
jnp.allclose(params_grad_n, jnp.stack(x[0]))

DeviceArray(True, dtype=bool)

In [78]:
assert jnp.allclose(params_grad_E, jnp.stack(x[1]))
jnp.allclose(params_grad_E, jnp.stack(x[1]))

DeviceArray(True, dtype=bool)

In [79]:
params_grad_E

DeviceArray([ -8.57162696, -14.75474883], dtype=float64)