In [2]:
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
import numpy as np
from itertools import product
import torch

In [3]:
from transition_matrix_eq_approx import transition_matrix, v, w, u, Neff

In [4]:
graph = nx.DiGraph()
tree = [(4, 5, 1.0), (4, 6, 4.6), (5, 0, 3.2), (5, 1, 2.2), (6, 2, 0.1), (6, 3, 2.4)]
for parent, child, weight in tree:
    graph.add_edge(parent, child, weight=weight)


In [7]:
# pos = graphviz_layout(graph, prog="dot")
# nx.draw_networkx(graph, pos, with_labels=True)
# plt.show()

In [8]:
list(graph.successors(0))

[]

In [9]:
graph.nodes[0]['sequence'] = [0, 1, 2, 3, 4]
graph.nodes[1]['sequence'] = [5, 6, 7, 8, 9]
graph.nodes[2]['sequence'] = [10, 11, 12, 13, 14]
graph.nodes[3]['sequence'] = [15, 0, 1, 2, 3]

sequence_length = 5
nucl_size = 16
nucls = list(range(nucl_size))

In [10]:
%load_ext autoreload
%autoreload 2

In [12]:
from Felselstein import Felselstein 

In [13]:
test = Felselstein(graph, sequence_length, nucls)

In [14]:
test.run()

In [15]:
test.graph.nodes[4]['likelihood_matrix']

