<a href="https://colab.research.google.com/github/Yexuan-Song/940-project/blob/main/BD_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tree generation

In [None]:
!pip install ngesh
import ngesh
import itertools
import numpy as np
import math
import random
from ete3 import Tree, TreeNode
from typing import Hashable, List, Optional
from collections import Counter
import pandas as pd

def extant(tree: Tree) -> List[TreeNode]:
    """

    Internal function returning a list of non-extinct leaves in a tree.
    :param tree: The tree whose nodes will be checked.
    :return: List of extant leaves.
    """

    # Return a filtered list compiled with a list comprehension; the
    # 'extinct' field is not part of ETE3 defaults, but we use here in
    # order to easily differentiate between alive and extinct leaves in
    # Birth-Death models.
    return [leaf for leaf in tree.get_leaves() if leaf.extinct is False]

def _gen_tree(
    birth: float,
    death: float,
    trans: list,
    name: list,
    max_len:int,
    min_leaves: Optional[int] = None,
    max_time: Optional[float] = None,
    lam: float = 0.0,
    prune: bool = False,
    labels: Optional[str] = "enum",
    seed: Optional[Hashable] = None,
) -> Tree:
    """
    Internal function for tree generation.
    This is an internal function for the tree generation, whose main
    difference to `gen_tree()`, the one exposed to the user, is that it
    does not guarantee that a tree will be generated, as the parameters and
    the random sampling might lead to dead-ends where all the leaves in
    a tree are extinct before any or all the stopping criteria are met.
    As an internal function, it does not set default values to the arguments
    and does not perform any checking on the values. Information on the
    arguments, which have the same variable names and properties, are given
    in the documentation for `gen_tree()`.
    :param birth: The birth rate (lambda) for the generated tree.
    :param death: The death rate (mu) for the generated tree. Must be explicitly set
        to zero for Yule model (i.e., birth only).
    :param trans: The transmission probabilities between two locations
    :param name: The names of the locations    
    :param min_leaves: A stopping criterion with the minimum number of extant leaves.
        The generated tree will have at least the number of requested
        extant leaves (possibly more, as the last speciation event might
        produce more leaves than the minimum specified.
        Defaults to `None`.
    :param max_time: A stopping criterion with the maximum allowed time for evolution.
        Defaults to `None`.
    :param lam: The expectation of interval for sampling a Poisson distribution
        during speciation, with a minimum of two descendants. Should be used
        if more than two descendants are to be allowed. Defaults to zero,
        meaning that all speciation events will have two and only two
        descendents.
    :param prune: A flag indicating whether any non-extant leaves should be
        pruned from the tree before it is returned.
    :param labels: The model to be used for generating random labels, either
        "enum" (for enumerated labels), "human" (for random single names),
        "bio" (for random biological names" or None. Defaults to "enum".
    :param seed: An optional seed for the random number generator. Defaults
        to `None`.
    :return: The tree randomly generated according to the parameters.
    """

    # Initialize the RNG
    random.seed()

    # Compute the overall event rate (birth plus death), from which the
    # random expovariate will be drawn. `birth` is here normalized in range
    # [0..1] so that we can directly compare with the results of
    # `.random()` and decide if the event is a birth or a death.
    # `death` does not need to be normalized, as it is not used anymore (the
    # only check, below, is `.random() <= birth`).
    event_rate = birth + death
    birth = birth / event_rate

    # Create the tree root as a node. Given that the root is at first set as
    # non-extinct and with a branch length of 0.0, it will be immediately
    # subject to either a speciation or extinction event.
    tree = Tree()
    tree.dist = 0.0
    tree.extinct = False

    #name the root of the tree
    tree.add_feature("location",name[0])
    tree.name = name[0]

    # Iterate until an acceptable tree is generated (breaking the loop with
    # a tree) or all leaves go extinct (breaking the loop with `tree` as None).
    # `total_time`, of which we keep track in case `max_time` is provided,
    # is the total evolution time (sum of branch lengths) from the root to the
    # extant nodes.
    total_time = 0.0
    while True:
        # Get the list of extant leaves
        leaf_nodes = extant(tree)

        # Compute the event time before the next birth/death event from a
        # random expovariate reflecting the number of extant leaves and the
        # combined event probability.
        event_time = random.expovariate(len(leaf_nodes) * event_rate)

        # Update the total evolution time. If a maximum allotted time
        # `max_time` is provided and we overshoot it, break the loop
        # without implementing the event (as, by the random event time, it
        # would take place *after* our maximum time, in the future).
        total_time += event_time
        if max_time and total_time > max_time:
            break

        # Select a random node among the extant ones and set it as extinct
        # before simulating either a birth or death event; the type of
        # event is decided based on the comparison of the result of a
        # `random.random()` call with `birth` (here already normalized in
        # relation to `event_rate`)
        node = np.random.choice(leaf_nodes)
        node.extinct = True
        if np.random.random() <= birth:
          if node.name == name[0]:
              #case infectious person is in location A
              #check infect person in location B or not
              if np.random.random() <= trans[0]:
                  for _ in range(2 + np.random.poisson(lam)):
                      child_node = Tree()
                      child_node.dist = 0
                      child_node.extinct = False
                      child_node.name = name[0]
                      child_node.add_feature("location",name[0])
                      node.add_child(child_node)
              else:
                 for _ in range(2 + np.random.poisson(lam)):
                      child_node = Tree()
                      child_node.dist = 0
                      child_node.extinct = False
                      child_node.name = name[1]
                      child_node.add_feature("location",name[1])
                      node.add_child(child_node)

          else:
              if np.random.random() <= trans[1]:
                  for _ in range(2 + np.random.poisson(lam)):
                      child_node = Tree()
                      child_node.dist = 0
                      child_node.extinct = False
                      child_node.name = name[1]
                      child_node.add_feature("location",name[1])
                      node.add_child(child_node)
              else:
                  for _ in range(2 + np.random.poisson(lam)):
                      child_node = Tree()
                      child_node.dist = 0
                      child_node.extinct = False
                      child_node.name = name[0]
                      child_node.add_feature("location",name[0])
                      node.add_child(child_node)
        # (Re)Extract the list of extant nodes, now that we might have new
        # children and that the randomly selected node went extinct
        # (easier than directly manipulating the Python list). From the
        # updated list, we will extend the branch length of all extant leaves
        # (thus including any new children) by the `event_time` computed
        # above.
        leaf_nodes = extant(tree)
        for leaf in leaf_nodes:
            new_leaf_dist = leaf.dist + event_time
            leaf.dist = min(new_leaf_dist, (max_time or new_leaf_dist))

        # If the event above was a death event, we might be in the undesirable
        # situation where all lineages went extinct before we
        # could finish the random generation according to the
        # user-requested parameters, so that one or both stopping criteria
        # cannot be satisfied. A solution could
        # be to recursively call this function, with the same
        # parameters, until a valid tree is found, but this is not
        # optimal (nor elegant) and might get us stuck in a
        # loop if we don't keep track of the number of iterations
        # (especially if we got to this point by using a
        # user-provided random seed and/or set of unfortunate parameters).
        # In face of that, it is preferable to be explicit about the problem by
        # returning a `None` value, with the user (or a wrapper
        # function) being in charge of asserting that the desired
        # number of random trees is collected (even if it is a single one).
        if not leaf_nodes:
            tree = None
            break

        # Check whether one or both the stopping criteria were reached
        if min_leaves and len(leaf_nodes) >= min_leaves:
            break

        if max_time and total_time >= max_time:
            break

    # In some cases we might end up with technically valid trees composed
    # only of the root. We make sure at least one speciation event took
    # place, returning `None` as failure in other cases.
    if tree and len(extant(tree)) <= 2:
        tree = None

    # Prune the tree, removing extinct leaves, if requested and if a
    # tree was found. Remember that the ete3 `prune()` method takes a list
    # of the nodes that will be kept, removing the other ones.
    if prune and tree:
        tree.prune(extant(tree))
    return tree
    
