In [3]:
#multiprocessing and io
import multiprocessing
import subprocess
import os
import pickle
import argparse

#numerics and vis
import math
import scipy.stats as stats
import numpy as np
import pandas as pd

#trees
import ete3
from ete3 import Tree, TreeStyle, TextFace


root = '/data/tobiassonva/data/eukgen/'
%cd {root}

import sys
sys.path.append(root)



/vf/users/tobiassonva/data/eukgen


In [9]:
#SEQUENCE ALIGNMENTS OPERATIONS


#helper IO functions for basic fasta reading into {id:seq} dict
#id taken as string up until first space character
def fasta_to_dict(fastastring=None, file=None):
    
    if file != None:
        with open(file, 'r') as fastafile:
            fastalines = fastafile.readlines()
    
    #trim everything after first space in line. Avoids pathologic cases of headers such as "NR_XXXX (abc-->cd)"
    for n, line in enumerate(fastalines):
        fastalines[n] = ''.join(line.split(' ')[0])
    
    fastastring = '\n'.join(fastalines)
    
    entries = [entry.strip() for entry in ''.join(fastastring).split('>') if entry != '']
    fastas = {entry.split('\n')[0]: ''.join(entry.split('\n')[1:]) for entry in entries}
    
    #replace unknown characters with A
    blacklist = set('BJUXZ')
    for key, seq in fastas.items():
        
        for char in set(seq):
            if char in blacklist:
                seq = seq.replace(char, 'A')
                print(f'WARNING: {key} Replaced illegal {char} with A')
                
        fastas[key] = seq
    return fastas

#returns simple single line fasta from {id:seq} dict
def dict_to_fasta(seq_dict, write_file=False):
    
    fasta_str =  '\n'.join(f'>{key}\n{value}' for key, value in seq_dict.items())
    
    if write_file != False:
        with open(write_file, 'w') as outfile:
            outfile.write(fasta_str)
        print(f'Wrote {write_file}')
            
    return fasta_str

#small helper for pkl parsing
def load_pkl(pkl_file):
    with open(pkl_file, 'rb') as infile:
        item = pickle.load(infile)
    return item

def dump_pkl(item, pkl_file):
    with open(pkl_file, 'wb') as outfile:
        pickle.dump(item, outfile)
    print(f'Pickled item as {pkl_file}')

#define the entropy for a string given amino acids, protein=True or, DNA protein=False
def column_entropy(string, protein=True, gaptoken='-'):
    
    size = len(string)
    counts = [string.count(i) for i in set(string).difference({gaptoken})]
    entropy = -sum([i/size*math.log2(i/size) for i in counts])
    
    if protein:
        entropy_uniform = math.log2(20)
    else:
        entropy_uniform = 2
        
    gap_entropy = entropy_uniform*(string.count(gaptoken)/size)
    information = entropy_uniform - entropy - gap_entropy
    
    return max(information,0)


#columnwise cut based on criteria
def filter_by_entropy(seq_dict, entropy_min, filter_accs=[], gaptoken='-'):
        
    #transpose seqs into columns
    cols = [''.join(seq) for seq in list(zip(*seq_dict.values()))]
    
    #filter only by columns present in filter_accs keys
    if filter_accs:
        filter_dict = {key:value for key, value in seq_dict.items() if key in filter_accs}
        filter_cols = [''.join(seq) for seq in list(zip(*filter_dict.values()))]
    
    else:
        filter_cols = cols
    
    #include based on entropy threshold
    filter_cols = [col for i, col in enumerate(cols) if column_entropy(filter_cols[i], gaptoken=gaptoken) > entropy_min]

    #transpose back to alignment
    filter_aln = [''.join(col) for col in list(zip(*filter_cols))]
    
    #return with original keys
    return {key: value for key, value in zip(seq_dict.keys(), filter_aln)}

#input helper to read tsv from file, merge singletons
def read_cluster_tsv(cluster_file, split_large=False, max_size=500, batch_single=False, single_cutoff=1):   

    #read TSV and group clusters based on first tsv column
    with open(cluster_file, 'r') as infile:
        clusters = {}

        for l in infile.readlines():
            cluster_acc, acc = l.strip().split('\t')

            if cluster_acc not in clusters.keys():
                clusters[cluster_acc] = [acc]

            else:
                clusters[cluster_acc].append(acc)
    
    #merge all clusters smaller than cutoff into one
    if batch_single:
        filter_dict = {}
        singles = []
        for key, accs in clusters.items():  
            #gather singletons
            if len(accs) <= single_cutoff:
                singles.extend(accs)
            
            #keep larger clusters
            else:
                filter_dict[key] = accs
                
        if singles:
            filter_dict[singles[0]] = singles
        
        clusters = filter_dict
    
    #split clusters larger than x into smaller pieces
    if split_large:
        filter_dict = {}
        for key, accs in clusters.items():  
            if len(accs) > max_size:
                #partition large clusters into batches of max_size
                for split in range(0, len(accs), max_size):
                    batch = accs[split:split + max_size]
                    filter_dict[batch[0]] = batch

            #keep smaller clusters
            else:
                filter_dict[key] = accs
                
        clusters = filter_dict         
        
    return clusters

