In [1]:
from itertools import combinations
from collections import Counter, defaultdict

from cogent3 import load_aligned_seqs, PhyloNode
import numpy as np
from sklearn.cluster import SpectralClustering
import networkx as nx

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfl = tf.linalg

tf.executing_eagerly()  # need to check whether this is the default for tensorflow > 2

2021-11-22 12:14:29.408670: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


True

In [3]:
# this stops tensorflow from snaffling all of the GPU
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

2021-11-22 12:14:30.835811: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-11-22 12:14:30.905203: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-22 12:14:30.905544: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 with Max-Q Design computeCapability: 7.5
coreClock: 1.095GHz coreCount: 46 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 357.69GiB/s
2021-11-22 12:14:30.905570: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-11-22 12:14:30.921009: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-11-22 12:14:30.931978: I tensorflow/s

In [4]:
def get_triples(aln, nuc_order='ACGT', codon_position=None, verbose=False):
    if codon_position:
        aln = aln[codon_position - 1::3]
    aln = aln.no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

In [57]:
@tf.function()
def transform_P_matrix(Q, t):
    Q0 = tf.concat([[-tf.reduce_sum(Q[0])], Q[0]], axis=0)
    Q1 = tf.concat([[Q[1,0]], [-tf.reduce_sum(Q[1])], Q[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q*t)

@tf.function()
def compose_P(pi, Q, ti, to):
    Pi = transform_P_matrix(Q, ti)
    Piinv = transform_P_matrix(Q, -ti)
    Po = transform_P_matrix(Q, to)
    P = tfl.matmul(tfl.diag(1/pi), Pi, transpose_b=True)
    P = tfl.matmul(P, tfl.diag(tf.matmul(tf.reshape(pi, shape=(1,4)), Piinv)))
    P = tfl.matmul(P, Po)[0]
    return P

@tf.function()
def transform(params):
    ptis = tf.exp(params[:3]) / (1 + tf.exp(params[:3]))
    pi = tfb.SoftmaxCentered()(params[3:6])
    ts = tf.exp(params[6:9])
    tab, tbc, tac = [ts[i] for i in range(3)]
    max0 = tfb.SoftClip(low=0, hinge_softness=0.05)  # 0.05)  # 0.01)
    toa = max0((tab + tac - tbc) / 2)
    tob = max0((tab + tbc - tac) / 2)
    toc = max0((tbc + tac - tab) / 2)
    tia = ptis[0] * toa
    tib = ptis[1] * tob
    tic = ptis[2] * toc
    
    Q_raw = params[9:]
    #Q = tfb.SoftmaxCentered()(Q_raw)
    Q = tf.concat([[Q_raw[1]], Q_raw, [Q_raw[1]], [Q_raw[1]]], axis=0)
    Q = tfb.SoftmaxCentered()(Q)
    Q = tf.stack([Q[:3], Q[3:]])
    Pa = compose_P(pi, Q, tia, toa)
    Pb = compose_P(pi, Q, tib, tob)
    Pc = compose_P(pi, Q, tic, toc)
    return pi, Pa, Pb, Pc, tia, tib, tic, toa, tob, toc, Q

@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pb, Pc, tia, tib, tic, _, _, _, _ = transform(params)
    J = tf.einsum('i,ij,ik,il', pi, Pa, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data)) # + \
#        0.01 * (tia + tib + tic)
#        0.01 * tf.reduce_sum(tf.exp(params[:6]))
#        0.01 * (tf.reduce_sum(tf.exp(params[:3])) - 0.99 * tf.reduce_max(tf.exp(params[:3])))
    return loss
    
@tf.function()
def loss(params, data):
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def stack_params(Qs, pis_and_tis, ts, sharesies, tix):
    parameters = [tf.concat([pis_and_tis[j],
                            [ts[tix[j,k]] for k in range(3)],
                            Qs[sharesies[j]]], axis=0)
                  for j in range(len(sharesies))]
    return tf.stack(parameters)

@tf.function()
def training_step(parameters, data, optimizer):
    Qs, pis_and_tis, ts, sharesies, tix = parameters
    with tf.GradientTape() as tape:
        params = stack_params(Qs, pis_and_tis, ts, sharesies, tix)
        loss_value = loss(params, data)
    gradients = tape.gradient(loss_value, parameters[:3])
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer)
        optimizer.apply_gradients(zip(gradients, parameters[:3]))
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

