In [None]:
import os
import io
import pickle
from datetime import datetime
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from ete3 import Tree
from Bio import Phylo
from typing import Literal

In [None]:
# set to None to use the default value
EPOCH_OVERRIDE = None
TREE_IN_EPOCH_OVERRIDE = None

MAX_FAILED_TREES_FRACTION = 1

DATA: Literal["simulated", "toy", "shearwater"] = "shearwater"

In [None]:
if DATA == "simulated":
    latest_result_search_dir = "results/cellphy_simulated_data"
    original_phy_file = "data/cellphy_simulated_set.phy"
elif DATA == "toy":
    latest_result_search_dir = "results/cellphy_toy_data"
    original_phy_file = "data/cellphy_toy_set.phy"
elif DATA == "shearwater":
    latest_result_search_dir = "results/cellphy_shearwater_data"
    original_phy_file = "data/cellphy_shearwater_set.phy"


def find_latest_result():
    # Initialize variables to keep track of the latest file and its creation date
    latest_file = None
    latest_date = datetime.min

    # Walk through the directory tree and search for the file
    for dirpath, dirnames, filenames in os.walk(latest_result_search_dir):
        for filename in filenames:
            if filename == "results.p":
                file_path = os.path.join(dirpath, filename)
                creation_date = datetime.fromtimestamp(os.path.getctime(file_path))
                if creation_date > latest_date:
                    latest_file = file_path
                    latest_date = creation_date

    if latest_file is None:
        raise FileNotFoundError("No results file found in the directory tree.")

    return latest_file


result_path = find_latest_result()
print(result_path)

In [None]:
with open(result_path, "rb") as f:
    data = pickle.load(f)

with open(original_phy_file, "r") as f:
    phy_file_raw = f.readlines()

labels = [line.split(" ")[0] for line in phy_file_raw[1:]]

In [None]:
class Node(ABC):
    @abstractmethod
    def has(self, value: str) -> bool:
        pass

    @abstractmethod
    def find(self, value: str) -> "Node | None":
        pass

    @abstractmethod
    def leaf_node_count(self) -> int:
        pass


class Leaf(Node):
    def __init__(self, value):
        self.value = value

    def has(self, value):
        return self.value == value

    def find(self, value):
        if self.value == value:
            return self
        else:
            return None

    def leaf_node_count(self):
        return 1

    def __str__(self):
        return str(self.value)


class Inner(Node):
    def __init__(self, left, right, left_len, right_len):
        self.left = left
        self.right = right
        self.left_len = left_len
        self.right_len = right_len

    def has(self, value):
        return self.left.has(value) or self.right.has(value)

    def find(self, value):
        return self.left.find(value) or self.right.find(value)

    def leaf_node_count(self):
        return self.left.leaf_node_count() + self.right.leaf_node_count()

    def __str__(self):
        return f"({self.left}:{self.left_len},{self.right}:{self.right_len})"


@dataclass
class BuiltTree:
    log_likelihood: float
    root: Node | None = None
    error: Exception | None = None


def build_tree(jc, left_branches, right_branches) -> Node:
    """Returns the Newick representation of the tree represented by the jump chain"""

    N = len(labels)

    parsed: list[frozenset[str]] = [
        frozenset((labels[int(y[1:])] for y in str(x)[2:-1].split("+"))) for x in jc[1:]
    ]

    trees: list[Node] = []

    entries_in_cur_merge_count = N - 1
    entries_in_last_merge: set[frozenset[str]] = set()

    for entry in parsed[:N]:
        entries_in_last_merge.add(entry)
        trees.append(Leaf(list(entry)[0]))

    i = N

    for merge_idx in range(N - 2):
        entries_in_cur_merge: set[frozenset[str]] = set()

        for _ in range(entries_in_cur_merge_count):
            entry = parsed[i]
            i += 1

            entries_in_cur_merge.add(entry)

            if entry not in entries_in_last_merge:
                trees_to_merge = list(
                    filter(lambda x: True in map(lambda y: x.has(y), entry), trees)
                )
                if len(trees_to_merge) != 2:
                    raise Exception("expected 2 trees to merge")

                new_inner = Inner(
                    trees_to_merge[0],
                    trees_to_merge[1],
                    left_branches[merge_idx],
                    right_branches[merge_idx],
                )
                trees = list(filter(lambda x: x not in trees_to_merge, trees))
                trees.append(new_inner)

        entries_in_cur_merge_count -= 1
        entries_in_last_merge = entries_in_cur_merge

    if i != len(parsed):
        raise Exception("not all entries were processed")

    if len(trees) != 2:
        raise Exception("expected 2 trees to merge")

    return Inner(
        trees[0],
        trees[1],
        left_branches[N - 2],
        right_branches[N - 2],
    )


