## Dirichlet Mixture Prior for NeuroAlign

#### Problems: 
- The true probability distribution of a column is only poorly estimated by the relative frequencies $\frac{n_i}{n}$ where $n_i$ is the count of amino acid $i$ and $n=\sum_i n_i$, especially if the number of sequences is low (which is the case when training NeuroAlign due to hardware constraints). 
- If each column predicts its own amino acid distribution, beam search (commonly done in natural language processing) can not be applied when doing autoregressive inference with the model since the output is not expected to agree on a single class (or amino acid).

#### Remedy:

Model the amino acid distribution of each column by a mixture of Dirichlets (Using Dirichlet mixture priors to derive hidden Markov models for protein families, Brown et al., 1993) i.e. a distribution of the form:

$$q_1 P_1 \dots q_k P_k$$

where $P_j$ are Dirichlet densities and $q_j$ are mixture coefficients that sum to 1.

Intuitively this models $k$ "fundamental" amino acid distributions with the assumption that a particular observed distribution is similar to one of these fundamental distributions. This framework obviously solves the first problem mentioned above. It also solves the second one:

Instead of a distribution over the amino acid alphabet, let a model predict a distribution over $k$. If $k$ is sufficiently large, the assumption that every observed count vector is well represented by exactly one of the Dirichlet distributions should be reasonable. Therefore, the model output can be interpreted as a vote for one of $k$ classes and beam search can be implemented by evaluating the $x$ most likely fundamental distributions for the next predicted column at each autoregressive step.

In [5]:
# uncomment below when using the wolke tf 2.3 image 
#import sys
#!{sys.executable} -m pip install tensorflow_probability==0.11.0

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow import keras
from tensorflow.keras import layers

Collecting tensorflow_probability==0.11.0
  Downloading tensorflow_probability-0.11.0-py2.py3-none-any.whl (4.3 MB)
[K     |████████████████████████████████| 4.3 MB 5.0 MB/s eta 0:00:01
Collecting cloudpickle==1.3
  Downloading cloudpickle-1.3.0-py2.py3-none-any.whl (26 kB)
Installing collected packages: cloudpickle, tensorflow-probability
  Attempting uninstall: cloudpickle
    Found existing installation: cloudpickle 1.6.0
    Uninstalling cloudpickle-1.6.0:
      Successfully uninstalled cloudpickle-1.6.0
  Attempting uninstall: tensorflow-probability
    Found existing installation: tensorflow-probability 0.12.1
    Uninstalling tensorflow-probability-0.12.1:
      Successfully uninstalled tensorflow-probability-0.12.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
distributed 2.30.1 requires cloudpickle>=1.5.0, but you have cloudpickle 1.3.0 which i

The Dirichlet mixture prior is a trainable layer with (alphabet_size + 1) * $k$ weights. It takes as input a batch of count vectors corresponding to the columns of a MSA and the number of sequences of the alignment. It outputs a discrete probability distribution over the $k$ fundamental amino acid distributions for each count vector in the batch.

It can be used as follows: In the NeuroAlign model, every time a raw count vector of amino acids appears (inputs and outputs) it is replaced by the output of the DirichletMixturePrior. The weights of the prior itself can:

- be pretrained and freezed.
- be jointly trained with the model. If $L$ is the model loss, then probably a new loss $L' = \lambda L + (1-\lambda) CE(prior, truth)$ can be used.
- maybe be found online for some choices of $k$.

In [None]:
import time

alphabet_size = 5

class DirichletMixturePrior(layers.Layer):
    def __init__(self, k):
        super(DirichletMixturePrior, self).__init__()
        # Dirichlet parameters > 0
        self.alpha = tf.nn.softplus(self.add_weight(shape=(1, k, alphabet_size),
                                        name="alpha", initializer="uniform", trainable=True))
        # mixture coefficients that sum to 1
        self.mixture_coeff = tf.nn.softmax(self.add_weight(shape=(1, k),
                                        name="mixture_coeff", initializer="uniform", trainable=True))


    # in: n x alphabet_size count vectors 
    # out: n x k posterior probabilty distribution P(k | count)
    def call(self, counts, total_count):
        dist = tfp.distributions.DirichletMultinomial(total_count, self.alpha)
        probs = dist.prob(tf.expand_dims(counts, 1)) #P(count | p_k)
        mix_probs = self.mixture_coeff * probs
        return mix_probs / tf.reduce_sum(mix_probs, axis=-1, keepdims=True)



dirichlet_mixture = DirichletMixturePrior(k = 10)

# count vectors from a MSA
n = 7
samples = 8
draws = tf.random.categorical([[0.1, 0.3, 0.2, 0.2, 0.2] for _ in range(n)], samples)
counts = np.zeros((n, alphabet_size))
for i in range(alphabet_size):
    counts[:,i] = np.sum(draws == i, axis=-1)

start = time.time()
count_probs = dirichlet_mixture(counts, samples)
end = time.time()
print(end - start)

for i,(p,c) in enumerate(zip(count_probs, counts)):
    print("count vector: ", i,  c)
    print("prob: ", i,  p)