In [107]:
from itertools import combinations
from collections import Counter

from cogent3 import load_aligned_seqs
import numpy as np

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

tf.executing_eagerly()

2021-09-01 07:08:09.145180: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


True

In [3]:
# thanks https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

2021-09-01 07:08:12.390328: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-09-01 07:08:12.480248: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-01 07:08:12.480745: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 2080 with Max-Q Design computeCapability: 7.5
coreClock: 1.095GHz coreCount: 46 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 357.69GiB/s
2021-09-01 07:08:12.480778: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-09-01 07:08:12.499367: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-09-01 07:08:12.511928: I tensorflow/s

In [84]:
def get_triples(aln, nuc_order='ACGT', codon_position=3, verbose=False):
    aln = aln[codon_position - 1::3].no_degenerates()
    if verbose:
        print(f'Got {len(aln)} positions')
    assert len(aln) <= np.iinfo(np.int32).max
    triples = []
    nuc_map = {n:i for i, n in enumerate(nuc_order)}
    for triple in combinations(range(aln.num_seqs), 3):
        F = np.zeros([4, 4, 4], dtype=np.int32)
        subaln = aln.get_sub_alignment(seqs=triple)
        for a, b, c in subaln:
            F[nuc_map[a], nuc_map[b], nuc_map[c]] += 1
        triples.append([tuple(subaln.names), F])
    return triples

In [87]:
@tf.function()
def transform_P_matrix(params):
    params = tf.exp(params)
    Q0 = tf.concat([[-tf.reduce_sum(params[0])], params[0]],
                   axis=0)
    Q1 = tf.concat([[params[1,0]], [-tf.reduce_sum(params[1])], params[1,1:]],
                   axis=0)
    Q = tf.concat([[Q0], [Q1], [Q1[::-1]], [Q0[::-1]]], axis=0)
    return tf.linalg.expm(Q)

@tf.function()
def transform(params):
    pi = tfb.SoftmaxCentered()(params[0])
    Pa = transform_P_matrix(params[1:3])
    Pm = transform_P_matrix(params[3:5])
    Pb = transform_P_matrix(params[5:7])
    Pc = transform_P_matrix(params[7:9])
    return pi, Pa, Pm, Pb, Pc
    
@tf.function()
def _loss(params_data):
    params, data = params_data
    pi, Pa, Pm, Pb, Pc = transform(params)
    J = tf.einsum('i,ij,ik,ku,kv', pi, Pa, Pm, Pb, Pc)
    loss = tf.reduce_sum(tf.keras.losses.KLDivergence()(J, data))
    return loss

@tf.function()
def loss(params, data):
    return tf.reduce_sum(tf.vectorized_map(_loss, (params, data)))

@tf.function()
def training_step(parameters, data, optimizer):
    with tf.GradientTape() as tape:
        loss_value = loss(parameters, data)
    gradients = tape.gradient(loss_value, parameters)
    return loss_value, gradients

# thanks https://github.com/mlgxmez/thelongrun_notebooks/blob/master/MLE_tutorial.ipynb
def mle_fit(data, loss, parameters, optimizer, steps=500, verbose=False):
    for i in range(steps):
        loss_value, gradients = training_step(parameters, data, optimizer)
        optimizer.apply_gradients([(gradients, parameters)])
        
        if i % 100 == 0:
            if verbose:
                iter_info = f"Step: {optimizer.iterations.numpy()}, initial loss: {loss_value.numpy()}"
                print(iter_info)

def check_cherry(parameters, data, N):
    ls =  np.array([_loss((parameters[i], data[i])).numpy() for i in range(data.shape[0])])
    ls = N*ls
    delta = ls - ls.min()
    weights = np.exp(-delta)
    return weights/weights.sum()
    
def symmetric_root_fits(triples, learning_rate=0.01, steps=3000, verbose=False):
    normal_initializer = tf.random_normal_initializer()
    K = len(triples)*3
    parameters = tf.Variable(normal_initializer(shape=[K, 9, 3], dtype=tf.float32), name='params')
    
    data = []
    for _, F in triples:
        Ja = (F/F.sum()).astype(np.float32)
        Jb = Ja.transpose([1, 2, 0])
        Jc = Ja.transpose([2, 0, 1])
        data.extend([Ja, Jb, Jc])
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
    root_probs = []
    for i, (names, F) in enumerate(triples):
        probs = check_cherry(parameters[3*i:3*(i+1)], data[3*i:3*(i+1)], F.sum())
        root_probs.append(dict(zip(names, probs)))
        
    return root_probs

