In [1]:
import os
os.environ['QT_QPA_PLATFORM']='offscreen'
import re
import Bio
from Bio import Seq, SeqIO

In [2]:
import ete3
from ete3 import Tree, faces, TreeStyle, PhyloTree, NodeStyle, TextFace, AttrFace, SeqMotifFace

In [3]:
treefile = '../../../../figshare/pks_phylogeny/KS_domain_cdhit80_prune.mpr.tree'
outfile = '../../../../figshare/pks_phylogeny/KS_domain_cdhit80_prune.mpr.pdf'
alnfile = '../../../../figshare/pks_phylogeny/KS_domain_cdhit80_prune.aln'
#scriptsdir = '.'
colorscheme = 'metazoa'
branch_min = 0.95

nodefile = 'nodes.dmp'
mergedfile = 'merged.dmp'
lineages_colors_file = 'lineage_colors.dmp'

In [4]:
seqDict = {}
alnLen = 0
for record in SeqIO.parse(alnfile, "fasta"):
    name = record.id
    sequence = str(record.seq)
    alnLen = len(sequence)
    seqDict[name] = sequence

In [5]:
#Initialize the taxonomy dictionary
taxnodes = {}
fi = open(nodefile)
for line in fi:
    line = line.rstrip('\n').split('\t')
    node, parent = line[0], line[2]
    taxnodes[node] = parent
fi.close()

fi = open(mergedfile)
for line in fi:
    line = line.rstrip('\n').split('\t')
    node, newnode = line[0], line[2]
    taxnodes[node] = taxnodes[newnode]
fi.close()

In [6]:
#Read in the lineages colors file and store in dictionaries 
orderDict = {}
colorDict = {}
lineages = {}
fi = open(lineages_colors_file)

for line in fi:
    line = line.rstrip('\n').split('\t')
    
    if line[3] == colorscheme:
        order, taxid_list, name, color = line[0], line[5], line[4], line[2]
        #print(order, name, color, taxid_list)
        orderDict[int(order)] = name
        colorDict[name] = color
        
        for taxid in taxid_list.split(','):
            #print(taxid, name)
            lineages[taxid] = name
            
fi.close()
#print(colorDict)

In [7]:
#Create function for returning the lineage string for a given taxonomy id
def taxdump(taxid):
    root = ''
    taxlist = []
    
    while root == '':
        if taxid == '':
            root = '1'
        elif taxid == '1':
            root = '1'
        else: 
            taxlist.append(taxid)
            taxid = taxnodes[taxid]
            
    return taxlist

In [8]:
# add species and lineage information to all leaves
# color target gene red 
# color all ME034 gene green

#t = PhyloTree(treefile, alignment=alnfile, alg_format="fasta")
t = Tree(treefile, format=1)
t.ladderize(direction=1)
leafSet = set()

#print("Custom mode:")
for n in t.get_leaves():
    leafSet.add(n.name)
    #print(n.name)
    speciesname = ''
    genename = ''
    taxid = ''
    
    tmp = n.name.split("-")
    speciesname = tmp[2]
    speciesname = speciesname.replace("_", " ")
    genename = tmp[0]
    taxid = tmp[1]
    #print(speciesname, genename, taxid)

    lincolor = '#000000'
    linname = ''
    lineage = taxdump(taxid)
    #print(lineage)
    
    for i in lineage:
        if str(i) in lineages:
            linname = lineages[str(i)]
            lincolor = colorDict[linname]
            #print(linname,lincolor)
            break
    
    n.add_features(lineage=linname)
    n.add_features(gene=genename)
    n.add_features(species=speciesname)
    n.add_features(taxid=taxid)

    #print("Species name:", n.species, "Species lineage:", n.lineage, "Color:", lincolor)

    # create a new label with a color attribute
    linF = AttrFace("lineage", fgcolor=lincolor, fsize=1)
    linF.background.color = lincolor
    linF.margin_top = linF.margin_bottom = linF.margin_left = 10
    
    speciesF = AttrFace("species", fsize=10, fgcolor=lincolor, fstyle="italic")
    speciesF.margin_right = speciesF.margin_left = 10
    taxidF = AttrFace("taxid", fsize=10, fgcolor=lincolor, fstyle="normal")
    taxidF.margin_right = taxidF.margin_left = 10

    if speciesname == 'Elysia crispata':
        geneF = AttrFace("gene", fsize=12, fgcolor="red", fstyle="bold")
        geneF.margin_right = geneF.margin_left = 5
    
    else:
        geneF = AttrFace("gene", fsize=10, fgcolor="black")
        geneF.margin_right = geneF.margin_left = 5

    # labels aligned to the same level
    n.add_face(speciesF, 0, position='aligned')
    n.add_face(geneF, 0, position='branch-right')
    n.add_face(taxidF, 1, position='aligned')
    n.add_face(linF, 2, position='aligned')
    
    my_motifs = [[0, alnLen, "compactseq", 2, 10, None, None, None]]
    seqF = SeqMotifFace(seq=seqDict[n.name], motifs=my_motifs, gap_format="blank")
    seqF.margin_right = seqF.margin_left = 5
    n.add_face(seqF, 3, "aligned")
    

