# Library imports

In [None]:
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
import itertools
import statistics
from newick import loads
import pprint
import math
import random 
import io
from io import StringIO
import sys
import re
import copy
from tqdm import tqdm
import warnings
from ete3 import Tree

# BioPython
from Bio import SeqIO, Phylo, AlignIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
from Bio.Phylo import draw
from Bio.Align import MultipleSeqAlignment
from Bio.Phylo.BaseTree import Clade

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font='Hiragino Maru Gothic Pro', context='notebook', style='white')
plt.figure(figsize=(20,12))

# Loading data

In [None]:
#Use sample data
input_data = "sample_seq/acc_hcov_pro_align_fix.fasta"

tree_check_file = [
    "sample_tree/hcov_pro_jcgap_root.txt",
    "sample_tree/hcov_pro_jccomp_root.txt",
    "sample_tree/hcov_pro_jcpair_root.txt",
    "sample_tree/hcov_pro_DAYHOFF_raxml_root.txt",
    "sample_tree/hcov_pro_JTTG_raxml_root.txt",
    "sample_tree/hcov_pro_poisson_bayes_root.txt",
    "sample_tree/hcov_pro_jttg_bayes_root.txt",
]


In [None]:
# Convert Newick tree format
def change_tree_format(newick_list):
    tree_list_new = []
    for newick in newick_list:
        with open(newick, 'r') as file:
            tree_tmp = file.read()
            
        # Cut species names at spaces
        otus = re.findall(r"'(.*?)'", tree_tmp)
        otus_rep = [s.split(" ")[0].replace("(","-").replace(")","-").replace(';','-').replace(',','-').replace("@","-") for s in otus]
        rep_dict = dict(zip(otus, otus_rep))
        for k,v in rep_dict.items():
            tree_tmp = tree_tmp.replace(k,v)
        tree_tmp = tree_tmp.replace("'","")
        tree_list_new.append(tree_tmp)
        
    return tree_list_new
tree_check_list = change_tree_format(tree_check_file)

# Convert to BioPython format
align_seq = AlignIO.read(input_data, "fasta")

# Display number of sequences and sequence length
print(f"Number of sequences: {len(align_seq)}")
print(f"Length: {len(align_seq[0].seq)}")

# Store sequence data for information content calculation
sequences_raw = {}
char_list = [] # Store occurring characters

for record in align_seq:
    # Replace characters that affect subsequent processing
    otu = str(record.id.replace(';','-').
              replace(',','-').
              replace("@","-").
              replace("(","-").
              replace(")","-").replace(' ','-'))    
    seq = str(record.seq)
    
    # Check types of characters present
    char_list += list(set(seq))
    
    # Store in dict
    sequences_raw[otu] = [seq[i:i + 1] for i in range(len(seq))]
    
# List of characters used
char_list = list(set(char_list))


# Decide internal nodes & Calculate amount of information

In [None]:
# Function to read Newick format tree and sequence, and get combinations of (parent, child, child) from leaves
def get_node_val(sequences, tree, df_nodedist):
    # Regular expression to extract innermost parentheses (A,B)
    cladeReg = '\(([^\(\)]+)\)'
    sequences_val = sequences.copy()
    mutual_info = []
    
    while re.search(cladeReg, tree):
        
        # Extract the most terminal pair
        match = re.search(cladeReg, tree)
        clade = match.group()
        
        leaf_A = clade.split(',')[0].replace('(','')
        leaf_B = clade.split(',')[1].replace(')','')
        
        if ':' in leaf_A:
            leaf_A = leaf_A.split(':')[0]
        if ':' in leaf_B:
            leaf_B = leaf_B.split(':')[0]
        
        if 'Inner' in leaf_A:
            leaf_A = leaf_A.split('Inner')[0]
        if 'Inner' in leaf_B:
            leaf_B = leaf_B.split('Inner')[0]
            
        
        # Create internal node name
        node_name = leaf_A + '@' + leaf_B
        
        # Get list of differences between current internal node and leaf nodes
        node_dist = df_nodedist.loc[node_name].to_dict()
        
        # Determine & save parent sequence
        sequences_val[node_name] = deside_parent_seq_lgs(sequences, node_dist)
        
        
        # Calculate mutual information between current internal node and its two children
        mutual_info_part = sum_mutual_info_normalized(sequences_val[leaf_A],
                                           sequences_val[leaf_B],
                                           sequences_val[node_name])
        
        # Save pair name & update tree
        tree = tree.replace(match.group(), node_name)
        
        
        mutual_info.append(mutual_info_part)
        
    return mutual_info

