# Tutorial: Exploring coalescent variation across the genome

## 1. Introduction

**1.1 Background**

With the increasing ease to generate large-scale genomic datasets, our ability to retrace evolutionary history based on DNA sequences has never been more profound. Traditional phylogenetic studies relied on one or a few molecular markers, but it is now becoming commonplace to survey phylogenetic patterns across 100s to 1000s of genetic loci. Nearing a genome-wide view on evolutionary history has major consequences for the field of phylogenetics since phylogenetic history often varies across the genome. In other words, a phylogenetic tree in one region of the genome may vary substantially in topology from a phylogenetic tree based on genetic data from a different region of the genome. A major objective in evolutionary and phylogenetic research is to understand the processes that lead to such discordance and to distinguish between neutral and non-neutral evolutionary dynamics.

**Empirical example demonstrating how the frequency of competing phylogenies can vary across chromosomes and between chromosomes (e.g. sex vs. autosomes)**

![example.png](img/example.png)

Coalescent theory presents a population genetic framework that can describe the expected differences in inheritance patterns between genomic loci based on neutral processes alone. Coalescent models aim to describe the probability that any two individuals in a population(s) share a common ancestor at a given time in the past and can easily be extended to multiple populations/species. Explicit demographic parameters, which we will explore in detail, change the probability of a coalescent event (i.e. two individuals share a common ancestor in the previous generation). Coalescent models can therefore be leveraged to generate null-models under neutral evolution and can be compared to empirical observations (i.e. phylogenomic data). In this computer practical, we use a coalescent based simulator to simulate tree sequences across a genomic region and demonstrate how the expected variation in coalescent histories change depending on several demographic properties.

**1.2 Practical overview**

