# Experiment: Iterating the Pipeline to Recover WAG from nothing.

The goal of this experiment is to determine _if_ and _how many_ iteration of our Pipeline are required to recover the WAG matrix (../input_data/synthetic_rate_matrices/WAG_FastTree.txt) from simulated data, starting from the (terrible estimate of the) uniform rate matrix (which can be found at ../input_data/synthetic_rate_matrices/Q1_uniform_FastTree.txt).
Each iteration thus consists of:
- (Re-)estimating the trees with FastTree, using the current estimate of the rate matrix.
- (Re-)estimating the rate matrix.

We compare this procedure againt running the pipeline once, using WAG to estimate the trees (i.e. the "oracle access" case).

Here are the questions that we answer:

**QUESTION 1**: Starting from the uniform rate matrix, does the sequence of estimates Q1_1, Q1_2, Q1_3, ... converge to the WAG matrix (../input_data/synthetic_rate_matrices/WAG_FastTree.txt)?

**ANSWER**: **No**. In fact, by looking at the sequence of estimates Q1_1, Q1_2, Q1_3, ... we see that they shrink to 0! My explanation is the following: maximum parsimony bias causes us to under-estimate the transition rates. As a consequence, when the trees are re-esimated, their branch lengths are estimated as longer than before. But then the re-estimated rate matrix adapts to the longer branch lengths by having even smaller rates. Induction suggests that the process looks like this: Q1_K = Q1_1 x bias ^ (K-1) where bias < 1 is the maximum parsimony multiplicative bias; similarly tree_K = tree_1 x bias ^ (K-1). This feedback loop makes the estimates shrink to zero. This phenomenon does not depend on the choice of the initial rate matrix used in FastTree: it is an unavoidable consequence of iterating a downward-biased estimator.

**QUESTION 2**: Starting from the uniform rate matrix, does the sequence of **normalized** estimates Q1_1_normalized, Q1_2_normalized, Q1_3_normalized, ... converge to the **normalized** WAG matrix (WAG_FastTree_normalized)? (Here Q1_K_normalized is defined as the normalized version of Q1_K, such that the mutation rate under stationarity is 1; similarly for WAG_FastTree_normalized)

**ANSWER**: **Pretty much yes!** Indeed, the sequence Q1_1_normalized, Q1_2_normalized, Q1_3_normalized, ... gets quite close to WAG_FastTree_normalized. In fact, to spoil the fun: it seems to converge in 1 step (see the next question).

**QUESTION 3**: Starting from the uniform rate matrix, _how fast_ does the sequence of normalized estimates Q1_1_normalized, Q1_2_normalized, Q1_3_normalized, ... converge to the normalized WAG matrix (WAG_FastTree_normalized)?

**ANSWER**: **In 1 step**. Indeed, Q1_1_normalized, Q1_2_normalized, Q1_3_normalized, ... are all almost identical, and close to WAG_FastTree_normalized.

**QUESTION 4**: Okay, so it looks like our pipeline is able to recover normalized WAG (WAG_FastTree_normalized) after 1 step starting from a garbage estimate. Does this mean that the rate matrix used in FastTree does not matter?

**ANSWER**: **Yes**, it looks like the rate matrix used in FastTree barely matters. Using the uniform rate matrix (../input_data/synthetic_rate_matrices/Q1_uniform_FastTree.txt) we obtain the estimate Q1_1_normalized, while using the WAG matrix (../input_data/synthetic_rate_matrices/WAG_FastTree.txt) we obtain the estimate Q1_with_WAG_FastTree_normalized, which are both EXTREMELY similar, and both quite close to the ground truth WAG_FastTree_normalized.

**QUESTION 5**: Okay, so the rate matrix used in FastTree does not seem to matter much. Is this because the reconstructed phylogenies are very similar?

**ANSWER**: **Yes**. Surprisingly, it seems that reconstructing trees with the uniform rate matrix (../input_data/synthetic_rate_matrices/Q1_uniform_FastTree.txt) or with WAG (../input_data/synthetic_rate_matrices/WAG_FastTree.txt) leads to very similar tree shapes (e.g. see repetition_1/trees/1h75_1_A.newick and using_WAG_FastTree/trees/1h75_1_A.newick respectively). These are, however, clearly different from the ground truth! (which is trees_ground_truth/1h75_1_A.newick).

# Global parameters

In [None]:
experiment_rootdir = "WAG_recovery_from_nothing"
n_process = 32

# Imports