In [9]:
# add lineage information to all internal nodes
style = NodeStyle()

style["size"] = 0
style["hz_line_width"] = 2
style["vt_line_width"] = 2
t.set_style(style)

for n in t.iter_descendants("postorder"):
    #print(n.name)
                
    style["size"] = 0
    style["hz_line_width"] = 2
    style["vt_line_width"] = 2
    n.set_style(style)
    
    lineage_set = set()
    # get descendants, if all descendants are members of same lineage, color lineage color
    #print("NODE CHILDREN:")
    for k in n.iter_descendants("postorder"):
        for l in k.get_leaves():
            lineage_set.add(l.lineage)
            #print("Gene:", l.gene, "Species:", l.species, "Lineage:", l.lineage, "Color:", lin2color[l.lineage])
    
    #print(len(lineage_set), lineage_set)
    if len(lineage_set) == 1:
        node_lin = ''.join(lineage_set)
        #print(len(lineage_set), lineage_set, node_lin)
    
        newstyle = NodeStyle()
        newstyle["size"] = 0
        newstyle["hz_line_width"] = 2
        newstyle["vt_line_width"] = 2
        newstyle["vt_line_color"] = colorDict[node_lin]
        newstyle["hz_line_color"] = colorDict[node_lin]
        n.img_style = newstyle
        
    #fix branchlengths?
    
    if n.name not in leafSet and n.name[0] != 'n':
        #print(n.name)
        
        if float(n.name) >= branch_min:
            #branch_support = n.name.split('.')[0]
            branch_support = n.name
            #print(branch_support)
            n.add_features(bootstrap=branch_support)
            
            if len(lineage_set) == 1:
                node_lin = ''.join(lineage_set)
                supF = AttrFace("bootstrap", fgcolor=colorDict[node_lin], fsize=8)
                supF.margin_right = supF.margin_left = 3
                n.add_face(supF, 0, position='branch-bottom')
                
            else:
                supF = AttrFace("bootstrap", fgcolor="#000000", fsize=8)
                supF.margin_right = supF.margin_left = 3
                n.add_face(supF, 0, position='branch-bottom')
        

for n in t.get_leaves():
    
    leafstyle = NodeStyle()
    leafstyle["size"] = 0
    leafstyle["hz_line_width"] = 2
    leafstyle["vt_line_width"] = 2
    leafstyle["vt_line_color"] = colorDict[n.lineage]
    leafstyle["hz_line_color"] = colorDict[n.lineage]
    n.img_style = leafstyle
    
    

In [10]:
# add legend

ts = TreeStyle()
ts.show_leaf_name = False
#ts.show_branch_support = True
ts.draw_guiding_lines = True

ts.title.add_face(TextFace("Taxonomy:", fsize=10), column=0)
for i in range(1, len(colorDict)):
    #print(orderDict[i], colorDict[orderDict[i]])
    ts.title.add_face(TextFace(orderDict[i], fsize=10, fgcolor=colorDict[orderDict[i]]), column=0)

In [None]:
# render image on notebook or save to file
t.render(outfile, tree_style=ts)
#t.render("%%inline", tree_style=ts)