This computer practical is styled in a Jupyter Notebook (NB) format (see Github [README](https://github.com/MozesBlom/tutorials/tree/main/2023_Phy_Eco_Evol) for more details). In short, Jupyter NBs contain either text cells (such as the current cell) in Markdown format or code cells (frequently) in Python3 format. The aim of this practical is not to provide an exhaustive introduction into Python/Markdown/Jupyter, but to provide an introduction into phylogenomic inference, how to account for genome wide variation in inheritance histories (coalescent variation) and what may cause such variation. Therefore, all code to run this practical is already in place and we will only make minor changes to the code itself to demonstrate the phylogenetic consequences of changing specific model parameters.

**1.3 Requirements**

To run this practical, the following Python modules are needed:
* Python 3+
    * [msprime](https://msprime.readthedocs.io/en/stable/)

## 2. Getting Started

**2.1 Importing modules**

The Python modules mentioned under *1.3 Requirements* first need to be imported before proceeding with the rest of the NB. If running on a Google Colab instance (see [Github](https://github.com/MozesBlom/tutorials/tree/main/2023_Phy_Eco_Evol) for further details) modules will always need to be installed. If running on your personal desktop, in JupyterLab for example, then this may not be needed if a conda environment is already loaded and [selected](https://github.com/jupyterlab/jupyterlab-desktop/blob/master/user-guide.md) with the correct modules installed. NOTE, if you go the latter route then make sure that JupyterLab itself and all dependencies are installed in the relevant Conda environment. Otherwise, the environment will not pop up in the JupyterLab Desktop environment list and cannot be selected. In all other user cases, first proceed with downloading the required Python modules:

In [None]:
# Download the necessary Python modules. If already present in Conda environment, then skip this step
%pip install msprime 

In [None]:
# Import the modules
import msprime
from IPython.display import SVG

## 3. Simulating ancestry with msprime

msprime is a backward-in-time coalescent simulator that can be used to simulate trees and sequences under explicit demographic scenarios. It computes the probability that any two individuals, their haplotypes to be precise, share a common ancestor at a given time in the past. In a single population, under neutral circumstances, this is primarily determined by effective population size. Between populations, the probability of coalescence is further determined by the rate of migration between populations and the time of divergence. While used in a population context, the same principles are relevant when thinking about phylogenetic history between species and is therefore a suitable exercise to explore the determinants of coalescent variation across genomes.

### 3.1 The Tree Sequence class

In [None]:
# Let's run a simple simulation using msprime
ts = msprime.sim_ancestry(samples=2, random_seed=1)
SVG(ts.draw_svg())

**QUESTION**: We specified to run a simulation with 2 individuals, why does the resulting phylogeny/tree include four tips?

**ANSWER**: ..

In [None]:
# Let's first dive deeper into the tree sequence that msprime generated
ts

msprime didn't simply generate an image of a phylogeny, but instead produced a summary of the demographic simulation which in this case describes a single phylogeny and all associated meta-information. For example, the number of individuals (`Individuals`), branches (`edges`), migration events (`Migrations`) etc. Importantly note that in this specific model, we simulated a tree sequence for a single site (`Sequence Length` = 1) only. `Sequence Length` can be interpreted as the number of continuous basepairs on a single contig or chromosome. Let's run a simulation for a longer sequence, say 10000 base pairs (bp).

In [None]:
ts = msprime.sim_ancestry(
                samples=2,
                random_seed=1,
                sequence_length=10000)
SVG(ts.draw_svg())

In [None]:
ts

Besides the `Sequence Length` nothing really changed. In example, we still only see one tree and the topology is exactly the same. This is because we haven't simulated any recombination events yet. **Without recombination, the evolutionary history of a genomic stretch will always be identical and we therefore do not expect any different coalescent history along a genomic sequence**! Recombination can be included in our simulation by specifying a `recombination_rate`: The probability of a recombination event per genomic unit (bp) per generation.

In [None]:
ts = msprime.sim_ancestry(
                samples=2,
                random_seed=1,
                sequence_length=10000,
                recombination_rate = 1e-4)
SVG(ts.draw_svg())

This is already more biologically meaningful: Recombination has introduced breakpoints along a chromosome and the coalescent history on each side of the breakpoint is different. Note, as you can see that does not always mean that the topology (order of diversification) changes. In some instances only the tree height, the time to the most recent common ancestor differs (which may be difficult to see in some trees). Now let's have another look at the `Tree Sequence` python class:

In [None]:
ts

By invoking recombination, the number of `Trees`, `Edges`, `Nodes` all have changed but all this information is still captured in a single `Tree Sequence` class. The `Tree Sequence` class is a helpful Python recipe to greatly reduce the computational footprint when studying phylogenies. In example, consider that we only simulated a 10000 bp. sequence here, but many genomes tend to be billions of bp. in size and studies can now include 1000s of individuals! The `Tree Sequence` class doesn't store every single phylogeny at any given site in a chromosome, but instead notes the (recombination) break points along a sequence where trees differ and only records the branches in the phylogeny that differ between break points and the differences in node heights.

However, besides the computational advantages of using the `Tree Sequence` class, there's also an important biological interpretation here as well. Evolutionary history along a chromosome, coalescent history, only varies between sites if a **recombination** event has occurred. Otherwise adjacent sites in a given sequence will share the same evolutionary history. Most phylogenetic methods are designed to reconstruct a single evolutionary history for a given sequence alignment and thus assume that this segment is free from recombination.

**EXERCISE**: Explore the role of recombination for modeling coalescent variation by changing the frequency of recombination in the code below. What happens when you increase the frequency and what happens if you decrease the frequency?

**ANSWER**: ...

In [None]:
ts = msprime.sim_ancestry(
                samples=2,
                random_seed=1,
                sequence_length=10000,
                recombination_rate = 1e-4)

print("The total number of trees simulated by changing the recombination frequency: " + str(ts.num_trees))

### 3.2 Effective population size

In the first section, we simulated the coalescent history for two individuals in a single population of a constant size. However, as mentioned before, population size is one of the major determinants that changes the probability of coalescence. Consider the following:

Assume an isolated island population of constant size (100 individuals) with random mating: What is the probability that two individuals share a common ancestor in the previous generation? What if it is a much bigger island with a larger constant population size over time (n = 1000)?

Let's explore how differences in (effective) population size influence our coalescent history using msprime

In [None]:
# Let's run two simulations and only vary Ne using a new parameter
ts1 = msprime.sim_ancestry(
                samples=2,
                random_seed=1,
                sequence_length=10000,
                population_size=10,
                recombination_rate = 1e-5)

ts2 = msprime.sim_ancestry(
                samples=2,
                random_seed=1,
                sequence_length=10000,
                population_size=100,
                recombination_rate = 1e-5)

In [None]:
SVG(ts1.draw_svg())

In [None]:
SVG(ts2.draw_svg())

Comparing the trees between both simulations isn't straightforward because the scale on the y-axis isn't the same. When invoking the `population_size` argument the branch lengths are automatically scaled to units of generations. We can then compare the differences in coalescent times by looking at the first tree for example.

In [None]:
tree1 = ts1.first()
print("Total branch length:", tree1.total_branch_length)
print("Time at root:", ts1.tables.nodes.time[tree1.root])

In [None]:
tree2 = ts2.first()
print("Total branch length:", tree2.total_branch_length)
print("Time at root:", ts2.tables.nodes.time[tree2.root])

**QUESTION**: What has happened? Under what circumstances do you expect a longer wait time until the Most Recent Common Ancestor of all individuals?

**ANSWER**: ...

### 3.3 A multi-population/species model

Up till now, we have only modeled coalescent variation across a chromosome within a single population/species. We have seen that phylogenies can vary in both divergence times and topology, under neutral circumstances. However, when employing phylogenetic approaches, we are often interested in describing the relationship between species. We can establish similar models using msprime and use this to demonstrate how gene trees (a phylogeny for a given gene or genomic region) can differ from the underlying species tree.

To run a multi-population model in msprime, we first need to describe the demographic model. Let's simulate the following:

- 4 species
- 2 individuals for each species
- constant population size over time
- no migration/gene flow between species
- equal divergence time between major splits

In [None]:
# To do so in msprime, we need to describe a demography object which then can be passed to `sim_ancestry`
Ne = 100
dem = msprime.Demography()
dem.add_population(name="human", initial_size=Ne)
dem.add_population(name="chimp", initial_size=Ne)
dem.add_population(name="gorilla", initial_size=Ne)
dem.add_population(name="mouse", initial_size=Ne)

# Table with basic demography model of extant species
dem

What we have created is a demographic model of the extant species. Running a simulation on this would run forever, since the simulation only stops when ALL individuals across all species have coalesced. In this model, this would never be the case because there would be no coalescent events between species. So let's specify the divergence times and ancestral populations:

In [None]:
# Add ancestral populations and the divergence times when they split off
div_t = 1000
dem.add_population(name="AncestralPopulation1", initial_size=Ne)
dem.add_population_split(time=div_t, derived=["human","chimp"], ancestral="AncestralPopulation1")
dem.add_population(name="AncestralPopulation2", initial_size=Ne)
dem.add_population_split(time=div_t*2, derived=["AncestralPopulation1","gorilla"], ancestral="AncestralPopulation2")
dem.add_population(name="AncestralPopulation3", initial_size=Ne)
dem.add_population_split(time=div_t*3, derived=["AncestralPopulation2","mouse"], ancestral="AncestralPopulation3")
dem

Note that the ancestral populations have now been included and we added three events: Each corresponding to a divergence time.

We can now create a simulation for this specific demographic model. Let's first simulate for a single tree

In [None]:
sample_size = 2

ts = msprime.sim_ancestry(
                samples={"human" : sample_size, "chimp" : sample_size, "gorilla" : sample_size, "mouse" : sample_size, "AncestralPopulation1" : 0, "AncestralPopulation2" : 0, "AncestralPopulation3" : 0},
                demography=dem,
                random_seed=1)
display(SVG(ts.first().draw(width=500, height=400)))

Since we are now working with a larger number of species and individuals, the phylogeny has expanded but it's not informative with regards to which tip belongs to which species. However, since we are working with species, there is now a new population/species column which can be used to colour the nodes by species

In [None]:
colour_map = {0:"red", 1:"blue", 2:"green", 3:"orange", 4:"purple", 5:"black", 6:"yellow", 7:"pink", 8:"brown", 9:"gray"}
node_colours = {u.id: colour_map[u.population] for u in ts.nodes()}
# The code below will only work in a Jupyter notebook with SVG output enabled.
display(SVG(ts.first().draw(node_colours=node_colours, width=500, height=400)))

Under the present model specifications, we have simulated a single phylogeny where we sampled two individuals per species and reconstructed the coalescent history of a single locus. In this simulation, all haplotypes belonging to the same species (coloured nodes here) share a MRCA and the gene tree will be identical to the species tree. However, we have already seen that effective population size has a strong effect on the probability that two haplotypes share a common ancestor in the previous generation and here we will demonstrate that this can lead to mismatches between gene and species tree.

**EXERCISE**: Increase the effective population size in the following model and simulate ancestry. Try a few values, what happens?

**ANSWER**: ...

In [None]:
# Model parameters
Ne = 10000
div_t = 1000
sample_size = 2

# ------------------ #

dem = msprime.Demography()
dem.add_population(name="human", initial_size=Ne)
dem.add_population(name="chimp", initial_size=Ne)
dem.add_population(name="gorilla", initial_size=Ne)
dem.add_population(name="mouse", initial_size=Ne)

dem.add_population(name="AncestralPopulation1", initial_size=Ne)
dem.add_population_split(time=div_t, derived=["human","chimp"], ancestral="AncestralPopulation1")
dem.add_population(name="AncestralPopulation2", initial_size=Ne)
dem.add_population_split(time=div_t*2, derived=["AncestralPopulation1","gorilla"], ancestral="AncestralPopulation2")
dem.add_population(name="AncestralPopulation3", initial_size=Ne)
dem.add_population_split(time=div_t*3, derived=["AncestralPopulation2","mouse"], ancestral="AncestralPopulation3")

ts = msprime.sim_ancestry(
                samples={"human" : sample_size, "chimp" : sample_size, "gorilla" : sample_size, "mouse" : sample_size, "AncestralPopulation1" : 0, "AncestralPopulation2" : 0, "AncestralPopulation3" : 0},
                demography=dem,
                random_seed=1)

colour_map = {0:"red", 1:"blue", 2:"green", 3:"orange", 4:"purple", 5:"black", 6:"yellow", 7:"pink", 8:"brown", 9:"gray"}
node_colours = {u.id: colour_map[u.population] for u in ts.nodes()}
# The code below will only work in a Jupyter notebook with SVG output enabled.
display(SVG(ts.first().draw(node_colours=node_colours, width=500, height=400)))

The process you have simulated above is commonly known as **deep coalescence** or **incomplete lineage sorting** (ILS). In such cases, gene trees will not follow the species tree and a coalescence event between two haplotypes of the same species predates a species split. In the above example, you have only changed the effective population size but the probability of ILS also depends on other parameters. **Question**: How could we change the model above, without changing the effective population size parameter, so that all haplotypes belonging to the same species share a MRCA?

In [None]:
# Model parameters
Ne = 10000
div_t = 1000
sample_size = 2

# ------------------ #

dem = msprime.Demography()
dem.add_population(name="human", initial_size=Ne)
dem.add_population(name="chimp", initial_size=Ne)
dem.add_population(name="gorilla", initial_size=Ne)
dem.add_population(name="mouse", initial_size=Ne)

dem.add_population(name="AncestralPopulation1", initial_size=Ne)
dem.add_population_split(time=div_t, derived=["human","chimp"], ancestral="AncestralPopulation1")
dem.add_population(name="AncestralPopulation2", initial_size=Ne)
dem.add_population_split(time=div_t*2, derived=["AncestralPopulation1","gorilla"], ancestral="AncestralPopulation2")
dem.add_population(name="AncestralPopulation3", initial_size=Ne)
dem.add_population_split(time=div_t*3, derived=["AncestralPopulation2","mouse"], ancestral="AncestralPopulation3")

ts = msprime.sim_ancestry(
                samples={"human" : sample_size, "chimp" : sample_size, "gorilla" : sample_size, "mouse" : sample_size, "AncestralPopulation1" : 0, "AncestralPopulation2" : 0, "AncestralPopulation3" : 0},
                demography=dem,
                random_seed=1)

colour_map = {0:"red", 1:"blue", 2:"green", 3:"orange", 4:"purple", 5:"black", 6:"yellow", 7:"pink", 8:"brown", 9:"gray"}
node_colours = {u.id: colour_map[u.population] for u in ts.nodes()}
# The code below will only work in a Jupyter notebook with SVG output enabled.
display(SVG(ts.first().draw(node_colours=node_colours, width=500, height=400)))

Up till now we have simulated a single phylogeny only because we excluded the recombination parameter. Let's include recombination and investigate how coalescence histories can vary across a chromosome. To simplify the simulation time, we reduce the effective population size and divergence time. 

In [None]:
# Model parameters
Ne = 10
div_t = 10
sample_size = 1

# ------------------ #

dem = msprime.Demography()
dem.add_population(name="human", initial_size=Ne)
dem.add_population(name="chimp", initial_size=Ne)
dem.add_population(name="gorilla", initial_size=Ne)
dem.add_population(name="mouse", initial_size=Ne)

dem.add_population(name="AncestralPopulation1", initial_size=Ne)
dem.add_population_split(time=div_t, derived=["human","chimp"], ancestral="AncestralPopulation1")
dem.add_population(name="AncestralPopulation2", initial_size=Ne)
dem.add_population_split(time=div_t*2, derived=["AncestralPopulation1","gorilla"], ancestral="AncestralPopulation2")
dem.add_population(name="AncestralPopulation3", initial_size=Ne)
dem.add_population_split(time=div_t*3, derived=["AncestralPopulation2","mouse"], ancestral="AncestralPopulation3")

ts = msprime.sim_ancestry(
                samples={"human" : sample_size, "chimp" : sample_size, "gorilla" : sample_size, "mouse" : sample_size, "AncestralPopulation1" : 0, "AncestralPopulation2" : 0, "AncestralPopulation3" : 0},
                demography=dem,
                sequence_length=10000,
                recombination_rate = 1e-4,
                random_seed=1)

colour_map = {0:"red", 1:"blue", 2:"green", 3:"orange", 4:"purple", 5:"black", 6:"yellow", 7:"pink", 8:"brown", 9:"gray"}
node_colours = {u.id: colour_map[u.population] for u in ts.nodes()}
# The code below will only work in a Jupyter notebook with SVG output enabled.
SVG(ts.draw_svg())

In [None]:
# If you want to look at the first tree only, coloured by species association:
display(SVG(ts.first().draw(node_colours=node_colours, width=500, height=400)))

**SUMMARY**: With the above simulations, we have demonstrated that effective population size and divergence time are important determinants that predict the probability of incomplete lineage sorting and that this can vary across a chromosome since recombination leads to distinct units that can have a different evolutionary history: Even under neutral circumstances (i.e. without selection, introgression, etc.). Modern day phylogenomic studies need to account for this variation in coalescent history and determine what processes have led to this discordance. For example, have a look at the distribution of topologies in the empirical example introduced above:

**Empirical example demonstrating how the frequency of competing phylogenies can vary across chromosomes and between chromosomes (e.g. sex vs. autosomes)**

![example.png](img/example.png)

When comparing autosomes with the sex chromosomes, there is much stronger support for a single topology on the sex chromosome (i.e. less ILS): The Z-chromosome in birds. Could you come up with a neutral explanation why that may be? Hint: Birds are diploid organisms and the Z-chromosome is equivalent to the X-chromosome in humans.

## 4. BONUS EXERCISE: Adding introgression

Genome wide variation in coalescent histories can be solely driven by a neutral process such as incomplete lineage sorting, but evolutionary studies increasingly report the prevalence of hybridization and introgression between closely related taxa across the Tree-of-Life. Introgression can lead to a substantial skew in phylogenetic signal and result in an unevennes in the relative frequency of topologies observed across a genome. In the above figure for example, the proportion of topologies is uneven among the gene trees incongruent with the species tree (BBAA). There is an excess of topologies supporting a close relationship between *Cicinurrus* and *Parotia* (ABBA) on the autosomes which is likely the outcome of past hybridization. We can demonstrate this phenomenon using msprime by including a new parameter which describes the rate of gene flow between lineages (a migration matrix).

### 4.1 Evaluating introgression with (a) TWISST
Without going into too much detail, we will use an approach called *'Topology Weighting'* ([TWISST](https://github.com/simonhmartin/twisst/tree/master)) to summarise and visualise the distribution of topologies across a chromosome. Normally TWISST is used as a standalone Python script but here we will directly incorporate the Python code underlying TWISST. If need be, we will first download one of the Python dependencies needed for TWISST.

In [None]:
# If needed, download ete3, numpy and matplotlib. If already present in Conda environment, then skip this step
%pip install ete3 numpy matplotlib

**The following cell contains all the code needed for running Twisst. Simply execute so we can use the TWISST functions in the subsequent section:**

In [None]:
import argparse
import itertools
import sys
import gzip
import ete3
import random
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from collections import deque

np.seterr(divide='ignore', invalid="ignore")

verbose = False
##############################################################################################################################

def sample(things, n = None, replace = False):
    if n == None: n = len(things)
    if replace == False: return random.sample(things,n)
    else: return [random.choice(things) for i in range(n)]

def randomComboGen(lists):
    while True: yield tuple(random.choice(l) for l in lists)

def readTree(newick_tree):
    try:
        if newick_tree[0] == "[": return ete3.Tree(newick_tree[newick_tree.index("]")+1:])
        else: return ete3.Tree(newick_tree)
    except:
        return None

def asciiTrees(trees, nColumns = 5):
    treeLines = [tree.get_ascii().split("\n") for tree in trees]
    maxLines = max(map(len,treeLines))
    for tl in treeLines:
        #add lines if needed
        tl += [""]*(maxLines - len(tl))
        #add spaces to each line to make all even
        lineLengths = map(len,tl)
        maxLen = max(lineLengths)
        for i in range(len(tl)): tl[i] += "".join([" "]*(maxLen-len(tl[i])))
    #now join lines that will be on the same row and print
    treeLinesChunked = [treeLines[x:(x+nColumns)] for x in range(0,len(trees),nColumns)]
    zippedLinesChunked = [zip(*chunk) for chunk in treeLinesChunked]
    return "\n\n".join(["\n".join(["    ".join(l) for l in chunk]) for chunk in zippedLinesChunked])


def getPrunedCopy(tree, leavesToKeep, preserve_branch_length):
    pruned = tree.copy("newick")
    ##prune function was too slow for big trees
    ## speeding up by first deleting all other leaves
    for leaf in pruned.iter_leaves():
        if leaf.name not in leavesToKeep: leaf.delete(preserve_branch_length=preserve_branch_length)
    #and then prune to fix the root (not sure why this is necessary, but it is)
    #but at least it's faster than pruning the full tree
    pruned.prune(leavesToKeep, preserve_branch_length = preserve_branch_length)
    return pruned

class NodeChain(deque):
    def __init__(self, nodeList, dists=None):
        super(NodeChain, self).__init__(nodeList)
        if dists is None: self.dists = None
        else:
            assert len(dists) == len(self)-1, "incorrect number of iternode distances"
            self.dists = deque(dists)
        self._set_ = None
    
    def addNode(self, name, dist=0):
        self.append(name)
        if self.dists is not None: self.dists.append(dist)
    
    def addNodeLeft(self, name, dist=0):
        self.appendleft(name)
        if self.dists is not None: self.dists.appendleft(dist)
    
    def addNodeChain(self, chainToAdd, joinDist=0):
        self.extend(chainToAdd)
        if self.dists is not None:
            assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances"
            self.dists.append(joinDist)
            self.dists.extend(chainToAdd.dists)
    
    def addNodeChainLeft(self, chainToAdd, joinDist=0):
        self.extendleft(chainToAdd)
        if self.dists is not None:
            assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances"
            self.dists.appendleft(joinDist)
            self.dists.extendleft(chainToAdd.dists)
    
    def chopLeft(self):
        self.popleft()
        if self.dists is not None: self.dists.popleft()
    
    def chop(self):
        self.pop()
        if self.dists is not None: self.dists.pop()
    
    def fuseLeft(self, chainToFuse):
        new = NodeChain(self, self.dists)
        assert new[0] == chainToFuse[0], "No common nodes"
        i = 1
        while new[1] == chainToFuse[i]:
            new.chopLeft()
            i += 1
        m = len(chainToFuse)
        while i < m:
            new.addNodeLeft(chainToFuse[i], chainToFuse.dists[i-1] if self.dists is not None else None)
            i += 1
        return new
    
    def simplifyToEnds(self, newDist=None):
        if self.dists is not None:
            if not newDist: newDist = sum(self.dists)
            self.dists.clear()
        leftNode = self.popleft()
        rightNode = self.pop()
        self.clear()
        self.append(leftNode)
        self.append(rightNode)
        if self.dists is not None:
            self.dists.append(newDist)
    
    def setSet(self):
        self._set_ = set(self)


##simpler version that only collapses monophyletic clades
#def getChainsToLeaves(node, collapseDict = None):
    #children = node.get_children()
    #if children == []:
        #node.add_feature("weight", 1)
        #return [NodeChain(node)]
    #chains = list(itertools.chain(*[getChainsToLeaves(child, collapseDict) for child in children]))
    #if (collapseDict and sum([len(chain) for chain in chains]) == len(chains) and
        #len(set([collapseDict[chain[0].name] for chain in chains])) == 1):
        ##all chains are a leaf from same group, so we collapse
        #newWeight = sum([chain[0].weight for chain in chains])
        #newDist = node.dist + sum([chain[0].dist * chain[0].weight * 1. for chain in chains]) / newWeight
        #chains[0][0].dist = newDist
        #chains[0][0].weight = newWeight
        #chains = [chains[0]]
    #else:
        #for chain in chains:
            #chain.addNodeLeft(node, dist=chain[0].dist)
    
    #return chains


def getChainsToLeaves(node, collapseDict = None, preserveDists = False):
    children = node.get_children()
    if children == []:
        #if it has no children is is a child, so just record a weight for the node and return is as a new 1-node chain
        chain = NodeChain([node], dists = [] if preserveDists else None)
        setattr(chain, "weight", 1)
        return [chain]
    #otherwise get chains for all children
    childrenChains = [getChainsToLeaves(child, collapseDict, preserveDists) for child in children]
    #now we have the chains from all children, we need to add the current node
    for childChains in childrenChains:
        for chain in childChains: chain.addNodeLeft(node, dist=chain[0].dist if preserveDists else None)
    
    #if collapsing, check groups for each node
    if collapseDict:
        nodeGroupsAll = np.array([collapseDict[chain[-1].name] for childChains in childrenChains for chain in childChains])
        nodeGroups = list(set(nodeGroupsAll))
        nGroups = len(nodeGroups)
        
        if (nGroups == 1 and len(nodeGroupsAll) > 1):
            #all leaves are from same group, so collapse to one chain
            #we can also preserve distances when doing this type of collapsing
            #first list all chains
            chains = [chain for childChains in childrenChains for chain in childChains]
            newWeight = sum([chain.weight for chain in chains])
            if preserveDists:
                newDist = sum([sum(chain.dists) * chain.weight * 1. for chain in chains]) / newWeight
                chains[0].simplifyToEnds(newDist = newDist)
            else:
                chains[0].simplifyToEnds()
            chains[0].weight = newWeight
            chains = [chains[0]]
        
        elif (nGroups == 2 and len(nodeGroupsAll) > 2 and preserveDists==False):
            #all chains end in a leaf from one of two groups, so we can simplify.
            #first list all chains
            chains = [chain for childChains in childrenChains for chain in childChains]
            #Start by getting index of each chain for each group
            indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups]
            #the new weight for each chain we keep will be the total node weight of all from each group 
            newWeights = [sum([chains[i].weight for i in idx]) for idx in indices]
            #now reduce to just a chain for each group 
            chains = [chains[idx[0]] for idx in indices]
            for j in range(nGroups):
                chains[j].simplifyToEnds()
                chains[j].weight = newWeights[j]
        
        #if we couldn't simply collapse completely, we might still be able to merge down a side branch
        #Side branches are child chains ending in a single leaf
        #If there is a lower level child branch that is itself a side branch, we can merge to it
        elif (preserveDists == False and len(childrenChains) == 2 and
            ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or
            (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))):
            chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0])
            #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group
            targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0]
            if len(targets) == 1 and collapseDict[chains[targets[0]][-1].name] == collapseDict[sideChain[-1].name]:
                #we have found a suitable chain to merge to
                targetChain = chains[targets[0]]
                newWeight = targetChain.weight + sideChain.weight
                targetChain.simplifyToEnds()
                targetChain.weight = newWeight
            else:
                #if we didn't find a suitable match, just add side chain
                chains.append(sideChain)
        else:
            #if there was no side chain, just list all chains
            chains = [chain for childChains in childrenChains for chain in childChains]
    #otherwise we are not collapsing, so just list all chains
    else:
        #chains = list(itertools.chain(*[getChainsToLeaves(child, collapseDict) for child in children]))
        chains = [chain for childChains in childrenChains for chain in childChains]
    #now we have the chains from all children, we need to add the current node
    
    return chains

#version for tree sequence tree format from msprime and tsinfer
def getChainsToLeaves_ts(tree, node=None, collapseDict = None):
    if node is None: node = tree.root
    children = tree.children(node)
    if children == ():
        #if it has no children is is a child
        #if it's in the collapseDict or there is not collapseDict
        #just record a weight for the node and return is as a new 1-node chain
        if collapseDict is None or node in collapseDict:
            chain = NodeChain([node])
            setattr(chain, "weight", 1)
            return [chain]
        else:
            return []
    #otherwise get chains for all children
    childrenChains = [getChainsToLeaves_ts(tree, child, collapseDict) for child in children]
    #now we have the chains from all children, we need to add the current node
    for childChains in childrenChains:
        for chain in childChains: chain.addNodeLeft(node)
    
    #if collapsing, check groups for each node
    if collapseDict:
        nodeGroupsAll = np.array([collapseDict[chain[-1]] for childChains in childrenChains for chain in childChains])
        nodeGroups = list(set(nodeGroupsAll))
        nGroups = len(nodeGroups)
        
        if (nGroups == 1 and len(nodeGroupsAll) > 1) or (nGroups == 2 and len(nodeGroupsAll) > 2):
            #all chains end in a leaf from one or two groups, so we can simplify.
            #first list all chains
            chains = [chain for childChains in childrenChains for chain in childChains]
            #Start by getting index of each chain for each group
            indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups]
            #the new weight for each chain we keep will be the total node weight of all from each group 
            newWeights = [sum([chains[i].weight for i in idx]) for idx in indices]
            #now reduce to just a chain for each group 
            chains = [chains[idx[0]] for idx in indices]
            for j in range(nGroups):
                chains[j].simplifyToEnds()
                chains[j].weight = newWeights[j]
        
        #if we couldn't simply collapse completely, we might still be able to merge down a side branch
        #Side branches are child chains ending in a single leaf
        #If there is a lower level child branch that is itself a side branch, we can merge to it
        elif (len(childrenChains) == 2 and
            ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or
            (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))):
            chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0])
            #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group
            targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0]
            if len(targets) == 1 and collapseDict[chains[targets[0]][-1]] == collapseDict[sideChain[-1]]:
                #we have found a suitable internal chain to merge to
                targetChain = chains[targets[0]]
                newWeight = targetChain.weight + sideChain.weight
                targetChain.simplifyToEnds()
                targetChain.weight = newWeight
            else:
                #if we didn't find a suitable match, just add side chain
                chains.append(sideChain)
        else:
            #if there was no side chain, just list all chains
            chains = [chain for childChains in childrenChains for chain in childChains]
    #otherwise we are not collapsing, so just list all chains
    else:
        chains = [chain for childChains in childrenChains for chain in childChains]
    #now we have the chains from all children, we need to add the current node
    
    return chains


