Du hast einen Baum und wir wollen fuer jedes Edge den satute test anwenden.
Dafuer schneiden wir das Edge und berechnen dann die partial likelihood an den zwei Nodes des edges.

In [183]:
import numpy as np
from typing import Optional
from functools import cache

In [159]:
from scipy.sparse.linalg import expm

In [160]:
rate_matrix = np.array([[-3, 1, 1, 1], [1, -3, 1, 1], [1, 1, -3, 1], [1, 1, 1, -3]])

In [175]:
import dataclasses

class Tree:
    def __init__(self, edges):
        for left, right, branchlength in edges:
            left.connect(right, branchlength)
            right.connect(left, branchlength)

        self.edges = edges

    def get_edges(self):
        return self.edges

    def get_branch_length(self, edge):
        return self.branch_lengths[edge]


class Node:
    name: str
    state: Optional[np.array] = dataclasses.field(default_factory=lambda: None)
    connected: dict = dataclasses.field(default_factory=dict)

    def __init__(self, name, state=None, connected=None):
        self.name = name
        self.state = state
        self.connected = connected or {}

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    def connect(self, other_node, branchlength):
        self.connected[other_node] = branchlength

    def is_leaf(self):
        return self.state is not None

In [176]:
def test(p1, p2):
    pass

In [177]:
a1, b2, u3, u4, c5, b6 = [Node("A", get_initial_likelihood_vector('A')), 
                          Node("B", get_initial_likelihood_vector('C')),
                          Node("3"),
                          Node("4"), Node("C", get_initial_likelihood_vector('A')), Node("D", get_initial_likelihood_vector('A'))]

tree = Tree([(a1, u3, 0.01), (u3, b2, 0.01), (u3, u4, 0.01), (u4, c5, 0.01), (u4, b6, 0.01)])

In [178]:
u3.is_leaf()

False

In [184]:
def get_initial_likelihood_vector(state):
    nucleotide_code_vector = {
        "A": [1, 0, 0, 0],
        "C": [0, 1, 0, 0],
        "G": [0, 0, 1, 0],
        "T": [0, 0, 0, 1],
        "N": [1, 1, 1, 1],
        "-": [1, 1, 1, 1],
    }
    return np.array(nucleotide_code_vector[state]).T


@cache
def partial_likelihood(tree, node, coming_from):
    results = 1
    if node.is_leaf():
        return node.state    
    for child in node.connected.keys():
        if not child.__eq__(coming_from):

            e = expm(rate_matrix * child.connected[node])
            p = partial_likelihood(tree ,child, node)
            results = results* ((e @ p))
    return results

In [185]:


for edge in tree.get_edges():
    left, right, branchlength = edge
    p1 = partial_likelihood(tree, left, right)
    p2 = partial_likelihood(tree, right, left)

    print(left.name, right.name, p1, p2)

    test(p1, p2)

A 3 [1 0 0 0] [8.96302554e-03 9.05534948e-03 9.14558596e-05 9.14558596e-05]
3 B [8.87459032e-01 9.14558596e-05 9.14558596e-05 9.14558596e-05] [0 1 0 0]
3 4 [9.51436495e-03 9.51436495e-03 9.60917551e-05 9.60917551e-05] [9.42048985e-01 9.60917551e-05 9.60917551e-05 9.60917551e-05]
4 C [9.05534948e-03 9.14558596e-05 2.75198971e-06 2.75198971e-06] [1 0 0 0]
4 D [9.05534948e-03 9.14558596e-05 2.75198971e-06 2.75198971e-06] [1 0 0 0]


In [220]:
e = expm(rate_matrix*0.01) @ get_initial_likelihood_vector('C')

def name_nodes_by_level_order(tree):
    i = 1
    for node in tree.traverse("levelorder"):
        if not node.is_leaf():
            node.name = f"Node{i}*"
            i += 1
    return tree

#a1, b2, u3, u4, c5, b6 = [Node("A", get_initial_likelihood_vector('A')), 
#                          Node("B", get_initial_likelihood_vector('C')),
#                          Node("3"),
#                          Node("4"), Node("C", get_initial_likelihood_vector('A')), Node("D", get_initial_likelihood_vector('A'))]
# tree = Tree([(a1, u3, 0.01), (u3, b2, 0.01), (u3, u4, 0.01), (u4, c5, 0.01), (u4, b6, 0.01)])
# from ete3 import Tree
# t = Tree('((((H,K)D,(F,I)G)B,E)A,((L,(N,Q)O)J,(P,S)M)C);', format=1)
# for node in t.traverse("postorder"):
# Do some analysis on node
#  print(node.name)

node_list = []
t = Tree('((A,B),(C,D));', format=1);
t = name_nodes_by_level_order(t)

for node in t.traverse("levelorder"):    
    if(node.is_leaf()):
      node_list.append(Node(node.name, get_initial_likelihood_vector('A')))

for node in t.traverse("levelorder"):
    for child_node in node.children:
        print(node.name, child_node.name)




Node1* Node2*
Node1* Node3*
Node2* A
Node2* B
Node3* C
Node3* D