tensor([[3.1379e-26, 5.5197e-16, 1.7704e-19, 3.8581e-28, 2.1139e-20, 1.4442e-25,
         1.6006e-19, 5.7140e-29, 9.7806e-23, 2.5996e-21, 1.7616e-18, 3.6165e-23,
         1.5151e-15, 9.6295e-18, 2.0195e-24, 3.0969e-29],
        [1.0652e-25, 3.4257e-16, 7.2174e-19, 9.3463e-28, 8.5255e-23, 3.0289e-26,
         6.6213e-20, 2.9564e-26, 4.8086e-24, 3.1534e-19, 2.4631e-18, 8.4204e-27,
         9.3299e-20, 1.7387e-18, 1.5806e-23, 2.4595e-26],
        [4.5621e-26, 1.0065e-14, 2.8741e-18, 5.4823e-27, 3.7304e-20, 3.2253e-23,
         3.6964e-19, 7.5509e-26, 6.6174e-22, 1.8446e-18, 1.6994e-18, 1.5582e-23,
         3.5976e-18, 3.6833e-17, 4.5547e-21, 1.5139e-28],
        [3.9084e-29, 1.8027e-18, 5.8238e-23, 4.2581e-31, 8.7350e-27, 3.1920e-28,
         7.2404e-24, 3.3465e-28, 8.6866e-28, 2.5654e-23, 5.6803e-22, 1.5248e-27,
         1.8443e-20, 3.3418e-19, 4.1809e-26, 1.9098e-30],
        [5.5474e-28, 7.0832e-21, 1.1156e-22, 5.7498e-28, 3.3947e-23, 4.8774e-28,
         4.2200e-23, 2.6570e-31, 9.1334

## TEST SUCCESSFUL

In [None]:
# nucl_to_ind_dict = {'A': 0, 'T': 1, 'G': 2, 'C':3}
# ind_to_nucl_dict = {value: key for key, value in nucl_to_ind_dict.items()}
# nucls = list(nucl_to_ind_dict.keys())


def initial_probs():
  return torch.ones(nucl_size) / nucl_size

def likelihood_prob(l_nucl, r_nucl, l_time, r_time):
  # s_l = substitution_matrix(0.01, l_time)
  # s_r = substitution_matrix(0.01, r_time)

  init_probs = initial_probs()
  likelihood = torch.zeros(1)

  # l_nucl_ind = nucl_to_ind_dict[l_nucl]
  # r_nucl_ind = nucl_to_ind_dict[r_nucl]

  for nucl in nucls:
    likelihood += init_probs[nucl] * T_M[nucl, l_nucl] * l_time * T_M[nucl, r_nucl] * r_time

  return likelihood



In [None]:
def full_log_likelihood(l_seq, r_seq, l_time, r_time):
  log_likelihood = torch.zeros(1)
  for nucl_left, nucl_right in zip(l_seq, r_seq):
    log_likelihood += torch.log(likelihood_prob(nucl_left, nucl_right, l_time, r_time))

  return log_likelihood

In [None]:
full_log_likelihood([1, 15, 12, 4], [1, 15, 11, 4], 0.005, 0.015)

In [None]:
full_log_likelihood([1, 15, 12, 4], [1, 15, 11, 3], 0.005, 0.015)

In [None]:
for node, data in graph.nodes(data=True):
  data['likelihood_matrix'] = torch.zeros((sequence_length, len(nucls)))

In [None]:
def Felselstein_step(graph, k_node, a_nucl, position_nucl):
  if len(list(graph.successors(k_node))) == 0:
    if graph.nodes[k_node]['sequence'][position_nucl] == a_nucl:
      graph.nodes[k_node]['likelihood_matrix'][position_nucl, a_nucl] = 1
    else:
      graph.nodes[k_node]['likelihood_matrix'][position_nucl, a_nucl] = 0

  else:
    graph.nodes[k_node]['likelihood_matrix'][position_nucl, a_nucl] = 0
    l_successor, r_successor = graph.successors(k_node)
    l_time = graph.edges[k_node, l_successor]['weight']
    r_time = graph.edges[k_node, r_successor]['weight']

    for l_nucl, r_nucl in product(nucls, nucls):
      likelihood = T_M[a_nucl, l_nucl] * l_time * \
                   graph.nodes[l_successor]['likelihood_matrix'][position_nucl, l_nucl] * \
                   T_M[a_nucl, r_nucl] * r_time * \
                   graph.nodes[r_successor]['likelihood_matrix'][position_nucl, r_nucl]

      graph.nodes[k_node]['likelihood_matrix'][position_nucl, a_nucl] += likelihood

def Felselstein(graph):
  leaf = [node for node in graph.nodes if len(list(graph.successors(node))) == 0][0]
  sequence_length = len(graph.nodes[leaf]['sequence'])
  head = [node for node in graph.nodes if len(list(graph.predecessors(node))) == 0][0]
  nodes_order = list(nx.bfs_layers(graph, head))[::-1]

  for seq_pos in range(sequence_length):
    for layer in nodes_order:
      for node in layer:
        for nucl in nucls:
          Felselstein_step(graph, node, nucl, seq_pos)

In [None]:
Felselstein(graph)

In [None]:
list(nx.bfs_layers(graph, 4))

In [None]:
graph.nodes[5]['likelihood_matrix']

In [None]:
def v (ind, nucl):
  ...

def w (ind_i, ind_j, nucl_i, nucl_j):
  ...


def g (x):
  ''' Formula 15 '''
  return 2 * x / (1 - np.exp(-2 * x))

def e (seq):
  ''' Formula 30 '''
  v_w_sum = 0
  for i in range(len(seq)):
    v_w_sum += v(i, seq[i])

    for j in range(i + 1, len(seq)):
      v_w_sum += w(i, j, seq[i], seq[j])

  return v_w_sum


def rate (nucl_1, nucl_2):
  ...

def p_rate(seq, ind, nucl):
  ''' Formula 29 '''
  seq_a = seq[:ind] + nucl + seq[ind + 1:]
  return rate(seq[ind], nucl) * g((e(seq_a) - e(seq)) / 2)


def transition_probability(seq_successors, seq_predecessors, length):
  ''' Formula 28 '''
  log_left_product = 0
  log_right_product = 0

  for i, (nucl_successor, nucl_predecessor) in enumerate(zip(seq_successors, seq_predecessors)):
    if nucl_successor != nucl_predecessor:
      log_left_product += np.log() + np.log(length)