def makeRootLeafChainDict(tree, collapseDict = None, preserveDists=False, treeFormat = "ete3"):
    if treeFormat == "ts":
        chains = getChainsToLeaves_ts(tree, collapseDict=collapseDict)
        return dict([(chain[-1],chain) for chain in chains])
    else:
        chains = getChainsToLeaves(tree, collapseDict=collapseDict, preserveDists=preserveDists)
        return dict([(chain[-1].name,chain) for chain in chains])

def makeLeafLeafChainDict(rootLeafChainDict, pairs):
    leafLeafChainDict = defaultdict(defaultdict)
    
    for pair in pairs:
        #get the leaf leaf chain by removing the unshared ancestors and joining root leaf chains end to end
        leafLeafChainDict[pair[0]][pair[1]] = rootLeafChainDict[pair[0]].fuseLeft(rootLeafChainDict[pair[1]])
    
    return leafLeafChainDict


def checkDisjointChains(leafLeafChains, pairsOfPairs, samples=None):
    if not samples:
        return [leafLeafChains[pairs[0][0]][pairs[0][1]]._set_.isdisjoint(leafLeafChains[pairs[1][0]][pairs[1][1]]._set_) for pairs in pairsOfPairs]
    else:
        return [leafLeafChains[samples[pairs[0][0]]][samples[pairs[0][1]]]._set_.isdisjoint(leafLeafChains[samples[pairs[1][0]]][samples[pairs[1][1]]]._set_) for pairs in pairsOfPairs]