def fit_triples(triples, t_sharesies=None, Q_sharesies=None, learning_rate=0.01, steps=3000, verbose=False):
    if Q_sharesies is None:
        Q_sharesies = list(range(len(triples)))
    assert len(triples) == len(Q_sharesies), 'triples and Q_sharesies mismatch'
    unique_Q_sharesies = list(set(Q_sharesies))
    Q_sharesies = [unique_Q_sharesies.index(s) for s in Q_sharesies]
    
    if t_sharesies is None:
        t_sharesies = list(range(len(triples)))
    assert len(triples) == len(t_sharesies), 'triples and t_sharesies mismatch'
    
    normal_initializer = tf.random_normal_initializer()
    Qs = tf.Variable(normal_initializer(shape=(len(unique_Q_sharesies), 2),
                                        dtype=tf.float32))

    pis_and_tis = tf.Variable(normal_initializer(shape=(len(triples), 6),
                                                 dtype=tf.float32))
    
    pairs = defaultdict(set)
    for (triple, _), group in zip(triples, t_sharesies):
        pairs[group].update(frozenset((a, b)) for a in triple for b in triple if a != b)
    pairs = [(g, p) for g, pairset in pairs.items() for p in pairset]
    tix = []
    for (triple, _), group in zip(triples, t_sharesies):
        ix = []
        for i in (0, 1, 2):
            ix.append(pairs.index((group, frozenset((triple[i], triple[(i+1)%3])))))
        tix.append(ix)
    length_initializer = tf.random_normal_initializer(mean=-0.7, stddev=0.35)
    ts = tf.Variable(length_initializer(shape=(len(pairs),), dtype=tf.float32))
    tix = tf.constant(tix)
    
    Q_sharesies = tf.constant(Q_sharesies)
    
    parameters = [Qs, pis_and_tis, ts, Q_sharesies, tix]
    
    data = []
    for names, F in triples:
        data.append(F/F.sum())
    data = tf.constant(data, dtype=tf.float32)
    
    if verbose:
        print(f'Fitting {data.shape[0]} triple(s) using {Qs.shape[0]} '
              f'rate matrix(es) and {ts.shape[0]} tip-to-tip distance(s).')
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
#     for a, (b, _) in zip(tix, triples):
#         print(a, b)
#     for i, p in enumerate(pairs):
#         print(np.exp(ts[i].numpy()), p)
    
    parameters = stack_params(Qs, pis_and_tis, ts, Q_sharesies, tix)
    output_parameters = [transform(parameters[i]) for i in range(len(parameters))]
    losses = tf.vectorized_map(_loss, (parameters, data)).numpy()
    return losses, output_parameters

In [6]:
def get_edges(triples, lengths, losses):
    edges = Counter()
    for lens, ls, (names, F) in zip(lengths, losses, triples):
        weights = lens/lens.sum()
        for name, weight in zip(names, weights):
            edges[frozenset(names) - {name}] += weight
    return edges

def edges_to_graph(edges):
    G = nx.Graph()
    for edge, weight in edges.items():
        G.add_edge(*edge, weight=weight)
    return G

def pulsar(triples, learning_rate=0.01, steps=3000, verbose=False):
    losses, params = fit_triples(triples, learning_rate=learning_rate,
                         steps=steps, verbose=verbose, sharesies=[0]*len(triples))
    lengths = [np.exp(p[:3]) for p in params]
    tree = pulsar_tree(triples, lengths, losses, verbose=verbose)
    return tree

def normalised_cut(edges, verbose=False):
    tips = np.unique([t for p in edges.keys() for t in p])
    affinity = np.zeros((len(tips),)*2)
    for i, tipi in enumerate(tips):
        for j, tipj in enumerate(tips):
            if i == j:
                break
            affinity[i, j] = edges[frozenset((tipi, tipj))]
    affinity += affinity.T
    sc = SpectralClustering(2, affinity='precomputed',  # random_state=0,
                            assign_labels='discretize')
    ix = sc.fit_predict(affinity).astype(bool)
    partition = list(tips[ix]), list(tips[np.logical_not(ix)])
    if verbose:
        G = edges_to_graph(edges)
        cut_value = nx.cut_size(G, partition[0], weight='weight')
        print(f'Cut value: {cut_value}, Partition:\n{partition}')
    return partition

