In this notebook, we'll look at Tensorflow's tools for computing a jacobian in the context of an SR algorithm.

We'll initialize some input, a network, and try to compute the jacobian in several different ways.

In [1]:
import tensorflow as tf
import time, timeit

N_WALKERS = 100
DIM = 3
N_PARTICLES = 4

In [2]:
x_input = tf.random.uniform(shape=(N_WALKERS, N_PARTICLES, DIM))

In [3]:
class DeepSetsWavefunction(tf.keras.models.Model):
    """Create a neural network eave function in N dimensions

    Boundary condition, if not supplied, is gaussian in every dimension

    Extends:
        tf.keras.models.Model
    """
    def __init__(self, ndim : int, nparticles: int, mean_subtract : bool, boundary_condition :tf.keras.layers.Layer = None):
        '''Deep Sets wavefunction for symmetric particle wavefunctions

        Implements a deep set network for multiple particles in the same system

        Arguments:
            ndim {int} -- Number of dimensions
            nparticles {int} -- Number of particls

        Keyword Arguments:
            boundary_condition {tf.keras.layers.Layer} -- [description] (default: {None})

        Raises:
            Exception -- [description]
        '''
        tf.keras.models.Model.__init__(self)

        self.ndim = ndim
        if self.ndim < 1 or self.ndim > 3:
           raise Exception("Dimension must be 1, 2, or 3 for DeepSetsWavefunction")

        self.nparticles = nparticles

        self.mean_subtract = mean_subtract


        n_filters_per_layer = 8
        n_layers            = 1
        bias                = True
        activation          = tf.keras.activations.tanh


        self.individual_net = tf.keras.models.Sequential()

        self.individual_net.add(
            tf.keras.layers.Dense(n_filters_per_layer,
                use_bias = bias)
            )

        for l in range(n_layers):
            self.individual_net.add(
                tf.keras.layers.Dense(n_filters_per_layer,
                    use_bias    = bias,
                    activation = activation)
                )


        self.aggregate_net = tf.keras.models.Sequential()

        for l in range(n_layers):
            self.individual_net.add(
                tf.keras.layers.Dense(n_filters_per_layer,
                    use_bias    = bias,
                    activation = activation)
                )
        self.aggregate_net.add(tf.keras.layers.Dense(1,
            use_bias = False))


    @tf.function(experimental_compile=False)
    def call(self, inputs, trainable=None):
        # Mean subtract for all particles:
        if self.nparticles > 1 and self.mean_subtract:
            mean = tf.reduce_mean(inputs, axis=1)
            xinputs = inputs - mean[:,None,:]
        else:
            xinputs = inputs

        x = []
        for p in range(self.nparticles):
            x.append(self.individual_net(xinputs[:,p,:]))

        x = tf.add_n(x)
        x = self.aggregate_net(x)

        # Compute the initial boundary condition, which the network will slowly overcome
        # boundary_condition = tf.math.abs(self.normalization_weight * tf.reduce_sum(xinputs**self.normalization_exponent, axis=(1,2))
        boundary_condition = -1. * tf.reduce_sum(xinputs**2, axis=(1,2))
        boundary_condition = tf.reshape(boundary_condition, [-1,1])


        return x + boundary_condition

    def n_parameters(self):
        return tf.reduce_sum( [ tf.reduce_prod(p.shape) for p in self.trainable_variables ])


In [4]:
wavefunction = DeepSetsWavefunction(ndim=DIM, nparticles=N_PARTICLES, mean_subtract=True)

In [5]:
output = wavefunction(x_input)

Now, we have a compiled wavfunction with a number of parameters

In [6]:
print(wavefunction.summary())

Model: "deep_sets_wavefunction"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      (100, 8)                  176       
_________________________________________________________________
sequential_1 (Sequential)    (100, 1)                  8         
Total params: 184
Trainable params: 184
Non-trainable params: 0
_________________________________________________________________
None


In [7]:
# In general, we can compute the gradient with respect to the inputs:

with tf.GradientTape() as tape:
    log_psiw = wavefunction(x_input)

