# LapJAX Tutorial
(updated at Oct-12-20223)

This tutorial aims to give you a quick understanding of LapJAX. We will cover the following topics:
1. How to use LapJAX to accelerate your code simply by changing the import statement.
2. How fast LapJAX is compared to stanford methods.
3. How to build your custom operators in LapJAX.


### A Quick Start
To use LapJAX, simply change all `jax` in the import statement to `lapjax`. For example, change
```python
import jax
import jax.numpy as jnp
from jax import vmap
```
to
```python
import lapjax
import lapjax.numpy as jnp
from lapjax import vmap
```
Without any further change, this code runs *exactly the same* as when you use `jax`. This ensures the compatibility of LapJAX under most of situations where laplacian is not needed.

In [1]:
# WARN: if you are running on the Ampere architecture GPU, e.g. A100/GTX3090,
# please make sure to close the TF32 option.
# Otherwise, you will lose numerical precisions.
import os
os.environ['NVIDIA_TF32_OVERRIDE'] = "0"

# import jax
# impoprt jax.numpy as jnp
import lapjax as jax
import lapjax.numpy as jnp

    No module named 'requests'
    This script requires `tensorflow` to be installed.


### Case 1: A stanford MLP model
Below we consider a stanford MLP network. Assume you have written the model using `jax` as follows. All functions remain unchanged when you use `lapjax`.


In [2]:
# define the hyperparameters
input_dim = 64
hidden_dim = 256
hidden_layer = 4
layer_dims = [input_dim,] + [hidden_dim,] * hidden_layer + [1,]

# define init function
def init_params(key):
    params = []
    left_dim = input_dim
    for right_dim in layer_dims:
        key, subkey = jax.random.split(key)
        params.append(jax.random.normal(subkey, (left_dim,right_dim)) * 0.1)
        left_dim = right_dim
    return params

# Define the network
def MLP(params,x):
    for param in params:
        # use lapjax.numpy to construct the function
        # This function can take both jax.ndarray and lapjax.LapTuple as input
        x = jnp.matmul(x, param)
        x = jnp.tanh(x)
    return x.reshape(-1)

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
params = init_params(key)

2023-10-23 17:16:40.424880: W external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 0 and 3; status: INTERNAL: failed to enable peer access from 0x7f936876b260 to 0x7f93648f6e40: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted
2023-10-23 17:16:40.431744: W external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 3 and 0; status: INTERNAL: failed to enable peer access from 0x7f93648f6e40 to 0x7f936876b260: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted


Now we construct the [pure function](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) to compute laplacian of the model in two ways.

In [3]:
# compute laplacian through standard jax method
def get_laplacian_function_orig(func):
    def lap(data):
        # assume data.shape = (N,)
        # compute hessian
        hess = jax.hessian(func)(data)
        return jnp.trace(hess,axis1=-1,axis2=-2)
    return lap

# compute laplacian through lapjax
def get_laplacian_function_lapjax(func):
    # ATTENTION: all you need to do is a few changes here!
    from lapjax import LapTuple
    def lap(data):
        input_laptuple = LapTuple(data, is_input=True)
        output_laptuple = func(input_laptuple)
        # LapTuple has value, grad
        return output_laptuple.lap
    return lap

# get laplacian function
lap_original = get_laplacian_function_orig(lambda x: MLP(params, x))
lap_lapjax = get_laplacian_function_lapjax(lambda x: MLP(params, x))

# note that the input and output of `lap_lapjax` are both `jax.numpy.ndarray`
# so we can use `jax.vmap` to deal with a batch of data. `jax.jit` also works well
vmap_lap_orignal = jax.jit(jax.vmap(lap_original))
vmap_lap_lapjax = jax.jit(jax.vmap(lap_lapjax))

We now test the precision of lapjax method.

In [4]:
batch_size = 1280

key, subkey = jax.random.split(key)
data = jax.random.normal(subkey,(batch_size, input_dim))

orig_results = vmap_lap_orignal(data)
lapjax_results = vmap_lap_lapjax(data)

# the maxium difference is a standard float32 numerical error
print(jnp.max(jnp.abs(orig_results-lapjax_results)))

3.8146973e-06


We now test the acceleration of lapjax method. Notice that for fully connected networks, there is not sparsity acceleration, so the speedup ratio is approximately 2.

In [5]:
import time
def compute_time(vfunc, key, batch_size, input_dim, iterations=100):
    key, subkey = jax.random.split(key)
    data_pool = jax.random.normal(subkey,(iterations*2, batch_size, input_dim))
    
    data_pool1, data_pool2 = jnp.split(data_pool, 2)

    # warm up to avoid the cache problem
    for data in data_pool1:
        val = vfunc(data)

    start_time = time.time()
    for data in data_pool2:
        val = vfunc(data)
    end_time = time.time()
    return val, end_time - start_time

val, duration = compute_time(vmap_lap_orignal, key, batch_size, input_dim)
print('time of hessian-trace:', duration)

