In [1]:
import os
import sys
from multiprocessing import Pool

import numpy as np
from scipy.spatial.distance import cdist
from ete3 import Tree, TextFace, TreeStyle, NodeStyle

sys.dont_write_bytecode = True

from my_library import neighbor_joining, Metrics
from my_vibe import VIBE, get_bipartitions, get_support

# Allocate resources

- **THREADS** : this variable will be passed to `torch`
- **DEVICE** : this will be passed to `torch`; allowed options include "cpu" and "cuda"

If using CUDA, we highly recommend monitoring the GPU memory usage while running this running this code `nvidia-smi -l 1`.

In [2]:
THREADS = 20
DEVICE  = 'cuda'

# Define input and ouput files

- **npz_file** : (input) a `numpy` file containing sequence headers and fixed-size embedding vectors for each sequence
- **colors_json** : (input, optional) json file containing coloring scheme for visualizations
- **vibe_dir** : (input) directory for storing results
- **exclude_file** : (input, optional) a text file listing sequence headers to exclude from the analysis

If you do not wish to provide an optional file, set it equal to an empty string like so `colors_json = ''`

In [3]:
npz_file     = 'datasets/Cas_1_2/cas_1_2_models/fixedsize/mean_of_residue_tokens.npz'
colors_json  = 'datasets/Cas_1_2/cas_protein_color.json'
vibe_dir     = 'datasets/Cas_1_2/vibe/'
#exclude_file = 'datasets/phosphatase/exclude.txt'

# Define the tree building function

In [4]:
def _nj(embedding, headers):
    distmat = Metrics.cosine(embedding, embedding) # make sure this line uses the correct distance metric
    # distmat = Metrics.ts_ss(embedding, embedding) # make sure this line uses the correct distance metric
    return neighbor_joining(distmat, headers)

# Set up variables

In [6]:
vibe_model        = f'{vibe_dir}/vibe_model.pt'
newick_reference  = f'{vibe_dir}/tree_reference.newick'
newick_replicates = f'{vibe_dir}/tree_replicates.newick'
newick_vibe       = f'{vibe_dir}/tree_vibe.newick'
pdf_tree_vibe     = f'{vibe_dir}/tree_vibe.pdf'

# colors for the tree
colors = eval(open(colors_json).read()) if os.path.exists(colors_json) else {}

# these names will be pruned from all tree as rogue taxa
#exclude = [i.strip() for i in open(exclude_file) if not i.isspace()] if os.path.exists(exclude_file) else []
exclude = [] #first try none 
headers    = np.load(npz_file, allow_pickle=True)['headers']
accessions = np.array([i.split()[0] for i in headers], dtype=object)
embedding  = np.load(npz_file)['embedding'].astype(np.float32)

mask       = np.array([i not in exclude for i in headers])
headers    = headers[mask]
accessions = accessions[mask]
embedding  = embedding[mask]

if not os.path.exists(vibe_dir):
    os.makedirs(vibe_dir)

with open(newick_reference, 'w') as w:
    w.write(_nj(embedding, headers))

# Train VAE

In [8]:
vibe = VIBE(
    encoder_layers = [1280, 640, 640, 320],
    latent_dim     = 320,
    decoder_layers = [320, 640, 640, 1280],
    tse_weight     = 0.1,
    max_iter       = 15000,
    warm_up        = 1000 / 15000,
    cool_down      = 2000 / 15000,
    kld_annealing  = True,
    start_beta     = 0.001,
    stop_beta      = 0.1,
    n_cycle        = 12,
    threads        = THREADS,
    device         = DEVICE,
    seed           = 420,
    log_dir        = vibe_dir)

vibe.fit(embedding)
vibe.dump(vibe_model)

====> Epoch: 500 | loss: 1.3885 | MSE: 1.3347 | TSE: 0.0710 | KLD: 46.7394 | beta: 0.0010                0      
====> Epoch: 1000 | loss: 0.0700 | MSE: 0.0457 | TSE: 0.0005 | KLD: 24.2322 | beta: 0.0010      
====> Epoch: 1500 | loss: 0.1447 | MSE: 0.0498 | TSE: 0.0006 | KLD: 15.1569 | beta: 0.0063      
====> Epoch: 2000 | loss: 0.3370 | MSE: 0.0939 | TSE: 0.0015 | KLD: 9.9445 | beta: 0.0244       
====> Epoch: 2500 | loss: 0.1130 | MSE: 0.0325 | TSE: 0.0004 | KLD: 12.8707 | beta: 0.0063      
====> Epoch: 3000 | loss: 0.3096 | MSE: 0.0700 | TSE: 0.0011 | KLD: 9.8090 | beta: 0.0244       
====> Epoch: 3500 | loss: 0.1016 | MSE: 0.0244 | TSE: 0.0003 | KLD: 12.3487 | beta: 0.0063      
====> Epoch: 4000 | loss: 0.2960 | MSE: 0.0602 | TSE: 0.0009 | KLD: 9.6495 | beta: 0.0244       
====> Epoch: 4500 | loss: 0.0972 | MSE: 0.0254 | TSE: 0.0003 | KLD: 11.4705 | beta: 0.0063      
====> Epoch: 5000 | loss: 0.2935 | MSE: 0.0608 | TSE: 0.0009 | KLD: 9.5228 | beta: 0.0244       
====> Epoch: 5

# Resample VAE for replicate trees

In [9]:
vibe = VIBE.load(vibe_model)
samples = vibe.resample(500)

with Pool(THREADS) as pool:
    output = pool.starmap(_nj, ((i, headers) for i in samples))

with open(newick_replicates, 'w') as w:
    w.write('\n'.join(output))

# Perform VIBE check

In [None]:
support

In [7]:
reference_tree = Tree(newick_reference)
reference_bits, reference_nodes = get_bipartitions(reference_tree, headers)

replicate_newicks = (f'{i};' for i in open(newick_replicates).read().split(';') if i.strip()!='')
replicate_trees   = ((reference_bits, i, headers) for i in map(Tree, replicate_newicks))

with Pool(THREADS) as pool:
    support_bits = pool.starmap(get_support, replicate_trees)
    support = (100 * np.array(support_bits).mean(0)).astype(int)

for percentage, node in zip(support, reference_nodes):
    node.support = percentage

for node in reference_tree.traverse():
    if node not in reference_nodes:
        node.support = 0

with open(newick_vibe, 'w') as w:
    w.write(reference_tree.write())

# Draw VIBE tree

In [8]:
def draw_tree(newick, pdf_file, colors):
    t = Tree(newick)
    t.ladderize()    
    names = [i.name for i in t.get_leaves()]
    def contains(n):
        j = [i for i in colors if i in n]
        return None if len(j)==0 else colors[j[-1]]
    d = {j: i for i, j in zip(map(contains,names),names) if i!=None}
    
    for node in t.traverse():
        nstyle = NodeStyle()
        if node.name in d:
            nstyle['bgcolor'] = d[node.name]
            node.set_style(nstyle)  
        nstyle["size"] = 0
        nstyle["vt_line_width"] = 1
        nstyle["hz_line_width"] = 1
        node.set_style(nstyle)
    
    ts = TreeStyle()
    ts.mode = "c"
    ts.root_opening_factor = .45
    ts.show_branch_support = True
    t.render(pdf_file, tree_style=ts)

draw_tree(newick_vibe, pdf_tree_vibe, colors)