def pairsDisjoint(pair1,pair2):
    return pair1[0] != pair2[0] and pair1[0] != pair2[1] and pair1[1] != pair2[0] and pair1[1] != pair2[1]

def makeTopoDict(taxonNames, topos=None, outgroup = None):
    output = {}
    output["topos"] = allTopos(taxonNames, []) if topos is None else topos
    if outgroup:
        for topo in output["topos"]: topo.set_outgroup(outgroup)
    output["n"] = len(output["topos"])
    pairs = list(itertools.combinations(taxonNames,2))
    pairsNumeric = list(itertools.combinations(range(len(taxonNames)),2))
    output["pairsOfPairs"] = [y for y in itertools.combinations(pairs,2) if pairsDisjoint(y[0],y[1])]
    output["pairsOfPairsNumeric"] = [y for y in itertools.combinations(pairsNumeric,2) if pairsDisjoint(y[0],y[1])]
    output["chainsDisjoint"] = []
    for tree in output["topos"]:
        rootLeafChains = makeRootLeafChainDict(tree)
        leafLeafChains = makeLeafLeafChainDict(rootLeafChains, pairs)
        for pair in pairs: leafLeafChains[pair[0]][pair[1]].setSet()
        output["chainsDisjoint"].append(checkDisjointChains(leafLeafChains, output["pairsOfPairs"]))
    return output

