# `TERMITE`
sTrand symmEtric tRiple MIncuTsupErtree

In [1]:
from itertools import combinations
from collections import Counter, defaultdict

from cogent3 import load_aligned_seqs, PhyloNode
import numpy as np
import networkx as nx

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

tf.executing_eagerly()  # need to check whether this is the default for tensorflow > 2

2021-09-24 06:49:25.080842: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


True

In [3]:
# this stops tensorflow from snaffling all of the GPU
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

2021-09-24 06:49:27.665873: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-09-24 06:49:27.763394: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-24 06:49:27.764637: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 with Max-Q Design computeCapability: 7.5
coreClock: 1.095GHz coreCount: 46 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 357.69GiB/s
2021-09-24 06:49:27.764706: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-09-24 06:49:27.792266: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-09-24 06:49:27.809852: I tensorflow/s

## Data import
Reads an alignments and creates a list of 4 x 4 x 4 joint frequencies tensors.

In [4]:
def get_triples(aln, nuc_order='ACGT', codon_position=None, verbose=False):
    if codon_position:
        aln = aln[codon_position - 1::3]
    aln = aln.no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

## Triple fitting functions
Collection of functions for concurrent fitting of many triples on CPUs and GPUs. Model is rooted, continuous-time, and strand-symmetric.

Also some functions for using Akaike-ish weights to build Semple and Steel-ish graphs.