def pulsar_tree(triples, lengths, losses, verbose=False):
    assert len(triples) == len(lengths)
    assert len(triples) == len(losses)
    
    edges = get_edges(triples, lengths, losses)
    if verbose:
        print('Graph:')
        for edge, weight in edges.items():
            print(edge, weight)
    partition = normalised_cut(edges, verbose)
    assert len(partition) == 2, 'polytomy detected. bailing'
    this_node = PhyloNode()
    for part in partition:
        if len(part) <= 1:
            this_node.append(PhyloNode(part.pop()))
            continue
        elif len(part) == 2:
            child = PhyloNode()
            for grandchild in part:
                child.append(PhyloNode(grandchild))
            this_node.append(child)
            continue
    
        part = set(part)
        part_losses = []
        part_triples = []
        part_lengths = []
        for lens, ls, (names, F) in zip(lengths, losses, triples):
            if set(names) <= part:
                part_lengths.append(lens)
                part_losses.append(ls)
                part_triples.append((names, F))
        this_node.append(pulsar_tree(part_triples, part_lengths, part_losses, verbose=verbose))
    return this_node

In [29]:
%%time
aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
triples = get_triples(aln, codon_position=3, verbose=True)
losses, params = fit_triples(triples, verbose=True, steps=3000,
                             Q_sharesies=[0]*len(triples), t_sharesies = [0]*len(triples))

Got 1379 positions
Fitting 10 triple(s) using 1 rate matrix(es) and 10 tip-to-tip distance(s).
Step: 1, initial loss: 0.28244325518608093
Step: 101, initial loss: 0.06167846918106079
Step: 201, initial loss: 0.03899339959025383
Step: 301, initial loss: 0.03469279780983925
Step: 401, initial loss: 0.03297993168234825
Step: 501, initial loss: 0.03209149092435837
Step: 601, initial loss: 0.03158606216311455
Step: 701, initial loss: 0.031270310282707214
Step: 801, initial loss: 0.031056631356477737
Step: 901, initial loss: 0.03090176172554493
Step: 1001, initial loss: 0.03078506886959076
Step: 1101, initial loss: 0.030695773661136627
Step: 1201, initial loss: 0.030626222491264343
Step: 1301, initial loss: 0.03057095594704151
Step: 1401, initial loss: 0.030526183545589447
Step: 1501, initial loss: 0.030489183962345123
Step: 1601, initial loss: 0.030458297580480576
Step: 1701, initial loss: 0.030432214960455894
Step: 1801, initial loss: 0.030410021543502808
Step: 1901, initial loss: 0.030390

In [62]:
losses

array([0.00247569, 0.00580734, 0.00419697, 0.00364184, 0.00372384,
       0.00199067, 0.00160424, 0.00357931, 0.00287118, 0.00297137])

In [30]:
losses

array([0.00173156, 0.00489512, 0.00330998, 0.00374211, 0.00403919,
       0.00181286, 0.00157304, 0.00399779, 0.00246676, 0.00270869],
      dtype=float32)

In [67]:
list(zip((n for n, _ in triples), (np.exp(p[:3]) for p in params)))

[(('Dog', 'Human', 'Mouse'),
  array([0.07298953, 0.10161132, 0.05906086], dtype=float32)),
 (('Dog', 'Human', 'Opossum'),
  array([0.09708241, 0.08377995, 0.6660794 ], dtype=float32)),
 (('Dog', 'Human', 'Platypus'),
  array([0.09662799, 0.05149724, 0.7859325 ], dtype=float32)),
 (('Dog', 'Mouse', 'Opossum'),
  array([0.09756454, 0.12715098, 0.5615496 ], dtype=float32)),
 (('Dog', 'Mouse', 'Platypus'),
  array([0.22691889, 0.07538477, 0.6418923 ], dtype=float32)),
 (('Dog', 'Opossum', 'Platypus'),
  array([0.3081079 , 0.39696702, 0.44309965], dtype=float32)),
 (('Human', 'Mouse', 'Opossum'),
  array([0.13214006, 0.07260554, 0.5675322 ], dtype=float32)),
 (('Human', 'Mouse', 'Platypus'),
  array([0.16070251, 0.08125931, 0.57675904], dtype=float32)),
 (('Human', 'Opossum', 'Platypus'),
  array([0.1439011 , 0.3845076 , 0.37645704], dtype=float32)),
 (('Mouse', 'Opossum', 'Platypus'),
  array([0.17595707, 0.3732416 , 0.38440546], dtype=float32))]