def makeGroupDict(groups, names=None):
    groupDict = {}
    for x in range(len(groups)):
        for y in groups[x]: groupDict[y] = x if not names else names[x]
    return groupDict

#Main weighting function that uses "chains" to check topologies and simplifies while generating chains
def weightTree(tree, taxa, taxonDict=None, pairs=None, topoDict=None, nIts=None,
                     getDists=False, simplify=True, abortCutoff=None, treeFormat="ete3", verbose=True,
                     taxonNames=None, outgroup=None):
    
    nTaxa = len(taxa)
    
    if not taxonDict: taxonDict = makeGroupDict(taxa)
    
    if pairs is None:
        pairs = [pair for taxPair in itertools.combinations(taxa,2) for pair in itertools.product(*taxPair)]
    
    rootLeafChains = makeRootLeafChainDict(tree, collapseDict=taxonDict if simplify else None, preserveDists=getDists, treeFormat=treeFormat)
    leavesRetained = rootLeafChains.keys()
    leavesRetainedSet = set(leavesRetained)
    leafWeights = dict([(ind, rootLeafChains[ind].weight) for ind in leavesRetained])
    _pairs = [pair for pair in pairs if pair[0] in leavesRetainedSet and pair[1] in leavesRetainedSet]
    leafLeafChains = makeLeafLeafChainDict(rootLeafChains, pairs=_pairs)
    #make a set for each chain so that 
    for pair in _pairs: leafLeafChains[pair[0]][pair[1]].setSet()
    
    if topoDict is None:
        if taxonNames is None: taxonNames = [str(x) for x in range(len(taxa))]
        topoDict = makeTopoDict(taxonNames, outgroup=outgroup)
    
    _taxa = [[ind for ind in taxon if ind in leavesRetainedSet] for taxon in taxa]
    
    if getDists:
        assert taxonNames is not None, "taxonNames required for recording pairwise distances"
        dists = np.zeros([nTaxa, nTaxa, topoDict["n"]])
    
    #we make a generator object for all combos
    nCombos = np.prod([len(t) for t in _taxa])
    #if not speciified assume all combinations must be considered
    if nIts is None: nIts = nCombos
    #if doing all combos, we use an exhaustive combo generator
    if nIts >= nCombos:
        if verbose: sys.stderr.write("Complete weighting for {} combinations\n".format(nCombos))
        #unless there are too many combos, in which case we abort
        if abortCutoff and nCombos > abortCutoff:
            if verbose: sys.stderr.write("Aborting\n")
            return None
        comboGenerator = itertools.product(*_taxa)
    #if we are doing a subset, then use a random combo generator, but make sure simplify was false
    else:
        #sys.stderr.write("Approximate weighting with {} combinations\n".format(nIts))
        assert not simplify, "Tree simplification should be turned off when considering only a subset of combinations."
        comboGenerator = randomComboGen(_taxa)
    
    #initialise counts array
    counts = [0]*topoDict["n"]
    i=0
    for combo in comboGenerator:
        i += 1
        chainsDisjoint = checkDisjointChains(leafLeafChains, topoDict["pairsOfPairsNumeric"], samples=combo)
        try: x = topoDict["chainsDisjoint"].index(chainsDisjoint)
        except:
            if i == nIts: break
            continue
        comboWeight = np.prod([leafWeights[ind] for ind in combo]) 
        counts[x] += comboWeight
        
        #get pairwise dists if necessary
        if getDists:
            comboPairs = [(combo[pair[0]], combo[pair[1]],) for pairs in topoDict["pairsOfPairsNumeric"] for pair in pairs]
            currentDists = np.zeros([nTaxa,nTaxa])
            for comboPair in comboPairs:
                taxPair = (taxonNames.index(taxonDict[comboPair[0]]), taxonNames.index(taxonDict[comboPair[1]]))
                currentDists[taxPair[0],taxPair[1]] = currentDists[taxPair[1],taxPair[0]] = sum(leafLeafChains[comboPair[0]][comboPair[1]].dists)
            dists[:,:,x] += currentDists*comboWeight
        
        if i == nIts: break
    
    meanDists = dists/counts if getDists else np.NaN
    return {"topos":topoDict["topos"], "weights":counts, "dists":meanDists}