In [10]:
#TREE OPERATIONS

#quick function for adding phylogenetic annotation to tree labels
def dirty_phyla_add(tree, tax_mapping):
    euk_entries = []
    for leaf in tree.get_leaves():
        try:
            acc = leaf.name
            entry = tax_mapping.loc[acc]
            leaf.add_feature('tax_superkingdom', entry['superkingdom'])
            
            leaf.add_feature('tax_class', entry['class'])
            
            if leaf.tax_superkingdom == 'Eukaryota':
                leaf.add_feature('tax_filter', 'Eukaryota')
            else:
                leaf.add_feature('tax_filter', entry['class'])
            
        except KeyError:
            leaf.add_feature('superkingom', 'ERROR')
            leaf.add_feature('class', 'ERROR')


#return closest non self leaf
def get_closest_leaf(leaf):
    
    near_leaves = [near_leaf for near_leaf in leaf.up.get_leaves() if near_leaf!=leaf]
    distances = [leaf.get_distance(near_leaf) for near_leaf in near_leaves]
    min_dist =min(distances)
    closest_leaf = near_leaves[distances.index(min_dist)]
    
    return closest_leaf, min_dist
    
#iteratively merge closest leaf pair until less than N leaves
def crop_leaves_to_size(tree, max_size):
    current_size = len(tree.get_leaves())
    
    if max_size >= current_size:
        print(f'Tree of length {current_size} smaller than {max_size}')
        return tree    
    
    leaf_partners = {leaf.name:get_closest_leaf(leaf) for leaf in tree.get_leaves()}
    leaf_partner_dist = {key:value[1] for key, value in leaf_partners.items()}
    leaf_partners = {key:value[0].name for key, value in leaf_partners.items()}
    
    #print(leaf_partner_dist)
    #print(leaf_partners)
    
    
    #serial implementetaion, not ideal as it produced uneven pruning if terminal brach length are very even.
    while max_size < current_size:
        
        min_distance = min(leaf_partner_dist.values())
        min_leaf_A = list(leaf_partner_dist.keys())[list(leaf_partner_dist.values()).index(min_distance)]
        
        min_leaf_B = leaf_partners[min_leaf_A]
        
        #print(f'{current_size} checking {min_leaf_A}, deleting closest partner is {min_leaf_B} with distance {min_distance}')
        
        #delete closest leaf
        try:
            tree.get_leaves_by_name(min_leaf_B)[0].delete()
        
        #if current leaf A maps to a deleted leaf update closest leaf and delete
        except IndexError:
            #update the min_leaf with new closest pair
            new_leaf = tree.get_leaves_by_name(min_leaf_A)[0]
            closest_new_leaf = get_closest_leaf(new_leaf)

            leaf_partner_dist[new_leaf.name] = closest_new_leaf[1]
            leaf_partners[new_leaf.name] = closest_new_leaf[0].name
            min_leaf_B = leaf_partners[min_leaf_A]
            
            tree.get_leaves_by_name(min_leaf_B)[0].delete()
            
        
        #delete the removed partner from dictionaries
        leaf_partner_dist.pop(min_leaf_B, None)
        leaf_partners.pop(min_leaf_B, None)
        
        #update the min_leaf with new closest pair
        new_leaf = tree.get_leaves_by_name(min_leaf_A)[0]
        closest_new_leaf = get_closest_leaf(new_leaf)
        
        leaf_partner_dist[new_leaf.name] = closest_new_leaf[1]
        leaf_partners[new_leaf.name] = closest_new_leaf[0].name
        
        current_size -= 1
        
        #print(leaf_partner_dist)
        #print(leaf_partners)

    
    return tree


#fit inverse gamma distribution to all internal node distances and 
#exclude those beyong threshold probability
def get_outlier_nodes_by_invgamma(tree, p_low=0, p_high=0.99, only_leaves=False):
    
    if only_leaves:
        node_dists = [(node, node.dist) for node in tree.get_leaves()]
    
    else:
        node_dists = [(node, node.dist) for node in tree.traverse()]
    
    dist_series = pd.Series([i[1] for i in node_dists])
    
    dist_stats = dist_series.describe()
    
    fit_alpha, fit_loc, fit_beta=stats.invgamma.fit(dist_series.values)

    cutoff_high = stats.invgamma.ppf(p_high, a=fit_alpha, loc=fit_loc, scale=fit_beta)
    cutoff_low = stats.invgamma.ppf(p_low, a=fit_alpha, loc=fit_loc, scale=fit_beta)
    
    outlier_nodes = [node[0] for node in node_dists if node[1] < cutoff_low or node[1] > cutoff_high]
    print(f'Identified {len(outlier_nodes)} outlier nodes outside interval {cutoff_low} > d > {cutoff_high}')
    
    return outlier_nodes