In [None]:
import sys
sys.path.append('../')
import os
import time
import logging
import numpy as np
import pandas as pd
import Phylo_util

if not os.path.exists(experiment_rootdir):
    os.makedirs(experiment_rootdir)

def init_logger():
    logger = logging.getLogger('phylo_correction')
    logger.setLevel(logging.DEBUG)
    fmt_str = "[%(asctime)s] - %(name)s - %(levelname)s - %(message)s"
    formatter = logging.Formatter(fmt_str)

    consoleHandler = logging.StreamHandler(sys.stdout)
    consoleHandler.setFormatter(formatter)
    logger.addHandler(consoleHandler)

    fileHandler = logging.FileHandler("Phylo-correction.log")
    fileHandler.setFormatter(formatter)
    logger.addHandler(fileHandler)

init_logger()

# First we simulate realistic *ground truth trees*

In [None]:
from src.phylogeny_generation import PhylogenyGenerator
from src.simulation import Simulator
from src.pipeline import Pipeline

a3m_dir = '../test_input_data/a3m_32_families'
pdb_dir = '../test_input_data/pdb_32_families'
ground_truth_tree_dir = f'{experiment_rootdir}/trees_ground_truth'

def simulate_ground_truth_trees():
    ground_truth_phylogeny_generator = PhylogenyGenerator(
        a3m_dir=a3m_dir,
        n_process=n_process,
        expected_number_of_MSAs=32,
        outdir=ground_truth_tree_dir,
        max_seqs=1024,
        max_sites=1024,
        rate_matrix='../input_data/synthetic_rate_matrices/WAG_FastTree.txt',
        use_cached=True,
        max_families=32,
    )
    ground_truth_phylogeny_generator.run()

simulate_ground_truth_trees()

# Simulate MSAs on the ground truth trees using WAG. (We will try to recover WAG from this data!)

In [None]:
a3m_simulated_dir = f'{experiment_rootdir}/a3m_simulated'
contact_simulated_dir = f'{experiment_rootdir}/doesnt_matter'
ancestral_states_simulated_dir = f'{experiment_rootdir}/doesnt_matter_either'

def simulate_ground_truth_MSAs():
    ground_truth_MSA_simulator = Simulator(
        a3m_dir=a3m_dir,
        tree_dir=ground_truth_tree_dir,
        a3m_simulated_dir=a3m_simulated_dir,
        contact_simulated_dir=contact_simulated_dir,
        ancestral_states_simulated_dir=ancestral_states_simulated_dir,
        n_process=n_process,
        expected_number_of_MSAs=32,
        max_families=32,
        simulation_pct_interacting_positions=0.0,  # So that ALL positions evolve under the WAG matrix.
        Q1_ground_truth="../input_data/synthetic_rate_matrices/WAG_matrix.txt",
        Q2_ground_truth="../input_data/synthetic_rate_matrices/Q2_uniform_constrained.txt",  # Doesn't matter
        use_cached=True,
    )
    ground_truth_MSA_simulator.run()

simulate_ground_truth_MSAs()

# Now, starting from a uniform guess for the transition rate matrix, repeatedly (re-)infer the trees and (re-)infer the rate matrix. Will we converge to WAG? How many iterations will we need? Let's find out!

In [None]:
def to_fast_tree_format(rate_matrix: np.array, output_path: str):
        r"""
        Writes out 'rate_matrix' to 'output_path' in FastTree 20 x 21 format, column-stochastic.
        """
        amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
        rate_matrix_df = pd.DataFrame(rate_matrix, index=amino_acids, columns=amino_acids)
        pi = Phylo_util.solve_stationery_dist(rate_matrix)
        rate_matrix_df = rate_matrix_df.transpose()
        rate_matrix_df['*'] = pi
        with open(output_path, "w") as outfile:
            for aa in amino_acids:
                outfile.write(aa + "\t")
            outfile.write("*\n")
        rate_matrix_df.to_csv(output_path, sep="\t", header=False, mode='a')