def weightTrees(trees, taxa=None, taxonDict=None, pairs=None, topoDict=None, nIts=None,
                     getDists=False, simplify=True, abortCutoff=None, treeFormat="ete3", verbose=True,
                     taxonNames=None, outgroup=None):
    
    if taxa is None:
        assert(treeFormat=="ts"), "Taxa must be specified as a list of lists."
        if taxonNames is None: taxonNames = [str(pop.id) for pop in trees.populations()]
        taxa = [[s for s in trees.samples() if str(trees.get_population(s)) == t] for t in taxonNames]
        
    if topoDict is None:
        if taxonNames is None: taxonNames = [str(x) for x in range(len(taxa))]
        topoDict = makeTopoDict(taxonNames, outgroup=outgroup)
    
    if not taxonDict: taxonDict = makeGroupDict(taxa, names=taxonNames)
    
    if pairs is None:
        pairs = [pair for taxPair in itertools.combinations(taxa,2) for pair in itertools.product(*taxPair)]
    
    _trees_ = trees.trees() if treeFormat=="ts" else trees
    
    allTreeData = [weightTree(tree, taxa, taxonDict=taxonDict, pairs=pairs, topoDict=topoDict, nIts=nIts, getDists=getDists, simplify=simplify, abortCutoff=abortCutoff, treeFormat=treeFormat, verbose=verbose) for tree in _trees_]
    
    output = {"topos":allTreeData[0]["topos"]}
    output["dists"] = np.array([x["dists"] for x in allTreeData])
    output["weights"] = np.array([x["weights"] for x in allTreeData])
    output["weights_norm"] = np.apply_along_axis(lambda x: x/x.sum(), 1, output["weights"])
    
    return output


