In [1]:
from itertools import combinations
from collections import Counter

from cogent3 import load_aligned_seqs
import numpy as np

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

tf.executing_eagerly()

2021-09-01 16:50:09.184786: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


True

In [3]:
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

2021-09-01 16:50:11.029925: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-09-01 16:50:11.079934: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-01 16:50:11.080332: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 with Max-Q Design computeCapability: 7.5
coreClock: 1.095GHz coreCount: 46 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 357.69GiB/s
2021-09-01 16:50:11.080360: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-09-01 16:50:11.081820: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-09-01 16:50:11.083156: I tensorflow/s

In [4]:
def get_triples(aln, nuc_order='ACGT', codon_position=3, verbose=False):
    aln = aln[codon_position - 1::3].no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

In [62]:
@tf.function()
def transform_P_matrix(params):
    params = tf.exp(params)
    Q0 = tf.concat([[-tf.reduce_sum(params[0])], params[0]],
                   axis=0)
    Q1 = tf.concat([[params[1,0]], [-tf.reduce_sum(params[1])], params[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q)

@tf.function()
def transform(params):
    pi = tfb.SoftmaxCentered()(params[0])
    Pa = transform_P_matrix(params[1:3])
    Pm = transform_P_matrix(params[3:5])
    Pb = transform_P_matrix(params[5:7])
    Pc = transform_P_matrix(params[7:9])
    return pi, Pa, Pm, Pb, Pc
    
@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pm, Pb, Pc = transform(params)
    J = tf.einsum('i,ij,ik,ku,kv', pi, Pa, Pm, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
    return loss
    
@tf.function()
def loss(params, data):
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def training_step(parameters, data, optimizer, unscrambler):
    with tf.GradientTape() as tape:
        unscrambled = _unscramble(parameters, unscrambler)
        loss_value = loss(unscrambled, data)
    gradients = tape.gradient(loss_value, parameters)
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, unscrambler, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer, unscrambler)
        optimizer.apply_gradients([(gradients, parameters)])
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

def check_cherry(parameters, data, N):
    ls =  np.array([_loss((parameters[i], data[i])).numpy() for i in range(data.shape[0])])
    ls = N*ls
    delta = ls - ls.min()
    weights = np.exp(-delta)
    return weights/weights.sum()
    
def symmetric_root_fits(triples, learning_rate=0.01, steps=3000, verbose=False):
    normal_initializer = tf.random_normal_initializer()
    K = len(triples)*3
    parameters = tf.Variable(normal_initializer(shape=[K, 9, 3], dtype=tf.float32), name='params')
    
    data = []
    for _, F in triples:
        Ja = (F/F.sum()).astype(np.float32)
        Jb = Ja.transpose([1, 2, 0])
        Jc = Ja.transpose([2, 0, 1])
        data.extend([Ja, Jb, Jc])
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
    root_probs = []
    for i, (names, F) in enumerate(triples):
        probs = check_cherry(parameters[3*i:3*(i+1)], data[3*i:3*(i+1)], F.sum())
        root_probs.append(dict(zip(names, probs)))
        
    return root_probs

@tf.function()
def _unscramble(parameters, unscramble):
    unscrambled = []
    for t1 in unscramble:
        unscrambled.append(tf.stack([parameters[i] for i in t1]))
    return tf.stack(unscrambled)

def symmetric_root_fits(triples, learning_rate=0.01, cherries_share_matrices=True, steps=3000, verbose=False):
    normal_initializer = tf.random_normal_initializer()

    K = 0
    cherry_loc = {}
    unscrambler = []
    data = []
    for names, F in triples:
        J = (F/F.sum()).astype(np.float32)
        for ix in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
            t1 = list(range(K, K+5)) # for pi and two Ps
            K += 5
            cherry = [names[ix[1]], names[ix[2]]]
            frozen_cherry = frozenset(cherry)
            if not cherries_share_matrices or frozen_cherry not in cherry_loc:
                new_loc = {cherry[0]: [K,K+1], cherry[1]: [K+2,K+3]}
                K += 4
                cherry_loc[frozen_cherry] = new_loc
            t1.extend(cherry_loc[frozen_cherry][cherry[0]]) # for the first cherry
            t1.extend(cherry_loc[frozen_cherry][cherry[1]]) # for the second cherry
            unscrambler.append(t1)
        
            data.append(J.transpose(ix))
        
    parameters = tf.Variable(normal_initializer(shape=[K, 3], dtype=tf.float32), name='params')
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, unscrambler, steps=steps, verbose=verbose)
    
    root_probs = []
    unscrambled = _unscramble(parameters, unscrambler)
    for i, (names, F) in enumerate(triples):
        probs = check_cherry(unscrambled[3*i:3*(i+1)], data[3*i:3*(i+1)], F.sum())
        root_probs.append(dict(zip(names, probs)))
        
    return root_probs

def pick_cherry(root_probs):
    cherry_llik = Counter()
    for probs in root_probs:
        names = frozenset(probs.keys())
        for name in names:
            cherry_llik[names - {name}] += np.log(probs[name])
    return cherry_llik.most_common()

In [63]:
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000197102.fa.gz', moltype="dna")
aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000179869.fa.gz', moltype="dna")

In [64]:
aln

0,1
,0
Human,ATGGCAACCTCC---AGAGGGGCCTCCCGGTGTCCTCGGGATATCGCCAATGTGATGCAG
Dog,G.A...------...---------------------------------------------
Mouse,...................CAT.T.......CC.A......C...A..............
Opossum,.....C..T...TCA...AT.CTA..AA.A.C...G..T..C..T..T.....A.....A
Platypus,------------...---------------------------------------------


In [66]:
%%time 
triples = get_triples(aln, verbose=True)
root_probs = symmetric_root_fits(triples, cherries_share_matrices=False, verbose=True)

Got 1379 positions
Step: 1, initial loss: 2.8546125888824463
Step: 101, initial loss: 2.2032418251037598
Step: 201, initial loss: 1.4977716207504272
Step: 301, initial loss: 0.8630205392837524
Step: 401, initial loss: 0.6658306121826172
Step: 501, initial loss: 0.5506930947303772
Step: 601, initial loss: 0.5024968981742859
Step: 701, initial loss: 0.4710041880607605
Step: 801, initial loss: 0.45927634835243225
Step: 901, initial loss: 0.43764108419418335
Step: 1001, initial loss: 0.4162985384464264
Step: 1101, initial loss: 0.39257046580314636
Step: 1201, initial loss: 0.36722075939178467
Step: 1301, initial loss: 0.3446645736694336
Step: 1401, initial loss: 0.3331846296787262
Step: 1501, initial loss: 0.31690871715545654
Step: 1601, initial loss: 0.29524195194244385
Step: 1701, initial loss: 0.2867218554019928
Step: 1801, initial loss: 0.2797970771789551
Step: 1901, initial loss: 0.27758467197418213
Step: 2001, initial loss: 0.262422114610672
Step: 2101, initial loss: 0.25830447673797

In [67]:
cherry_llik = pick_cherry(root_probs)

In [68]:
cherry_llik

[(frozenset({'Dog', 'Human'}), -1.1088603734970093),
 (frozenset({'Human', 'Mouse'}), -2.119705319404602),
 (frozenset({'Opossum', 'Platypus'}), -2.7004566192626953),
 (frozenset({'Mouse', 'Platypus'}), -3.5348814129829407),
 (frozenset({'Dog', 'Mouse'}), -11.171769767999649),
 (frozenset({'Dog', 'Opossum'}), -33.75392091440153),
 (frozenset({'Human', 'Opossum'}), -34.93848013877869),
 (frozenset({'Mouse', 'Opossum'}), -36.706530690193176),
 (frozenset({'Human', 'Platypus'}), -46.11605179309845),
 (frozenset({'Dog', 'Platypus'}), -51.074777364730835)]

In [61]:
root_probs

[{'Dog': 0.38561577, 'Human': 0.33816576, 'Mouse': 0.27621853},
 {'Dog': 1.0313155e-07, 'Human': 0.21758206, 'Opossum': 0.78241783},
 {'Dog': 0.10697844, 'Human': 0.20842554, 'Platypus': 0.68459606},
 {'Dog': 0.24954318, 'Mouse': 0.32080266, 'Opossum': 0.4296541},
 {'Dog': 0.21373376, 'Mouse': 0.19986603, 'Platypus': 0.5864002},
 {'Dog': 0.9488019, 'Opossum': 0.034241207, 'Platypus': 0.016956909},
 {'Human': 0.27611402, 'Mouse': 0.17491361, 'Opossum': 0.54897237},
 {'Human': 0.30928385, 'Mouse': 1.8782882e-07, 'Platypus': 0.69071597},
 {'Human': 0.92236674, 'Opossum': 0.037754644, 'Platypus': 0.03987858},
 {'Mouse': 0.72271425, 'Opossum': 0.13628605, 'Platypus': 0.1409997}]