# SPR

In [2]:
# Import packages.
import copy
import msprime
import numpy as np
import pandas as pd
import toyplot
import toyplot.svg
import toytree
# Print version numbers.
print('copy', matplotlib.__version__)
print('msprime', msprime.__version__)
print('numpy', np.__version__)
print('pandas', pd.__version__)
print('toyplot', toyplot.__version__)
print('toytree', toytree.__version__)

copy 3.6.3
matplotlib 3.6.3
msprime 1.2.0
numpy 1.23.5
pandas 1.5.3


In [23]:
# Intialize a node class.
class Node:
    
    # Intialize the node.
    def __init__(self, node_id, age, node_type, parent=None, l_child=None, r_child=None):
        """
        Node Types
            - 0: leaf node
            - 1: coalescent event node
            - 2: visibile recombination
            - 3: hidden recombination
        """
        self.node_id = node_id
        self.age = age
        self.node_type = node_type
        self.parent = parent
        self.l_child = l_child
        self.r_child = r_child
        self.parent_dist = None
        self.l_child_dist = None
        self.r_child_dist = None
        
    # Define a deep copy method.
    def __deepcopy__(self, memo):
        """
        Return a deepy copy of an isntance of the Node class.
        """
        # Avoid infinite loops
        if id(self) in memo:
            return memo[id(self)]
        # Create a shallow copy of the current node
        copied_node = copy.copy(self)
        memo[id(self)] = copied_node
        # Deep copy children and parent
        copied_node.parent = copy.deepcopy(self.parent, memo)
        copied_node.l_child = copy.deepcopy(self.l_child, memo)
        copied_node.r_child = copy.deepcopy(self.r_child, memo)
        return copied_node
    
    # Define a method to check if a node is a leaf.
    def is_leaf(self):
        """
        True if the node is a leaf, False otherwise.
        """
        return self.node_type == 0
    
    # Define a method to compute the distance to the children.
    def dist_to_children(self):
        """
        Compute the distance from the current node to its children.
        """
        if self.l_child is not None:
            self.l_child_dist = self.age - self.l_child.age
        if self.r_child is not None:
            self.r_child_dist = self.age - self.r_child.age
    
    # Define a method to compute the distance to the parent node.
    def dist_to_parent(self):
        """
        Compute the distance from the current node to its parent.
        """
        if self.parent is not None:
            self.parent_dist = self.parent.age - self.age
            
    # Define a function to initialize distance to parent and children nodes.
    def init_dists(self):
        """
        Intialize the distances to the parent and children nodes.
        """
        self.dist_to_parent()
        self.dist_to_children()


# Intialize a tree class.
class Tree:
    
    # Intialize the tree.
    def __init__(self, left=0.0, right=1.0):
        self.left = left
        self.right = right
        self.root = None
        self.length = None
        self.next_node_id = None
        self.next_rec_id = -1
        self.nodes = {}
        self.edges = {}
        self.upper_bounds = None
        self.recomb_node = None
        self.recoal_node = None
        
    def __deepcopy__(self, memo):
        """
        Return a deepy copy of an isntance of the Tree class.
        """
        # Avoid infinite loops.
        if id(self) in memo:
            return memo[id(self)]
        # Create a shallow copy of the tree.
        copied_tree = copy.copy(self)
        memo[id(self)] = copied_tree
        # Deep copy nodes and edges.
        copied_tree.nodes = copy.deepcopy(self.nodes, memo)
        copied_tree.edges = copy.deepcopy(self.edges, memo)
        return copied_tree
        
    # Define a method to add a node to the tree.
    def add_node(self, node):
        """
        Add a new node to the tree.
        """
        self.nodes[node.node_id] = node
        
    # Define a method to remove a node from the tree.
    def rmv_node(self, node):
        """
        Remove a new node to the tree.
        """
        del self.nodes[node.node_id]
    
    # Define a method to intialize node distances.
    def init_branch_lengths(self):
        """
        Intialize all the branch lengths for the current tree.
        """    
        # For every node.
        for node_id in self.nodes:
            # Intialize branch lengths.
            self.nodes[node_id].init_dists()
        
    # Define a method to intialize the edges on a tree.
    def init_edges(self):
        """
        Intialize all the edges on the current tree.
        """
        # Intialize variables.
        i = 0
        Lx = 0
        upper_bounds = []
        # For every node.
        for node in self.nodes:
            # If the node is not a leaf.
            if not self.nodes[node].is_leaf():
                # Record the interval's upper bound.
                upper_bounds.append(self.nodes[node].age)
                # Intialize the edge for parent -> left child.
                self.edges[i] = {}
                self.edges[i]['parent'] = self.nodes[node].node_id
                self.edges[i]['child'] = self.nodes[node].l_child.node_id
                self.edges[i]['upper'] = self.nodes[node].age
                self.edges[i]['lower'] = self.nodes[node].l_child.age
                self.edges[i]['length'] = self.nodes[node].l_child_dist
                i += 1
                Lx += self.nodes[node].l_child_dist
                # Intialize the edge for parent -> right child.
                self.edges[i] = {}
                self.edges[i]['parent'] = self.nodes[node].node_id
                self.edges[i]['child'] = self.nodes[node].r_child.node_id
                self.edges[i]['upper'] = self.nodes[node].age
                self.edges[i]['lower'] = self.nodes[node].r_child.age
                self.edges[i]['length'] = self.nodes[node].r_child_dist
                i += 1
                Lx += self.nodes[node].r_child_dist
        # Set the tree properties.
        self.upper_bounds = sorted(upper_bounds)
        self.length = Lx
                
    # Define a method to find the root node
    def find_root(self):
        """
        Determine the root node on the current tree.
        """
        root_node = max(self.nodes, key=lambda k: self.nodes[k].age)
        self.root = root_node
        
    # Define a method to replace an existing node's child with a new child.
    def replace_child(self, node_id, old_child, new_child):
        """
        Replace a node's existing child node.
        """
        # If the left child is the child we are replacing.
        if self.nodes[node_id].l_child.node_id == old_child.node_id:
            # Replace the left child with the new child node.
            self.nodes[node_id].l_child = new_child
        # If the right child is the child we are replacing.
        if self.nodes[node_id].r_child.node_id == old_child.node_id:
            # Replace the right child with the new child node.
            self.nodes[node_id].r_child = new_child
        
    # Define a method to replace an exiting node on the tree with a new node.
    def replace_node(self, old_node, new_node):
        """
        Remove an old node and add a new node.
        """
        # Remove the old node from the tree.
        self.rmv_node(old_node)
        # Add the new node to the tree.
        self.add_node(new_node)
    
    # Define a function to set the next node id.
    def init_next_node_id(self):
        """
        Set the next node id.
        """
        last_coal = self.recoal_node
        max_node = max(self.nodes)
        if last_coal is not None:
            self.next_node_id = max([last_coal.node_id, max_node]) + 1
        else:
            self.next_node_id = max_node + 1
    
    # Define a method to recursively construct the newick information.
    def _to_newick_recursive(self, node):
        """
        Recursively construct the Newick information for a given node.
        """
        # Return the leaf id if the node is a leaf.
        if node.is_leaf():
            return f'n{node.node_id}'
        # For internal nodes, get the newick information for each child.
        l_child_info = self._to_newick_recursive(node.l_child)
        r_child_info = self._to_newick_recursive(node.r_child)
        return '({}:{},{}:{})'.format(l_child_info, node.l_child_dist, 
                                      r_child_info, node.r_child_dist)
    
    # Define a method to export a tree in newick format.
    def to_newick(self):
        """
        Convert the tree to its Newick format.
        """
        # Start the conversion from the root.
        nwk = self._to_newick_recursive(self.nodes[self.root])
        # The Newick format ends with a semicolon.
        return nwk + ';'