#starting from one leaf with an attribute traverse upwards untill 
#all leaves from the ancestor is no loger monophyletic under the given attribute
#repeat for all remaining leaves
#if any clade would have the global root as ancestor rerooot and retry to avoid 
#false paraphyly created by tree data struture
def get_paraphyletic_groups(tree, attribute='tax_superkingdom', attr_value='Eukaryota', current_root=False):
    
    #tree.set_outgroup(tree.get_farthest_leaf()[0])
    
    if current_root:
        tree.set_outgroup(current_root)
    else:
        current_root = tree.get_tree_root()
    
    #get a list of all leaves with an attribute matching the match value provided
    check_leaves = [leaf for leaf in tree.get_leaves() if getattr(leaf, attribute)==attr_value]
    clade_nodes = []
    
    seed_node = check_leaves[0]    

    while check_leaves:
        #assume monophyly
        mono = True
        
        #check for all parent leaves if attribute matches the value, if not its not monophyletic, break 
        for leaf in seed_node.up.get_leaves():
            if getattr(leaf, attribute)!=attr_value:
                mono = False
                break
        
        #if monophyletic try higher node 
        if mono:
            seed_node = seed_node.up
        
        #else retrun node and exclude all leaves from list of leaves to check
        else:
            clade_nodes.append(seed_node)
            check_leaves = [leaf for leaf in check_leaves if leaf not in seed_node.get_leaves()]
            if check_leaves:
                seed_node = check_leaves[0] 
    
    #if parent has no parent it is the root
    if [node for node in clade_nodes if node.up.up == None]:
        #print('A tree clade has rooted parent nodes, rerooting')
        
        #get the first non-clade daughter from current root 
        non_clade_daughter = [node for node in current_root.children if node not in clade_nodes][0] 
        
        return get_paraphyletic_groups(tree, attribute=attribute, attr_value=attr_value, current_root=non_clade_daughter)

    else:
        return clade_nodes


#return the entropy of decendant and non decendant leaf labels 
def get_entropy_for_partition(tree, node, attribute='tax_filter', attr_value='Eukaryota'):
    
    all_labels = [getattr(leaf, attribute) for leaf in tree.get_leaves()]
    all_label_count = all_labels.count(attr_value)
    tree_width = len(all_labels)
    
#     base_label_Px = (all_label_count/tree_width)
#     base_label_H = base_label_Px*np.log2(1/base_label_Px)
    
    clade_labels = [getattr(leaf, attribute) for leaf in node.get_leaves()]
    clade_label_count = clade_labels.count(attr_value)
    clade_width = len(clade_labels)
    
    #calculate label entropy
    #print('AAA', clade_labels, clade_label_count, clade_width)
    label_Px = (clade_label_count)/(clade_width)
    label_H = label_Px*np.log2(1/label_Px)
    
    #calculate external entropy change
    
    # if all labels in the clade the external entropy is 0
    if all_label_count-clade_label_count == 0:
        external_label_H = 0

    else:
        #calculate external entropy change
        external_label_Px = (all_label_count-clade_label_count)/(tree_width-clade_width)
        external_label_H = external_label_Px*np.log2(1/external_label_Px)


    return label_H, external_label_H
    
#assign soft LCA node based on minimizing entropy between given label outside and inside clade
#more pessimissive than voting ratio, qualitatively underestimates
def get_soft_LCA_by_relative_entropy(tree, attribute='tax_superkingdom', attr_value='Eukaryota', save_loss=False):
    #count vote ratio for each node for one taxa
    tree_width = len(tree.get_leaves())
    root = tree.get_tree_root()
    
    lowest_total_H = float("inf")
    best_node = root
    vote_label = attr_value
    
    all_labels = [getattr(leaf, attribute) for leaf in tree.get_leaves()]
    all_label_count = all_labels.count(attr_value)
    tree_width = len(all_labels)
    
    all_label_Px = (all_label_count/tree_width)
    all_labels_H = all_label_Px*np.log2(1/all_label_Px)
    
    #for debugging
    if save_loss:
        for node in tree.traverse():
            node.add_feature('vote_loss', 'None')
    
    #check for monophyly
    LCA_groups = get_paraphyletic_groups(tree, attribute=attribute, attr_value=attr_value)
    if len(LCA_groups) == 1:
        print(f'The attribute {attribute} is monophyletic for {attr_value}. Returning LCA node.')
        return (LCA_groups[0], 0)


    #the best partition will be on the path from an LCA node to the root
    tested_nodes = []
    for node in LCA_groups:
        
        node_label_count = len([leaf for leaf in node.get_leaves() if getattr(leaf, attribute) == attr_value])

        #ascend until all labeled leaves are decendants of node 
        while node != root:
           
            #skip known nodes
            if node not in tested_nodes:
                
                #calculate internal and external entropy
                label_H, external_label_H =  get_entropy_for_partition(tree, node, attribute=attribute, attr_value=attr_value)
                total_H = label_H + external_label_H
                    
                #penalize leaf LCAs to avoid laddered LCAs when having repeated outgroups
                #not neccesarily waned as spread singleons get their global LCA as soft_LCA
                if node.is_leaf():
                    total_H += 0.5
                
                #update best guess
                if total_H < lowest_total_H:
                    lowest_total_H = total_H
                    best_node = node
                
                if save_loss:
                    node.add_feature('vote_loss', total_H)
            
            #break after calculations if all nodes are decendants
            node_label_count = len([leaf for leaf in node.get_leaves() if getattr(leaf, attribute) == attr_value])
            if node_label_count == all_label_count:
                break
            #print(node_label_count, all_label_count)
            
            #ascend
            tested_nodes.append(node)
            node = node.up
            
            
    #print(f'Best node for {attr_value} has a total H of {lowest_total_H}')
    return (best_node, lowest_total_H)