def summary(weightsData):
    if "weights_norm" not in weightsData:
        weights = np.apply_along_axis(lambda x: x/x.sum(), 1, weightsData["weights"])
    else:
        weights =weightsData["weights_norm"]
    meanWeights = weights.mean(axis=0)
    for i in range(len(meanWeights)):
        print("Topo", i+1)
        print(weightsData["topos"][i].get_ascii())
        print(round(meanWeights[i],3))
        print("\n\n")

def listToNwk(t):
    t = str(t)
    t = t.replace("[","(")
    t = t.replace("]",")")
    t = t.replace("'","")
    t += ";"
    return(t)

def allTopos(branches, _topos=None, _topo_IDs=None):
    if _topos is None or _topo_IDs is None:
        _topos = []
        _topo_IDs = set([])
    assert 4 <= len(branches) <= 8, "Please specify between 4 and 8 unique taxon names."
    #print("topos contains", len(_topos), "topologies.")
    #print("current tree is:", branches)
    for x in range(len(branches)-1):
        for y in range(x+1,len(branches)):
            #print("Joining branch", x, branches[x], "with branch", y, branches[y])
            new_branches = list(branches)
            new_branches[x] = [new_branches[x],new_branches.pop(y)]
            #print("New tree is:", new_branches)
            if len(new_branches) == 3:
                #print("Tree has three branches, so appending to topos.")
                #now check that the topo doesn't match a topology already in trees, and if not add it
                t = ete3.Tree(listToNwk(new_branches))
                ID = t.get_topology_id()
                if ID not in _topo_IDs:
                    _topos.append(t)
                    _topo_IDs.add(ID)
            else:
                #print("Tree still unresolved, so re-calling function.")
                _topos = allTopos(new_branches, _topos, _topo_IDs)
    #print(_topo_IDs)
    return(_topos)