def build_trees_for_epoch(epoch: int) -> list[BuiltTree]:
    """Returns a list of trees for the given epoch, sorted with highest likelihood first"""

    log_lik_list = data["log_lik_R"][epoch]
    jc_list = data["jump_chain_evolution"][epoch]
    left_branches_list = np.transpose(data["left_branches"][epoch])
    right_branches_list = np.transpose(data["right_branches"][epoch])

    trees = []

    for i in range(len(log_lik_list)):
        try:
            tree = build_tree(jc_list[i], left_branches_list[i], right_branches_list[i])
            trees.append(BuiltTree(log_likelihood=log_lik_list[i], root=tree))
        except Exception as e:
            trees.append(BuiltTree(log_likelihood=log_lik_list[i], error=e))

    return sorted(trees, key=lambda x: x.log_likelihood, reverse=True)

In [None]:
def find_best_epoch() -> int:
    """
    Finds the epoch with the largest elbo, where the epoch must satisfy additional conditions
    """

    if EPOCH_OVERRIDE is not None:
        return EPOCH_OVERRIDE

    sorted_epochs = np.argsort(data["elbos"])[::-1]

    for epoch in sorted_epochs:
        built_trees = build_trees_for_epoch(epoch)
        failed_count = len([tree for tree in built_trees if tree.root is None])
        failed_fraction = failed_count / len(built_trees)

        if failed_fraction > MAX_FAILED_TREES_FRACTION:
            print(
                f"Skipping epoch {epoch} because {failed_fraction} of the trees failed to build"
            )
            continue

        return epoch

    raise Exception("No epoch found that satisfies the conditions")


best_epoch = find_best_epoch()
print("Best epoch:", best_epoch)


def build_best_tree() -> tuple[Node, int, float]:
    """Returns None if no tree could be built"""

    first_tree = TREE_IN_EPOCH_OVERRIDE or 0

    built_trees = build_trees_for_epoch(best_epoch)

    for i, tree in enumerate(built_trees[first_tree:]):
        if tree.root is not None:
            tree_num = i + first_tree
            print(f"Built tree {tree_num} with log likelihood {tree.log_likelihood}")
            return tree.root, tree_num, tree.log_likelihood

    raise Exception("No tree could be built")

In [None]:
def plot_tree(tree: Node, root_name: str, file_name: str):
    _, ax = plt.subplots(figsize=(10, tree.leaf_node_count() * 0.15))
    phylo_tree = Phylo.read(io.StringIO(str(tree)), "newick")  # type: ignore
    phylo_tree.root_with_outgroup(root_name)
    Phylo.draw(phylo_tree, axes=ax, do_show=False)
    plt.savefig(os.path.join(os.path.dirname(result_path), file_name))
    plt.show()


root, root_num, root_ll = build_best_tree()

plot_tree(
    root, "Healthy", f"tree-epoch-{best_epoch}-num-{root_num}-ll-[{round(root_ll)}].png"
)

In [None]:
def get_rf_distance(t1: Node, t2: Node):
    """Returns the Robinson-Foulds distance between two trees"""
    t1_tree = Tree(str(t1) + ";")
    t2_tree = Tree(str(t2) + ";")
    return t1_tree.robinson_foulds(t2_tree, unrooted_trees=True)[0]


def get_rf_dists_in_epoch(epoch: int):
    """Returns Robinson-Foulds distances relative to the first tree in the epoch"""
    built_trees = build_trees_for_epoch(epoch)

    dists = []
    failed = 0

    first_tree = None

    for tree in built_trees:
        if tree.root is None:
            failed += 1
            continue

        if first_tree is None:
            first_tree = tree.root
        else:
            dists.append(get_rf_distance(first_tree, tree.root))

    return dists, failed


rf_dists, failed = get_rf_dists_in_epoch(best_epoch)

print("RF distances:", rf_dists)
print("Failed:", failed)

if len(rf_dists) > 0:
    print("Max RF distance:", max(rf_dists))