def get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value='Eukaryota', min_size=1, min_purity=0, max_entropy=9999):
    
    print(f'Searching for LCA_nodes by checking where {attribute} is {attr_value}')
        
    soft_LCA_nodes = []
    total_nodes = len([getattr(node, attribute) for node in tree.get_leaves() if getattr(node, attribute) == attr_value])    
    
    while total_nodes > 0:
        
        soft_LCA_node, lowest_H_loss = get_soft_LCA_by_relative_entropy(tree, attribute=attribute, attr_value=attr_value)
        

        soft_LCA_node_leaves = soft_LCA_node.get_leaves()
        soft_LCA_node_size = len(soft_LCA_node_leaves)
        
        #mark labeled included nodes and decrement total nodes left to check
        for node in soft_LCA_node_leaves:
            if getattr(node, attribute) == attr_value:
                setattr(node, attribute, 'SAMPLED')
                total_nodes -= 1
        
            
        soft_LCA_nodes.append(soft_LCA_node)
        
    #reset leaf node attributes
    for node in tree.get_leaves():
        if getattr(node, attribute) == 'SAMPLED':
            setattr(node, attribute, attr_value)        
    
    #recalculate entropies
    print(f'\tRecalculating sizes, purities and entropies for all LCA nodes')
    filtered_soft_LCA_nodes = []
    
    for i, node in enumerate(soft_LCA_nodes):
        label_H, external_label_H = get_entropy_for_partition(tree, node, attribute=attribute, attr_value=attr_value)
        
        soft_LCA_members = [getattr(leaf, attribute) for leaf in node.get_leaves()]
        soft_LCA_size = len(soft_LCA_members)
        soft_LCA_purity = soft_LCA_members.count(attr_value) / soft_LCA_size
        soft_LCA_entropy = label_H + external_label_H
        
        if soft_LCA_size >= min_size and soft_LCA_purity >= min_purity and soft_LCA_entropy <= max_entropy:
            print(f'\tFound node of size {soft_LCA_size} with purity of {soft_LCA_purity} and entropy {soft_LCA_entropy} as LCA for {attr_value}')
            filtered_soft_LCA_nodes.append([node, soft_LCA_size, soft_LCA_purity, soft_LCA_entropy])
        else:
            print(f'\tRejected node of size {soft_LCA_size} with purity of {soft_LCA_purity} and entropy {soft_LCA_entropy} as LCA for {attr_value}')
    
    if len(filtered_soft_LCA_nodes) == 0:
        print(f'WARNING: No valid LCA nodes present for {attribute} = {attr_value} under conditions min_size={min_size}, min_purity={min_purity}, max_entropy={max_entropy}') 
    
    print()
    
    return filtered_soft_LCA_nodes

In [98]:
#wrapper function for target reclustering which required a temporary accession file to be written
def mmseqs_run_target(target, members, root, seq_DB, threads=1):
    
    #boilerplate for informative print()
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {target}:'
    print(f'{threadID_string} Preparing mmseqs data for target {target}\n', end='')
    
    basename = f'{root+target}'
    
    #write all accessstions to temporary file
    with open(basename, 'w') as outfile:
        outfile.writelines([acc+'\n' for acc in members])
    
    #create a temporary seqDB, extract fastas, cluster and calculate cluster.tsv file
    os.system(f'mmseqs createsubdb -v 0 --id-mode 1 {basename} {seq_DB} {basename}.DB')
    os.system(f'mmseqs convert2fasta -v 0 {basename}.DB {basename}.fasta')
    os.system(f'mmseqs cluster -v 0 --remove-tmp-files 1 --threads {threads} -s 7.5 {basename}.DB {basename}.cluster {root}tmp/{target}')
    os.system(f'mmseqs createtsv -v 0 {seq_DB} {seq_DB} {basename}.cluster {basename}.cluster.tsv')
    
    #clean 
    os.system(f'mmseqs rmdb {basename}.DB -v 0')
    os.system(f'mmseqs rmdb {basename}.cluster -v 0')
    os.system(f'rm {basename}')
    
    return