In [11]:
# Define a function to intialize a tree from a msprime simulaion.
def init_msp_tree(k, Ne, ploidy, seed=None):
    """
    Returns a Tree object from a msprime simulation.
    
    k      -- Number of chromosomes to simulate.
    Ne     -- Effective population size.
    ploidy -- Haploid or diploid coalescent units.
    seed   -- Random seed for reporducibility.
    """
    # Simulate a tree under the standard coalescent.
    ts = msprime.sim_ancestry(
        samples=[msprime.SampleSet(k, ploidy=1)],
        population_size=Ne,
        ploidy=ploidy,
        random_seed=seed,
        discrete_genome=False,
    )
    # Intialize the current tree.
    tree = Tree()
    # For ever node.
    for node_id, age in enumerate(ts.tables.nodes.time):
        # If the node is a leaf.
        if age == 0:
            # Intialize the node.
            node = Node(
                node_id=node_id, age=age, node_type=0,
                parent=None, l_child=None, r_child=None,
            )
            # Add the node to the tree.
            tree.add_node(node)
        # Else, the node is an ancestral node.
        else:
            # Intialize the node.
            node = Node(
                node_id=node_id, age=age, node_type=1,
                parent=None, l_child=None, r_child=None,
            )
            # Add the node to the tree.
            tree.add_node(node)
    # For every parent node.
    for parent in np.unique(ts.tables.edges.parent):
        # Find the children of the parent node.
        left_child, right_child = ts.tables.edges[ts.tables.edges.parent == parent].child
        # Update the parent node for the two children.
        tree.nodes[left_child].parent = tree.nodes[parent]
        tree.nodes[right_child].parent = tree.nodes[parent]
        # Update the children nodes for the parent.
        tree.nodes[parent].l_child = tree.nodes[left_child]
        tree.nodes[parent].r_child = tree.nodes[right_child]
    # Intialize branch lengths.
    tree.init_branch_lengths()
    # Intialize the edges for the current tree.
    tree.init_edges()
    # Intialize the root node.
    tree.find_root()
    # Intialize the next node id.
    tree.init_next_node_id()
    return tree



In [20]:
# Simulate a tree under the standard coalescent.
ts = msprime.sim_ancestry(
    samples=[msprime.SampleSet(5, ploidy=1)],
    population_size=1,
    ploidy=2,
    random_seed=42,
    discrete_genome=False,
)
print(ts.draw_text())

1.95┊     8     ┊
    ┊   ┏━┻━━┓  ┊
0.77┊   7    ┃  ┊
    ┊  ┏┻━┓  ┃  ┊
0.16┊  ┃  ┃  6  ┊
    ┊  ┃  ┃ ┏┻┓ ┊
0.09┊  5  ┃ ┃ ┃ ┊
    ┊ ┏┻┓ ┃ ┃ ┃ ┊
0.00┊ 0 1 4 2 3 ┊
    0           1



In [21]:
print(ts.first().as_newick(include_branch_lengths=True))

((n2:0.16140311708806793,n3:0.16140311708806793):1.78920588471614894,(n4:0.77003148971445468,(n0:0.09385361654436790,n1:0.09385361654436790):0.67617787317008682):1.18057751208976214);


In [22]:
tree = init_msp_tree(k=5, Ne=1, ploidy=2, seed=42)
print(tree.to_newick())

((l2:0.16140311708806793,l3:0.16140311708806793):1.789205884716149,(l4:0.7700314897144547,(l0:0.0938536165443679,l1:0.0938536165443679):0.6761778731700868):1.1805775120897621);
