In [1]:
import os
import sys
from multiprocessing import Pool

import numpy as np
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt

from ete3 import Tree, TextFace, TreeStyle, NodeStyle

from bokeh.io import output_file, show, save
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.plotting import figure

sys.dont_write_bytecode = True
np.set_printoptions(precision=6, suppress=True)

# Allocate resources

- **THREADS** : this variable will be passed to `multiprocessing.Pool`


In [2]:
THREADS = 2

# Define input and output files

- **OUTPUT_DIR** : (input) directory containing previous results; outputs will be also be stored here
- **COLORS_JSON** : (input, optional) json file containing coloring scheme for visualizations

If you do not wish to provide an optional file, set it equal to an empty string like so `COLORS_JSON = ''`

In [3]:
# for the phosphatase dataset
OUTPUT_DIR  = 'datasets/phosphatase/phosphatase_models'
COLORS_JSON = 'datasets/phosphatase/phosphatase_colors.json'

# # for the kinase dataset
# OUTPUT_DIR  = 'datasets/protein_kinase/kinase_models'
# COLORS_JSON = 'datasets/protein_kinase/kinase_colors.json'

# # for the radical sam dataset
# OUTPUT_DIR  = 'datasets/radical_sam/radicalsam_models'
# COLORS_JSON = 'datasets/radical_sam/radicalsam_colors.json'

# Set up variables

In [4]:
sequences_npz = f'{OUTPUT_DIR}/sequences.npz'
headers       = np.load(sequences_npz, allow_pickle=True)['headers']
sequences     = np.load(sequences_npz, allow_pickle=True)['sequences']

models_dir    = f'{OUTPUT_DIR}/models'
models        = [np.load(f'{models_dir}/{i}') for i in sorted(os.listdir(models_dir))]
models        = [{k: i[k].item() if i[k].ndim==0 else i[k] for k in i} for i in models]

colors = eval(open(COLORS_JSON).read()) if os.path.exists(COLORS_JSON) else {}

# Write newick files

In [5]:
def _export_newick(newick, newick_file):
    with open(newick_file, 'w') as w:
        w.write(newick)

out_dir = f'{OUTPUT_DIR}/viz_nj'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
_queue = ((i['viz(nj)'], f'{out_dir}/{i["representation"]}_{i["metric"]}.newick') for i in models)
pool   = Pool(THREADS)
pool.starmap(_export_newick, _queue)
pool.terminate()

# Draw trees

Draw trees using ete3. Use multiprocessing to generate multiple plots at once.


In [6]:
def _draw_tree(newick, pdf_file, colors):
    t = Tree(newick)
    t.ladderize()    
    names = [i.name for i in t.get_leaves()]
    def contains(n):
        j = [i for i in colors if i in n]
        return None if len(j)==0 else colors[j[-1]]
    d = {j: i for i, j in zip(map(contains,names),names) if i!=None}
    
    for node in t.traverse():
        nstyle = NodeStyle()
        if node.name in d:
            nstyle['bgcolor'] = d[node.name]
            node.set_style(nstyle)  
        nstyle["size"] = 0
        nstyle["vt_line_width"] = 1
        nstyle["hz_line_width"] = 1
        node.set_style(nstyle)
    
    ts = TreeStyle()
    ts.mode = "c"
    ts.root_opening_factor = .45
    ts.show_branch_support = False
    t.render(pdf_file, tree_style=ts)

out_dir = f'{OUTPUT_DIR}/viz_nj'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
_queue = ((i['viz(nj)'], f'{out_dir}/{i["representation"]}_{i["metric"]}.pdf', colors) for i in models)
pool   = Pool(THREADS)
pool.starmap(_draw_tree, _queue)
pool.terminate()

# Draw scatterplots (static)

Draw static plots using matplotlib. Use multiprocessing to generate multiple plots at once.

In [7]:
def _draw_scatter_static(projection, eps_file, names, colors):
    fill = [[colors[j] for j in colors if j in i] for i in names]
    fill = [i[0] if len(i)>0 else 'white' for i in fill]
    plt.figure(figsize=(9,9))
    plt.scatter(*projection.T, c=fill, edgecolors='black', linewidths=1)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(eps_file)

out_dir = f'{OUTPUT_DIR}/viz_densmap'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
_queue = ((i['viz(densmap)'], f'{out_dir}/{i["representation"]}_{i["metric"]}.eps', headers, colors) for i in models)
pool   = Pool(THREADS)
pool.starmap(_draw_scatter_static, _queue)
pool.terminate()

# Draw scatterplots (interactive)

Draw interactive plots using bokeh. Use multiprocessing to generate multiple plots at once.

In [8]:
def _draw_scatter_interactive(projection, html_file, names, colors):
    fill = [[colors[j] for j in colors if j in i] for i in names]
    fill = [i[0] if len(i)>0 else 'white' for i in fill]
    source = ColumnDataSource(data=dict(x=projection[:,0], y=projection[:,1], desc=names, fill=fill))
    hover = HoverTool(tooltips=[('desc', '@desc')])
    p = figure(plot_width=1200, plot_height=1200, tools=[hover], title=None)
    p.circle('x', 'y', size=9, source=source, line_width=1, line_color='black', fill_color='fill')
    p.toolbar.logo = None
    p.toolbar_location = None
    output_file(html_file)
    save(p)

out_dir = f'{OUTPUT_DIR}/viz_densmap'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
_queue = ((i['viz(densmap)'], f'{out_dir}/{i["representation"]}_{i["metric"]}.html', headers, colors) for i in models)
pool   = Pool(THREADS)
pool.starmap(_draw_scatter_interactive, _queue)
pool.terminate()