#main script for formatting query and target .fasta and .cluster.tsv files
#should be updated to subprocess.run() as well as matching style of latter processing"
def microcosm_prepare_mmseqs(query, query_DB, target_DB, root):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    query_root = root+query+'/'
    basename = query_root+query
    threadID_string = f'{thread} | {query}:'
    os.system(f'mkdir {query_root}/tmp')
    
    #parse target clusters
    target_clusters = read_cluster_tsv(f'{basename}.targets')  

    print(f'{threadID_string} Started \n', end='')
    print(f'{threadID_string} Preparing mmseqs data for query\n', end='')

    #createa new DB for the query sequences and cluster it
    os.system(f"mmseqs createsubdb -v 0 --id-mode 1 --subdb-mode 1 {basename}.acc {query_DB} {basename}.DB")
    os.system(f'mmseqs convert2fasta -v 0 {basename}.DB {basename}.fasta')
    
    print(f'{threadID_string} Preparing mmseqs data for merged target hits\n', end='')

    #create a subDB and extract sequences, create a new seqDB in order to recreate the header lookup
    #otherwise each query using header info searches the entire original header lookup
    with open(f'{basename}.members', 'w') as outfile:
        for members in target_clusters.values():
            outfile.writelines([member+'\n' for member in members])
    
    os.system(f"mmseqs createsubdb -v 0 --id-mode 1 --subdb-mode 1 {basename}.members {target_DB} {basename}.members.DB")
    os.system(f"mmseqs convert2fasta -v 0 {basename}.members.DB {basename}.members.fasta")
    
    #clean
    os.system(f'mmseqs rmdb -v 0 {basename}.DB')
    os.system(f'mmseqs rmdb -v 0 {basename}.cluster')
    os.system(f'mmseqs rmdb -v 0 {basename}.members.DB')
    os.system(f'find {query_root} -maxdepth 1 -type l -exec unlink {{}} \;')
    os.system(f'rm -r *members* {query_root}/tmp')
    

In [12]:
#run command for generating a diversified ensemble with muscle and then extracting the maxcc aln
def microcosm_muscle_ensamble(base_fasta, threads, muscle_reps, muscle5_exe, super5=False):
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {base_fasta}:'
    
    os.environ['OMP_NUM_THREADS'] = str(threads)
    
    logfile = open(f'{base_fasta}.muscle.log', 'a')
    
    if super5:
        for rep in range(muscle_reps):
            print(f'aligning {rep}')
            muscle5_align_ete_params = f" -super5 {base_fasta} -output {base_fasta}.@.super5-tmp -perm all -perturb {rep} -threads {threads}"
            muscle5_ete_command = muscle5_exe+muscle5_align_ete_params
            subprocess.run(muscle5_ete_command.split(), stdout=logfile, stderr=logfile)

        root = '/'.join(base_fasta.split('/')[:-1])+'/'

        efa_files = [file for file in os.listdir(root) if file.endswith('super5-tmp')]

        with open(base_fasta+'.muscle-efa', 'w') as efa_merge:
            for file in efa_files:
                with open(root+file, 'r') as efa_in:
                    efa_merge.write(f'<{file}\n')
                    efa_merge.write(efa_in.read())
                subprocess.run(f'rm {root+file}'.split())


        #extract maximum CC alignment
        muscle5_maxcc_params = f' -maxcc {base_fasta}.muscle-efa -output {base_fasta}.muscle'
        muscle5_command = muscle5_exe+muscle5_maxcc_params
        subprocess.run(muscle5_command.split(), stdout=logfile, stderr=logfile)

        logfile.close()
        
        return f'{base_fasta}.muscle'

    else:
        #run diversified ensemble
        muscle5_align_ete_params = f" -threads {threads} -diversified -replicates {muscle_reps} -align {base_fasta} -output {base_fasta}.muscle-efa"
        muscle5_ete_command = muscle5_exe+muscle5_align_ete_params
        subprocess.run(muscle5_ete_command.split(), stdout=logfile, stderr=logfile)

        #extract maximum CC alignment
        muscle5_maxcc_params = f' -maxcc {base_fasta}.muscle-efa -output {base_fasta}.muscle'
        muscle5_command = muscle5_exe+muscle5_maxcc_params
        subprocess.run(muscle5_command.split(), stdout=logfile, stderr=logfile)
        
        logfile.close()
        
        return f'{base_fasta}.muscle'



#read fasta, align with FAMSA, filter and construct FastTree,
#crop eaves to size and write new fasta with only cropeed tree leaf sequences
def microcosm_reduce_size(base_fasta, threads, max_leaf_size, filter_entropy, famsa_exe, fasttree_exe):
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {base_fasta}:'

    #threads for FastTree
    os.environ['OMP_NUM_THREADS'] = str(threads)
    
    famsa_logfile = open(f'{base_fasta}.famsa.log', 'a')
    fasttree_logfile = open(f'{base_fasta}.famsa.log', 'a')
    
    #align seqs
    print(threadID_string+' Aligning FAMSA')

    famsa_command = famsa_exe+f' -t {threads} {base_fasta} {base_fasta}.famsa'
    subprocess.run(famsa_command.split(), stdout=famsa_logfile, stderr=famsa_logfile)

    #backup original alignment
    subprocess.run(f'cp {base_fasta}.famsa {base_fasta}.famsa.b'.split())

    #filter euk by entropy
    print(threadID_string+' Filtering by entropy')

    aln = fasta_to_dict(file=f'{base_fasta}.famsa')
    aln_filter = filter_by_entropy(aln, filter_entropy)
    dict_to_fasta(aln_filter, write_file=f'{base_fasta}.famsa')


    #construct a FastTree
    print(threadID_string+' Constructing FastTree')

    fasttree_command = fasttree_exe+f" -gamma -out {base_fasta}.fasttree {base_fasta}.famsa"
    subprocess.run(fasttree_command.split(), stdout=fasttree_logfile , stderr=fasttree_logfile )


    #reduce leaves to size
    print(threadID_string+' Cropping Leaves')

    tree = Tree(f'{base_fasta}.fasttree')
    reduce_leaves_to_size(tree, max_leaf_size)

    #write cropped fasta
    cropped_leaves = tree.get_leaf_names()
    cropped_aln = {key:value.replace('-', '') for key, value in aln.items() if key in cropped_leaves}

    #replace original fasta with cropped version and save backup

    subprocess.run(f'cp {base_fasta} {base_fasta}.uncropped'.split())
    dict_to_fasta(cropped_aln, write_file=base_fasta)

    famsa_logfile.close()
    fasttree_logfile.close()
    
    return base_fasta