In [31]:
list(zip((n for n, _ in triples), losses, [np.array([p.numpy() for p in param[4:10]]) for param in params]))

[(('Dog', 'Human', 'Mouse'),
  0.0017315601,
  array([0.0609278 , 0.05544734, 0.12454414, 0.13344057, 0.16606377,
         0.61740303], dtype=float32)),
 (('Dog', 'Human', 'Opossum'),
  0.0048951195,
  array([0.00089988, 0.0055109 , 0.3468333 , 0.12890965, 0.17078088,
         0.34949952], dtype=float32)),
 (('Dog', 'Human', 'Platypus'),
  0.0033099754,
  array([0.05474074, 0.00290636, 0.4005302 , 0.14595523, 0.15327966,
         0.40244585], dtype=float32)),
 (('Dog', 'Mouse', 'Opossum'),
  0.0037421086,
  array([0.00113644, 0.05251506, 0.3040527 , 0.17119642, 0.57771105,
         0.30498797], dtype=float32)),
 (('Dog', 'Mouse', 'Platypus'),
  0.0040391926,
  array([0.05881398, 0.03780889, 0.38075227, 0.16513251, 0.5839924 ,
         0.38237575], dtype=float32)),
 (('Dog', 'Opossum', 'Platypus'),
  0.001812863,
  array([0.19560938, 0.16787796, 0.23364855, 0.30799335, 0.16828553,
         0.23815155], dtype=float32)),
 (('Human', 'Mouse', 'Opossum'),
  0.0015730357,
  array([0.02075593

In [78]:
params

<tf.Tensor: shape=(10, 11), dtype=float32, numpy=
array([[-2.708005  , -2.3047087 , -2.538702  , -2.4713752 , -1.947004  ,
        -0.49924543,  0.05634893, -0.41402954,  0.08538109,  1.3506602 ,
        -0.95855683],
       [-4.7744646 , -4.5075455 , -1.6724694 , -1.4267347 , -1.2865919 ,
        -0.57767147,  0.19750692, -1.0690533 , -5.043477  ,  1.3506602 ,
        -0.95855683],
       [-2.6475873 , -2.6633735 , -0.5653625 , -2.617783  , -2.0681472 ,
        -0.968321  ,  0.1028522 , -0.3440413 ,  0.14183377,  1.3506602 ,
        -0.95855683],
       [-2.438525  , -2.336017  , -0.5599119 , -2.4999914 , -0.65879494,
        -1.6509225 ,  0.06806601, -0.25119123,  0.20840725,  1.3506602 ,
        -0.95855683],
       [-1.444854  , -2.6757386 , -0.40964496, -3.0720534 , -0.7538406 ,
        -1.6699716 ,  0.08764298, -0.11357285,  0.34374702,  1.3506602 ,
        -0.95855683],
       [-0.9719933 , -0.88674045, -0.7620446 , -1.2795044 , -3.358952  ,
        -1.8986369 ,  0.10729649, -0.

In [32]:
lengths = [np.array([p.numpy() for p in param[4:7]]) for param in params]
tree = pulsar_tree(triples, lengths, losses, verbose=True)
print(tree.ascii_art())

Graph:
frozenset({'Mouse', 'Human'}) 1.8607659339904785
frozenset({'Mouse', 'Dog'}) 1.8777560740709305
frozenset({'Dog', 'Human'}) 2.3729870915412903
frozenset({'Opossum', 'Human'}) 0.536705068545416
frozenset({'Opossum', 'Dog'}) 0.5536943003535271
frozenset({'Platypus', 'Human'}) 0.5978275388479233
frozenset({'Platypus', 'Dog'}) 0.3666835343465209
frozenset({'Mouse', 'Opossum'}) 0.4507262341212481
frozenset({'Mouse', 'Platypus'}) 0.4644346237182617
frozenset({'Opossum', 'Platypus'}) 0.9184193313121796
Cut value: 2.970071299932897, Partition:
(['Opossum', 'Platypus'], ['Dog', 'Human', 'Mouse'])
Graph:
frozenset({'Mouse', 'Human'}) 0.25289714336395264
frozenset({'Mouse', 'Dog'}) 0.230149045586586
frozenset({'Dog', 'Human'}) 0.5169537663459778
Cut value: 0.48304618895053864, Partition:
(['Mouse'], ['Dog', 'Human'])
                    /-Opossum
          /--------|
         |          \-Platypus
---------|
         |          /-Mouse
          \--------|
                   |          /-D

In [187]:
%%time
aln = load_aligned_seqs('brca1.fasta', moltype='dna')
subaln = aln.get_similar(aln.take_seqs(['Human']).seqs[0], min_similarity=0.84)
triples = get_triples(subaln, codon_position=3, verbose=True)
tree = pulsar(triples, verbose=True)
print(tree.ascii_art())

Got 881 positions
Fitting 455 triples using 1 rate matrix(es)
Step: 1, initial loss: 148.9008026123047
Step: 101, initial loss: 40.13922882080078
Step: 201, initial loss: 23.3443603515625
Step: 301, initial loss: 19.115922927856445
Step: 401, initial loss: 16.462268829345703
Step: 501, initial loss: 14.36288833618164
Step: 601, initial loss: 12.7217378616333
Step: 701, initial loss: 11.5491304397583
Step: 801, initial loss: 10.769973754882812
Step: 901, initial loss: 10.237008094787598
Step: 1001, initial loss: 9.82467269897461
Step: 1101, initial loss: 9.46207046508789
Step: 1201, initial loss: 9.078824996948242
Step: 1301, initial loss: 8.74043083190918
Step: 1401, initial loss: 8.47661304473877
Step: 1501, initial loss: 8.271474838256836
Step: 1601, initial loss: 8.111785888671875
Step: 1701, initial loss: 7.9773149490356445
Step: 1801, initial loss: 7.854578495025635
Step: 1901, initial loss: 7.747161865234375
Step: 2001, initial loss: 7.646888256072998
Step: 2101, initial loss: 7.

In [385]:
%%time
aln = load_aligned_seqs('brca1.fasta', moltype='dna')
subaln = aln.get_similar(aln.take_seqs(['Human']).seqs[0], min_similarity=0.80)
triples = get_triples(subaln, codon_position=3, verbose=True)
losses, params = fit_triples(triples, verbose=True, steps=3000, sharesies=[0]*len(triples))  # , sharesies=sharesies)  # , sharesies=[0]*len(triples))

Got 745 positions
Fitting 7140 triples using 1 rate matrix(es)
Step: 1, initial loss: 932.1404418945312
Step: 101, initial loss: 193.00250244140625
Step: 201, initial loss: 86.87216186523438
Step: 301, initial loss: 59.58234786987305
Step: 401, initial loss: 48.13275909423828
Step: 501, initial loss: 42.375404357910156
Step: 601, initial loss: 39.149967193603516
Step: 701, initial loss: 37.17272186279297
Step: 801, initial loss: 35.861366271972656
Step: 901, initial loss: 34.932430267333984
Step: 1001, initial loss: 34.23923110961914
Step: 1101, initial loss: 33.700767517089844
Step: 1201, initial loss: 33.269264221191406
Step: 1301, initial loss: 32.914608001708984
Step: 1401, initial loss: 32.616668701171875
Step: 1501, initial loss: 32.36134338378906
Step: 1601, initial loss: 32.13839340209961
Step: 1701, initial loss: 31.940216064453125
Step: 1801, initial loss: 31.761077880859375
Step: 1901, initial loss: 31.59670639038086
Step: 2001, initial loss: 31.443927764892578
Step: 2101, i

In [386]:
list(zip((n for n, _ in triples), losses, [np.array([p.numpy() for p in param[4:10]]) for param in params]))

[(('FlyingFox', 'FreeTaile', 'LittleBro'),
  0.005206179,
  array([0.01996232, 0.00444928, 0.0006153 , 0.10571226, 0.01687876,
         0.04955609], dtype=float32)),
 (('FlyingFox', 'FreeTaile', 'TombBat'),
  0.0059564696,
  array([0.02457512, 0.00343366, 0.00521923, 0.10511307, 0.01736984,
         0.01263802], dtype=float32)),
 (('FlyingFox', 'FreeTaile', 'LeafNose'),
  0.0047125067,
  array([0.00923002, 0.0102211 , 0.00889685, 0.07330912, 0.04733199,
         0.06368802], dtype=float32)),
 (('FlyingFox', 'FreeTaile', 'Horse'),
  0.004242083,
  array([0.01965397, 0.0075755 , 0.01195962, 0.09485808, 0.02642714,
         0.07972658], dtype=float32)),
 (('FlyingFox', 'FreeTaile', 'Rhino'),
  0.0029441984,
  array([0.01914574, 0.03133254, 0.03313847, 0.07314131, 0.04749844,
         0.07142801], dtype=float32)),
 (('FlyingFox', 'FreeTaile', 'Pangolin'),
  0.0036895326,
  array([0.01842042, 0.01126666, 0.07119681, 0.07517575, 0.04548229,
         0.08318905], dtype=float32)),
 (('FlyingFo

In [387]:
#lengths = [np.exp(p[:3]) for p in params]
#lengths = [np.exp(p[:3]) == np.exp(p[:3]).max() for p in params]
#lengths = [np.exp(p[:3])/np.exp(p[:3]).sum() for p in params]
#h = np.array([-(p*np.log(p)).sum() for p in lengths])
#lengths = np.diag(1/(h + 1)) @ lengths
lengths = [np.array([p.numpy() for p in param[4:7]]) for param in params]
tree = pulsar_tree(triples, lengths, losses, verbose=False)
print(tree.ascii_art())

                              /-Galago
                    /--------|
                   |          \-TreeShrew
                   |
                   |                    /-Orangutan
          /--------|          /--------|
         |         |         |         |          /-Human
         |         |         |          \--------|
         |         |         |                   |          /-Chimpanzee
         |          \--------|                    \--------|
         |                   |                              \-Gorilla
         |                   |
         |                   |          /-HowlerMon
         |                    \--------|
         |                              \-Rhesus
         |
         |                                        /-Llama
         |                              /--------|
         |                             |          \-Pig
         |                    /--------|
         |                   |         |          /-Cow
         |     

In [246]:
lengths = [np.exp(p[:3]) for p in params]
#lengths = [np.exp(p[:3]) == np.exp(p[:3]).max() for p in params]
lengths = [np.exp(p[:3])/np.exp(p[:6]).sum() for p in params]
#h = np.array([-(p*np.log(p)).sum() for p in lengths])
#lengths = np.diag(1/(h + 1)) @ lengths
tree = pulsar_tree(triples, lengths, losses, verbose=False)
print(tree.ascii_art())

                              /-HowlerMon
                    /--------|
                   |          \-Sloth
          /--------|
         |         |          /-Pangolin
         |          \--------|
         |                   |          /-FlyingLem
         |                    \--------|
         |                              \-Rhesus
         |
         |                                        /-Orangutan
---------|                              /--------|
         |                             |         |          /-Chimpanzee
         |                    /--------|          \--------|
         |                   |         |                    \-Human
         |          /--------|         |
         |         |         |          \-Rhino
         |         |         |
         |         |          \-Horse
          \--------|
                   |                    /-Llama
                   |          /--------|
                   |         |          \-SpermWhale
       

In [157]:
sorted(list(zip(h, (n for n, _ in triples))))

[(0.16873865, ('FlyingLem', 'Human', 'Chimpanzee')),
 (0.18088952, ('SpermWhale', 'Human', 'Chimpanzee')),
 (0.29110888, ('Rhino', 'Orangutan', 'Gorilla')),
 (0.3259487, ('Rhino', 'Orangutan', 'Human')),
 (0.33282077, ('HowlerMon', 'Human', 'Chimpanzee')),
 (0.3472516, ('Rhesus', 'Human', 'Chimpanzee')),
 (0.4210934, ('HowlerMon', 'Orangutan', 'Human')),
 (0.48457676, ('HowlerMon', 'Orangutan', 'Chimpanzee')),
 (0.5511679, ('Rhesus', 'Orangutan', 'Human')),
 (0.630245, ('HowlerMon', 'Orangutan', 'Gorilla')),
 (0.6435783, ('Llama', 'SpermWhale', 'Rhesus')),
 (0.6758433, ('HumpbackW', 'Gorilla', 'Human')),
 (0.6785302, ('FlyingLem', 'Gorilla', 'Human')),
 (0.7086155, ('SpermWhale', 'Gorilla', 'Human')),
 (0.7768129, ('Llama', 'SpermWhale', 'HairyArma')),
 (0.7871436, ('Horse', 'Gorilla', 'Human')),
 (0.7897128, ('HowlerMon', 'Gorilla', 'Human')),
 (0.83272696, ('Horse', 'Rhino', 'HowlerMon')),
 (0.85429287, ('Llama', 'HumpbackW', 'HowlerMon')),
 (0.8665911, ('Llama', 'HowlerMon', 'Chimpa

In [258]:
from sklearn.cluster import KMeans, AffinityPropagation
from sklearn.metrics import davies_bouldin_score
X = [p[6:] for p in params]
fit = AffinityPropagation().fit(X)
davies_bouldin_score(X, fit.labels_)



0.7439558932161066

In [259]:
sharesies = fit.labels_

In [260]:
sharesies

array([16,  0, 16, 16, 16, 16, 16, 16, 16, 16,  0, 13, 13,  2,  2,  2,  2,
       15, 15,  6,  6,  6,  6, 13,  6,  0,  1,  0,  2,  2, 12,  2,  2,  2,
        4, 12,  9, 12, 15, 12,  2,  2,  2,  2, 13,  2,  2,  2, 15,  6,  2,
       12,  2, 13,  2,  2, 12, 12, 12, 12,  2,  2, 15, 16, 13, 13, 13, 13,
        2, 15,  1,  5,  1,  1, 12, 12, 14, 11, 11,  6, 12, 17,  3,  6, 15,
        8,  6, 15, 12, 15, 13,  2,  2,  2,  2,  6,  2,  2,  2,  2,  2,  2,
        6,  1, 16,  2,  2,  2,  2,  2,  2,  2,  2,  6,  3, 16,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 15,  6,
        2,  2, 10,  2, 15, 16, 13, 13, 13, 13,  2, 15,  1,  1,  4,  4,  2,
       15,  2, 11,  4,  2, 15, 17,  3,  2, 15,  8,  2, 15, 10, 15, 13,  0,
        4,  6, 15,  6,  6,  6,  6,  6,  6, 15,  3,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6, 15,  6,  6,  6,  6,  6,  6, 15, 15,  6,  6, 15,
        6,  6, 12, 12,  5, 16, 16, 16, 15, 15,  5, 16,  5,  5, 15, 15, 14,
       11,  7, 15, 15, 17

In [423]:
aln = load_aligned_seqs('brca1.fasta', moltype='dna')
subaln = aln.get_similar(aln.take_seqs(['Human']).seqs[0], min_similarity=0.84)
from cogent3.evolve import distance
from cogent3.phylo import nj
from cogent3.evolve.models import JC69
d = distance.EstimateDistances(subaln[2::3], submodel=HKY85())
d.run(show_progress=False)
mytree = nj.nj(d.get_pairwise_distances())
mytree = mytree.root_at_midpoint()
print(mytree.ascii_art())
type(d)

   0%|                                                                  |00:00<?

                              /-HairyArma
                    /edge.1--|
                   |          \-Sloth
                   |
                   |                              /-Llama
          /edge.7--|                    /edge.3--|
         |         |                   |         |          /-HumpbackW
         |         |          /edge.4--|          \edge.2--|
         |         |         |         |                    \-SpermWhale
         |         |         |         |
         |          \edge.6--|          \-Pangolin
-root----|                   |
         |                   |          /-Horse
         |                    \edge.5--|
         |                              \-Rhino
         |
         |          /-FlyingLem
         |         |
          \edge.0.2|          /-HowlerMon
                   |         |
                    \edge.8--|          /-Rhesus
                             |         |
                              \edge.9--|          /-Orangutan
    

cogent3.evolve.distance.EstimateDistances

In [278]:
type(d.get_pairwise_distances())

cogent3.evolve.fast_distance.DistanceMatrix

In [279]:
from cogent3.evolve.fast_distance import DistanceMatrix

In [281]:
DistanceMatrix(pd.DataFrame())

names,0,1
0,2.0,3.0
1,4.0,2.0


In [282]:
nj.nj?

In [7]:
from cogent3.app import io

In [8]:
dstore = io.get_data_store("../data/horse_pig_bats-filtered.tinydb")
loader = io.load_db()
dstore.describe

record type,number
completed,878
incomplete,122
logs,1


In [None]:
%%time
num_alns = len(dstore)
all_triples = []
Q_sharesies = []
t_sharesies = []
for share, aln_name in enumerate(dstore):
#    if share in (150, 217, 475, 619, 627):
#        continue
    aln = loader(aln_name)
    triples = get_triples(aln, codon_position=3, verbose=False)
    all_triples.extend(triples)
    Q_sharesies.extend([share]*len(triples))
    t_sharesies.extend([share]*len(triples))
all_losses, all_parameters = fit_triples(all_triples, steps=4000, verbose=True,
                                         Q_sharesies=Q_sharesies, t_sharesies=t_sharesies)

Fitting 3512 triple(s) using 878 rate matrix(es) and 5268 tip-to-tip distance(s).


In [45]:
def get_trees(all_triples, all_lengths, all_losses):
    trees = []
    for i in range(0, len(all_losses), 4):
        if np.isnan(all_lengths[i:i+4]).any():
            continue
        tree = pulsar_tree(all_triples[i:i+4], all_lengths[i:i+4], all_losses[i:i+4])
        trees.append(tree)
    return trees

In [50]:
all_lengths = [np.array([p.numpy() for p in param[4:10]]) for param in all_parameters]

In [51]:
for i, lengths in enumerate(all_lengths):
    if np.isnan(lengths).any():
        print(i//4)

17
17
17
17
25
25
25
25
27
27
27
27
94
94
94
94
104
104
104
104
109
109
109
109
113
113
113
113
123
123
123
123
129
129
129
129
165
165
165
165
183
183
183
183
191
191
191
191
205
205
205
205
217
217
217
217
223
223
223
223
232
232
232
232
251
251
251
251
262
262
262
262
271
271
271
271
293
293
293
293
328
328
328
328
329
329
329
329
343
343
343
343
391
391
391
391
404
404
404
404
425
425
425
425
437
437
437
437
460
460
460
460
472
472
472
472
475
475
475
475
478
478
478
478
495
495
495
495
512
512
512
512
517
517
517
517
530
530
530
530
558
558
558
558
560
560
560
560
567
567
567
567
575
575
575
575
588
588
588
588
599
599
599
599
616
616
616
616
619
619
619
619
627
627
627
627
657
657
657
657
669
669
669
669
692
692
692
692
704
704
704
704
717
717
717
717
728
728
728
728
744
744
744
744
762
762
762
762
781
781
781
781
794
794
794
794
851
851
851
851
867
867
867
867
868
868
868
868
871
871
871
871


In [52]:
trees = get_trees(all_triples, all_lengths, all_losses)

In [53]:
len(trees)

820

In [54]:
ghb_siblings = Counter()
ordered_siblings = []
for tree in trees:
    tree = tree.unrooted()
    for sibling in tree.get_node_matching_name('Greater horseshoe bat').parent.children:
        if sibling.name in ('Microbat', 'Pig', 'Horse'):
            ghb_siblings[sibling.name] += 1
            ordered_siblings.append(sibling.name)

In [55]:
ghb_siblings

Counter({'Microbat': 527, 'Pig': 147, 'Horse': 146})

In [56]:
ghb_siblings['Microbat']/len(trees)

0.6426829268292683