def gen_tree(
    birth: float,
    death: float,
    trans: list,
    name: list,
    max_len:int,
    min_leaves: Optional[int] = None,
    max_time: Optional[float] = None,
    lam: float = 0.0,
    prune: bool = False,
    labels: Optional[str] = "enum",
    seed: Optional[Hashable] = None,
) -> Tree:
    MAX_ATTEMPTS = 10000
    cur_att = 0
    tree = _gen_tree(
            birth = birth,
            death = death,
            trans = trans,
            name = name,
            max_time = max_time, max_len=max_len,)
    
    #if tree != None and len(tree) < max_len and len(tree) >= 50:
    if tree != None and len(tree) < max_len:
      return tree

    else:
      #while tree == None or len(tree) >= max_len or len(tree) < 50:
      while tree == None or len(tree) >= max_len:

        cur_att += 1
        if cur_att == MAX_ATTEMPTS:
              raise RuntimeError("Unable to generate a valid tree.")

        tree = _gen_tree(
              birth = birth,
              death = death,
              trans = trans,
              name = name,
              max_time = max_time, max_len=max_len,)
        #if  tree != None and len(tree) < max_len and len(tree) >= 50:
        if  tree != None and len(tree) < max_len:
          break
    return tree

Tree encoding

In [None]:
def rescale_tree(tre, target_avg_length):
    """
    Returns branch length metrics (all branches taken into account and external only)
    :param tre: ete3.Tree, tree on which these metrics are computed
    :param target_avg_length: float, the average branch length to which we want to rescale the tree
    :return: float, resc_factor
    """
    
    # branch lengths
    dist_all = [node.dist for node in tre.traverse("levelorder")]

    all_bl_mean = np.mean(dist_all)

    resc_factor = all_bl_mean/target_avg_length

    for node in tre.traverse():
        node.dist = node.dist/resc_factor

    return resc_factor