def microcosm_realign_and_filter(query, root, threads=1, max_euk_leaf_size=200, max_prok_leaf_size=1300,
                                filter_entropy=0.5, muscle_reps_euk=25, muscle_reps_prok=5):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {query}:'
    
    query_root = root+query+'/'
    euk_fasta = query_root+query+'.fasta'
    prok_fasta = query_root+query+'.members.fasta'
    
    famsa_exe = '/data/tobiassonva/data/software/FAMSA-2.0.1/famsa'
    fasttree_exe = 'FastTree'
    muscle5_exe = 'muscle'


    
    #for euk
    euk_seqs = fasta_to_dict(file=euk_fasta)
    euk_size = len(euk_seqs.keys())
    
    #if there are too many eukaryotic sequences, crop to size
    if euk_size > max_euk_leaf_size:
        print(threadID_string+f' There are more than {max_euk_leaf_size} sequences in {euk_fasta} ({euk_size}), will crop to size')    
        euk_fasta = microcosm_reduce_size(euk_fasta, threads, max_euk_leaf_size, filter_entropy, famsa_exe, fasttree_exe)

    else: 
        print(threadID_string+f' There are less than {max_euk_leaf_size} sequences in {euk_fasta} ({euk_size}), no cropping needed')    
        
    #align using muscle
    print(threadID_string+f' Aligning with muscle5 as ensemble with {muscle_reps_euk} replicates')
    euk_muscle = microcosm_muscle_ensamble(euk_fasta, threads, muscle_reps_euk, muscle5_exe, super5=False)
    
    
    #for prok
    prok_seqs = fasta_to_dict(file=prok_fasta)
    prok_size = len(prok_seqs.keys())
    
    #if there are too many prokaryotic sequences, crop to size
    if prok_size > max_prok_leaf_size:
        print(threadID_string+f' There are more than {max_prok_leaf_size} sequences in {prok_fasta}({prok_size}), will crop to size')    
        prok_fasta = microcosm_reduce_size(prok_fasta, threads, max_prok_leaf_size, filter_entropy, famsa_exe, fasttree_exe)

    else: 
        print(threadID_string+f' There are less than {max_prok_leaf_size} sequences in {prok_fasta}({prok_size}), no cropping needed')    
    
    #align using muscle
    print(threadID_string+f' Aligning with muscle5 as ensemble with {muscle_reps_prok} replicates')
    prok_muscle = microcosm_muscle_ensamble(prok_fasta, threads, muscle_reps_prok, muscle5_exe, super5=False)
    
    
    #filter both cropped alignments by columnwise bitscore
    print(threadID_string+f' Filtering alignments to columnwise bitscore > {filter_entropy}')

    #backup original alignments
    subprocess.run(f'cp {euk_fasta}.muscle {euk_fasta}.muscle.b'.split())
    subprocess.run(f'cp {prok_fasta}.muscle {prok_fasta}.muscle.b'.split())    
    
    aln = fasta_to_dict(file=f'{euk_fasta}.muscle')
    aln_filter = filter_by_entropy(aln, filter_entropy)
    dict_to_fasta(aln_filter, write_file=f'{euk_fasta}.muscle')
    
    aln = fasta_to_dict(file=f'{prok_fasta}.muscle')
    aln_filter = filter_by_entropy(aln, filter_entropy)
    dict_to_fasta(aln_filter, write_file=f'{prok_fasta}.muscle')
    
    return f'{euk_fasta}.muscle'