def symmetric_root_fits(triples, learning_rate=0.01, steps=3000, verbose=False):
    normal_initializer = tf.random_normal_initializer()
    K = len(triples)*3
    parameters = tf.Variable(normal_initializer(shape=[K, 9, 3], dtype=tf.float32), name='params')
    
    data = []
    for _, F in triples:
        Ja = (F/F.sum()).astype(np.float32)
        Jb = Ja.transpose([1, 2, 0])
        Jc = Ja.transpose([2, 0, 1])
        data.extend([Ja, Jb, Jc])
    data = tf.stack(data)

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
    mle_fit(data, loss, parameters, optimizer, steps=steps, verbose=verbose)
    
    root_probs = []
    for i, (names, F) in enumerate(triples):
        probs = check_cherry(parameters[3*i:3*(i+1)], data[3*i:3*(i+1)], F.sum())
        root_probs.append(dict(zip(names, probs)))
        
    return root_probs

def pick_cherry(root_probs):
    cherry_llik = Counter()
    for probs in root_probs:
        names = frozenset(probs.keys())
        for name in names:
            cherry_llik[names - {name}] += np.log(probs[name])
    return cherry_llik.most_common()

In [102]:
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000197102.fa.gz', moltype="dna")
# aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000131018.fa.gz', moltype="dna")
aln = load_aligned_seqs('/home/ben/Data/pentads/ENSG00000179869.fa.gz', moltype="dna")

In [103]:
aln

0,1
,0
Human,ATGGGGCATGCCGGGTGCCAGTTCAAAGCCCTGCTGTGGAAGAATTGGCTCTGCAGACTC
Dog,------------------------------------------------------------
Mouse,........C..T...C........C..........C........C...A.T.........
Opossum,------------------------------------------------------------
Platypus,.......T.ATT..CAA.........G..T...T.C..................C.C..G


In [104]:
%%time 
triples = get_triples(aln, verbose=True)
root_probs = symmetric_root_fits(triples, verbose=True)

Got 4428 positions
Step: 1, initial loss: 1.3090912103652954
Step: 101, initial loss: 1.0768522024154663
Step: 201, initial loss: 0.6612101793289185
Step: 301, initial loss: 0.31701719760894775
Step: 401, initial loss: 0.2674697935581207
Step: 501, initial loss: 0.23480693995952606
Step: 601, initial loss: 0.20158718526363373
Step: 701, initial loss: 0.17683981359004974
Step: 801, initial loss: 0.16765709221363068
Step: 901, initial loss: 0.15469048917293549
Step: 1001, initial loss: 0.14041869342327118
Step: 1101, initial loss: 0.1260424703359604
Step: 1201, initial loss: 0.12195804715156555
Step: 1301, initial loss: 0.1146974191069603
Step: 1401, initial loss: 0.11000671982765198
Step: 1501, initial loss: 0.09930513054132462
Step: 1601, initial loss: 0.08042018860578537
Step: 1701, initial loss: 0.0787866860628128
Step: 1801, initial loss: 0.07334928214550018
Step: 1901, initial loss: 0.07255487889051437
Step: 2001, initial loss: 0.07245475053787231
Step: 2101, initial loss: 0.072297

In [105]:
cherry_llik = pick_cherry(root_probs)

In [106]:
cherry_llik

[(frozenset({'Dog', 'Human'}), -2.720725655555725),
 (frozenset({'Opossum', 'Platypus'}), -2.739441156387329),
 (frozenset({'Dog', 'Mouse'}), -3.0095942616462708),
 (frozenset({'Human', 'Platypus'}), -3.0251817107200623),
 (frozenset({'Human', 'Mouse'}), -3.269514501094818),
 (frozenset({'Dog', 'Platypus'}), -3.314244031906128),
 (frozenset({'Dog', 'Opossum'}), -3.3645277619361877),
 (frozenset({'Human', 'Opossum'}), -3.4927477836608887),
 (frozenset({'Mouse', 'Platypus'}), -3.7081881761550903),
 (frozenset({'Mouse', 'Opossum'}), -31.82746696472168)]

In [96]:
root_probs

[{'Dog': 0.33617315, 'Human': 0.33344793, 'Mouse': 0.33037892},
 {'Dog': 7.589255e-15, 'Human': 6.7333067e-15, 'Opossum': 1.0},
 {'Dog': 9.1874886e-14, 'Human': 7.206145e-14, 'Platypus': 1.0},
 {'Dog': 1.977101e-09, 'Mouse': 0.9999422, 'Opossum': 5.7866357e-05},
 {'Dog': 4.8960076e-09, 'Mouse': 4.0729464e-09, 'Platypus': 1.0},
 {'Dog': 0.38544434, 'Opossum': 0.2873022, 'Platypus': 0.32725346},
 {'Human': 0.26378027, 'Mouse': 0.2738001, 'Opossum': 0.46241966},
 {'Human': 0.2969577, 'Mouse': 2.5067038e-07, 'Platypus': 0.70304203},
 {'Human': 0.33773392, 'Opossum': 0.33536422, 'Platypus': 0.32690188},
 {'Mouse': 0.51515216, 'Opossum': 0.4848476, 'Platypus': 2.1749203e-07}]