In [None]:
import os
import io
import pickle
from datetime import datetime
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from ete3 import Tree
from Bio import Phylo

In [None]:
# set to None to use the default value
EPOCH_OVERRIDE = None
TREE_IN_EPOCH_OVERRIDE = None

USE_SIMULATED_DATA = False

In [None]:
if USE_SIMULATED_DATA:
    latest_result_search_dir = "results/cellphy_simulated_data"
    original_phy_file = "data/cellphy_simulated_set.phy"
else:
    latest_result_search_dir = "results/cellphy_toy_data"
    original_phy_file = "data/cellphy_toy_set.phy"


def find_latest_result():
    # Initialize variables to keep track of the latest file and its creation date
    latest_file = None
    latest_date = datetime.min

    # Walk through the directory tree and search for the file
    for dirpath, dirnames, filenames in os.walk(latest_result_search_dir):
        for filename in filenames:
            if filename == "results.p":
                file_path = os.path.join(dirpath, filename)
                creation_date = datetime.fromtimestamp(os.path.getctime(file_path))
                if creation_date > latest_date:
                    latest_file = file_path
                    latest_date = creation_date

    if latest_file is None:
        raise FileNotFoundError("No results file found in the directory tree.")

    return latest_file


result_path = find_latest_result()
print(result_path)

In [None]:
with open(result_path, "rb") as f:
    data = pickle.load(f)

with open(original_phy_file, "r") as f:
    phy_file_raw = f.readlines()

labels = [line.split(" ")[0] for line in phy_file_raw[1:]]

In [None]:
class Node(ABC):
    @abstractmethod
    def has(self, value: str) -> bool:
        pass

    @abstractmethod
    def find(self, value: str) -> "Node | None":
        pass

    @abstractmethod
    def flat_values(self) -> set[str]:
        pass


class Leaf(Node):
    def __init__(self, value):
        self.value = value

    def has(self, value):
        return self.value == value

    def find(self, value):
        if self.value == value:
            return self
        else:
            return None

    def flat_values(self):
        return set([self.value])

    def __str__(self):
        return str(self.value)


class Inner(Node):
    def __init__(self, left, right, left_len, right_len):
        self.left = left
        self.right = right
        self.left_len = left_len
        self.right_len = right_len

    def has(self, value):
        return self.left.has(value) or self.right.has(value)

    def find(self, value):
        return self.left.find(value) or self.right.find(value)

    def flat_values(self):
        return self.left.flat_values() | self.right.flat_values()

    def __str__(self):
        return f"({self.left}:{self.left_len},{self.right}:{self.right_len})"


@dataclass
class BuiltTree:
    log_likelihood: float
    root: Node | None = None
    error: Exception | None = None


def build_tree(jc, left_branches, right_branches) -> Node:
    """Returns the Newick representation of the tree represented by the jump chain"""

    parsed: list[frozenset[str]] = [
        frozenset((labels[int(y[1:])] for y in str(x)[2:-1].split("+"))) for x in jc[1:]
    ]

    entries: list[frozenset[str]] = []
    seen: set[frozenset[str]] = set()

    for entry in parsed:
        if entry in seen:
            continue

        seen.add(entry)
        entries.append(entry)

    # print([set(entry) for entry in entries])

    trees: list[Node] = []

    merge_index = 0

    for entry in entries:
        if len(entry) == 1:
            trees.append(Leaf(list(entry)[0]))
        else:
            trees_to_merge = list(
                filter(lambda x: True in map(lambda y: x.has(y), entry), trees)
            )
            if len(trees_to_merge) != 2:
                raise Exception(
                    f"too many trees to merge (entry={entry}, trees_to_merge={[tree.flat_values() for tree in trees]})"
                )
                # print(f"Warning, skipping merging {len(trees_to_merge)} trees")
                # continue

            new_inner = Inner(
                trees_to_merge[0],
                trees_to_merge[1],
                left_branches[merge_index],
                right_branches[merge_index],
            )
            trees = list(filter(lambda x: x not in trees_to_merge, trees))
            trees.append(new_inner)

            merge_index += 1

    if len(trees) == 1:
        return trees[0]
    if len(trees) == 2:
        return Inner(
            trees[0], trees[1], left_branches[merge_index], right_branches[merge_index]
        )
    else:
        raise Exception(
            f"too many trees to merge (trees={[tree.flat_values() for tree in trees]})"
        )


def build_trees_for_epoch(epoch: int) -> list[BuiltTree]:
    """Returns a list of trees for the given epoch, sorted with highest likelihood first"""

    log_lik_list = data["log_lik_R"][epoch]
    jc_list = data["jump_chain_evolution"][epoch]
    left_branches_list = np.transpose(data["left_branches"][epoch])
    right_branches_list = np.transpose(data["right_branches"][epoch])

    trees = []

    for i in range(len(log_lik_list)):
        try:
            tree = build_tree(jc_list[i], left_branches_list[i], right_branches_list[i])
            trees.append(BuiltTree(log_likelihood=log_lik_list[i], root=tree))
        except Exception as e:
            trees.append(BuiltTree(log_likelihood=log_lik_list[i], error=e))

    return sorted(trees, key=lambda x: x.log_likelihood, reverse=True)


def build_best_tree() -> Node | None:
    """Returns None if no tree could be built"""

    epoch = EPOCH_OVERRIDE or data["best_epoch"]
    first_tree = TREE_IN_EPOCH_OVERRIDE or 0

    print("Using epoch", epoch)

    built_trees = build_trees_for_epoch(epoch)

    for i, tree in enumerate(built_trees[first_tree:]):
        if tree.root is not None:
            print(
                f"Built tree {i+first_tree} with log likelihood {tree.log_likelihood}"
            )
            return tree.root

    return None

In [None]:
def plot_tree(tree: Node, root_name: str):
    phylo_tree = Phylo.read(io.StringIO(str(tree)), "newick")  # type: ignore
    phylo_tree.root_with_outgroup(root_name)
    Phylo.draw(phylo_tree)  # type: ignore
    plt.show()


root = build_best_tree()

if root is None:
    raise Exception("No tree could be built")

plot_tree(root, "Healthy")

In [None]:
def get_rf_distance(t1: Node, t2: Node):
    t1_tree = Tree(str(t1) + ";")
    t2_tree = Tree(str(t2) + ";")
    return t1_tree.robinson_foulds(t2_tree, unrooted_trees=True)[0]


def get_rf_distance_rel_to_first_tree_in_epoch(epoch: int):
    built_trees = build_trees_for_epoch(epoch)

    dists = []
    failed = 0

    first_tree = None

    for tree in built_trees:
        if tree.root is None:
            failed += 1
            continue

        if first_tree is None:
            first_tree = tree.root
        else:
            dists.append(get_rf_distance(first_tree, tree.root))

    return dists, failed


epoch = EPOCH_OVERRIDE or data["best_epoch"]

rf_dists, failed = get_rf_distance_rel_to_first_tree_in_epoch(epoch)
max_rf_dist = max(rf_dists)

print("RF distances:", rf_dists)
print("Max RF distance:", max_rf_dist)
print("Failed:", failed)