val, duration = compute_time(vmap_lap_lapjax, key, batch_size, input_dim)
# forward laplacian is roughly 2 times faster than the original method
print('time of forward laplacian:', duration)

time of hessian-trace: 1.0552332401275635
time of forward laplacian: 0.6265256404876709


### Case 2: A Slater-Determinants based model
Below we consider a Slater-Determinants based model that is typically used to represent wave functions. As you will see, the model is very sparse, and thus the acceleration of lapjax is significant.

In [6]:
# CASE2: a Slater-Determinants like wave functions
# In this case, we could leverage the derivative sparsity to
# achieve over a magnitude speed-up.

# construct input params
key = jax.random.PRNGKey(42)

n_elec = 16 # number of electrons
input_dim = 3 # the dimension of electron position
hidden_dim = 256
hidden_layer = 2
layer_dims = [input_dim,] + [hidden_dim,] * hidden_layer + [n_elec,]

def init_params(key):
    params = []
    left_dim = input_dim
    for right_dim in layer_dims:
        key, subkey = jax.random.split(key)
        params.append(jax.random.normal(subkey, (left_dim,right_dim)) * 0.1)
        left_dim = right_dim
    return params

# construct the wave functions. 
def slater_determinants(params, x):

    # x.shape = (n_elec * input_dim,)
    x = x.reshape(n_elec, input_dim)
    for param in params:
        # Each electron is processed by the same MLP function
        x = jnp.matmul(x, param)
        x = jnp.tanh(x)

    x = x + jnp.eye(x.shape[0])

    _, lnpsi = jnp.linalg.slogdet(x)
    return lnpsi

key, subkey = jax.random.split(key)
params_sd = init_params(key)

Similar, we construct the pure function to compute the kinetic energy in two ways. In Variational Monte Carlo, we should compute the local kinetic energy, which is defined as: 
$$
E_k = \frac{-0.5 \times \nabla^2 \psi(\mathbf x)}{\psi (\mathbf x)} = -0.5 \times \nabla^2 \ln \psi (\mathbf x) - 0.5 \times (\nabla \ln \psi (\mathbf x))^2.
$$

In [7]:
def get_kinetic_function_orig(func):
    def kinetic(data):
        grad = jax.grad(func)(data)
        hess = jax.hessian(func)(data)
        return -0.5 * (jnp.trace(hess) + jnp.sum(grad ** 2))
        
    return kinetic

def get_kinetic_function_lapjax(func):
    from lapjax import LapTuple
    def kinetic(data):
        input_laptuple = LapTuple(data, is_input=True)
        output_laptuple = func(input_laptuple)

        # A Laptupe stores both gradient and laplacian information,
        # so we do not need to compute gradient again.
        return -0.5 * output_laptuple.lap - 0.5 * jnp.sum(output_laptuple.grad**2)

    return kinetic

# get kinetic function
ke_original = get_kinetic_function_orig(lambda x: slater_determinants(params_sd, x))
ke_lapjax = get_kinetic_function_lapjax(lambda x: slater_determinants(params_sd, x))

vmap_ke_orignal = jax.jit(jax.vmap(ke_original))
vmap_ke_lapjax = jax.jit(jax.vmap(ke_lapjax))

We now test the precision of lapjax method.

In [8]:
# CASE2: precision test
batch_size = 128

key, subkey = jax.random.split(key)
data = jax.random.normal(subkey,(batch_size, input_dim*n_elec))

orig_results = vmap_ke_orignal(data)
lapjax_results = vmap_ke_lapjax(data)

# the maxium difference is a standard float32 numerical error
print(jnp.max(jnp.abs(orig_results-lapjax_results)))

4.7683716e-07


We now test the acceleration of lapjax method. The efficiency improvement is significant.

In [9]:
val, duration = compute_time(vmap_ke_orignal, 
                             key, batch_size, input_dim * n_elec)
print('time of hessian-trace:', duration)

val, duration = compute_time(vmap_ke_lapjax, 
                             key, batch_size, input_dim * n_elec)
# forward laplacian is roughly 2 times faster than the original method
print('time of forward laplacian:', duration)

time of hessian-trace: 0.6880514621734619
time of forward laplacian: 0.12378358840942383


### Case 3: Customize Operators
We understand that the operators you need may not have been wrapped by lapjax yet. In this case, you can easily wrap operators in `jax` as you want. Assume we want to use `jax.numpy.isnan` in our model, which does not support a `LapTuple` input.

In [10]:
from lapjax import LapTuple
import lapjax.numpy as jnp
lap = LapTuple(jnp.eye(4), is_input=True) / jnp.eye(4)
try:
    print(jnp.isnan(lap))
except Exception as e:  # should see the unwarpped error.
    print(e)

Lapjax encounters unwrapped function 'isnan'.
Please consider using other functions or wrap it yourself.
You can refer to README for more information about customized wrap.


To wrap a `jax` function, e.g., `f`, we need to:
1. specify what classes `f` belongs. For instance, `jax.numpy.exp2` belongs to `FElement`, as the operator is applied to each element of the input.
2. bind `f` to corresponding class, and write a customized function for `LapTuple` (only if needed).

