# $\beta$ Variational Autoencoders to Disentangle Multi-channel Spiking Data

I already have quite a few general notes on $\beta$ Variational Autoencoders with Tensorflow Probability in the IntracranialNeurophysDL repository [here](https://github.com/SachsLab/IntracranialNeurophysDL/blob/master/notebooks/05_04_betaVAE_TFP.ipynb). This notebook provides $\beta$-VAE component implementations that are more useful with our macaque PFC data.

Here we define a series of model-builder functions. Each function takes `params`, a dictionary of hyperparameters, and `inputs` containing one or more Keras tensors, and each returns the model outputs and other intermediate variables that need to be tracked.

We have generic functions to create the graphs for f- and z-encoders; we have a function to create the first part of the decoder graph; and a function to complete the decoder graph.

We also provide an end-to-end model to show how to use it.

These components are exported to our indl library. Our separate data analysis notebooks import this module, and possibly others (e.g., LFADS) to build models for analyzing our data. We don't do significant data analysis here.

As much as possible, we try to make the functions generic enough that we can use parameters to switch between different $\beta$-VAE implementations.

We identify 4 different VAE models for consideration:
* [disentangled sequential autoencoders (DSAE)](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/disentangled_vae.py) ([Li and Mandt, ICML 2018](https://arxiv.org/pdf/1803.02991.pdf)) - "Full" model.
* Same as above - "Factorized' model
* [FHVAE](https://github.com/wnhsu/ScalableFHVAE) ([Hsu and Glass](https://arxiv.org/pdf/1804.03201.pdf))
* [LFADS (latest: AutoLFADS)](https://github.com/snel-repo/lfads-cd/) ([Keshtkaran, ..., Pandarinath, 2021](https://www.biorxiv.org/content/10.1101/2021.01.13.426570v1))

The below table provides some differences between the models, but is perhaps incorrect and needs to be updated. Please do not rely on it.

|              | LFADS                     | DSAE full                       | DSAE factorized | FHVAE    |
| :---         | :---                      | :---                            | :---            | :---     |
| f RNN        | Bidir. GRU                | Bidir. LSTM                     | --              | LSTM x2  |
| f prior      | $\mathcal{N}(0,\kappa I)$ | $\mathcal{N}(\mu_z,\sigma_z I)$ | --              | $\mathcal{N}(\mu_2,0.5^2I)$ |
| z RNN        | A: Bidir. GRU, B:GRU      | Bdir. LSTM -> RNN               | MLP             | LSTM x2  |
| z RNN input  | A: x; B: (A(x), fac)      | concat(x, tile(f))              | $x_t$           | concat(input, tile(f)) |
| z prior      | LearnableAutoRegressive1Prior | LSTM(0)                         | --              | $\mathcal{N}(0,I)$ |
| Decoder RNN  | GRU                       | ??                              | ??              | LSTM x2  |
| RNN input0   | 0 / z                     | ??                              | ??              | concat(f, z) |
| RNN state0   | f                         | ??                              | ??              | 0  |
| RNN output   | -MLP-> fac -MLP-> rates   | ??                              | ??              | (x_mu, x_logvar) |
| Decoder loss | -log(p spike\|Poisson(rates)) | ??                          | ??              | sparse sce with logits |
| Learning rate| 1e-2 decay 0.95 every 6   | ??                              | ??              | ?? |


## Hyperparameters

We separate our hyperparameters into non-tunable 'arguments' and tunable 'parameters'. This helps with the hyperparameter optimization framework.


## Prepare inputs

Apply dropout, split f_encoder inputs off from inputs to prevent acausal modeling (optional), coordinated dropout (optional), CV mask (not implemented yet), Dense to input factors (optional).


In [6]:
test_prepare_inputs(n_times=246)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 246, 36)]    0                                            
__________________________________________________________________________________________________
dropout (Dropout)               (None, 246, 36)      0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_strided_slice_1 (Te [(None, 246, 36)]    0           dropout[0][0]                    
__________________________________________________________________________________________________
coordinated_dropout (Coordinate ((None, 246, 36), (N 0           tf_op_layer_strided_slice_1[0][0]
______________________________________________________________________________________________

## *f*-Encoder

Transform full sequence of "features" (`inputs` or `ReadIn(inputs)`) through (1) RNN then (2) affine to yield parameters of latent posterior distribution:
$$q(f | x_{1:T})$$
This distribution is a multivariate normal, optionally with off-diagonal elements allowed.


Model loss will include the KL divergence between the static latent posterior and a prior. The prior is a learnable multivariate normal diagonal. The prior is initialized with a mean of 0 and a stddev of 1 but these are trainable by default.



In [69]:
test_create_f_encoder()

Model: "f_encoder_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, 36)]   0                                            
__________________________________________________________________________________________________
f_rnn_0 (Bidirectional)         (None, 256)          126720      input_1[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 256)          0           f_rnn_0[0][0]                    
__________________________________________________________________________________________________
f_scale (Dense)                 (None, 10)           2570        dropout[0][0]                    
____________________________________________________________________________________

## *z*-Encoder

$q(z_t | x_{1:T})$

I have also seen this called the "Dynamic Encoder", or in LFADS the "Controller Input" encoder.

The *z*-Encoder varies quite a bit between the different Disentangling/$\beta$ Variational Autoencoder implementations. Indeed, in some formulations it isn't used at all, such as the LFADS model without inferred controller input.

* The inputs are the original data sequences ($x_t$).
* Unlike the *f*-encoder, here we output full sequences.
* The output sequences parameterize a multivariate normal distribution **at each timestep**
* The encoder itself has as its first layer
    * a RNN (LSTM, GRU), often bidirectional, or
    * a simple MLP as in the DHSAE Factorized model
* If the first layer is an RNN then there is usually a second layer forward-only RNN.

### Extra Details - DHSAE Full

* The inputs are concatenated with a tiled sample from $q(f)$.
* We've added the option to instead concatenate on the inputs into the second RNN.

### Extra Details - LFADS

* Like its f-Encoder, the RNN cells are a GRU with clipping.
* The secondary RNN input is the output from the primary RNN concatenated with the **decoder RNN's previous step + transformed through the factor Dense layer**.

Because the LFADS secondary RNN is so complicated, it is integrated into the decoder RNN itself in a "complex cell". The complex cell includes the z2-cell, making the z2 outputs variational in $q(z_t)$, sampling $q(z_t)$ for the inputs to the generative RNN cell, passing the output of the generative RNN step through a Dense to-factors layer, and finally using that output as one of the inputs to the z2 cell. If `params['gen_cell_type']` is `"Complex"`, then we assume that LFADS is being used and we thus skip the second RNN in `create_z_encoder` and we skip making the latents variational in `make_z_variational`.


In [None]:
# TODO: Rework this
# TODO: Compare to LFADS' prior on enc_z.
def sample_dynamic_prior(self, timesteps, samples=1, batches=1, fixed=False):
    """
    Samples from self.dynamic_prior_cell `timesteps` times.
    On each step, the previous (sample, state) is fed back into the cell
    (zero_state used for 0th step).

    The cell returns a multivariate normal diagonal distribution for each timestep.
    We collect each timestep-dist's params (loc and scale), then use them to create
    the return value: a single MVN diag dist that has a dimension for timesteps.

    The cell returns a full dist for each timestep so that we can 'sample' it.
    If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
    to doing a generative RNN (init state = zeros, return_sequences=True) then passing
    those values through a pair of Dense layers to parameterize a single MVNDiag.

    :param timesteps: Number of timesteps to sample for each sequence.
    :param samples: Number of samples to draw from the latent distribution.
    :param batches: Number of sequences to sample.
    :param fixed: Boolean for whether or not to share the same random
        sample across all sequences in batch.
    """
    if fixed:
        sample_batch_size = 1
    else:
        sample_batch_size = batches

    sample, state = self.dynamic_prior_cell.zero_state([samples, sample_batch_size])
    locs = []
    scale_diags = []
    sample_list = []
    for _ in range(timesteps):
        dist, state = self.dynamic_prior_cell(sample, state)
        sample = dist.sample()
        locs.append(dist.parameters["loc"])
        scale_diags.append(dist.parameters["scale_diag"])
        sample_list.append(sample)

    sample = tf.stack(sample_list, axis=2)
    loc = tf.stack(locs, axis=2)
    scale_diag = tf.stack(scale_diags, axis=2)

    if fixed:  # tile along the batch axis
        sample = sample + tf.zeros([batches, 1, 1])

    return sample, tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
    # TODO: Move 1 of the batch dims into event dims

In [49]:
test_create_z_encoder()

Model: "z_encoder_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, 36)]        0         
_________________________________________________________________
z_rnn_1 (Bidirectional)      (None, None, 32)          5088      
Total params: 5,088
Trainable params: 5,088
Non-trainable params: 0
_________________________________________________________________
<class 'tensorflow.python.framework.ops.EagerTensor'> (16, 246, 32)


In [None]:
dynamic_prior_cell = LearnableMultivariateNormalDiagCell(3, 4, cell_type='gru')
sample, state = dynamic_prior_cell.zero_state([1, 1])
locs = []
scale_diags = []
sample_list = []
for _ in range(161):
    dist, state = dynamic_prior_cell(sample, state)
    sample = dist.sample()
    locs.append(dist.parameters["loc"])
    scale_diags.append(dist.parameters["scale_diag"])
    sample_list.append(sample)

## Generator (Decoder part 1)

$p(x_t | z_t, f)$

The generator is an RNN that outputs full sequences from the encoded latents which comprise a single-timestep latent vector (*f*) and optionally a low-dimensional sequence ($z_t$). Note that these latents are distributions, and therefore must be sampled from to get the initial state and/or the inputs to the generative RNN.

The generative RNN outputs a sequence. This sequence is typically transformed through a Dense layer to yield the "factors". However, in LFADS the factors are fedback to the z2_encoder step-by-step, and this cannot be accomplished in a normal sequential layer connection. Instead, LFADS includes the dense layer inside a "ComplexCell". To be consistent with the LFADS implementation we need to include the to-dense layer in other `create_generator_` functions.

In [None]:
from indl.model.lfads.complex_cell import ComplexCell


def create_generator_LFADS():
    """
    units_gen,
    units_con,
    factors_dim,
    co_dim,
    ext_input_dim,
    inject_ext_input_to_gen,
    """
    
    # TODO: Sample/Mean from $q(f)$. This will replace the first element in generator init_states
    #  TODO: need a custom function for sample-during-train-mean-during-test. See nn.dropout for inspiration.
    # TODO: Sample from $q(z_t)$, and optionally concat with ext_input, to build generator inputs.
    
    
    # TODO: continue generator from lfads-cd/lfadslite.py start at 495
    custom_cell = ComplexCell(
        params['gen_dim'],  # Units in generator GRU
        con_hidden_state_dim,  # Units in controller GRU
        params['factors_dim'],
        params['co_dim'],
        params['ext_input_dim'],
        True,
    )
    generator = tfkl.RNN(custom_cell, return_sequences=True,
                         # recurrent_regularizer=tf.keras.regularizers.l2(l=gen_l2_reg),
                         name='gen_rnn')
    init_states = generator.get_initial_state(gen_input)
    
    
    gen_output = generator(gen_input, initial_state=init_states)
    factors = gen_output[-1]
    return factors

## Probabilistic Reconstructed Input (Decoder part 2)

The factors are passed through a Dense layer and the outputs are the same dimensionality as the inputs, but instead of reconstructing the inputs, they parameterize a distribution representing the inputs. This distribution can be Gaussian or Poisson, with the latter being more appropriate for (binned) spike counts.