In [None]:
# Function to predict internal node sequences
def deside_parent_seq_lgs(sequences, node_dist):
        
    global char_list # List of characters appearing in the current dataset
    
    lgs = len(sequences[next(iter(sequences))]) # Sequence length
    
    # Optimize for each site
    parent_seq = []
    for i in range(lgs):
        
        site_dist_min = 1000000
        site_dist_min_char = ""
        
        for c in char_list: # Character for parent site
            # Sum weighted distances between current internal node and all leaf nodes
            d = 0
            
            for seq in sequences.keys():
                if sequences[seq][i] != c: # Count if mismatch
                    d += node_dist[seq]
                 
            if d < site_dist_min:
                site_dist_min = d
                site_dist_min_char = c
    
        parent_seq.append(site_dist_min_char)
        
    return parent_seq



In [None]:
# Function to assign names connected by @ to internal nodes
def assign_at_names(node):
    if node.is_leaf():
        return node.name
    else:
        child_names = []
        for child in node.children:
            child_name = assign_at_names(child)
            child_names.append(child_name)
        # Create name by connecting subtrees with @
        node.name = "@".join(child_names)
        return node.name
    
# Function to calculate the maximum distance between two leaf nodes in the phylogenetic tree (L_max)
def get_max_leaf_distance(tree):
    leaves = tree.get_leaves()
    max_distance = 0
   
    for leaf1 in leaves:
        for leaf2 in leaves:
            if leaf1 != leaf2:
                distance = tree.get_distance(leaf1, leaf2)
                if distance > max_distance:
                    max_distance = distance
   
    return max_distance
    
# Function to calculate distances between internal nodes and leaves, storing in a dataframe
def get_dist(newick_str):
    tree = Tree(newick_str, format=1)
    assign_at_names(tree)
    leaves = tree.get_leaves()
    leaf_names = [leaf.name for leaf in leaves]
    internal_nodes = [node for node in tree.traverse() if not node.is_leaf()]
    
    data = {}
    for internal_node in internal_nodes:
        distances = []
        for leaf in leaves:
            # Find MRCA (Most Recent Common Ancestor)
            mrca = tree.get_common_ancestor(internal_node, leaf)
            # Add the distance from MRCA to internal node and from MRCA to leaf
            distance = internal_node.get_distance(mrca) + leaf.get_distance(mrca)
            distances.append(distance)
        data[internal_node.name] = distances
    
    # Create dataframe
    df = pd.DataFrame(data, index=leaf_names).transpose()
    
    # Convert to weights
    L_max = get_max_leaf_distance(tree)
    df = L_max - df 
    
    return df

In [None]:
# Function to calculate the amount of information in (parent, child, child)
def sum_mutual_info_normalized(c1, c2, p):
    # Common length
    n = len(p)
    
    # Frequency counts
    nuc_count_xz = Counter([(a, c) for a, c in zip(c1, p)])
    nuc_count_yz = Counter([(b, c) for b, c in zip(c2, p)])    
    nuc_count_x = Counter(c1)
    nuc_count_y = Counter(c2)
    nuc_count_z = Counter(p)
    
    # Entropy of Z: S(Z)
    S_z = 0.0
    for z, count_z in nuc_count_z.items():
        P_z = count_z / n
        if P_z > 0:
            S_z -= P_z * math.log2(P_z)
            
    # Mutual information between X and Z: MI(X; Z)
    I_xz = 0.0
    for (x, z), count_xz in nuc_count_xz.items():
        P_xz = count_xz / n
        P_x = nuc_count_x[x] / n
        P_z = nuc_count_z[z] / n
        if P_xz > 0 and P_x > 0 and P_z > 0:
            I_xz += P_xz * math.log2(P_xz / (P_x * P_z))
        
    # Mutual information between Y and Z: MI(Y; Z)
    I_yz = 0.0
    for (y, z), count_yz in nuc_count_yz.items():
        P_yz = count_yz / n
        P_y = nuc_count_y[y] / n
        P_z = nuc_count_z[z] / n
        if P_yz > 0 and P_y > 0 and P_z > 0:
            I_yz += P_yz * math.log2(P_yz / (P_y * P_z))
            
    # Normalized values: MI(X; Z) / S(Z) and MI(Y; Z) / S(Z)
    normalized_I_xz = I_xz / S_z if S_z > 0 else 0
    normalized_I_yz = I_yz / S_z if S_z > 0 else 0
    
    return normalized_I_xz + normalized_I_yz

# Execution

In [None]:
for i, tree in enumerate(tree_check_list):
    # Display filenames
    print(tree_check_file[i])
    
    df_nodedist = get_dist(tree)
    res = get_node_val(sequences_raw, tree, df_nodedist)
    
    # Average mutual information (divide by 2 to get value per parent-child pair)
    res_mean = np.mean(res)/2
    
    # Display results
    print(round(res_mean, 4))