def iterate_pipeline_starting_from_uniform_rate_matrix():
    rate_matrix_estimates = ["../input_data/synthetic_rate_matrices/Q1_uniform_FastTree.txt"]  # We start with the uniform guess.

    for repetition in range(1, 6, 1):
        pipeline = Pipeline(
            outdir=f"{experiment_rootdir}/repetition_{repetition}",
            max_seqs=1024,
            max_sites=1024,
            armstrong_cutoff=8.0,
            rate_matrix=rate_matrix_estimates[-1],  # We use the last estimate to build the trees!
            n_process=32,
            expected_number_of_MSAs=32,
            max_families=32,
            a3m_dir=a3m_simulated_dir,
            pdb_dir=pdb_dir,
            use_cached=True,
            num_epochs=2000,
            device='cpu',
            center=0.06,
            step_size=0.1,
            n_steps=50,
            keep_outliers=False,
            max_height=1000.0,
            max_path_height=1000,
            precomputed_contact_dir=None,
            precomputed_tree_dir=None,
            precomputed_maximum_parsimony_dir=None,
        )
        pipeline.run()
        # Write out the learned rate matrix to Q1_K_normalize
        learned_rate_matrix = np.loadtxt(
            os.path.join(pipeline.learnt_rate_matrix_dir, "learned_matrix.txt")
        )
        learned_rate_matrix_path = f'{experiment_rootdir}/Q1_{repetition}'
        to_fast_tree_format(learned_rate_matrix, output_path=learned_rate_matrix_path)
        rate_matrix_estimates.append(learned_rate_matrix_path)

        # Write out the normalized learned rate matrix: Q1_K_normalized
        pi = Phylo_util.solve_stationery_dist(learned_rate_matrix)
        mutation_rate = pi @ -np.diag(learned_rate_matrix)
        normalized_learned_rate_matrix = learned_rate_matrix / mutation_rate
        normalized_learned_rate_matrix_path = f'{experiment_rootdir}/Q1_{repetition}_normalized'
        to_fast_tree_format(normalized_learned_rate_matrix, output_path=normalized_learned_rate_matrix_path)

iterate_pipeline_starting_from_uniform_rate_matrix()

# Let's see what we would have gotten in 1 iteration if we had used the WAG matrix to reconstruct the trees (the "oracle access" case).

In [None]:
def run_pipeline_using_WAG_for_FastTree():
    pipeline = Pipeline(
        outdir=f"{experiment_rootdir}/using_WAG_FastTree",
        max_seqs=1024,
        max_sites=1024,
        armstrong_cutoff=8.0,
        rate_matrix='../input_data/synthetic_rate_matrices/WAG_FastTree.txt',  # We use the ground truth in FastTree!
        n_process=32,
        expected_number_of_MSAs=32,
        max_families=32,
        a3m_dir=a3m_simulated_dir,
        pdb_dir=pdb_dir,
        use_cached=True,
        num_epochs=2000,
        device='cpu',
        center=0.06,
        step_size=0.1,
        n_steps=50,
        keep_outliers=False,
        max_height=1000.0,
        max_path_height=1000,
        precomputed_contact_dir=None,
        precomputed_tree_dir=None,
        precomputed_maximum_parsimony_dir=None,
    )
    pipeline.run()
    # Write out the learned rate matrix to Q1_with_WAG_FastTree
    learned_rate_matrix = np.loadtxt(
        os.path.join(pipeline.learnt_rate_matrix_dir, "learned_matrix.txt")
    )
    learned_rate_matrix_path = f'{experiment_rootdir}/Q1_with_WAG_FastTree'
    to_fast_tree_format(learned_rate_matrix, output_path=learned_rate_matrix_path)

    # Write out the normalized learned rate matrix: Q1_with_WAG_FastTree_normalized
    pi = Phylo_util.solve_stationery_dist(learned_rate_matrix)
    mutation_rate = pi @ -np.diag(learned_rate_matrix)
    normalized_learned_rate_matrix = learned_rate_matrix / mutation_rate
    normalized_learned_rate_matrix_path = f'{experiment_rootdir}/Q1_with_WAG_FastTree_normalized'
    to_fast_tree_format(normalized_learned_rate_matrix, output_path=normalized_learned_rate_matrix_path)

run_pipeline_using_WAG_for_FastTree()

# Compute normalized WAG matrix (WAG_FastTree_normalized)

In [None]:
def compute_normalized_WAG_matrix():
    WAG_matrix = np.array(pd.read_csv("../input_data/synthetic_rate_matrices/WAG_matrix.txt",sep="\t").iloc[:20, 1:21])
    pi = Phylo_util.solve_stationery_dist(WAG_matrix)
    mutation_rate = pi @ -np.diag(WAG_matrix)
    normalized_WAG_matrix = WAG_matrix / mutation_rate
    normalized_WAG_matrix_path = f'{experiment_rootdir}/WAG_FastTree_normalized'
    to_fast_tree_format(normalized_WAG_matrix, output_path=normalized_WAG_matrix_path)

compute_normalized_WAG_matrix()