# By default, this essentially SUMS over the dimension of log_psiw
print(log_psiw.shape)
grads = tape.gradient(log_psiw, wavefunction.trainable_variables)

(100, 1)


In [8]:
for g in grads:
    print(g.shape)

(3, 8)
(8,)
(8, 8)
(8,)
(8, 8)
(8,)
(8, 1)


We can also compute the jacobian, instead of the gradient, which in this case is the gradient but for only one walker at a time:

In [9]:

with tf.GradientTape() as tape:
    log_psiw = wavefunction(x_input)

# By default, this essentially SUMS over the dimension of log_psiw
print(log_psiw.shape)
jac = tape.jacobian(log_psiw, wavefunction.trainable_variables)

(100, 1)


In [10]:
for j in jac:
    print(j.shape)

(100, 1, 3, 8)
(100, 1, 8)
(100, 1, 8, 8)
(100, 1, 8)
(100, 1, 8, 8)
(100, 1, 8)
(100, 1, 8, 1)


We can verify this, when average, matches the total gradient: 

In [11]:
for j, g in zip(jac, grads):
#     print(j)
#     print(g)
#     print(tf.reduce_sum(j, axis=(0,1)))
    diff = tf.abs(g - tf.reduce_sum(j, axis=(0,1)))
    print(tf.reduce_max(diff))


tf.Tensor(8.1956387e-07, shape=(), dtype=float32)
tf.Tensor(3.0517578e-05, shape=(), dtype=float32)
tf.Tensor(8.6426735e-07, shape=(), dtype=float32)
tf.Tensor(4.5776367e-05, shape=(), dtype=float32)
tf.Tensor(8.34465e-07, shape=(), dtype=float32)
tf.Tensor(7.6293945e-05, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)


For every layer, the jacobian in this case is quite well matching the gradients if you sum over the input dimension.

It seems like we should be able to compute this per walker too:

In [12]:
with tf.GradientTape(persistent=True) as tape:
    log_psiw = wavefunction(x_input)

    split = tf.split(log_psiw, (1, N_WALKERS-1))
    
# print(split)
# By default, this essentially SUMS over the dimension of log_psiw
grad = tape.gradient(split[1], wavefunction.trainable_variables)


In [13]:
@tf.function
def jacobian_comp(inputs, _wavefunction):

    with tf.GradientTape() as tape:
        log_psiw = _wavefunction(inputs)

    # By default, this essentially SUMS over the dimension of log_psiw
    jac = tape.jacobian(log_psiw, _wavefunction.trainable_variables)

    return jac

In [14]:
@tf.function
def jacobian_grad(inputs, _wavefunction):
    
    n_walkers = inputs.shape[0]
    
    with tf.GradientTape(persistent=True) as tape:
        log_psiw = _wavefunction(inputs)

        split = tf.split(log_psiw, n_walkers)

    # print(split)
    # By default, this essentially SUMS over the dimension of log_psiw
    grad = [tape.gradient(s, _wavefunction.trainable_variables) for s in split]

    jac = []
    for i, l in enumerate(_wavefunction.trainable_variables):
        temp = tf.stack([g[i] for g in grad])
        temp = tf.reshape(temp,  log_psiw.shape + l.shape)
        jac.append(temp)
    
    return jac
                             

In [15]:
start = time.time()
jc = jacobian_comp(x_input, wavefunction)
print("Compilation time: ", time.time() - start)

Compilation time:  1.5968520641326904


In [16]:
start = time.time()
jg = jacobian_grad(x_input, wavefunction)
print("Compilation time: ", time.time() - start)

Compilation time:  13.892029047012329


In [17]:
%timeit jacobian_comp(x_input, wavefunction)

1.54 ms ± 41.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [18]:
%timeit jacobian_grad(x_input, wavefunction)

3.82 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
print(jg[0].shape) 
print(jc[0].shape)

(100, 1, 3, 8)
(100, 1, 3, 8)


IN the end, the split version gives correct results and comparative performance (on my CPU) but has a dramatic compile time.  The main reason to pursue it is the reduced memory usage, which doesn't scale as N_WALKERS^2.  Tensorflow also has the option for limiting the number of parallel iterations in the jacobian calculation, which saves memory significantly.