In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import hyperiax
from jax.random import PRNGKey, split
import jax
from jax import numpy as jnp
from hyperiax.execution import OrderedExecutor
from hyperiax.models import UpLambdaReducer, DownLambda, UpLambda
from hyperiax.models.functional import pass_up
from hyperiax.tree.topology import symmetric_topology
from hyperiax.tree import HypTree
from hyperiax.plotting import plot_tree_text, plot_tree_2d_scatter
from matplotlib import pyplot as plt
import jax

In [2]:
key = PRNGKey(0)

## Rate likelihood estimation - rough draft

In this notebook, we setup a rough draft of rate estimation of discerete variables 

This notebook follows work of Sergei T. 

The outline of this notebook comes in; 

    1. Creating the tree

    2. Simulating discerete character from root to leafs.  (NOT implemented, i just draft some randomly)

    3. One up pass using likelihood to restimate the the rate
    
    4. Optimizer to converge for correct estimation of the rate



### 1. Creating the Tree

First, we initialize a tree. This creates a tree with a chosen topology. This topology is a "stupid" class, in the sense that it contains no data - and only serves as a representation of the data we intend to work on.

Setting `height=3` and `degree=2`, gives us a tree with 4 layers, where each node has `2` children

In [3]:
topology = symmetric_topology(height=2, degree=2)
plot_tree_text(topology)


       None
   ┌────┴────┐
  None      None   
 ┌─┴──┐    ┌─┴──┐  
None None None None


In [4]:
tree = HypTree(topology,precompute_child_gathers=True)


# Propreties of the tree

# Branch/edge length, assumed to be constant one 
tree.add_property('edge_length', shape=(1,))
tree.data["edge_length"]  = jnp.array([1.0] * tree.size)


for i,node in enumerate(tree.iter_topology_bfs()):
   node.name = str(i)

# Empty properties to fill out later 

# Chacters storing 
tree.add_property('value', shape=(2,))

# plot tree again 
plot_tree_text(tree)
#

   0
 ┌─┴─┐
 1   2 
┌┴┐ ┌┴┐
3 4 5 6


### 2. Simulating node values

Could be a markov chain, I will just draft some random values for now

I have kept the strcutre of the code how to do it proberly 

In [5]:
# Insert some correct thing here. 

#@jax.jit
#def down(noise, edge_length,parent_value, **args):
#    # Insert 
#    return {'value':value}

# Define root variable for our simulation 
#tree.data['value'] = tree.data['value'].at[0].set([0,0])

# Execute the simulation on tree
#downmodel = DownLambda(down_fn=down)
#exe = OrderedExecutor(downmodel)
#exe.down(tree)

### Fake data

In [6]:
# Draw random 0,1 in the size corresponding to the under here
# Random sampler 
tree.data['value'] = tree.data['value'].at[tree.is_leaf].set(jax.random.bernoulli(key, p=0.5, shape=(sum(tree.is_leaf),2)))

# Print over discerete variables 
print(tree.data['value'][tree.is_leaf])


[[0. 1.]
 [1. 0.]
 [1. 0.]
 [1. 0.]]


### 3. Estimating the inner nodes

Here we could mask our internal nodes "value" from our downsimulation, since we would only now the leaf values. 
We keep with this format, and therefore more our leaf data to another variable


# Upward Pass Computation

We estimate the upwards probability from the leafs to the root; 

$$
p(parent)= p(t1)*p(t2) 
$$
where $p(t1)=d1*exp(Q*t1)^T$, and t1 is the branch length 

In [7]:
tree.add_property('estimated_value', shape=(2,))
leaf_data = tree.data['value'][tree.is_leaf]
tree.data['estimated_value'] = tree.data['estimated_value'].at[tree.is_leaf].set(leaf_data)

leaf_edgelength = tree.data['edge_length'][tree.is_leaf]

Now we estimate the probability in the leaf nodes, since we only have data in chacaters in the leafs, and upwards we multiply the edgelengths with probabilies (but I may have misunderstood that part)

Manual estimating here:

In [45]:
# Define rate matrix, which is 2x2 as a function of input of alpha and beta
# as a function 
Q_rate_matrix = lambda alpha, beta: jnp.array([[-alpha, alpha], [beta,-beta]])
prob_estimation = lambda chacter,length : jnp.dot(chacter,jnp.exp(Q_rate_matrix(alpha,beta)*length))


# Define alpha and beta
alpha = 0.5
beta = 0.5


# Print the functions as a test
print(Q_rate_matrix(0.5,1))
print(prob_estimation(jnp.array([1,0]),1))

# JUST A CHECK
# Apply the functions on leaf data and estimate the probability
leaf_data_prob = jax.vmap(prob_estimation, in_axes=0)(leaf_data,leaf_edgelength)
print("all leafs")
print(leaf_data_prob)

# assign the values in the leafs of the tree

[[-0.5  0.5]
 [ 1.  -1. ]]
[0.60653067 1.6487212 ]
all leafs
[[1.6487212  0.60653067]
 [0.60653067 1.6487212 ]
 [0.60653067 1.6487212 ]
 [0.60653067 1.6487212 ]]


# Hyperiax implementation of up