For the `jax.numpy.isnan` case, we need to judge whether the input is nan or not, and there is no gradient or laplacian value (or you can regard it as 0). When the output of a function `f` should only contain arrays with zero gradient and laplacian, `f` should belong to `FConstruction` class.

In [11]:
from lapjax import custom_wrap, FType
custom_wrap(jnp.isnan, FType.CONSTRUCTION)

jnp.isnan(lap)

Successfully bind function 'isnan' to FType.CONSTRUCTION.


DeviceArray([[False,  True,  True,  True],
             [ True, False,  True,  True],
             [ True,  True, False,  True],
             [ True,  True,  True, False]], dtype=bool)

Similarly, we can wrap `jax.numpy.isinf` and `jax.numpy.isfinite` as above. To wrap `jax.numpy.exp2`, which belongs to `FElement` class, we can simply do

In [12]:
f = lambda x: jnp.log(jnp.nansum(jnp.exp(x), axis=-1))
try:
    f(lap)
except Exception as e:
    print(e)
print("\nNow we wrap it.\n")
from lapjax import custom_wrap, FType
custom_wrap(jnp.nansum, FType.LINEAR)
output = f(lap)
output.value, output.grad.shape

You can customize the function yourself and bind to `CUSTOMIZED`.


Lapjax encounters unwrapped function 'nansum'.
Please consider using other functions or wrap it yourself.
You can refer to README for more information about customized wrap.

Now we wrap it.

Successfully bind function 'nansum' to FType.LINEAR.


(DeviceArray([1., 1., 1., 1.], dtype=float32), (16, 4))

You should have noticed the output "Notice that if custom_type is `FLinear`, you might loss the sparsity. Please consider customize the function and bind to `CUSTOMIZED`." This is because `lapjax` leverages the sparsity of arrays to accelerate, and binding a function to `FLiner` class without curating the sparsity of output LapTuples will result in a loss of sparsity. In this case, you should consider curating the sparsity of output LapTuples yourself, and bind the function to `CUSTOMIZED` class. For example, below is the way to customize `jax.numpy.nansum` carefully.

In [13]:
def cst_nansum(*args, **kwargs):    # same inputs as jnp.nansum
    array: LapTuple = args[0]
    # Used already wrapped functions to compose nansum.
    # Notice that we have wrapped isnan before. 
    array = jnp.where(jnp.isnan(array), 0, array) # mask nan to 0.
    args = (array,) + args[1:]

    return jnp.sum(*args, **kwargs)
custom_wrap(jnp.nansum, FType.CUSTOMIZED, cst_f=cst_nansum, overwrite=True)
output = f(lap)
output.value, output.grad.shape

Successfully bind function 'nansum' to FType.CUSTOMIZED.


(DeviceArray([1., 1., 1., 1.], dtype=float32), (4, 4))

If you compare two outputs above, you will find that the second cell have a smaller gradient size. This is because `lapjax.numpy.sum` has curated the sparsity such that the output LapTuple only contains truly non-zero gradient values. Let's compare their efficiencies.

In [14]:
input_dim = 100 # shape = (input_dim, input_dim)
batch_size = 16

# bind to `FLINEAR` type
custom_wrap(jnp.nansum, FType.LINEAR, overwrite=True)
def f(x):
    sq = x.reshape(input_dim, input_dim) / (1-jnp.eye(input_dim))
    return jnp.mean(jnp.log(jnp.nansum(jnp.exp(sq), axis=-1)))
lap_original = get_laplacian_function_lapjax(f)
vmap_lap_orignal = jax.jit(jax.vmap(lap_original))

key, subkey = jax.random.split(key)
val, duration = compute_time(vmap_lap_orignal, 
                             subkey, batch_size, input_dim ** 2)
print('time of no sparsity wrap:', duration)

# bind to `FCUSTOMIZED` type
custom_wrap(jnp.nansum, FType.CUSTOMIZED, cst_f=cst_nansum, overwrite=True)
lap_customized = get_laplacian_function_lapjax(f)
vmap_lap_orignal = jax.jit(jax.vmap(lap_customized))

key, subkey = jax.random.split(subkey)
val, duration = compute_time(vmap_lap_orignal, 
                             key, batch_size, input_dim ** 2)
print('time of sparsity wrap:', duration)

Successfully bind function 'nansum' to FType.LINEAR.
time of no sparsity wrap: 0.9739506244659424
Successfully bind function 'nansum' to FType.CUSTOMIZED.
time of sparsity wrap: 0.012285947799682617


#### Summary
It should be always kept in mind that the major acceleration of `lapjax` comes from the sparsity. Thus, try to use existing functions to write your (minimal) laplacian computation function, and try to wrap to `FELEMENT`, `FCONSTRUCTION`, `FMERGING`, and `FCUSTOMIZED` as much as possible.

If you think some functions are commonly used and should be wrapped in `lapjax`, please feel free to contact us or raise issues. Enjoy!