In [27]:
from abc import ABC, abstractmethod
import numpy as np
from scipy.linalg import expm

num_states = 3
Q = np.array([[-2, 1, 1], [1, -2, 1], [1, 1, -2]])


def pr(cur_state: int, t: float):
    cur_state_one_hot = np.zeros(num_states)
    cur_state_one_hot[cur_state] = 1

    return expm(Q * t) @ cur_state_one_hot


class Node(ABC):
    @abstractmethod
    def likelihood(self, state: int) -> float:
        pass


class Leaf(Node):
    def __init__(self, state: int):
        self.state = state

    def likelihood(self, state: int) -> float:
        return 1 if state == self.state else 0


class Parent(Node):
    def __init__(self, tL: float, nodeL: "Node", tR: float, nodeR: "Node"):
        self.tL = tL
        self.nodeL = nodeL
        self.tR = tR
        self.nodeR = nodeR

    def likelihood(self, state: int) -> float:
        likelihood_arr_L = np.array(
            [self.nodeL.likelihood(i) for i in range(num_states)]
        )
        likelihood_arr_R = np.array(
            [self.nodeR.likelihood(i) for i in range(num_states)]
        )

        totalL = np.dot(pr(state, self.tL), likelihood_arr_L)
        totalR = np.dot(pr(state, self.tR), likelihood_arr_R)

        return totalL * totalR

    def treeLikelihood(
        self, prior: list[float] = [1 / num_states] * num_states
    ) -> float:
        return sum([prior[i] * self.likelihood(i) for i in range(num_states)])

In [28]:
parent1 = Parent(1, Leaf(0), 1, Leaf(1))
parent2 = Parent(0.5, Leaf(2), 0.5, Leaf(2))
parent3 = Parent(0.5, parent1, 1.5, Leaf(0))
parent4 = Parent(1, parent3, 2, parent2)
parent5 = Parent(0.5, parent4, 2.5, Leaf(1))

parent5.treeLikelihood()

0.0015047673351038832