In [10]:
def probability_function(child_estimated_value, child_edge_length, **kwargs):

    # Do probability estimation 
    probs = jax.vmap(prob_estimation, in_axes=0)(child_estimated_value,child_edge_length)

    # Multiply probabilities element-wise and reduce along the first axis
    return {'estimated_value': jnp.prod(probs, axis=1)}

In [11]:

upmodel = UpLambda(up_fn=probability_function)
upmodelexe = OrderedExecutor(upmodel)
res = upmodelexe.up(tree)



In [46]:
#Print the results
tree.data

{'edge_length': Array([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
 'value': Array([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], dtype=float32),
 'estimated_value': Array([[10.610551  ,  5.086161  ],
        [ 1.        ,  1.        ],
        [ 0.36787945,  2.7182817 ],
        [ 0.        ,  1.        ],
        [ 1.        ,  0.        ],
        [ 1.        ,  0.        ],
        [ 1.        ,  0.        ]], dtype=float32)}

In [14]:
# Do manual computation to see the results is true 

leaf_data_prob = jax.vmap(prob_estimation, in_axes=0)(leaf_data,leaf_edgelength)

# Check the first node at see it is the same
print("manual estimation of internal nodes")
n3 = jnp.dot(leaf_data[0],jnp.exp(Q_rate_matrix(alpha,beta)*1))
n4 = jnp.dot(leaf_data[1],jnp.exp(Q_rate_matrix(alpha,beta)*1))
print(n3*n4)

n5 = jnp.dot(leaf_data[2],jnp.exp(Q_rate_matrix(alpha,beta)*1))
n6 = jnp.dot(leaf_data[3],jnp.exp(Q_rate_matrix(alpha,beta)*1))
print(n5*n6)


# Do reshape so it fits with the upwards pass
leaf_data_prob = jax.vmap(prob_estimation, in_axes=0)(leaf_data,leaf_edgelength)
leaf_data_prob = leaf_data_prob.reshape(2,2,2)

# See these are the same
midder = (jnp.prod(leaf_data_prob, axis=1))
print("internal nodes")
print(midder)

# and then do the last to see the value
print("Root") 
print(jnp.dot(jnp.array(midder[0]),jnp.exp(Q_rate_matrix(alpha,beta)*1))*jnp.dot(jnp.array(midder[1]),jnp.exp(Q_rate_matrix(alpha,beta)*1)))

manual estimation of internal nodes
[1. 1.]
[0.36787945 2.7182817 ]
internal nodes
[[1.         1.        ]
 [0.36787945 2.7182817 ]]
Root
[10.610551  5.086161]


## 4. Do the optimization to estimate alpha/beta

NOTE - I may not have recalled the likelihood function correctly

In [49]:
# Use JAX's BFGS optimizer from scipy.optimize

from jax.example_libraries import optimizers
from jax import grad
pi_root = [1,1]
pred_root = tree.data["estimated_value"][tree.is_root]

# Define the negative log-likelihood function
def negative_log_likelihood(params):
    alpha, beta = params
    Q = Q_rate_matrix(alpha, beta)
    # note we do not include edge length since we are in the root
    root_probs = jnp.exp(Q)
    likelihood = jnp.sum(jnp.array(pi_root) * jnp.dot(pred_root, root_probs))

    return -jnp.log(likelihood)


# Initialize optimization parameters
initial_params = jnp.array([0.5, 0.5])
# Use JAX's optimizers for BFGS optimization
from jax.example_libraries import optimizers
import jax.numpy as jnp

# Define the optimization function
@jax.jit
def optimize_step(i, opt_state, pi_root, pred_root):
    params = get_params(opt_state)
    value, grads = jax.value_and_grad(negative_log_likelihood)(params)
    return opt_update(i, grads, opt_state), value

# Initialize the optimizer
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(initial_params)

# Optimization loop
num_iterations = 100
for i in range(num_iterations):
    # Re-estimate the root value using upmodelexe.up
    upmodelexe.up(tree)
    pred_root = jnp.array(tree.data["estimated_value"][tree.is_root])
    
    opt_state, value = optimize_step(i, opt_state, jnp.array(pi_root), pred_root)
    
    if i % 10 == 0:
        print(f"Iteration {i}, Loss: {value}")

# Get the final optimized parameters
final_params = get_params(opt_state)
estimated_alpha, estimated_beta = final_params

print(f"Estimated alpha: {estimated_alpha}")
print(f"Estimated beta: {estimated_beta}")
print(f"Final negative log-likelihood: {negative_log_likelihood(final_params)}")



Iteration 0, Loss: -3.5667128562927246
Iteration 10, Loss: -3.6169826984405518
Iteration 20, Loss: -3.6757652759552
Iteration 30, Loss: -3.743407726287842
Iteration 40, Loss: -3.81976580619812
Iteration 50, Loss: -3.9042000770568848
Iteration 60, Loss: -3.995724678039551
Iteration 70, Loss: -4.09318733215332
Iteration 80, Loss: -4.195422172546387
Iteration 90, Loss: -4.301346778869629
Estimated alpha: 1.61799955368042
Estimated beta: 1.6179993152618408
Final negative log-likelihood: -4.410018444061279
