In [1]:
# import torch
import numpy as np

# define tree and cluster




In [2]:
# tree structrue
theta = np.array([[+1,-1,-1,0,0,0], 
                 [0,+1,-1,0,0,0],
                 [+1,+1,+1,-1,-1,-1,],
                 [0,0,0,+1,-1,0],
                 [0,0,0,+1,+1,-1]])

# define ilr and inverse ilr

In [3]:
np.empty(2, dtype=np.float32)

array([1., 1.], dtype=float32)

In [4]:
def get_n_plus_and_n_minus(theta, w):
    Dm1, D = theta.shape
    n_plus = np.empty(Dm1, dtype=np.float)
    n_minus = np.empty(Dm1, dtype=np.float)
    for i in range(Dm1):
        n_plus[i] = np.sum(w[theta[i]==1])
        n_minus[i] = np.sum(w[theta[i]==-1])
    return n_plus, n_minus

In [5]:
def get_psi(theta, n_plus, n_minus):
    m, n = theta.shape
    psi = np.zeros_like(theta, dtype=np.float)
    for i in range(m):
        for j in range(n):
            if theta[i, j] == 1:
                psi[i, j] = 1/n_plus[i] * np.sqrt(n_plus[i] * n_minus[i] / (n_plus[i] + n_minus[i]))
            elif theta[i, j] == -1:
                psi[i, j] = -1/n_minus[i] * np.sqrt(n_plus[i] * n_minus[i] / (n_plus[i] + n_minus[i]))
    return psi

In [6]:
def log_geometric_mean(y, w):
    return np.sum(w * np.log(y), axis=-1)/np.sum(w, axis=-1)

In [7]:
def ilr_transform(x, w, theta, n_plus, n_minus):
    """
    x: (D, ), w: (D, ), theta: (D-1, D), n_plus: (D-1, ), n_minuse: (D-1, )
    return ystar: (D-1, )
    """
    D, = x.shape
    y = x / w

    # compute log geometric mean ratio
    log_gm_ratio = np.empty((D-1,), dtype=np.float)
    for i in range(D-1):
        plus_idx = theta[i] == 1
        loggp_yi_plus = log_geometric_mean(y[plus_idx], w[plus_idx])
            
        minus_idx = theta[i] == -1
        loggp_yi_minus = log_geometric_mean(y[minus_idx], w[minus_idx])
         
        log_gm_ratio[i] = loggp_yi_plus - loggp_yi_minus
    
    normalizing_constant = np.sqrt(n_plus * n_minus / (n_plus + n_minus))  # (D-1, )
    
    ystar = normalizing_constant * log_gm_ratio
    return ystar

In [8]:
def inverse_ilr_transform(ystar, psi, w):
    """
    ystar: (D-1,), psi: (D-1, D), w: (D, )
    return: x: (D,)
    """
    
    # (1, D-1) * (D-1, D) -> (1, D)
    exp_ystar_psi = np.exp(np.matmul(ystar[None,], psi))
    exp_ystar_psi = exp_ystar_psi.squeeze(0)
    
    y = exp_ystar_psi / np.sum(exp_ystar_psi)
    x = y * w / np.sum(y * w)
    return x

# data

In [9]:
# counts
c = np.array([10, 20, 30, 40, 25, 15], dtype=np.float)

x = c / c.sum()
# weights
w = np.array([2.1, 3.2, 3, 5, 10, 2])

In [10]:
# ilr transformation
n_plus, n_minus = get_n_plus_and_n_minus(theta, w)
psi = get_psi(theta, n_plus, n_minus)
ystar = ilr_transform(x, w, theta, n_plus, n_minus)

In [11]:
ystar

array([-0.62542566, -0.58484527,  1.28946544,  2.12361312, -0.94436926])

In [12]:
x_recon = inverse_ilr_transform(ystar, psi, w)

In [13]:
x_recon

array([0.07142857, 0.14285714, 0.21428571, 0.28571429, 0.17857143,
       0.10714286])

In [14]:
x

array([0.07142857, 0.14285714, 0.21428571, 0.28571429, 0.17857143,
       0.10714286])

# dynamics

For any **inner** node $i$,
* The probability that $y_i^*$ will change is: $p(i\ breaks) = \left[\prod_{a \in a_i} b_a \right] \cdot b_i$, where $a_i$ are node $i$'s ancestors.
* Conditioning on that $y_i^*$ will change, then $a_i$ are bound to break. The remaing consideration is, for any node $j (j \neq i)$ that is not in $a_i$, the probability that node $j$ represents a group, i.e., j doesn't break but node $j$'s ancestors break: $p(j\ represents\ group| i\ breaks) = \prod_{k \in a_j \cap \neg (a_i \cup \{i\})} b_k \cdot (1 - b_j)$, where $\neg U$ represents nodes are not in set $U$.
* Then $y_i^*$'s change rate will be
$$\frac{dy_i^*}{dt} = p(i\ breaks) \cdot (g_i + \sum_{j\in \neg (a_i \cup \{i\})} p(j\ represents\ group | i\ breaks) \cdot A_{ij} \cdot r_j),$$
where $r_j$ is the relative abundance of node $j$.

