In [1]:
import dendropy
import csv
import baltic as bt
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np

This script generates new minor lineage designations based on the current defining rules. It will pull out all possible minor lineage designations in the annotated tree. It does NOT take old names into account, so when suggesting new lineages, please take these lineages and check what their names will be based onw hat has already been used.

The input is a tree annotated in nexus format with major lineages and genotypes. The output is an annotated tree and TSV file.

In [2]:
min_distance = 25
min_size = 15

genotype_annotation = "new_genotype"
major_lineage_annotation = "major_lineage"

annotated_tree = ""

In [3]:
def find_genotype_defining(tree):
    
    defining_nodes = []
    
    for node in tree.preorder_node_iter():
        genotype = node.annotations[genotype_annotation].value
        if node != tree.seed_node:
            parent_genotype = node.parent_node.annotations[genotype_annotation].value
            if genotype != parent_genotype:
                defining_nodes.append(node)
            
    return defining_nodes

def find_major_lineage_defining(tree):
    
    defining_nodes = []
    
    for node in tree.preorder_node_iter():
        major_lin = node.annotations[major_lineage_annotation].value
        if node != tree.seed_node:
            parent_lin = node.parent_node.annotations[major_lineage_annotation].value
            if major_lin != parent_lin:
                defining_nodes.append(node)
            
    return defining_nodes

In [4]:
def annotate_tree(focal_node, annotation):
    
    focal_node.annotations["minor_lineage"].value = f'"{annotation}"'
    
    for node in focal_node.child_nodes():
        annotate_tree(node, annotation)

    return

In [5]:
def find_path(node, path):
    
    path.append(node)
    if node != tree.seed_node:
        find_path(node.parent_node, path)
    else:
        return path
    
    return path

In [6]:
def generate_groups(tree, sero):
    
    meet_threshold = defaultdict(list)
    distances = {}

    genotype_nodes = find_genotype_defining(tree)
    major_lin_nodes = find_major_lineage_defining(tree)
    
    for node in tree.preorder_node_iter():
        if "_" in node.annotations[major_lineage_annotation].value:
            major_lineage = node.annotations[major_lineage_annotation].value
            distance = (node.edge_length*aln_lengths[sero])
            size = (len(node.leaf_nodes()))

            if distance >= min_distance and size >= min_size:
                if node not in genotype_nodes and node not in major_lin_nodes:
                    meet_threshold[major_lineage].append(node)
                    distances[node] = distance

        
    paths_to_root = defaultdict(dict)
    for major_lineage, node_list in meet_threshold.items():
        for node in node_list:
            path = []
            paths_to_root[major_lineage][node] = find_path(node, path)
            
    return meet_threshold, distances, paths_to_root, major_lin_nodes
    
    

In [7]:
def get_levels(paths_to_root):
    node_to_level = {}
    major_lins_levels = defaultdict(dict)
    
    for major_lineage, minor_lineage_paths in paths_to_root.items():
        major_lins_levels[major_lineage] = defaultdict(list)
        for key_node, path in minor_lineage_paths.items():
            node_to_level[key_node] = 0
            for ancestor in path:
                if ancestor != key_node and ancestor in minor_lineage_paths.keys():
                    node_to_level[key_node] += 1
                        
    for node, level in node_to_level.items():
        major_lins_levels[node.annotations[major_lineage_annotation].value][level].append(node)
        
    return major_lins_levels, node_to_level

In [8]:
def get_parents(path, meet_threshold, major_lin_nodes):
    
    ancestors = set()
    for ancestor in path:
        if ancestor in meet_threshold:
            ancestors.add(ancestor)
        if ancestor in major_lin_nodes:
            break
    
    return ancestors

