# `triple-threat` proof of concept
Proof of concept for a triple-fitting, cherry-picking, tree-building algorithm.

In [None]:
from itertools import combinations
from collections import Counter, defaultdict

from cogent3 import load_aligned_seqs, PhyloNode
import numpy as np

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

tf.executing_eagerly()  # need to check whether this is the default for tensorflow > 2

In [None]:
# this stops tensorflow from snaffling all of the GPU
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

## Data import
Reads an alignments and creates a list of 4 x 4 x 4 joint frequencies tensors.

In [None]:
def get_triples(aln, nuc_order='ACGT', codon_position=None, verbose=False):
    if codon_position:
        aln = aln[codon_position - 1::3]
    aln = aln.no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

## Triple fitting functions
Collection of functions for concurrent fitting of many triples on CPUs and GPUs. Model is rooted, continuous-time, and strand-symmetric.

Also some functions for using Akaike-ish weights to find the pair that is most probably a cherry.

In [None]:
@tf.function()
def transform_P_matrix(params):
    params = tf.exp(params)
    Q0 = tf.concat([[-tf.reduce_sum(params[0])], params[0]],
                   axis=0)
    Q1 = tf.concat([[params[1,0]], [-tf.reduce_sum(params[1])], params[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q)

@tf.function()
def transform(params):
    pi = tfb.SoftmaxCentered()(params[0])
    Pa = transform_P_matrix(params[1:3])
    Pm = transform_P_matrix(params[3:5])
    Pb = transform_P_matrix(params[5:7])
    Pc = transform_P_matrix(params[7:9])
    return pi, Pa, Pm, Pb, Pc
    
@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pm, Pb, Pc = transform(params)
    J = tf.einsum('i,ij,ik,ku,kv', pi, Pa, Pm, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
    return loss
    
@tf.function()
def loss(params, data):
    # could do better managing the variance in the shared matrix case here
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def training_step(parameters, data, optimizer, unscrambler):
    with tf.GradientTape() as tape:
        unscrambled = _unscramble(parameters, unscrambler)
        loss_value = loss(unscrambled, data)
    gradients = tape.gradient(loss_value, parameters)
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, unscrambler, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer, unscrambler)
        optimizer.apply_gradients([(gradients, parameters)])
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

def check_cherry(ls, N):
    ls = N*ls
    delta = ls - ls.min()
    weights = np.exp(-delta)
    return weights/weights.sum()

@tf.function()
def _unscramble(parameters, unscramble):
    unscrambled = []
    for t1 in unscramble:
        unscrambled.append(tf.stack([parameters[i] for i in t1]))
    return tf.stack(unscrambled)

def fit_triples(triples, learning_rate=0.01, cherries_share_matrices=True, steps=3000, verbose=False):
    K = 0
    cherry_loc = {}
    unscrambler = []
    data = []
    for names, F in triples:
        J = (F/F.sum()).astype(np.float32)
        for ix in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
            t1 = list(range(K, K+5)) # for pi and two Ps
            K += 5
            cherry = [names[ix[1]], names[ix[2]]]
            frozen_cherry = frozenset(cherry)
            if not cherries_share_matrices or frozen_cherry not in cherry_loc:
                new_loc = {cherry[0]: [K,K+1], cherry[1]: [K+2,K+3]}
                K += 4
                cherry_loc[frozen_cherry] = new_loc
            t1.extend(cherry_loc[frozen_cherry][cherry[0]]) # for the first cherry
            t1.extend(cherry_loc[frozen_cherry][cherry[1]]) # for the second cherry
            unscrambler.append(t1)
        
            data.append(J.transpose(ix))

    normal_initializer = tf.random_normal_initializer()
    parameters = tf.Variable(normal_initializer(shape=[K, 3], dtype=tf.float32), name='params')
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, unscrambler, steps=steps, verbose=verbose)
    
    parameters = _unscramble(parameters, unscrambler)
    losses = tf.vectorized_map(_loss, (parameters, data)).numpy()
    losses = [losses[i:i+3] for i in range(0, len(losses), 3)]
    fits = [[p.numpy() for p in transform(params)] for params in parameters]
    fits = [fits[i:i+3] for i in range(0, len(fits), 3)]
    return losses, fits

def pick_cherry(root_probs):
    cherry_llik = Counter()
    for probs in root_probs:
        names = frozenset(probs.keys())
        for name in names:
            cherry_llik[names - {name}] += np.log(probs[name])
    return cherry_llik.most_common()

def get_cherries(triples, losses, verbose=False):
    root_probs = []
    for losses, (names, F) in zip(losses, triples):
        probs = check_cherry(losses, F.sum())
        root_probs.append(dict(zip(names, probs)))
    
    cherry_llik = pick_cherry(root_probs)
    if verbose:
        for (n1, n2), ll in cherry_llik:
            print(f'{n1}, {n2}: {ll}')
    
    return list(cherry_llik[0][0])

def get_Ps(cherry_names, triples, fits):
    Ps = {}
    ixes = [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
    for (names, _), triple_fit in zip(triples, fits):
        if set(cherry_names) < set(names):
            for name, ix, fit in zip(names, ixes, triple_fit):
                if name not in cherry_names:
                    if names[ix[1]] == cherry_names[0]:
                        return fit[-2], fit[-1]
                    return fit[-1], fit[-2]

## Merge-step functions
Functions that merge pairs of of triples to create merged triples. Concurrently creates joint probability matrices for all merged triples.

In [None]:
def merge_cherries(triples, cherries, learning_rate=0.01, steps=3000, info_decay=0.9, verbose=False):
    keepers = []
    to_be_merged = defaultdict(lambda: [None, None])
    cherry_names, cherry_Ps = cherries
    for names, F in triples:
        num_uncommon = len(set(names) - set(cherry_names))
        if num_uncommon == 3:
            keepers.append((names, F))
        elif num_uncommon == 2:
            uncommon = frozenset(names) - frozenset(cherry_names)
            name = (set(names) - uncommon).pop()
            new_names = sorted(uncommon) + [name]
            ix = [names.index(n) for n in new_names]
            J = F.transpose(ix).astype(np.float32)
            J /= J.sum()
            N = F.sum()  # make no mistake, this is a kludge
            to_be_merged[tuple(new_names[:2])][name == cherry_names[1]] = J

    @tf.function()
    def transform(params):
        return tf.reshape(tfb.SoftmaxCentered()(params), (4, 4, 4))
    
    @tf.function()
    def _loss(params_data):
        params, data = params_data
        J = transform(params)
        twoJs = [J @ cherry_Ps[0], J @ cherry_Ps[1]]
        # could do better managing the variance here
        loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
        return loss
    
    @tf.function()
    def loss(params, data):
        return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

    @tf.function()
    def training_step(parameters, data, optimizer):
        with tf.GradientTape() as tape:
            loss_value = loss(parameters, data)
        gradients = tape.gradient(loss_value, parameters)
        return loss_value, gradients
    
    def mle_fit(data, loss, parameters, optimizer, steps, verbose):
        for i in range(steps):
            loss_value, gradients = training_step(parameters, data, optimizer)
            optimizer.apply_gradients([(gradients, parameters)])

            if i % 100 == 0:
                if verbose:
                    iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                    print(iter_info)
    
    normal_initializer = tf.random_normal_initializer()
    K = len(to_be_merged)
    parameters = tf.Variable(normal_initializer(shape=[K, 63], dtype=tf.float32), name='params')
    data = tf.stack(list(to_be_merged.values()))
    
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
    for names, params in zip(to_be_merged, parameters.numpy()):
        keepers.append((list(names) + ['-'.join(cherry_names)], info_decay*N*transform(params).numpy()))
        
    return keepers

## Tree building algorithm
Functions for using cherry-picking to build trees.

In [None]:
def update_tree(forest, cherries):
    new_node = PhyloNode('-'.join(cherries[0]))
    for name, P in zip(*cherries):
        child = forest[name]
        child.P = P
        new_node.append(child)
        del forest[name]
    forest[new_node.name] = new_node

def triple_threat(triples, cherries_share_matrices=False,
                  learning_rate=0.01, steps=3000, info_decay=0.9, verbose=False):
    triples = list(triples)
    forest = {n: PhyloNode(n) for names, _ in triples for n in names}
    while True:
        if verbose:
            print('Looking for cherries')
        losses, fits = fit_triples(triples, cherries_share_matrices=cherries_share_matrices,
                                   learning_rate=learning_rate, steps=steps, verbose=verbose)
        cherry_names = get_cherries(triples, losses, verbose=verbose)
        if verbose:
            print('Fitting cherries')
        if not cherries_share_matrices:
            triples_for_cherry = [t for t in triples if set(cherry_names) < set(t[0])]
            _, fits = fit_triples(triples_for_cherry, cherries_share_matrices=True,
                                  learning_rate=learning_rate, steps=steps, verbose=verbose)
            cherry_Ps = get_Ps(cherry_names, triples_for_cherry, fits)
        else:
            cherry_Ps = get_Ps(cherry_names, triples, fits)
        cherries = cherry_names, cherry_Ps
        update_tree(forest, cherries)
        if verbose:
            for tree in forest.values():
                print(tree)
        if len(triples) == 1:
            break
        if verbose:
            print('Merging cherries')
        triples = merge_cherries(triples, cherries, learning_rate=learning_rate, steps=steps, 
                                 info_decay=info_decay, verbose=verbose)
    tree = PhyloNode('-'.join(forest.keys()))
    tree.pi = fits[0][0][0]
    for name, child in forest.items():
        if name in triples[0][0]:
            fit = fits[0][triples[0][0].index(name)]
            Pa = fit[1]
            Pm = fit[2]
    for name, child in forest.items():
        child.P = Pa if name in triples[0][0] else Pm
        tree.append(child)
    return tree

# Some examples
## Example 1
Fit a rooted phylogeny of 5 mammals.

In [None]:
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000197102.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000179869.fa.gz', moltype="dna")
aln = load_aligned_seqs('brca1.fasta', moltype='dna')

In [None]:
subaln = aln.get_similar(aln.take_seqs(['Human']).seqs[0],
                      min_similarity=0.83)
subaln

### All at once
First run `triple-threat` all the way through. Run time is not fantastic for five taxa, but it also fits non-stationary models to all edges, and is expected to scale well.

In [None]:
%%time
triples = get_triples(subaln, codon_position=3, verbose=True)
tree = triple_threat(triples, cherries_share_matrices=False, info_decay=0.9, verbose=True)

In [None]:
fig = tree.get_figure("radial", width=600, height=600)
fig.show()

## Example 2
### A single iteration
Now run through a single iteration of the algorithm.
#### Fit triples
This step fits rooted, strand-symmetric, continuous-time models to every taxon triple.

In [None]:
%%time 
triples = get_triples(aln, verbose=True)
losses, fits = fit_triples(triples, cherries_share_matrices=False, verbose=True)

### Find the most probably cherry
This step finds the pair of taxa that are most likely to be a cherries using model selection calculations, then simultaneously fits that cherry across all of the triples of which it is a part to estimate its rate matrices.

In [None]:
cherry_names = get_cherries(triples, losses, verbose=True)
triples_for_cherry = [t for t in triples if set(cherry_names) < set(t[0])]
_, fits = fit_triples(triples_for_cherry, cherries_share_matrices=True, verbose=True)

This step extracts the simultaneously fitted transition probability matrices for the cherry taxa.

In [None]:
cherry_Ps = get_Ps(cherry_names, triples_for_cherry, fits)

Test a single agglomeration step.

In [None]:
forest = {}
cherries = cherry_names, cherry_Ps
update_tree(forest, cherries)

These were the cherry taxa and their corresponding transition probability matrices.

In [None]:
cherries

In [None]:
forest

These were the data that were fitted simultaneously to find the above rate matrics for Dog and Human.

In [None]:
triples_for_cherry

### Merge the cherries
Merge the cherries by fitting the triples between the unmerged taxa and the new merged node, using the probability matrices that we fitted above.

In [None]:
%%time
triples = merge_cherries(triples, cherries, verbose=True)

At the end of the first iteration, the triples now consist of unmerged taxa and the newly created merged node (called "Dog-Human").

In [None]:
triples