# `triple-threat` proof of concept
Proof of concept for a triple-fitting, cherry-picking, tree-building algorithm.

In [1]:
from itertools import combinations
from collections import Counter, defaultdict

from cogent3 import load_aligned_seqs, PhyloNode
import numpy as np

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

tf.executing_eagerly()  # need to check whether this is the default for tensorflow > 2

2021-09-02 14:15:42.661246: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


True

In [3]:
# this stops tensorflow from snaffling all of the GPU
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

2021-09-02 14:15:44.238394: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-09-02 14:15:44.279961: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-02 14:15:44.280270: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 with Max-Q Design computeCapability: 7.5
coreClock: 1.095GHz coreCount: 46 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 357.69GiB/s
2021-09-02 14:15:44.280292: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-09-02 14:15:44.281452: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-09-02 14:15:44.282516: I tensorflow/s

## Data import
Reads an alignments and creates a list of 4 x 4 x 4 joint frequencies tensors.

In [4]:
def get_triples(aln, nuc_order='ACGT', codon_position=3, verbose=False):
    aln = aln[codon_position - 1::3].no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

## Triple fitting functions
Collection of functions for concurrent fitting of many triples on CPUs and GPUs. Model is rooted, continuous-time, and strand-symmetric.

Also some functions for using Akaike-ish weights to find the pair that is most probably a cherry.

