## Model - Infinite DPM - Chinese Restaurant Mixture Model (CRPMM)

#### Dirichlet mixture model where number of clusters is learned. 

ref = reference sequence  
$N$ = number of reads  
$K$ = number of clusters/components  
$L$ = genome length (number of positions)  
alphabet = {A, C, G, T, -}


no-mutation rate: $\gamma \sim Beta(a,b)$   
no-error rate: $\theta \sim Beta(c,d)$   
Cluster weights ($K$-dim): $\pi | \alpha \sim Dir(\alpha)$  
Cluster assignments ($N$-dim): $z|\pi \sim Categorical(\pi)$  
Cluster centers/haplotypes ($K$x$L$-dim): $h | ref, \gamma \sim Categorical(W) $ 
with $W(l,i)=
\begin{cases} 
\gamma,  \text{ if }i = ref[l] \\
\frac{1-\gamma}{4}, \text{ else. }
\end{cases}$ for $l \in {1, ..., L}$ and $i\in {1,..., |alphabet|}$  
Likelihood of the reads ($N$-dim): $r | z, h, \theta \sim Categorical(E)$ 
with $E(n,l,i)=
\begin{cases} 
\theta,  \text{ if }i = h_{z_n}[l] \\
\frac{1-\theta}{4}, \text{ else. }
\end{cases}$ for $n \in {1, ..., N}$, $l \in {1, ..., L}$ and $i\in {1,..., |alphabet|}$  


In [1]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, Predictive

from jax import random
import jax
import jax.numpy as jnp

import arviz as az
import matplotlib.pyplot as plt

In [2]:
# Minimal example
reference = jnp.array([0])
reads = jnp.array([[0], [1], [1], [1], [0], [1], [0], [1]])
alphabet ='01'

cluster_num = 5

input_data = reference, reads, len(alphabet)

In [3]:
# Use the following as inspiration
# https://forum.pyro.ai/t/variational-inference-for-dirichlet-process-clustering/98/2 

def model_infiniteCRPMM(input_data):
    reference, read_data, alphabet_length = input_data

    # parameters
    read_count = read_data.shape[0]
    genome_length = read_data.shape[1]
    alphabet_length = alphabet_length
    alpha0 = 0.1 
    
    haplotypes = {}  # sample this lazily
    crp_counts = [] 

    # define rates
    mutation_rate = numpyro.sample('mutation_rate', dist.Beta(1, 1))
    error_rate = numpyro.sample('error_rate', dist.Beta(1, 1))

    # create matrix of rates
    mutation_rate_matrix = jnp.full((genome_length, alphabet_length), (1 - mutation_rate) / (alphabet_length - 1))
    mutation_rate_matrix = custom_put_along_axis(mutation_rate_matrix, reference.reshape(genome_length, 1), mutation_rate, axis=1)

    #loc, scale = jnp.zeros(1), jnp.ones(1)*2
    #alpha = numpyro.sample("alpha", dist.LogNormal(loc,scale)) # alpha must be more than zero
    
    for n in range(read_count):
        print('----')
        print('read number ', n)
        print('crp_counts ', crp_counts)
        # sample from a CRP
        weights = jnp.array(crp_counts + [alpha0])
        weights /= weights.sum()
        print('weights ', weights)
        
        cluster_assignments = numpyro.sample("cluster_assignments"+str(n), dist.Categorical(weights))
        print('cluster_assignments', cluster_assignments)
        
        if cluster_assignments >= len(crp_counts): 
            # new cluster 
            crp_counts.append(1) 
        else: 
            # cluster already exists
            crp_counts[cluster_assignments] += 1
            
        # sample haplotypes
        # lazily sample cluster mean
        if int(cluster_assignments) not in haplotypes.keys():
            haplotypes[int(cluster_assignments)] = numpyro.sample("haplotypes"+str(cluster_assignments), dist.Categorical(mutation_rate_matrix))
        
        print('shape haplotypes[int(cluster_assignments)] ', haplotypes[int(cluster_assignments)].shape)
        error_rate_matrix = jnp.full((genome_length, alphabet_length), (1 - error_rate) / (alphabet_length - 1))
        print('error_rate ', error_rate)
        print('shape error_rate_matrix', error_rate_matrix.shape)
        print('before ' , type(error_rate_matrix))
        print('haplotypes[int(cluster_assignments)] ',haplotypes[int(cluster_assignments)])
        error_rate_matrix = custom_put_along_axis(error_rate_matrix, haplotypes[int(cluster_assignments)].reshape(genome_length, 1), error_rate, axis=1)
        
        print('after ',type(error_rate_matrix))
        
        obs = numpyro.sample("obs"+str(n), dist.Categorical(error_rate_matrix), obs=read_data[n])
        

In [None]:
rng_key = jax.random.PRNGKey(0)

num_warmup, num_samples = 2000, 20000

model = model_infiniteCRPMM

# Run NUTS. How many chains? 
kernel = NUTS(model)
mcmc = MCMC(
    DiscreteHMCGibbs(kernel),
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=2
)
mcmc.run(rng_key, input_data)