def microcosm_merge_align_tree(query, root, threads=1, filter_entropy=0.5, muscle_reps=25):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {query}:'
    
    query_root = root+query+'/'
    euk_muscle = query_root+query+'.fasta.muscle'
    prok_muscle = query_root+query+'.members.fasta.muscle'
    
    merged_fasta = query_root+query+'.merged.fasta'
    
    muscle5_exe = 'muscle'   
    iqtree_exe = 'iqtree2'
    
    evo_model_params = '-m MFP -mset LG,Q.pfam --cmin 4 --cmax 12'
    
    iqtree_logfile = open(f'{query}.iqtree.log', 'a')

    #filter both cropped alignments by columnwise bitscore
    print(threadID_string+f' Merging alignments {euk_muscle} and {prok_muscle} to .merged.fasta') 
    
    #merge aligned and filtered pro and euk
    with open(merged_fasta, 'w') as merged:
        subprocess.run(f'cat {euk_muscle} {prok_muscle}'.split(), stdout=merged)
    
    #realign merged using muscle
    print(threadID_string+f'  Realigning with muscle5 as ensemble with {muscle_reps} replicates')
    merged_muscle = microcosm_muscle_ensamble(merged_fasta, threads, muscle_reps, muscle5_exe, super5=False)

    
    #filter alignments by columnwise bitscore
    print(threadID_string+f' Filtering alignment to columnwise bitscore > {filter_entropy}')

    #backup original alignments
    subprocess.run(f'cp {merged_muscle} {merged_muscle}.b'.split())
    
    aln = fasta_to_dict(file=f'{merged_muscle}')
    aln_filter = filter_by_entropy(aln, filter_entropy)
    dict_to_fasta(aln_filter, write_file=f'{merged_muscle}')
    
    
    #construct IQtree
    print(threadID_string+f' Constructing IQTree for {merged_muscle}')
    iqtree_command = f'{iqtree_exe} -s {merged_muscle} {evo_model_params} --threads {threads} -B 1000'
    subprocess.run(iqtree_command.split(), stdout=iqtree_logfile, stderr=iqtree_logfile)
    
    iqtree_logfile.close()
    
    return
    
    
    


