In [1]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from jax import random
import jax
import jax.numpy as jnp
import torch

In [2]:
import skbio 
import numpy as np

In [12]:
fref_in='../../test_data/super_small_ex/ref.fasta'
freads_in='../../test_data/super_small_ex/seqs.fasta'

In [13]:
def seq_mapping(seq, alphabet):
    # Coding scheme
    # 0:A, 1:C, 2:G, 3:T 4:- (NOT YET:, 5:N)
    mapped = []
    for base in seq:
            mapped.append(alphabet.find(base))
    return np.array(mapped)

def fasta2ref(fref_in, alphabet):
    # Coding scheme
    # 0:A, 1:C, 2:G, 3:T 4:- (NOT YET:, 5:N)
    for seq in skbio.io.read(fref_in, format='fasta'):
        ref = seq_mapping(str(seq), alphabet)
    return ref

def fasta2reads(freads_in, alphabet):
    # Coding scheme
    # 0:A, 1:C, 2:G, 3:T 4:- (NOT YET:, 5:N)
    reads_mapped = []
    for seq in skbio.io.read(freads_in, format='fasta'):
        reads_mapped.append(seq_mapping(str(seq), alphabet))
    return np.array(reads_mapped)

In [14]:

alphabet ='ACGT-'
B = len(alphabet) # size alphabet

# Coding scheme
# 0:A, 1:C, 2:G, 3:T 4:-, 5:N

ref = fasta2ref(fref_in, alphabet)
reads = fasta2reads(freads_in, alphabet) # what about non-unique reads ? 

L=ref.shape[0] # length of genome
N=reads.shape[0] # number of reads

### Simplyfied model with fixed hyperparameters

ref = reference sequence  
$N$ = number of reads  
$K$ = number of clusters/components  
$L$ = genome length (number of positions)
alphabet = {A, C, G, T, -}

Fixed parameters: 
* mutation rate: $1- \gamma$
* error rate: $1-\theta$
* Dirichlet prior: $\alpha = (\alpha_1, ..., \alpha_k)$

Cluster weights ($K$-dim): $\pi | \alpha \sim Dir(\alpha)$  
Cluster assignments ($N$-dim): $z|\pi \sim Categorical(\pi)$  
Cluster centers/haplotypes ($K$x$L$-dim): $h | ref, \gamma \sim Categorical(W) $ 
with $W(l,i)=
\begin{cases} 
\gamma,  \text{ if }i = ref[l] \\
\frac{1-\gamma}{4}, \text{ else. }
\end{cases}$ for $l \in {1, ..., L}$ and $i\in {1,..., |alphabet|}$  
Likelihood of the reads ($N$-dim): $r | z, h, \theta \sim Categorical(E)$ 
with $E(n,l,i)=
\begin{cases} 
\theta,  \text{ if }i = h_{z_n}[l] \\
\frac{1-\theta}{4}, \text{ else. }
\end{cases}$ for $n \in {1, ..., N}$, $l \in {1, ..., L}$ and $i\in {1,..., |alphabet|}$  



In [9]:
K = 10  # Fixed number of components

# fixed constants
gamma = 0.70
theta = 0.99    


L=ref.shape[0] # length of genome
N=reads.shape[0] # number of reads
B = len(alphabet)
# --- Mutation matrix ----
# LxB-dimensional
weight = np.full((L,B),(1-gamma)/(B-1))
np.put_along_axis(weight, ref.reshape(L,1), gamma, axis=1) # is written into weight
# KxLxB dimensional 
ref_gamma_weight = np.array(K*[weight]) # KxLxB 

# --- Error matrix --- 
# NxLxB dimensional 
weight_theta = np.full((N,L,B),(1-theta)/(B-1))


#@config_enumerate
def model(reads): # reads is N x L dimensional 
    
    # hyperparameter
    alpha = np.ones(K)/K
    
    pi = numpyro.sample('pi', dist.Dirichlet(alpha))
    
    genome_axis = numpyro.plate('genome_axis', L, dim=-1)
    with numpyro.plate('haplo_axis', K, dim=-2):
        with genome_axis:
            h = numpyro.sample('h', dist.Categorical(ref_gamma_weight))
            print('h ', h.shape)
    
    with numpyro.plate('read_axis', N,dim=-2):
        z = numpyro.sample('z', dist.Categorical(pi))
        print('z ', h.shape)

        with genome_axis:
            weight_theta = np.full((N,L,B),(1-theta)/(B-1))
            np.put_along_axis(weight_theta, h[z].reshape(N,L,1), theta, axis=2)
            print('weight_theta ', weight_theta.shape)
            numpyro.sample('obs', dist.Categorical(weight_theta), obs=reads) # N x L dimensional 
            
    #print(f"     pi.shape = {pi.shape}")
    #print(f"     h.shape = {h.shape}")
    #print(f"     z.shape = {z.shape}")
    

In [10]:
print('genome lengt = L =', L)
print('number of reads = N =', N) # number of reads
print('lenght of alphabet = B =', B)

genome lengt = L = 7
number of reads = N = 7
lenght of alphabet = B = 5


In [11]:
rng_key = jax.random.PRNGKey(0)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
)
mcmc.run(rng_key, reads)

h  (10, 7)
z  (10, 7)
weight_theta  (7, 7, 5)
h  (5, 1, 1)
z  (5, 1, 1)


  mcmc.run(rng_key, reads)
  mcmc.run(rng_key, reads)


InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (10, 1, 1, 1, 1, 1) and (7, 7, 1)