def writeWeights(weightsFile, weightsData, include_topologies=True, include_header=True):
    nTopos = len(weightsData["topos"])
    if include_topologies:
        for x in range(nTopos): weightsFile.write("#topo" + str(x+1) + " " + weightsData["topos"][x].write(format = 9) + "\n") 
    if include_header:
        weightsFile.write("\t".join(["topo" + str(x+1) for x in range(nTopos)]) + "\n")
    #write weights
    weightsFile.write("\n".join(["\t".join(row) for row in weightsData["weights"].astype(str)]) + "\n")


def writeTsWindowData(filename, ts):
    with open("filename", "wt") as dataFile:
        dataFile.write("chrom\tstart\tend\n")
        dataFile.write("\n".join(["\t".join(["chr1", str(tree.interval[0]), str(tree.interval[1])]) for tree in ts.trees()]) + "\n")

#############################################################################################################################################

### 4.2 Simulating introgression with msprime (using a legacy version that is compatible with TWISST)
Unfortunately, TWISST has not been updated yet to work seamlessly with the most recent version of msprime (v1.+) so **we will use a legacy API msprime command** to simulate ancestry. Let's first simulate a four population model WITHOUT introgression.

In [None]:
pop_n = 10
pop_Ne = 1000

t_01 = 1000
t_02 = 5000
t_03 = 10000

########################################################################################################################

population_configurations = [msprime.PopulationConfiguration(sample_size=pop_n, initial_size=pop_Ne),
                             msprime.PopulationConfiguration(sample_size=pop_n, initial_size=pop_Ne),
                             msprime.PopulationConfiguration(sample_size=pop_n, initial_size=pop_Ne),
                             msprime.PopulationConfiguration(sample_size=pop_n, initial_size=pop_Ne)]

demographic_events = [msprime.MassMigration(time=t_01, source=1, destination=0, proportion=1.0), # first merge
                      msprime.MassMigration(time=t_02, source=2, destination=0, proportion=1.0), # next merge
                      msprime.MassMigration(time=t_03, source=3, destination=0, proportion=1.0)] # final merge

ts = msprime.simulate(population_configurations = population_configurations,
                      demographic_events = demographic_events,
                      length = 50000,
                      recombination_rate = 5e-7)

ts.num_trees

As you may have noticed the model specification/syntax is a little bit different but it creates a simular simulation as we have seen before. The output is also very similar, which is an object of the TreeSequence Class and we will then use the `weightTrees` function of TWISST to summarise the resulting topologies.

In [None]:
weightsData = weightTrees(ts, treeFormat="ts", outgroup = "3", verbose=False)
summary(weightsData)

**EXERCISE**: Can you think of a few ways to increase/decrease the unevenness between the observed frequency of topologies? Give it a try by rerunning the last two code cells

**ANSWER**: ...

The `weightTrees` function in Twisst rooted all simulated trees on one specific outgroup (a population, not a sample!) and then calculated the proportion of the observed topologies. In other words, a rooted four taxon phylogeny has three possible topologies and we calculated the relative frequency of each topology. We can also visualise a smoothed distribution of the topologies across the simulated sequence:

In [None]:
#extract mid positions on chromosome from tree sequence file
position = [(tree.interval[0] + tree.interval[1])/2 for tree in ts.trees()]

#normalise weights by dividing by number of combinations
weights = weightsData["weights"]/10000

#create a plot with all three topology weights
for i in range(3): 
    plt.plot(position, weights[:,i], label='topo'+str(i+1))
    
plt.legend()

**Let's include introgression in the model**. To integrate introgression in the simulation, we need to adjust the migration matrix and add a migration rate change into the `demographic events` parameter of the model. Let's simulate equal rate of gene flow between population 1 and population 2, before time point 1 and migration will stop after population 0 and population 1 have merged:

In [None]:
migration_matrix = [[0,    0,    0,    0],
                    [0,    0,    1e-4, 0],
                    [0,    1e-4, 0,    0],
                    [0,    0,    0,    0]]


t_01 = 1000
t_02 = 5000
t_03 = 10000

demographic_events = [msprime.MassMigration(time=t_01, source=1, destination=0, proportion=1.0), # first merge
                      msprime.MigrationRateChange(time=t_01, rate=0, matrix_index=(2, 1)), # mig stop after merge
                      msprime.MigrationRateChange(time=t_01, rate=0, matrix_index=(1, 2)), # migration is equal in both directions
                      msprime.MassMigration(time=t_02, source=2, destination=0, proportion=1.0), #next merge
                      msprime.MassMigration(time=t_03, source=3, destination=0, proportion=1.0)] #final merge

ts = msprime.simulate(population_configurations = population_configurations,
                      migration_matrix = migration_matrix,
                      demographic_events = demographic_events,
                      length = 50000,
                      recombination_rate = 5e-7)

ts.num_trees

Before moving on to calculating the observed proportion of each topology, how do you think the above migration scenario has impacted the frequency of each topology?

**Answer**:


In [None]:
weightsData = weightTrees(ts, treeFormat="ts", outgroup = "3", verbose=False)
summary(weightsData)

In [None]:
#extract mid positions on chromosome from tree sequence file
position = [(tree.interval[0] + tree.interval[1])/2 for tree in ts.trees()]

#normalise weights by dividing by number of combinations
weights = weightsData["weights"]/10000

#create a plot with all three topology weights
for i in range(3): 
    plt.plot(position, weights[:,i], label='topo'+str(i+1))

plt.legend()

**SUMMARY**: Similar to a neutral process such as incomplete lineage sorting, introgression can lead to widespread heterogeneity in coalescent histories across a genome. Reconstructing a species tree can already be challenging for groups where a lot of ILS can be expected given their diversification history, and can be further complicated if introgression also plays a role. In fact, it has been shown that in some groups the most common gene tree is actually incongruent with the species tree.

Finally, in the above model `recombination rate` is equal across the chromosomal segment. However, we know that this often does not match biological reality. Recombination is highly hetereogeneous within and across chromosomes, and is another important factor determining the likelihood that a foreign haplotype block can become fixed in a population. However, we leave such simulations for another exercise! :)