In [14]:
@tf.function()
def transform_P_matrix(params):
    params = tf.exp(params)
    Q0 = tf.concat([[-tf.reduce_sum(params[0])], params[0]],
                   axis=0)
    Q1 = tf.concat([[params[1,0]], [-tf.reduce_sum(params[1])], params[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q)

@tf.function()
def transform(params):
    pi = tfb.SoftmaxCentered()(params[0])
    Pa = transform_P_matrix(params[1:3])
    Pm = transform_P_matrix(params[3:5])
    Pb = transform_P_matrix(params[5:7])
    Pc = transform_P_matrix(params[7:9])
    return pi, Pa, Pm, Pb, Pc
    
@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pm, Pb, Pc = transform(params)
    J = tf.einsum('i,ij,ik,ku,kv', pi, Pa, Pm, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
    return loss
    
@tf.function()
def loss(params, data):
    # could do better managing the variance in the shared matrix case here
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def training_step(parameters, data, optimizer, unscrambler):
    with tf.GradientTape() as tape:
        unscrambled = _unscramble(parameters, unscrambler)
        loss_value = loss(unscrambled, data)
    gradients = tape.gradient(loss_value, parameters)
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, unscrambler, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer, unscrambler)
        optimizer.apply_gradients([(gradients, parameters)])
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

@tf.function()
def _unscramble(parameters, unscramble):
    unscrambled = []
    for t1 in unscramble:
        unscrambled.append(tf.stack([parameters[i] for i in t1]))
    return tf.stack(unscrambled)

def fit_triples(triples, learning_rate=0.01, cherries_share_matrices=True, steps=3000, verbose=False):
    K = 0
    cherry_loc = {}
    unscrambler = []
    data = []
    for names, F in triples:
        J = (F/F.sum()).astype(np.float32)
        for ix in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
            t1 = list(range(K, K+5)) # for pi and two Ps
            K += 5
            cherry = [names[ix[1]], names[ix[2]]]
            frozen_cherry = frozenset(cherry)
            if not cherries_share_matrices or frozen_cherry not in cherry_loc:
                new_loc = {cherry[0]: [K,K+1], cherry[1]: [K+2,K+3]}
                K += 4
                cherry_loc[frozen_cherry] = new_loc
            t1.extend(cherry_loc[frozen_cherry][cherry[0]]) # for the first cherry
            t1.extend(cherry_loc[frozen_cherry][cherry[1]]) # for the second cherry
            unscrambler.append(t1)
        
            data.append(J.transpose(ix))

    normal_initializer = tf.random_normal_initializer()
    parameters = tf.Variable(normal_initializer(shape=[K, 3], dtype=tf.float32), name='params')
    data = tf.stack(data)
    
    if verbose:
        print(f'Fitting {data.shape[0]} triples')
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, unscrambler, steps=steps, verbose=verbose)
    
    parameters = _unscramble(parameters, unscrambler)
    losses = tf.vectorized_map(_loss, (parameters, data)).numpy()
    losses = [losses[i:i+3] for i in range(0, len(losses), 3)]
    fits = [[p.numpy() for p in transform(params)] for params in parameters]
    fits = [fits[i:i+3] for i in range(0, len(fits), 3)]
    return losses, fits

def cherry_weights(ls, N):
    ls = N*ls
    delta = ls - ls.min()
    weights = np.exp(-delta)
    weights = weights/weights.sum()
    h = (-weights*np.log(weights))[weights != 0.].sum()
    if h < 1:
        return weights
        return ls == ls.min()
    return np.zeros(3, dtype=bool)

def get_edges(triples, losses):
    edges = Counter()
    for losses, (names, F) in zip(losses, triples):
        weights = cherry_weights(losses, F.sum())
        for name, weight in zip(names, weights):
            edges[frozenset(names) - {name}] += weight
    # for edge in edges:
    #   edges[edge] = np.exp(edges[edge])
    return edges

def get_Ps(cherry_names, triples, fits):
    Ps = {}
    ixes = [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
    for (names, _), triple_fit in zip(triples, fits):
        if set(cherry_names) < set(names):
            for name, ix, fit in zip(names, ixes, triple_fit):
                if name not in cherry_names:
                    if names[ix[1]] == cherry_names[0]:
                        return fit[-2], fit[-1]
                    return fit[-1], fit[-2]

## Tree building algorithm
Where the magic happens.

In [15]:
def termite(triples, learning_rate=0.01, steps=3000, verbose=False):
    losses, _ = fit_triples(triples, cherries_share_matrices=False,
                               learning_rate=learning_rate, steps=steps, verbose=verbose)
    tree = termite_tree(triples, losses, verbose=verbose)
    return tree
    
def termite_tree(triples, losses, verbose=False):
    edges = get_edges(triples, losses)
    if verbose:
        print('Graph:')
        for edge, weight in edges.items():
            print(edge, weight)
    G = nx.Graph()
    for edge, weight in edges.items():
        G.add_edge(*edge, weight=weight)
    cut_value, partition = nx.stoer_wagner(G)
    if verbose:
        print(f'Cut value: {cut_value}, Partition:\n{partition}')
    assert len(partition) == 2, 'polytomy detected. bailing'
    this_node = PhyloNode()
    for part in partition:
        if len(part) <= 1:
            this_node.append(PhyloNode(part.pop()))
            continue
        elif len(part) == 2:
            child = PhyloNode()
            for grandchild in part:
                child.append(PhyloNode(grandchild))
            this_node.append(child)
            continue
    
        part = set(part)
        part_losses = []
        part_triples = []
        for losses, (names, F) in zip(losses, triples):
            if set(names) <= part:
                part_losses.append(losses)
                part_triples.append((names, F))
        this_node.append(termite_tree(part_triples, part_losses, verbose=verbose))
    return this_node

# Some examples
## Example 1
Fit a rooted phylogeny of 5 mammals.

In [18]:
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000197102.fa.gz', moltype="dna")
aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000179869.fa.gz', moltype="dna")
# aln = load_aligned_seqs('brca1.fasta', moltype='dna')

In [8]:
subaln = aln.get_similar(aln.take_seqs(['Human']).seqs[0],
                      min_similarity=0.84)
subaln

0,1
,0
HowlerMon,TGTGGCACAAATACTCATGCCAGCTCATTACAGCATGAGAACAGCAGTTTGTTACTCACT
Horse,.............................G....................A.........
Rhino,........G....................G..................G.A.........
Pangolin,..................................................A.........
Llama,.........G........................................A.........
SpermWhale,.........G............................A...........A.........
HumpbackW,.........G......................A.....A...........A.........
FlyingLem,..................................................A....G....
Rhesus,..........................................---...............


### All at once
First run `termite` all the way through.

In [19]:
%%time
triples = get_triples(aln, codon_position=2, verbose=True)
tree = termite(triples, verbose=True)
print(tree.ascii_art())

Got 1379 positions
Fitting 30 triples
Step: 1, initial loss: 9.581822395324707
Step: 101, initial loss: 7.584673881530762
Step: 201, initial loss: 4.693903923034668
Step: 301, initial loss: 3.0212626457214355
Step: 401, initial loss: 2.5468015670776367
Step: 501, initial loss: 2.294861316680908
Step: 601, initial loss: 2.175772190093994
Step: 701, initial loss: 2.1215875148773193
Step: 801, initial loss: 2.088082790374756
Step: 901, initial loss: 2.0649378299713135
Step: 1001, initial loss: 2.0482425689697266
Step: 1101, initial loss: 2.035581588745117
Step: 1201, initial loss: 2.025559425354004
Step: 1301, initial loss: 2.017711639404297
Step: 1401, initial loss: 2.0114567279815674
Step: 1501, initial loss: 2.0062649250030518
Step: 1601, initial loss: 2.0014472007751465
Step: 1701, initial loss: 1.9963480234146118
Step: 1801, initial loss: 1.9901407957077026
Step: 1901, initial loss: 1.9713897705078125
Step: 2001, initial loss: 1.9658265113830566
Step: 2101, initial loss: 1.9628350734

## Example 2
### A single iteration
Now run through a single iteration of the algorithm.
#### Fit triples
Fits rooted, strand-symmetric, continuous-time models to every taxon triple.

In [78]:
losses, fits = fit_triples(triples, cherries_share_matrices=False, verbose=True, steps=4000)

Fitting 1365 triples
Step: 1, initial loss: 462.26214599609375
Step: 101, initial loss: 302.33087158203125
Step: 201, initial loss: 151.7705841064453
Step: 301, initial loss: 84.62445068359375
Step: 401, initial loss: 62.533382415771484
Step: 501, initial loss: 53.457027435302734
Step: 601, initial loss: 48.88383483886719
Step: 701, initial loss: 46.14476013183594
Step: 801, initial loss: 44.29399871826172
Step: 901, initial loss: 42.8905029296875
Step: 1001, initial loss: 41.79762268066406
Step: 1101, initial loss: 40.92477035522461
Step: 1201, initial loss: 40.17105484008789
Step: 1301, initial loss: 39.56158447265625
Step: 1401, initial loss: 39.026336669921875
Step: 1501, initial loss: 38.559486389160156
Step: 1601, initial loss: 38.22875213623047
Step: 1701, initial loss: 37.95553207397461
Step: 1801, initial loss: 37.676185607910156
Step: 1901, initial loss: 37.369667053222656
Step: 2001, initial loss: 37.12908172607422
Step: 2101, initial loss: 36.96723937988281
Step: 2201, init

#### Create $S_\mathcal{T}\left/E^\text{max}_\mathcal{T}\right.$
Creates the edges in Semple and Steel's $S_\mathcal{T}\left/E^\text{max}_\mathcal{T}\right.$ graph, at least as I understand it.

In [79]:
edges = get_edges(triples, losses)
edges

0.9362246
1.0755355
1.8557614e-07
0.47296438
1.71997e-07
1.04344345e-07
1.1742688e-07
1.8921176e-07
1.0714965
0.9157308
0.93431044
1.0985845
1.0963657
1.0665685
1.0979221
1.0982212
1.088633
0.6891866
0.6928891
0.6930749
1.5602834e-06
0.6918924
1.6335871e-06
1.0976589
1.6262458e-06
1.0983827
0.65957546
1.0876217
1.0030837
1.0893022
1.0946157
1.0882509
1.0880258
1.0846765
1.0954083
1.0918016
1.0884489
1.0985942
0.6911012
3.876891e-07
1.0935391
1.096978
4.663095e-07
3.945222e-07
8.1851056e-07
0.6819309
1.0947587
0.9944037
6.39513e-07
1.0949429
2.5384923e-05
1.0976357
1.0984776
0.059883654
0.69271946
1.0951114
1.0836622
5.8947e-07
5.9149374e-07
1.089884
1.0947607
1.0984529
5.788365e-07
0.6675526
1.0986028
0.00035077735
1.0964332
1.0240635
1.0985948
1.0979124
1.7391195e-18
2.0987137e-07
1.0605047
1.0950999
1.2002204e-06
1.0985401
1.09132
1.2364619e-18
1.0984294
1.0956256
1.0985408
1.0858507
1.0967926
1.0866812
1.0984645
1.0209215
1.0975864
0.6929233
1.0954068
4.613786e-07
1.0459065
1.09704


Counter({frozenset({'Pangolin', 'Rhino'}): 1,
         frozenset({'Horse', 'Pangolin'}): 7,
         frozenset({'Horse', 'Rhino'}): 7,
         frozenset({'Llama', 'Rhino'}): 1,
         frozenset({'Horse', 'Llama'}): 0,
         frozenset({'Rhino', 'SpermWhale'}): 3,
         frozenset({'Horse', 'SpermWhale'}): 6,
         frozenset({'HumpbackW', 'Rhino'}): 4,
         frozenset({'Horse', 'HumpbackW'}): 6,
         frozenset({'FlyingLem', 'Rhino'}): 2,
         frozenset({'FlyingLem', 'Horse'}): 1,
         frozenset({'HowlerMon', 'Rhino'}): 0,
         frozenset({'Horse', 'HowlerMon'}): 0,
         frozenset({'Rhesus', 'Rhino'}): 1,
         frozenset({'Horse', 'Rhesus'}): 1,
         frozenset({'Orangutan', 'Rhino'}): 1,
         frozenset({'Horse', 'Orangutan'}): 1,
         frozenset({'Gorilla', 'Rhino'}): 1,
         frozenset({'Gorilla', 'Horse'}): 2,
         frozenset({'Human', 'Rhino'}): 2,
         frozenset({'Horse', 'Human'}): 1,
         frozenset({'Chimpanzee', 'Rhino'})

#### Find the root by partitioning on the minimum cut
Perform the minimum cut to partition our tip set into two, one either side of the root.

In [80]:
G = nx.Graph()
for edge, weight in edges.items():
    G.add_edge(*edge, weight=weight)
cut_value, partition = nx.stoer_wagner(G)
print(f'Cut value: {cut_value}, Partition:\n{partition}')

Cut value: 16, Partition:
(['Orangutan'], ['Human', 'Rhesus', 'Gorilla', 'Horse', 'Pangolin', 'Sloth', 'Llama', 'FlyingLem', 'HowlerMon', 'SpermWhale', 'Chimpanzee', 'Rhino', 'HairyArma', 'HumpbackW'])


In [37]:
w = np.ones(3)/3
(-w*np.log(w)).sum()

1.0986122886681096

In [55]:
nx.stoer_wagner?

The rest is left as an exercise for the reader (or just look in `termite` above) - the algorithm continues recursively.