def add_dist_to_root(tre):
    """
    Add distance to root (dist_to_root) attribute to each node
    :param tre: ete3.Tree, tree on which the dist_to_root should be added
    :return: void, modifies the original tree
    """

    for node in tre.traverse("preorder"):
        if node.is_root():
            node.add_feature("dist_to_root", 0)
        elif node.is_leaf():
            node.add_feature("dist_to_root", getattr(node.up, "dist_to_root") + node.dist)
            # tips_dist.append(getattr(node.up, "dist_to_root") + node.dist)
        else:
            node.add_feature("dist_to_root", getattr(node.up, "dist_to_root") + node.dist)
            # int_nodes_dist.append(getattr(node.up, "dist_to_root") + node.dist)
    return None

TARGET_AVG_BL = 1
def encode_into_most_recent(tree_input, sampling_proba, location, trans):
    """Rescales all trees from tree_file so that mean branch length is 1,
    then encodes them into full tree representation (most recent version)
    :param tree_input: ete3.Tree, that we will represent in the form of a vector
    :param sampling_proba: float, value between 0 and 1, presumed sampling probability value
    :return: pd.Dataframe, encoded rescaled input trees in the form of most recent, last column being
     the rescale factor
    """



    def get_not_visited_anc(leaf):
        while getattr(leaf, "visited", 0) >= len(leaf.children)-1:
            leaf = leaf.up
            if leaf is None:
                break
        return leaf

    def get_deepest_not_visited_tip(anc):
        max_dist = -1
        tip = None
        for leaf in anc:
            if leaf.visited == 0:
                distance_leaf = getattr(leaf, "dist_to_root") - getattr(anc, "dist_to_root")
                if distance_leaf > max_dist:
                    max_dist = distance_leaf
                    tip = leaf
        return tip

    def get_dist_to_root(anc):
        dist_to_root = getattr(anc, "dist_to_root")
        if anc.location == location[0]:
          return dist_to_root
        else:
          return -dist_to_root
        
        #if the locations of internal nodes are unknown
        #return dist_to_root

    def get_dist_to_anc(feuille, anc):
        dist_to_anc = getattr(feuille, "dist_to_root") - getattr(anc, "dist_to_root")
        if anc.location == location[0]:
          return dist_to_anc
        else:
          return -dist_to_anc
    
    def name_tree(tre):
      existing_names = Counter((_.name for _ in tre.traverse() if _.name))
      i = 0
      for node in tre.traverse('levelorder'):
          node.name = i
          i += 1
      return None


      

    def encode(anc):
        leaf = get_deepest_not_visited_tip(anc)
        yield get_dist_to_anc(leaf, anc)

        leaf.visited += 1
        anc = get_not_visited_anc(leaf)

        if anc is None:
            return
        anc.visited += 1
        yield get_dist_to_root(anc)
        for _ in encode(anc):
            yield _

    def complete_coding(encoding, max_length):
        add_vect = np.repeat(0, max_length - len(encoding))
        add_vect = list(add_vect)
        encoding.extend(add_vect)
        return encoding

    def refactor_to_final_shape(result_v, sampling_p, maxl):
        def reshape_coor(max_length):
            tips_coor = np.arange(0, max_length, 2)
            tips_coor = np.insert(tips_coor, -1, max_length + 1)

            int_nodes_coor = np.arange(1, max_length - 1, 2)
            int_nodes_coor = np.insert(int_nodes_coor, 0, max_length)
            int_nodes_coor = np.insert(int_nodes_coor, -1, max_length + 2)

            order_coor = np.append(int_nodes_coor, tips_coor)

            return order_coor

        reshape_coordinates = reshape_coor(maxl)

        # add sampling probability:
        if maxl == 999:
            result_v.loc[:, 1000] = 0
            result_v['1001'] = sampling_p
            result_v['1002'] = sampling_p
        else:
            result_v.loc[:, 400] = 0
            result_v['401'] = sampling_p
            result_v['402'] = sampling_p

        # reorder the columns
        result_v = result_v.iloc[:,reshape_coordinates]

        return result_v

    # local copy of input tree
    tree = tree_input.copy()

    if len(tree) < 200:
        max_len = 399
    else:
        max_len = 999

    # rescale branch lengths
    rescale_factor = rescale_tree(tree, target_avg_length=TARGET_AVG_BL)

    # set all nodes to non visited:
    for node in tree.traverse():
        setattr(node, "visited", 0)

    name_tree(tree)

    add_dist_to_root(tree)

    tree_embedding = list(encode(tree))

    tree_embedding = complete_coding(tree_embedding, max_len)
    #tree_embedding.append(rescale_factor)

    result = pd.DataFrame(tree_embedding, columns=[0])

    result = result.T
    # refactor to final shape: add sampling probability, put features in order

    trans = pd.DataFrame(trans,columns=[0])
    trans = trans.T

    result = refactor_to_final_shape(result, sampling_proba, max_len)
    
    return result, rescale_factor,trans

