Todo:

coding:
- compare with R (phytools)

theory:
- find out if inverting big trees is really a problem for realistic phylogenies

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import hyperiax
from jax.random import PRNGKey, split
import jax
from jax import numpy as jnp
from hyperiax.execution import LevelwiseTreeExecutor, DependencyTreeExecutor
from hyperiax.models import UpLambda, DownLambda
from hyperiax.models.functional import pass_up
import jax
from functools import partial

import time

key = PRNGKey(0)

## Phylogenetic mean and variance computed via _phylogenetic independent contrasts_ (PIC) 

Vi use independent contrasts to estimate inner nodes, root and variance under a Brownian motion model.

### Inner nodes and root

Inner nodes are estimated accorcding to: 

$$

T_I = \frac{\frac{1}{e_{x1}}T_{x1} + \frac{1}{e_{x2}}T_{x2}}{\frac{1}{e_{x1}}+\frac{1}{e_{x2}}}
$$
Where I is the intermediate note between the two nodes x1 and x2 
T is the trait for the corresponding node 
and e is the edge length of the corresponding node 

To account for uncertainty in estimating intermediate nodes, the PIC estimator modifies the edge lengths according to: 

$$
e_{I} = e_I +\frac{ e_{x1}*e_{x2}}{e_{x1}+e_{21}}
$$
where $e$ is the edge length, and I is the intermediate note between the two nodes x1 and x2

### Variance estimate

The PIC variance estimate is given by

$$\hat{\sigma}^2_{PIC} = \frac{\sum_{i,j} s^2_{ij}}{n-1}$$

where

$$ s_{ij} = \frac{T_{x_i} - T_{x_j}}{v_i + v_j}$$

which is computed from all (given or estimated) inner nodes and the PIC-adjusted edge-lengths.




# make tree
Add a random example tree, where each value in the nodes is a quantitative trait. 

To illustrate the edge length correction, we intially assume each length is equal to 1. 

In [3]:
levels = 7

tree = hyperiax.tree.builders.symmetric_tree(levels, 2)

subkey, key = split(key)

for i, node in enumerate(tree.iter_bfs()):
    node['edge_length'] = 1#+i

for i, leaf in enumerate(tree.iter_leaves()):
    key, subkey = split(key)
    leaf['size'] = 1 #jax.random.randint(subkey, (1,),1,20)
    leaf['independent_contrast'] = 0 #jax.random.randint(subkey, (1,),1,20)
   # leaf['size'] = i +1

In [4]:
if levels < 5:
    tree.plot_tree()

## Compute inner nodes and independent contrasts

In [6]:
### dummy pass_up function (for testing)

up = pass_up('size','edge_length')

def fuse(child_edge_length,child_size, edge_length, **kwargs):

    return {'size':1, 'edge_length':1}


In [88]:
# ### for inner node estimation only (using the IC reweighted edges)

# up = pass_up('size','edge_length')

# def fuse(child_edge_length,child_size, edge_length, **kwargs):

#     corrected_edge_length = edge_length+(child_edge_length[0]*child_edge_length[1])/(child_edge_length[0]+child_edge_length[1])

#     top = 1/child_edge_length[0]*child_size[0]+1/child_edge_length[1]*child_size[1] 
#     bot = 1/child_edge_length[0]              +1/child_edge_length[1] 
#     return {'size':top/bot, 'edge_length':corrected_edge_length}


# upmodel = UpLambda(up, fuse)
# root_exe = DependencyTreeExecutor(upmodel, batch_size=5)
# corrected_tree = root_exe.up(corrected_tree)



In [89]:
### this version includes computation of independent contrasts (to be used for the variance estimate)

# up = pass_up('size', 'independent_contrast','edge_length')

# def fuse(child_edge_length,child_size, edge_length, **kwargs):

#     corrected_edge_length = edge_length+(child_edge_length[0]*child_edge_length[1])/(child_edge_length[0]+child_edge_length[1])

#     # inner node est
#     numerator = 1/child_edge_length[0]*child_size[0]+1/child_edge_length[1]*child_size[1] 
#     denominator = 1/child_edge_length[0]              +1/child_edge_length[1] 
#     node_est = numerator / denominator

#     # independent contrast 
#     IC = (child_size[0] - child_size[1]) / (child_edge_length[0] + child_edge_length[1])
    
#     return {'size': node_est, 'independent_contrast': IC, 'edge_length': corrected_edge_length}




In [7]:
start = time.time()

upmodel = UpLambda(up, fuse)
root_exe = DependencyTreeExecutor(upmodel, batch_size=5)
estimated_tree = root_exe.up(tree)


end = time.time()
print(end - start)

0.10872292518615723


In [None]:
print(estimated_tree.root)

TreeNode({'edge_length': Array(1.9990234, dtype=float32, weak_type=True), 'size': Array(1., dtype=float32, weak_type=True)}) with 2 children


Compute covariance

In [None]:
# extract the N-1 independent contrasts and compute the variance estimate

independent_contrasts = jnp.array([node['independent_contrast'] for node in estimated_tree.iter_bfs() if node.children])  # do this in a faster way?

# Compute the sum of outer products
var_est_pic = jnp.einsum('ji,jk->ik', independent_contrasts, independent_contrasts) / independent_contrasts.shape[0]

print(var_est_pic)

[[6.814815]]


# Plot estimated inner nodes (trait values)

Here i move the size to the "name" place, because I earlier implenented that it was possible to see the name of nodes (and therefor replace it with the size)
It would might be better for a futgure implementation to do something about the dot size and with coloring 

Note also that the root is not included in this plot with names...

In [None]:
for leaf in corrected_tree.iter_bfs():
    leaf.name = "   "+str(leaf['size'][0]) + "   "
    
if levels < 5:
    corrected_tree.plot_tree_text()