### tree data structure setup

In [15]:
class Tree(object):
    def __init__(self, parent=None, left=None, right=None, node_idx=None, inode_idx=None, taxon_idx=None, b=0.5):
        self.left = left
        self.right = right
        self.parent = parent
        self.b = b
        self.inode_idx = inode_idx
        self.taxon_idx = taxon_idx
        self.node_idx = node_idx

    def is_taxon(self):
        return self.left is None and self.right is None
    
    def print_tree(self, level=0):
        name = self.name()
        print("   " * level + name)
        if self.left is not None:
            self.left.print_tree(level+1)
        if self.right is not None:
            self.right.print_tree(level+1)

    def name(self):
        if self.is_taxon():
            name = "node {} taxon {}".format(self.node_idx, self.taxon_idx)
        else:
            name = "node {} inode {}".format(self.node_idx, self.inode_idx)
        return name

In [16]:
def convert_theta_to_tree(theta):
    n_node = theta.shape[0] + theta.shape[1]
    node_reference = [0] * n_node  # placeholder
    root = convert_theta_to_tree_helper(theta, node_reference, None)
    return root, node_reference

def convert_theta_to_tree_helper(theta, node_reference, parent, is_left_child=True):
    # find and return parent's left/right child node
    if parent is None:
        n_taxa = theta.shape[1]
        node_taxa = np.arange(n_taxa)
    else:
        parent_taxon_idx = parent.inode_idx
        assert parent_taxon_idx is not None
        theta_node = theta[parent_taxon_idx]
        if is_left_child:
            node_taxa = np.where(theta_node == +1)[0]
        else:
            node_taxa = np.where(theta_node == -1)[0]

    if len(node_taxa) == 1:
        taxon_idx = node_taxa[0]
        n_inode = theta.shape[0]
        node_idx = n_inode + taxon_idx
        child = Tree(parent=parent, taxon_idx=taxon_idx, node_idx=node_idx, b=0)  # taxon nodes cannot break
        node_reference[node_idx] = child
    else:
        inode_idx = -1
        for i, theta_i in enumerate(theta):
            leaves = np.where(theta_i != 0)[0]
            if len(leaves) == len(node_taxa) and (leaves == node_taxa).all():
                inode_idx = i
                break
        assert inode_idx != -1, "cannot find the child whose leaves should be {}".format(node_taxa)

        child = Tree(parent=parent, inode_idx=inode_idx, node_idx=inode_idx)
        child.left = convert_theta_to_tree_helper(theta, node_reference, child, is_left_child=True)
        child.right = convert_theta_to_tree_helper(theta, node_reference, child, is_left_child=False)
        node_reference[inode_idx] = child

    return child

In [17]:
root, reference = convert_theta_to_tree(theta)

In [18]:
root.print_tree()

node 2 inode 2
   node 0 inode 0
      node 5 taxon 0
      node 1 inode 1
         node 6 taxon 1
         node 7 taxon 2
   node 4 inode 4
      node 3 inode 3
         node 8 taxon 3
         node 9 taxon 4
      node 10 taxon 5


In [19]:
for node in reference:
    print(node.name())

node 0 inode 0
node 1 inode 1
node 2 inode 2
node 3 inode 3
node 4 inode 4
node 5 taxon 0
node 6 taxon 1
node 7 taxon 2
node 8 taxon 3
node 9 taxon 4
node 10 taxon 5


In [20]:
def get_p_i(n_inode, root):
    p_i = np.empty(n_inode)
    get_p_i_helper(p_i, root, 1)
    return p_i

def get_p_i_helper(p_i, node, p_ancestors_break):
    if node.is_taxon():
        return
    node_idx = node.node_idx
    p_ancestors_break *= node.b
    p_i[node_idx] = p_ancestors_break
    get_p_i_helper(p_i, node.left, p_ancestors_break)
    get_p_i_helper(p_i, node.right, p_ancestors_break)

In [21]:
def get_p_j_given_i(n_node, n_inode, root):
    p_j_given_i = np.empty((n_inode, n_node))
    get_p_j_given_i_helper(p_j_given_i, n_node, root, root)
    return p_j_given_i

def get_p_j_given_i_helper(p_j_given_i, n_node, root, node):
    # node: inode i
    if node.is_taxon():
        return
    b_copy = node.b
    node.b = 1

    p_j_given_inode = np.empty(n_node)
    get_p_j_given_inode_helper(p_j_given_inode, root, 1)
    get_p_j_given_i_helper(p_j_given_i, n_node, root, node.left)
    get_p_j_given_i_helper(p_j_given_i, n_node, root, node.right)

    inode_idx = node.inode_idx
    p_j_given_i[inode_idx] = p_j_given_inode
    node.b = b_copy
    
