# Coevolution Toy Estimation

The goal of this experiment is to determine whether our optimizer is able to recover a 400 x 400 coevolution rate matrix. In this experiment, we let the ground truth 400 x 400 rate matrix be the product of independent WAG matrices, which we call WAG_x_WAG (found at `WAG_x_WAG_matrix.txt`). We use 32 protein families for this experiment, each with 1024 sequences and 1024 sites. All sites evolve under coevolution (i.e. they form 512 pairs).

We assume access to the ground truth trees and ancestral states, so we are only testing the efficacy of our optimizer.

**QUESTION 1**: Is our optimizer able to recover the WAG_x_WAG matrix?

**ANSWER**: **Yes!** I still hacen't performed a rigorous quantitative analysis, but looking at the estimate (`coevolution_toy_estimation/pipeline_w_ancestral_states/Q2__32_families__None_seqs_None_sites_None_RM_None_angstrom__0.06_center_0.1_step_size_50_n_steps_False_outliers_1000.0_max_height_1000_max_path_height__1000_epochs/learned_matrix.txt`) we see that is it quite similar to the ground truth (`WAG_x_WAG_matrix.txt`). In particular, the optimizer wasn't told which entries of the matrix are 0, and yet it seems to have figures them out very well.


# Global parameters

In [None]:
experiment_rootdir = "coevolution_toy_estimation"
n_process = 32

# Imports

In [None]:
import sys
sys.path.append('../')
import os
import time
import logging
import numpy as np
import pandas as pd
import Phylo_util

if not os.path.exists(experiment_rootdir):
    os.makedirs(experiment_rootdir)

def init_logger():
    logger = logging.getLogger('phylo_correction')
    logger.setLevel(logging.DEBUG)
    fmt_str = "[%(asctime)s] - %(name)s - %(levelname)s - %(message)s"
    formatter = logging.Formatter(fmt_str)

    consoleHandler = logging.StreamHandler(sys.stdout)
    consoleHandler.setFormatter(formatter)
    logger.addHandler(consoleHandler)

    fileHandler = logging.FileHandler("Phylo-correction.log")
    fileHandler.setFormatter(formatter)
    logger.addHandler(fileHandler)

init_logger()

# First we simulate realistic *ground truth trees*

In [None]:
from src.phylogeny_generation import PhylogenyGenerator
from src.simulation import Simulator
from src.pipeline import Pipeline

a3m_dir = '../test_input_data/a3m_32_families'
pdb_dir = '../test_input_data/pdb_32_families'
ground_truth_tree_dir = f'{experiment_rootdir}/trees_ground_truth'

def simulate_ground_truth_trees():
    ground_truth_phylogeny_generator = PhylogenyGenerator(
        a3m_dir=a3m_dir,
        n_process=n_process,
        expected_number_of_MSAs=32,
        outdir=ground_truth_tree_dir,
        max_seqs=1024,
        max_sites=1024,
        rate_matrix='../input_data/synthetic_rate_matrices/WAG_FastTree.txt',
        use_cached=True,
        max_families=32,
    )
    ground_truth_phylogeny_generator.run()

simulate_ground_truth_trees()

# Simulate MSAs on the ground truth trees using WAG x WAG. (We will try to recover WAG x WAG from this data!)

In [None]:
# Create WAG x WAG matrix
WAG_x_WAG_matrix_path = f"{experiment_rootdir}/WAG_x_WAG_matrix.txt"

def create_WAG_x_WAG_matrix():
    import numpy as np
    import pandas as pd
    WAG_matrix = pd.read_csv("../input_data/synthetic_rate_matrices/WAG_matrix.txt", sep="\t", index_col=0, keep_default_na=False, na_values=[""])

    amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
    pairs_of_aa = []
    for aa1 in amino_acids:
        for aa2 in amino_acids:
            pairs_of_aa.append(aa1 + aa2)
    WAG_x_WAG_matrix = pd.DataFrame(np.zeros(shape=(400, 400)), index=pairs_of_aa, columns=pairs_of_aa)

    for aa1 in amino_acids:
        for aa2 in amino_acids:
            for aa3 in amino_acids:
                for aa4 in amino_acids:
                    if (aa1 == aa3) and (aa2 == aa4):
                        WAG_x_WAG_matrix.loc[aa1 + aa2, aa1 + aa2] = WAG_matrix.loc[aa1, aa1] + WAG_matrix.loc[aa2, aa2]
                    elif (aa1 == aa3):
                        WAG_x_WAG_matrix.loc[aa1 + aa2, aa1 + aa4] = WAG_matrix.loc[aa2, aa4]
                    elif (aa2 == aa4):
                        WAG_x_WAG_matrix.loc[aa1 + aa2, aa3 + aa2] = WAG_matrix.loc[aa1, aa3]
    WAG_x_WAG_matrix.to_csv(WAG_x_WAG_matrix_path, sep="\t")

create_WAG_x_WAG_matrix()

In [None]:
a3m_simulated_dir = f'{experiment_rootdir}/a3m_simulated'
contact_simulated_dir = f'{experiment_rootdir}/contacts_simulated'
ancestral_states_simulated_dir = f'{experiment_rootdir}/ancestral_states_simulated'

def simulate_ground_truth_MSAs():
    ground_truth_MSA_simulator = Simulator(
        a3m_dir=a3m_dir,
        tree_dir=ground_truth_tree_dir,
        a3m_simulated_dir=a3m_simulated_dir,
        contact_simulated_dir=contact_simulated_dir,
        ancestral_states_simulated_dir=ancestral_states_simulated_dir,
        n_process=n_process,
        expected_number_of_MSAs=32,
        max_families=32,
        simulation_pct_interacting_positions=1.0,  # So that ALL positions evolve under the WAG x WAG matrix.
        Q1_ground_truth="../input_data/synthetic_rate_matrices/WAG_matrix.txt", # Doesn't matter
        Q2_ground_truth=WAG_x_WAG_matrix_path,
        use_cached=True,
    )
    ground_truth_MSA_simulator.run()

simulate_ground_truth_MSAs()

In [None]:
def run_pipeline_w_ancestral_states():
    pipeline = Pipeline(
        outdir=f"{experiment_rootdir}/pipeline_w_ancestral_states",
        max_seqs=None,
        max_sites=None,
        armstrong_cutoff=None,
        rate_matrix=None,
        n_process=32,
        expected_number_of_MSAs=32,
        max_families=32,
        a3m_dir=a3m_simulated_dir,
        pdb_dir=None,
        use_cached=True,
        num_epochs=1000,
        device='cpu',
        center=0.06,
        step_size=0.1,
        n_steps=50,
        keep_outliers=False,
        max_height=1000.0,
        max_path_height=1000,
        precomputed_contact_dir=contact_simulated_dir,
        precomputed_tree_dir=None,
        precomputed_maximum_parsimony_dir=ancestral_states_simulated_dir,
        learn_pairwise_model=True,
    )
    pipeline.run()

run_pipeline_w_ancestral_states()