In [9]:
def find_siblings(major_lins_levels, node_to_level, paths_to_root, meet_threshold, major_lin_nodes):
    
    remove_list = defaultdict(list)
    for major_lin, level_dict in major_lins_levels.items():
        for level, node_list in level_dict.items():
            if len(node_list) <= 1:
                remove_list[major_lin].append(node_list[0])
    
    for major_lin, level_dict in major_lins_levels.items():
        for level, node_list in level_dict.items():
            if len(node_list) > 1:
                for node in node_list:
                    group = defaultdict(set)
                    ancestor_set1 = get_parents(paths_to_root[major_lin][node], meet_threshold[major_lin], major_lin_nodes)
                    for node2 in node_list:
                        if node != node2:
                            ancestor_set2 = get_parents(paths_to_root[major_lin][node2], meet_threshold[major_lin], major_lin_nodes)
                            
                            if len(ancestor_set1.intersection(ancestor_set2)) == level:
                                group[node].add(node2)

                    if len(group[node]) < 1:
                        remove_list[major_lin].append(node)        
                
    
    levels_to_remove = {}
    for major_lineage, minor_lineage_paths in paths_to_root.items():
        specific_removes = remove_list[major_lineage]
        for key_node, path in minor_lineage_paths.items():
            levels_to_remove[key_node] = 0
            for ancestor in path:
                if ancestor != key_node:
                    if ancestor in specific_removes:
                        levels_to_remove[key_node]  += 1
    
    new_levels = {}    
    new_major_lins_levels = defaultdict(dict)
    
    for major_lineage in major_lins_levels.keys():
        new_major_lins_levels[major_lineage] = defaultdict(list)
    
    for node, level in node_to_level.items():
        if node not in remove_list[node.annotations[major_lineage_annotation].value]:
            new_level = level - levels_to_remove[node]
            new_levels[node] = new_level
            new_major_lins_levels[node.annotations[major_lineage_annotation].value][new_level].append(node)

    return new_levels, new_major_lins_levels

In [10]:
def name_lineages(major_lins_levels, paths_to_root):
    
    annotations = {}
    base_to_count = {}
    for major_lineage, level_dict in major_lins_levels.items():
        sorted_dict = {k:v for k,v in sorted(level_dict.items())}
        for level, node_list in sorted_dict.items():
            for node in node_list:
                base = major_lineage
                for ancestor in paths_to_root[major_lineage][node]:
                    if ancestor in annotations.keys():
                        base = annotations[ancestor]
                        break
                if base in base_to_count:
                    count = base_to_count[base]
                else:
                    count = 0
                    base_to_count[base] = 0
                base_to_count[base] += 1
                
                annotations[node] = f"{base}.{count+1}"
                    
    return annotations

In [11]:
def sort_outputs(major_lins_levels, sero, paths_to_root, distances):

    annotations = name_lineages(major_lins_levels, paths_to_root)
        
    for node in tree.preorder_node_iter():
        if node.annotations['minor_lineage'].value == "":
            node.annotations['minor_lineage'].value = node.annotations['major_lineage'].value
        
    annotation_set = set()
    with open(f"../phylo_outputs/minor_lineages/{sero}_minor_lins.csv", 'w') as fw:
        fw.write("name,mutation_number\n")
        for node, annotation in annotations.items():
            annotate_tree(node, annotation)
    
            fw.write(f"{annotation},{distances[node]}\n")
        
            print(annotation, len(node.leaf_nodes()), node.edge_length*aln_lengths[f"{sero}"])
            annotation_set.add(annotation)
        
    return tree, annotation_set

In [12]:
aln_lengths = {}
aln_lengths["DENV1"] = 10179
aln_lengths["DENV2"] = 10176
aln_lengths["DENV3"] = 10173
aln_lengths["DENV4"] = 10164

## running

In [None]:
for sero in ["DENV1", "DENV2", "DENV3", "DENV4"]:

    tree = dendropy.Tree.get(path=annotated_tree, schema="nexus")

    meet_threshold_dict, distances, paths_to_root, major_lin_nodes = generate_groups(tree, sero)
    major_lins_levels, node_to_level = get_levels(paths_to_root)
        
    new_levels, new_major_lins_levels = find_siblings(major_lins_levels, node_to_level, paths_to_root, meet_threshold_dict, major_lin_nodes)
    
    tree, annotation_set = sort_outputs(new_major_lins_levels, sero, paths_to_root, distances)
   
    tree.write(path=f"../phylo_outputs/minor_lineages/{sero}_annotated.tree", schema="nexus")

## annotations

In [15]:
for sero in ["DENV1", "DENV2", "DENV3", "DENV4"]:
    
    new = open(f"{sero}_annotations.tsv", 'w')
    new.write("taxon\tminor_lineage\n")
    tree = dendropy.Tree.get(path=f"{sero}_annotated_minor.tree", schema="nexus")
    for tip in tree.leaf_node_iter():
        new.write(f"{tip.taxon.label}\t{tip.annotations['minor_lineage'].value}\n")

    new.close()