In [1]:
def microcosm_tree_analysis(query, root, max_tree_leaves=1500, 
                            outlier_inv_gamma_low = 0,
                            outlier_inv_gamma_high = 0.999,
                            prok_min_size = 2, prok_min_purity = 0.5, euk_min_size = 3, euk_min_purity = 0.76):
    
    #configure paths and flags
    thread = multiprocessing.current_process().pid
    threadID_string = f'{thread} | {query}:'
    
    query_root = root+query+'/'
    merged_tree = query_root+query+'.merged.fasta.muscle.contree'
    
    
    #filter alignments by columnwise bitscore
    print(threadID_string+f' Reading consensus tree from {merged_tree}')
    
    #load parsed taxonomy data
    #standardize and reformat tax data
    prok_tax = load_pkl('analysis/core_data/prok2111_protein_taxonomy_trimmed.pkl')
    euk_tax = load_pkl('analysis/core_data/euk72_protein_taxonomy.pkl')
    euk_tax.drop(['orgid', 'species'], axis=1, inplace=True)
    tax_merge = pd.concat([euk_tax, prok_tax])
    euk_header = load_pkl('analysis/core_data/euk72_header_mapping.pkl')
    
    #initialize tree
    tree = Tree(merged_tree)
    tree_name = merged_tree.split('/')[1]
    tree_header = euk_header[euk_header.acc==tree_name].header.values
    
    dirty_phyla_add(tree, tax_merge)
    
    #merge leaf pairs until total amount of leaves is smaller than x
    tree = reduce_leaves_to_size(tree, max_tree_leaves)
    
    #calculate devaiting branch distances
    print(threadID_string+f' Identifying outlier nodes by branch inverse gamma distribution')
    outlier_nodes =  get_outlier_nodes_by_invgamma(tree, p_low=outlier_inv_gamma_low, p_high=outlier_inv_gamma_high, only_leaves=False)
    #cut_nodes = [node.detach() for node in outlier_nodes]
    
    
    #calculate soft LCA nodes for prok and euk using partition entropy 
    print(threadID_string+f' Evaluationg soft LCAs from {merged_tree}')
    
    filter_taxa = set([leaf.tax_filter for leaf in tree.get_leaves()])
    hard_LCA_dict = {}
    soft_LCA_dict = {}

    for tax in filter_taxa:
        hard_LCA_dict[tax] = get_paraphyletic_groups(tree, attribute='tax_filter', attr_value=tax)
        soft_LCA_dict[tax] = get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value=tax,
                                                   min_size=prok_min_size, min_purity=prok_min_purity)
    
    valid_prok_LCAs = [key for key, value in soft_LCA_dict.items() if value != []]

    soft_LCA_dict['Eukaryota'] = get_multiple_soft_LCAs(tree, attribute='tax_filter', attr_value='Eukaryota',
                                                        min_size=euk_min_size, min_purity=euk_min_purity)

    print(threadID_string+f' Found valid soft LCAs for a total of {len(valid_prok_LCAs)} out of a possible {len(filter_taxa)-1} taxons')
    print(threadID_string+f' Found {len(soft_LCA_dict["Eukaryota"])} valid LCAs for Eukaryota')
    
    #add LCA nodes to list for NODE visualisation
    hard_LCA_nodes = [node for LCA_nodes in hard_LCA_dict.values() for node in LCA_nodes]
    #all first prok LCAs
    soft_LCA_nodes = [node[0][0] for node in soft_LCA_dict.values() if node != []]
    #add all euk LCAs
    soft_LCA_nodes = [node[0] for node in soft_LCA_dict['Eukaryota'] if node != []]

    #add LCA labels for LABEL visualisation
    for node in tree.traverse():
        node.add_feature('soft_LCA', '')
        node.add_feature('soft_LCA_H', '')

    #add the first valid soft_LCA as LCA for prok and all valid soft_LCAs for euk
    for tax, nodes in soft_LCA_dict.items():
        if nodes != []:
            #add all euk nodes
            if tax == 'Eukaryota':
                for node in nodes:
                    node[0].soft_LCA = tax
                    node[0].soft_LCA_H = node[3]                


            #add first prok node
            else:
                nodes[0][0].soft_LCA = tax
                nodes[0][0].soft_LCA_H = nodes[0][3]

        #skip empty
        else:
            pass
    
    #--- TREE PRINTING ---
    
    #define overall tree styling
    ts = TreeStyle()
    ts.title.add_face(TextFace(tree_header, fsize=8), column=0)
    ts.mode = 'r'
    ts.show_leaf_name = False
    ts.show_branch_length = False
    ts.show_branch_support = False
    #ts.optimal_scale_level = 'full'
    ts.allow_face_overlap = True
    ts.scale = 50

    default_node_style = ete3.NodeStyle()
    default_node_style['size'] = 0
    default_node_style['fgcolor'] = 'Black'

    default_leaf_style = ete3.NodeStyle()
    default_leaf_style['size'] = 0
    default_leaf_style['fgcolor'] = 'Black'

    arc_node_style = ete3.NodeStyle()
    arc_node_style['size'] = 2
    arc_node_style['fgcolor'] = 'Cyan'

    euk_node_style = ete3.NodeStyle()
    euk_node_style['size'] = 2
    euk_node_style['fgcolor'] = 'Green'

    outlier_node_style = ete3.NodeStyle()
    outlier_node_style['size'] = 3
    outlier_node_style['fgcolor'] = 'Red'

    hard_LCA_node_style = ete3.NodeStyle()
    hard_LCA_node_style['size'] = 2
    hard_LCA_node_style['fgcolor'] = 'Gray'

    soft_LCA_node_style = ete3.NodeStyle()
    soft_LCA_node_style['size'] = 4
    soft_LCA_node_style['fgcolor'] = 'Black'

    soft_euk_LCA_node_style = ete3.NodeStyle()
    soft_euk_LCA_node_style['size'] = 4
    soft_euk_LCA_node_style['fgcolor'] = 'Gray'

    #set styling for all leaves and internal nodes
    for node in tree.traverse():
        node.set_style(default_node_style)
        node.add_face(TextFace(node.soft_LCA, fsize=4), column=0)
        node.add_face(TextFace(node.soft_LCA_H, fsize=4), column=0)

        if node.is_leaf():
            node.add_face(TextFace(node.name, fsize=4), column=1)
            node.add_face(TextFace(' '+node.tax_class,  fsize=4), column=2)
            node.set_style(default_leaf_style)

            if node.tax_superkingdom == 'Eukaryota':
                node.set_style(euk_node_style)

            elif node.tax_superkingdom == 'Archaea':
                node.set_style(arc_node_style)

        if node in hard_LCA_nodes:
            node.set_style(hard_LCA_node_style)

        if node in soft_LCA_nodes:
            node.set_style(soft_LCA_node_style)

        if node in outlier_nodes:
            node.set_style(outlier_node_style)
            
    #for more consistent visualisation
    tree.ladderize()
    tree.write(features=["name", 'soft_LCA', "tax_filter", "tax_superkingdom"], outfile=merged_tree+'.annot')
    #tree.render(merged_tree+'.annot.pdf', tree_style=ts)
    #view_tree(tree, ts, backup=file+'.pdf')
    
    return merged_tree+'.annot'

In [14]:
#prepare all mmseqs
def microcosm_run(query, root, threads):

    microcosm_prepare_mmseqs(query, 
                             query_DB='euk72/euk72',
                             target_DB='prok2111/prok2111',
                             root=root)

    microcosm_realign_and_filter(query, 
                                 root, 
                                 threads=threads,
                                 max_euk_leaf_size=50,
                                 max_prok_leaf_size=250,
                                 filter_entropy=0.5,
                                 muscle_reps_euk=25, 
                                 muscle_reps_prok=10)
    
    microcosm_merge_align_tree(query, 
                               root,
                               threads=threads, 
                               filter_entropy=0.5, 
                               muscle_reps=10)

    microcosm_tree_analysis(query, 
                            root, 
                            max_tree_leaves=1500, 
                            prok_min_size = 2, 
                            prok_min_purity = 0.5,
                            euk_min_size = 3,
                            euk_min_purity = 0.76)
    

In [None]:

#argparse define
parser = argparse.ArgumentParser(description='Evalute microcosm')
parser.add_argument('--query', type=str, required=True, help='query root name, maps to folder containing containing .acc and .targets accession files within microcosm')
parser.add_argument('--root', type=str, required=True, help='path to microcosm root')
parser.add_argument('--threads', type=int, required=True, help='threads to run')

args = parser.parse_args()

#run main
if __name__ == '__main__':
    microcosm_run(args.query, args.root, args.threads)





15047