Data generation (50000 trees)

In [None]:
average=1
t = gen_tree(1.0,0.5,3,[0.7,0.7],["A","B"],max_time = 5.0,max_len=200)
r, rescale_factor, trans = encode_into_most_recent(t,sampling_proba=0.9,location = ["A","B"],trans = [0.7,0.7])
X_train = pd.DataFrame(data=r)
Y_train = pd.DataFrame(data=trans)
X_train = X_train.values
X_train = np.reshape(X_train,(1,201,2))
Y_train = Y_train.values


for i in range(49999):
  print(i)
  tr = [np.random.random(),np.random.random()]
  t = gen_tree(birth = np.random.uniform(0.9,1.0),death = np.random.uniform(0.5,0.6),latency = 3,trans=tr,name=["A","B"],max_time = 7.0,max_len=200)
  r, rescale_factor, trans = encode_into_most_recent(t,sampling_proba=0.9,location = ["A","B"],trans=tr)
  X = pd.DataFrame(data=r)
  Y = pd.DataFrame(data=trans)
  X = X.values
  X = np.reshape(X,(1,201,2))
  Y = Y.values
  X_train = np.vstack((X_train,X))
  Y_train = np.vstack((Y_train,Y))

  

Check shape

In [None]:
X_train.shape

Neural network

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv1D, MaxPool1D, GlobalAveragePooling1D, Reshape,MaxPool2D,GlobalAveragePooling2D
from tensorflow.keras.metrics import MeanAbsolutePercentageError