def get_p_j_given_inode_helper(p_j_given_inode, node, p_ancestors_break):
    if node is None:
        return
    node_idx = node.node_idx
    p_j_given_inode[node_idx] = p_ancestors_break * (1 - node.b)
    get_p_j_given_inode_helper(p_j_given_inode, node.left, p_ancestors_break * node.b)
    get_p_j_given_inode_helper(p_j_given_inode, node.right, p_ancestors_break * node.b)

In [22]:
reference[0].b = 1.0
reference[1].b = 0.0
reference[2].b = 1.0
reference[3].b = 0.0
reference[4].b = 0.0

In [23]:
n_inode = theta.shape[0]
n_node = theta.shape[0] + theta.shape[1]
p_i = get_p_i(n_inode, root)
p_j_given_i = get_p_j_given_i(n_node, n_inode, root)

In [24]:
print(p_i)
print(p_j_given_i)

[1. 0. 1. 0. 0.]
[[0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 1. 0. 0. 1. 1. 1.]
 [0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1.]]


In [25]:
np.random.seed(10)

In [26]:
g = np.random.randn(n_inode)
A = np.random.randn(n_inode, n_node) * 10

In [27]:
def get_inode_relative_abundance(root, x_t, n_inode):
    r_t_inode = np.empty(n_inode)
    get_inode_relative_abundance_helper(r_t_inode, root, x_t)
    return r_t_inode
    
def get_inode_relative_abundance_helper(r_t_inode, node, x_t):
    if node.is_taxon():
        return x_t[node.taxon_idx]
    inode_idx = node.inode_idx
    left_r_t = get_inode_relative_abundance_helper(r_t_inode, node.left, x_t)
    right_r_t = get_inode_relative_abundance_helper(r_t_inode, node.right, x_t)
    inode_r_t = left_r_t + right_r_t
    r_t_inode[inode_idx] = inode_r_t
    return inode_r_t    

def simulate(root, g, A, p_i, p_j_given_i, N, n_days):
    def transition_step(ystar_t):
        """
        ystar_t: (D-1, )
        return ystar_tp1: (D-1,)
        """
        # (1, D-1) * (D-1, D) -> (1, D)
        x_t = inverse_ilr_transform(ystar_t, psi, w)
        r_t_inode = get_inode_relative_abundance(root, x_t, len(ystar_t))
        r_t = np.concatenate([r_t_inode, x_t])
        ystar_tp1 = ystar_t + p_i * (g + (A * p_j_given_i).dot(r_t))
        return ystar_tp1

    def emission_step(ystar_t):
        """
        ystar_t: (D-1, )
        return c_t: (D)
        """
        x_t = inverse_ilr_transform(ystar_t, psi, w)
        logN = np.random.normal(loc=np.log(N), scale=0.5)
        N_t = np.random.poisson(np.exp(logN))
        c_t= np.random.multinomial(N_t, x_t).astype(float)
        return x_t, c_t

    n_inode, n_node = A.shape
    n_taxa = n_node - n_inode
    y_star = np.random.randn(n_inode) * 5

    Y = []
    X = []
    C = []
    for _ in range(n_days):
        x, c = emission_step(y_star)
        Y.append(y_star)
        X.append(x)
        C.append(c)
        y_star = transition_step(y_star)
        
    return np.array(Y), np.array(X), np.array(C)

In [28]:
N, n_days = 10000, 10
simulate(root, g, A, p_i, p_j_given_i, N, n_days)

(array([[ -1.16091128,  -2.5086445 ,   5.64392577,  -3.48905015,
          -0.40561092],
        [  2.35687703,  -2.5086445 ,   1.29389737,  -3.48905015,
          -0.40561092],
        [  4.80635792,  -2.5086445 ,  -3.00435898,  -3.48905015,
          -0.40561092],
        [  6.39837318,  -2.5086445 ,  -7.85695902,  -3.48905015,
          -0.40561092],
        [  6.80911115,  -2.5086445 , -13.97729704,  -3.48905015,
          -0.40561092],
        [  6.4843133 ,  -2.5086445 , -20.91582922,  -3.48905015,
          -0.40561092],
        [  6.07388166,  -2.5086445 , -27.94988416,  -3.48905015,
          -0.40561092],
        [  5.65962595,  -2.5086445 , -34.98819712,  -3.48905015,
          -0.40561092],
        [  5.24521675,  -2.5086445 , -42.02668048,  -3.48905015,
          -0.40561092],
        [  4.83080141,  -2.5086445 , -49.06517063,  -3.48905015,
          -0.40561092]]),
 array([[6.84654353e-02, 9.93787361e-02, 6.99552328e-01, 8.05569279e-03,
         1.08913901e-01, 1.56339070

In [29]:
def get_nodes_x(x):
    nodes_x = np.array([x[0], x[1:3].sum(), x[3:6].sum()])
    return nodes_x