In [5]:
@tf.function()
def transform_P_matrix(params):
    params = tf.exp(params)
    Q0 = tf.concat([[-tf.reduce_sum(params[0])], params[0]],
                   axis=0)
    Q1 = tf.concat([[params[1,0]], [-tf.reduce_sum(params[1])], params[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q)

@tf.function()
def transform(params):
    pi = tfb.SoftmaxCentered()(params[0])
    Pa = transform_P_matrix(params[1:3])
    Pm = transform_P_matrix(params[3:5])
    Pb = transform_P_matrix(params[5:7])
    Pc = transform_P_matrix(params[7:9])
    return pi, Pa, Pm, Pb, Pc
    
@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pm, Pb, Pc = transform(params)
    J = tf.einsum('i,ij,ik,ku,kv', pi, Pa, Pm, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
    return loss
    
@tf.function()
def loss(params, data):
    # could do better managing the variance in the shared matrix case here
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def training_step(parameters, data, optimizer, unscrambler):
    with tf.GradientTape() as tape:
        unscrambled = _unscramble(parameters, unscrambler)
        loss_value = loss(unscrambled, data)
    gradients = tape.gradient(loss_value, parameters)
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, unscrambler, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer, unscrambler)
        optimizer.apply_gradients([(gradients, parameters)])
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

def check_cherry(ls, N):
    ls = N*ls
    delta = ls - ls.min()
    weights = np.exp(-delta)
    return weights/weights.sum()

@tf.function()
def _unscramble(parameters, unscramble):
    unscrambled = []
    for t1 in unscramble:
        unscrambled.append(tf.stack([parameters[i] for i in t1]))
    return tf.stack(unscrambled)

def fit_triples(triples, learning_rate=0.01, cherries_share_matrices=True, steps=3000, verbose=False):
    K = 0
    cherry_loc = {}
    unscrambler = []
    data = []
    for names, F in triples:
        J = (F/F.sum()).astype(np.float32)
        for ix in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
            t1 = list(range(K, K+5)) # for pi and two Ps
            K += 5
            cherry = [names[ix[1]], names[ix[2]]]
            frozen_cherry = frozenset(cherry)
            if not cherries_share_matrices or frozen_cherry not in cherry_loc:
                new_loc = {cherry[0]: [K,K+1], cherry[1]: [K+2,K+3]}
                K += 4
                cherry_loc[frozen_cherry] = new_loc
            t1.extend(cherry_loc[frozen_cherry][cherry[0]]) # for the first cherry
            t1.extend(cherry_loc[frozen_cherry][cherry[1]]) # for the second cherry
            unscrambler.append(t1)
        
            data.append(J.transpose(ix))

    normal_initializer = tf.random_normal_initializer()
    parameters = tf.Variable(normal_initializer(shape=[K, 3], dtype=tf.float32), name='params')
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, unscrambler, steps=steps, verbose=verbose)
    
    parameters = _unscramble(parameters, unscrambler)
    losses = tf.vectorized_map(_loss, (parameters, data)).numpy()
    losses = [losses[i:i+3] for i in range(0, len(losses), 3)]
    fits = [[p.numpy() for p in transform(params)] for params in parameters]
    fits = [fits[i:i+3] for i in range(0, len(fits), 3)]
    return losses, fits

def pick_cherry(root_probs):
    cherry_llik = Counter()
    for probs in root_probs:
        names = frozenset(probs.keys())
        for name in names:
            cherry_llik[names - {name}] += np.log(probs[name])
    return cherry_llik.most_common()

def get_cherries(triples, losses, verbose=False):
    root_probs = []
    for losses, (names, F) in zip(losses, triples):
        probs = check_cherry(losses, F.sum())
        root_probs.append(dict(zip(names, probs)))
    
    cherry_llik = pick_cherry(root_probs)
    if verbose:
        for (n1, n2), ll in cherry_llik:
            print(f'{n1}, {n2}: {ll}')
    
    return list(cherry_llik[0][0])

def get_Ps(cherry_names, triples, fits):
    Ps = {}
    ixes = [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
    for (names, _), triple_fit in zip(triples, fits):
        if set(cherry_names) < set(names):
            for name, ix, fit in zip(names, ixes, triple_fit):
                if name not in cherry_names:
                    if names[ix[1]] == cherry_names[0]:
                        return fit[-2], fit[-1]
                    return fit[-1], fit[-2]

## Merge-step functions
Functions that merge pairs of of triples to create merged triples. Concurrently creates joint probability matrices for all merged triples.

In [6]:
def merge_cherries(triples, cherries, learning_rate=0.01, steps=3000, verbose=False):
    keepers = []
    to_be_merged = defaultdict(lambda: [None, None])
    cherry_names, cherry_Ps = cherries
    for names, F in triples:
        num_uncommon = len(set(names) - set(cherry_names))
        if num_uncommon == 3:
            keepers.append((names, F))
        elif num_uncommon == 2:
            uncommon = frozenset(names) - frozenset(cherry_names)
            name = (set(names) - uncommon).pop()
            new_names = sorted(uncommon) + [name]
            ix = [names.index(n) for n in new_names]
            J = F.transpose(ix).astype(np.float32)
            J /= J.sum()
            N = F.sum()  # make no mistake, this is a kludge
            to_be_merged[tuple(new_names[:2])][name == cherry_names[1]] = J

    @tf.function()
    def transform(params):
        return tf.reshape(tfb.SoftmaxCentered()(params), (4, 4, 4))
    
    @tf.function()
    def _loss(params_data):
        params, data = params_data
        J = transform(params)
        twoJs = [J @ cherry_Ps[0], J @ cherry_Ps[1]]
        # could do better managing the variance here
        loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
        return loss
    
    @tf.function()
    def loss(params, data):
        return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

    @tf.function()
    def training_step(parameters, data, optimizer):
        with tf.GradientTape() as tape:
            loss_value = loss(parameters, data)
        gradients = tape.gradient(loss_value, parameters)
        return loss_value, gradients
    
    def mle_fit(data, loss, parameters, optimizer, steps, verbose):
        for i in range(steps):
            loss_value, gradients = training_step(parameters, data, optimizer)
            optimizer.apply_gradients([(gradients, parameters)])

            if i % 100 == 0:
                if verbose:
                    iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                    print(iter_info)
    
    normal_initializer = tf.random_normal_initializer()
    K = len(to_be_merged)
    parameters = tf.Variable(normal_initializer(shape=[K, 63], dtype=tf.float32), name='params')
    data = tf.stack(list(to_be_merged.values()))
    
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
    for names, params in zip(to_be_merged, parameters.numpy()):
        keepers.append((list(names) + ['-'.join(cherry_names)], N*transform(params).numpy()))
        
    return keepers

## Tree building algorithm
Functions for using cherry-picking to build trees.

In [42]:
def update_tree(forest, cherries):
    new_node = PhyloNode('-'.join(cherries[0]))
    for name, P in zip(*cherries):
        child = forest[name]
        child.P = P
        new_node.append(child)
        del forest[name]
    forest[new_node.name] = new_node

def triple_threat(triples, cherries_share_matrices=False, learning_rate=0.01, steps=3000, verbose=False):
    triples = list(triples)
    forest = {n: PhyloNode(n) for names, _ in triples for n in names}
    while True:
        if verbose:
            print('Looking for cherries')
        losses, fits = fit_triples(triples, cherries_share_matrices=cherries_share_matrices,
                                   learning_rate=learning_rate, steps=steps, verbose=verbose)
        cherry_names = get_cherries(triples, losses, verbose=verbose)
        if verbose:
            print('Fitting cherries')
        if not cherries_share_matrices:
            triples_for_cherry = [t for t in triples if set(cherry_names) < set(t[0])]
            _, fits = fit_triples(triples_for_cherry, cherries_share_matrices=True,
                                  learning_rate=learning_rate, steps=steps, verbose=verbose)
            cherry_Ps = get_Ps(cherry_names, triples_for_cherry, fits)
        else:
            cherry_Ps = get_Ps(cherry_names, triples, fits)
        cherries = cherry_names, cherry_Ps
        update_tree(forest, cherries)
        if verbose:
            for tree in forest.values():
                print(tree)
        if len(triples) == 1:
            break
        if verbose:
            print('Merging cherries')
        triples = merge_cherries(triples, cherries, learning_rate=learning_rate, steps=steps, verbose=verbose)
    tree = PhyloNode('-'.join(forest.keys()))
    tree.pi = fits[0][0][0]
    for name, child in forest.items():
        if name in triples[0][0]:
            fit = fits[0][triples[0][0].index(name)]
            Pa = fit[1]
            Pm = fit[2]
    for name, child in forest.items():
        child.P = Pa if name in triples[0][0] else Pm
        tree.append(child)
    return tree

# Some examples
## Example 1
Fit a rooted phylogeny of 5 mammals.

In [None]:
aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000197102.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000179869.fa.gz', moltype="dna")

In [44]:
aln

0,1
,0
Platypus,ATGTCGGAG------GGCGGCGGCGGCGAGGACGGCTCGTCGGGCCTGGAGGTGTCGGCG
Dog,.........CCGGGC........................G...................C
Human,.........CCCGGG........................G.C..AT....A........C
Mouse,.........CCG...---.....................G.C..................
Opossum,---------......---------------------------------------------


### All at once
First run `triple-threat` all the way through. Run time is not fantastic for five taxa, but it also fits non-stationary models to all edges, and is expected to scale well.

In [47]:
%%time
triples = get_triples(aln, verbose=True)
tree = triple_threat(triples, verbose=True)

Got 4406 positions
Looking for cherries
Step: 1, initial loss: 2.0050997734069824
Step: 101, initial loss: 1.699263334274292
Step: 201, initial loss: 1.1660088300704956
Step: 301, initial loss: 0.4601401388645172
Step: 401, initial loss: 0.34876924753189087
Step: 501, initial loss: 0.2980298101902008
Step: 601, initial loss: 0.27511167526245117
Step: 701, initial loss: 0.265436589717865
Step: 801, initial loss: 0.25224578380584717
Step: 901, initial loss: 0.2409302294254303
Step: 1001, initial loss: 0.23519524931907654
Step: 1101, initial loss: 0.2128434181213379
Step: 1201, initial loss: 0.20457565784454346
Step: 1301, initial loss: 0.18996067345142365
Step: 1401, initial loss: 0.17771275341510773
Step: 1501, initial loss: 0.16422908008098602
Step: 1601, initial loss: 0.14813873171806335
Step: 1701, initial loss: 0.1435757577419281
Step: 1801, initial loss: 0.14247237145900726
Step: 1901, initial loss: 0.1379510760307312
Step: 2001, initial loss: 0.1373562216758728
Step: 2101, initial

Mouse, Dog-Human: -0.7382480502128601
Opossum, Platypus: -1.7115052938461304
Opossum, Dog-Human: -2.054489850997925
Mouse, Opossum: -37.92973393201828
Mouse, Platypus: -41.13019585609436
Dog-Human, Platypus: -42.67369222640991
Fitting cherries
Step: 1, initial loss: 0.37203001976013184
Step: 101, initial loss: 0.31469833850860596
Step: 201, initial loss: 0.21630287170410156
Step: 301, initial loss: 0.07224678993225098
Step: 401, initial loss: 0.05269111320376396
Step: 501, initial loss: 0.04841533303260803
Step: 601, initial loss: 0.046584099531173706
Step: 701, initial loss: 0.04116344824433327
Step: 801, initial loss: 0.034670811146497726
Step: 901, initial loss: 0.031730230897665024
Step: 1001, initial loss: 0.031393181532621384
Step: 1101, initial loss: 0.031220776960253716
Step: 1201, initial loss: 0.031097780913114548
Step: 1301, initial loss: 0.03100006654858589
Step: 1401, initial loss: 0.03091934323310852
Step: 1501, initial loss: 0.030855314806103706
Step: 1601, initial loss:

Step: 1201, initial loss: 0.022450005635619164
Step: 1301, initial loss: 0.016299208626151085
Step: 1401, initial loss: 0.012470549903810024
Step: 1501, initial loss: 0.011936123482882977
Step: 1601, initial loss: 0.011813578195869923
Step: 1701, initial loss: 0.01175751257687807
Step: 1801, initial loss: 0.011725084856152534
Step: 1901, initial loss: 0.0117051862180233
Step: 2001, initial loss: 0.011692282743752003
Step: 2101, initial loss: 0.011683222837746143
Step: 2201, initial loss: 0.011676255613565445
Step: 2301, initial loss: 0.011670678853988647
Step: 2401, initial loss: 0.011665689758956432
Step: 2501, initial loss: 0.011660991236567497
Step: 2601, initial loss: 0.011653145775198936
Step: 2701, initial loss: 0.011438918299973011
Step: 2801, initial loss: 0.010594790801405907
Step: 2901, initial loss: 0.010543818585574627
Platypus;
(Opossum,(Mouse,(Dog,Human)Dog-Human)Mouse-Dog-Human)Opossum-Mouse-Dog-Human;
CPU times: user 10min 8s, sys: 1min 26s, total: 11min 35s
Wall time: 

In [48]:
print(tree.ascii_art())

          /-Platypus
-Platypus-Opossum-Mouse-Dog-Human
         |          /-Opossum
          \Opossum-Mouse-Dog-Human
                   |          /-Mouse
                    \Mouse-Dog-Human
                             |          /-Dog
                              \Dog-Human
                                        \-Human


## Example 2
### A single iteration
Now run through a single iteration of the algorithm.
#### Fit triples
This step fits rooted, strand-symmetric, continuous-time models to every taxon triple.

In [10]:
%%time 
triples = get_triples(aln, verbose=True)
losses, fits = fit_triples(triples, cherries_share_matrices=False, verbose=True)

Got 4406 positions


2021-09-02 14:15:54.529725: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-09-02 14:15:54.551956: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2599990000 Hz
2021-09-02 14:15:54.552824: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55b8b2e3ac20 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-09-02 14:15:54.552856: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2021-09-02 14:15:54.631071: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2

Step: 1, initial loss: 2.0066494941711426
Step: 101, initial loss: 1.6990165710449219
Step: 201, initial loss: 1.1723039150238037
Step: 301, initial loss: 0.4476985037326813
Step: 401, initial loss: 0.3442932963371277
Step: 501, initial loss: 0.32133156061172485
Step: 601, initial loss: 0.3021972179412842
Step: 701, initial loss: 0.2701646387577057
Step: 801, initial loss: 0.25689423084259033
Step: 901, initial loss: 0.24404454231262207
Step: 1001, initial loss: 0.23771673440933228
Step: 1101, initial loss: 0.21686285734176636
Step: 1201, initial loss: 0.20941677689552307
Step: 1301, initial loss: 0.18028579652309418
Step: 1401, initial loss: 0.16758763790130615
Step: 1501, initial loss: 0.16270193457603455
Step: 1601, initial loss: 0.15450730919837952
Step: 1701, initial loss: 0.15144827961921692
Step: 1801, initial loss: 0.14939221739768982
Step: 1901, initial loss: 0.14767670631408691
Step: 2001, initial loss: 0.14716917276382446
Step: 2101, initial loss: 0.14671730995178223
Step: 2

### Find the most probably cherry
This step finds the pair of taxa that are most likely to be a cherries using model selection calculations, then simultaneously fits that cherry across all of the triples of which it is a part to estimate its rate matrices.

In [11]:
cherry_names = get_cherries(triples, losses, verbose=True)
triples_for_cherry = [t for t in triples if set(cherry_names) < set(t[0])]
_, fits = fit_triples(triples_for_cherry, cherries_share_matrices=True, verbose=True)

Dog, Human: -1.0707451105117798
Mouse, Dog: -1.8164032697677612
Mouse, Human: -1.8881000876426697
Opossum, Platypus: -2.5792858600616455
Mouse, Platypus: -46.93198621273041
Opossum, Human: -51.83248162269592
Opossum, Dog: -54.830931663513184
Mouse, Opossum: -80.58363389968872
Human, Platypus: -91.7210224866867
Dog, Platypus: -97.26261615753174
Step: 1, initial loss: 0.7077893018722534
Step: 101, initial loss: 0.6105999946594238
Step: 201, initial loss: 0.39786848425865173
Step: 301, initial loss: 0.1679067462682724
Step: 401, initial loss: 0.13688445091247559
Step: 501, initial loss: 0.13185706734657288
Step: 601, initial loss: 0.11637300997972488
Step: 701, initial loss: 0.09264581650495529
Step: 801, initial loss: 0.09011288732290268
Step: 901, initial loss: 0.08946087956428528
Step: 1001, initial loss: 0.08910869061946869
Step: 1101, initial loss: 0.08884220570325851
Step: 1201, initial loss: 0.08711621910333633
Step: 1301, initial loss: 0.07376112043857574
Step: 1401, initial loss:

This step extracts the simultaneously fitted transition probability matrices for the cherry taxa.

In [12]:
cherry_Ps = get_Ps(cherry_names, triples_for_cherry, fits)

Test a single agglomeration step.

In [13]:
forest = {}
cherries = cherry_names, cherry_Ps
update_tree(forest, cherries)

These were the cherry taxa and their corresponding transition probability matrices.

In [45]:
cherries

(['Dog', 'Human'],
 (array([[0.70033926, 0.0410389 , 0.23727772, 0.02134408],
         [0.00668178, 0.90446866, 0.02488895, 0.06396052],
         [0.06396053, 0.02488895, 0.9044687 , 0.00668178],
         [0.02134408, 0.23727769, 0.0410389 , 0.70033926]], dtype=float32),
  array([[0.7839454 , 0.02744941, 0.1780567 , 0.01054858],
         [0.00591483, 0.9303335 , 0.00940431, 0.05434741],
         [0.0543474 , 0.00940431, 0.93033355, 0.00591483],
         [0.01054858, 0.1780567 , 0.02744941, 0.7839453 ]], dtype=float32)))

In [14]:
forest

{'Dog-Human': Tree("(Dog,Human)Dog-Human;")}

These were the data that were fitted simultaneously to find the above rate matrics for Dog and Human.

In [15]:
triples_for_cherry

[[('Dog', 'Human', 'Mouse'),
  array([[[ 283,   33,  137,   19],
          [  18,   22,    6,    2],
          [  44,   16,  113,    7],
          [   4,    5,    4,   18]],
  
         [[  22,   13,   15,    5],
          [  14,  841,   27,  111],
          [  11,   16,   32,    5],
          [  11,  122,   14,  127]],
  
         [[  76,   14,  100,    9],
          [   4,   22,   13,    9],
          [ 121,   40, 1019,   21],
          [   0,   13,   12,   11]],
  
         [[   9,    5,    3,    3],
          [   7,  108,   12,   76],
          [   4,    4,    7,   10],
          [  23,  172,   36,  326]]], dtype=int32)],
 [('Dog', 'Human', 'Opossum'),
  array([[[338,  19,  71,  44],
          [ 21,   4,   6,  17],
          [ 89,  10,  58,  23],
          [ 12,   1,   3,  15]],
  
         [[ 29,   6,   4,  16],
          [ 89, 476,  39, 389],
          [ 20,  10,  17,  17],
          [ 19,  62,   8, 185]],
  
         [[113,  10,  55,  21],
          [ 13,  17,   3,  15],
       

### Merge the cherries
Merge the cherries by fitting the triples between the unmerged taxa and the new merged node, using the probability matrices that we fitted above.

In [17]:
%%time
triples = merge_cherries(triples, cherries, verbose=True)





Step: 1, initial loss: 0.18694207072257996
Step: 101, initial loss: 0.05630214512348175
Step: 201, initial loss: 0.016560954973101616
Step: 301, initial loss: 0.011870048940181732
Step: 401, initial loss: 0.010795547626912594
Step: 501, initial loss: 0.010174106806516647
Step: 601, initial loss: 0.009626062586903572
Step: 701, initial loss: 0.009043915197253227
Step: 801, initial loss: 0.008381888270378113
Step: 901, initial loss: 0.007614144589751959
Step: 1001, initial loss: 0.006727422121912241
Step: 1101, initial loss: 0.005726528353989124
Step: 1201, initial loss: 0.004647366236895323
Step: 1301, initial loss: 0.0035685193724930286
Step: 1401, initial loss: 0.0026003215461969376
Step: 1501, initial loss: 0.001836904906667769
Step: 1601, initial loss: 0.001307002967223525
Step: 1701, initial loss: 0.0009760017856024206
Step: 1801, initial loss: 0.0007846312946639955
Step: 1901, initial loss: 0.0006791851483285427
Step: 2001, initial loss: 0.0006221724906936288
Step: 2101, initial l

At the end of the first iteration, the triples now consist of unmerged taxa and the newly created merged node (called "Dog-Human").

In [18]:
triples

[(('Mouse', 'Opossum', 'Platypus'),
  array([[[276,  15,  80,  24],
          [ 11,  12,   3,   9],
          [ 50,   3,  62,   3],
          [ 25,  12,   8,  58]],
  
         [[ 90,  37,  20,  30],
          [ 37, 353,  18, 159],
          [ 10,  12,  24,   9],
          [ 28, 181,  12, 426]],
  
         [[315,  15, 166,  25],
          [  9,  15,   9,  18],
          [150,  17, 631,  23],
          [ 22,  25,  18,  92]],
  
         [[ 37,  11,  12,  21],
          [ 10,  56,   4,  50],
          [  4,   6,  10,   4],
          [ 32,  81,  11, 410]]], dtype=int32)),
 (['Mouse', 'Opossum', 'Dog-Human'],
  array([[[271.39813  ,  17.798874 ,  94.73732  ,  10.987216 ],
          [ 11.867694 ,  14.324279 ,   5.5005774,   2.9197848],
          [ 44.02035  ,   2.1981473,  70.13179  ,   2.260109 ],
          [ 42.59225  ,  15.920035 ,  19.829195 ,  24.439869 ]],
  
         [[ 35.510742 ,  90.71925  ,  31.583857 ,  19.548342 ],
          [ 10.520011 , 467.30704  ,  19.237043 ,  71.135506 ]