model = Sequential()
model.add(Conv1D(50,3,input_shape=(201,2),activation='relu'))
model.add(Conv1D(50,10,activation='relu'))
model.add(MaxPool1D(pool_size=10))
model.add(Conv1D(80,10,activation='relu'))
model.add(GlobalAveragePooling1D())


#FFNN 
model.add(Dense(64,activation='elu'))
model.add(Dense(32,activation='elu'))
model.add(Dense(8,activation='elu'))
model.add(Dense(2,activation='linear'))


model.compile(loss='mse',optimizer="adam",metrics=['mse'])
model.summary()
his = model.fit(X_train,Y_train,batch_size=50,epochs=50)

Plot model loss and model mse

In [None]:
import matplotlib.pyplot as plt

plt.plot(his.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train'], loc='upper left')
plt.show()


plt.plot(his.history['mse'])
plt.title('model mse')
plt.ylabel('mse')
plt.xlabel('epoch')
plt.legend(['train'], loc='upper left')
plt.show()

Generate 100 test trees with the same trans probabilities, need to uncommen lines 240,245,258 in the tree generation code to ensure the sizes of trees are between 50 and 200.

In [None]:
average=1
t = gen_tree(1.0,0.5,[0.3,0.5],["A","B"],max_time = 10.0,max_len=200)
print(len(t))
r, rescale_factor, trans = encode_into_most_recent(t,sampling_proba=0.9,location = ["A","B"],trans = [0.3,0.5])
X_1 = pd.DataFrame(data=r)
Y_1 = pd.DataFrame(data=trans)
X_1 = X_1.values
X_1 = np.reshape(X_1,(1,201,2))
Y_1 = Y_1.values


for i in range(99):
  tr = [0.3,0.5]
  t = gen_tree(birth = 1.0,death = 0.5,trans=tr,name=["A","B"],max_time = 10.0,max_len=200)
  print(len(t))
  r, rescale_factor, trans = encode_into_most_recent(t,sampling_proba=0.9,location = ["A","B"],trans=tr)
  X = pd.DataFrame(data=r)
  Y = pd.DataFrame(data=trans)
  X = X.values
  X = np.reshape(X,(1,201,2))
  Y = Y.values
  X_1 = np.vstack((X_1,X))
  Y_1 = np.vstack((Y_1,Y))

Find mean prediction

In [None]:
e_1 = 0
for i in range(99):
  e_1 += data[i][0]
print(e_1/100)

e_2 = 0
for i in range(99):
  e_2 += data[i][1]
print(e_2/100)

Produce bar plots

In [None]:
e_3 = 0
for i in range(99):
  if 0.25 < data[i][0] < 0.35 and 0.45 < data[i][1] < 0.55:
    e_3 += 1
print(e_3)

e_4 = 0
for i in range(99):
  if 0.2 < data[i][0] < 0.4 and 0.4 < data[i][1] < 0.6:
    e_4 += 1
print(e_4)

e_5 = 0
for i in range(99):
  if 0.15 < data[i][0] < 0.45 and 0.35 < data[i][1] < 0.65:
    e_5 += 1
print(e_5)

e_6 = 0
for i in range(99):
  if 0.1 < data[i][0] < 0.5 and 0.3 < data[i][1] < 0.7:
    e_6 += 1
print(e_6)


fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
interval = ['5%','10%','15%','20%']
number = [e_3,e_4,e_5,e_6]
ax.bar(interval,number)
plt.show()