In [1]:
import glob
import plotly
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import Bio
import numpy as np
from sklearn import preprocessing

from Bio.SeqUtils.ProtParam import ProteinAnalysis as PA
from Bio.SeqUtils.IsoelectricPoint import IsoelectricPoint as IP

In [2]:
from numba import cuda
print(cuda.gpus)

<Managed Device 0>


In [3]:
llr_dict = {protein.split('/')[-1].split('_')[0]:pd.read_csv(protein, 
                                                             index_col=[0]) 
            for protein in glob.glob('/data/VariantAnalysis/esm1b_preprs/*/*_heatmap_logits.csv')}

In [14]:
embedding_dict = {umap.split('/')[-1].split('_')[0]:pd.read_csv(umap, index_col=[0]) 
                  for umap in glob.glob('/data/VariantAnalysis/vae_model/*/*_umap.csv')}

In [9]:
aa_list = ['K','R',
           'H','E',
           'D','N',
           'Q','T',
           'S','C',
           'G','A',
           'V','L',
           'I','M',
           'P','Y',
           'F','W']

aa_list = aa_list[::-1]

In [10]:
def _create_fig():
    return make_subplots(rows=11, cols=4,
                         shared_xaxes=False,
                         shared_yaxes=False,
                         vertical_spacing=.01,
                         specs=[[{"type": "heatmap","colspan":4, 'rowspan':4},{}, {}, {}],
                                [{}, None, None, None],
                                [{}, None, None, None],
                                [{}, None, None, None],
                                [{}, None, None, None],
                                [None, None, None, None],
                                [None, {"type":"scatter", 'rowspan':5, 'colspan':2}, {},{}],
                                [None, {}, None, None],
                                [None, {}, None, None],
                                [None, {}, None, None],
                                [None, {}, None, None]])
    
def _create_heatmap_trace(dms_df, visible=False):
    dms_df = dms_df.reindex(aa_list)

    x = list(dms_df)
    y = dms_df.index.tolist()
    z = dms_df.values

    hovertext = list()
    for yi, yy in enumerate(y):
        hovertext.append(list())
        for xi, xx in enumerate(x):
            hovertext[-1].append('x-aa: {}<br />y-aa: {}<br />logit: {}'.format(xx,yy, str(round(float(z[yi][xi]), 2))))
            
    return go.Heatmap(
        z = dms_df.values,
        x = x,
        y = y,
        hoverinfo='text',
        text=hovertext,
        colorscale='Viridis',
        visible=visible,
        reversescale=True,
        colorbar={'len':0.25, 'y':.85}
    )

def _create_scatter_trace(embedding_df, visible=False):
    hover_df = embedding_df[['umap1', 'umap2', 'name', 'pos', 'alt_aa', 'zeroshot_raw', 'zeroshot', 'dist_from_WT']]
    
    hovertext = list()
    for data in hover_df.values:
        hovertext.append('x: {}<br />y: {}<br />name: {}<br />pos: {}<br />alt_aa: {}<br />zeroshot_raw: {}<br />zeroshot: {}<br />dist_from_WT: {}'.format(str(round(data[0],2)), 
        str(round(data[1],2)), data[2], data[3], data[4], str(round(data[5],2)), str(round(data[6], 2)), str(round(data[7], 2))))
    
    return go.Scatter(x=embedding_df['umap1'],
                      y=embedding_df['umap2'],
                      hoverinfo='text',
                      text=hovertext,
                      mode='markers',
                      visible=visible,
                      marker=dict(color= embedding_df['zeroshot'].tolist(),
                           colorscale='Viridis', 
                           size=2, 
                           colorbar={'thickness':20, 'len':0.25, 'x':.8, 'y':.2}))

def _update_layout(fig, buttons=None, template=None, height=1000, width=1000, title='Variant Fx'):
    fig.update_layout(updatemenus=[dict(active=0,buttons=buttons)], height=height, width=width, title_text=title)
    if template:
        for theme in template:
            fig.update_layout(template=theme)

In [16]:
fig = _create_fig()

In [17]:
p53_heatmap = _create_heatmap_trace(llr_dict['P53'], visible=True)
fig.append_trace(p53_heatmap,1,1) # 0

p53_umap = _create_scatter_trace(embedding_dict['P53'], visible=True)
fig.append_trace(p53_umap, 7,2) # 1

In [None]:
rask_heatmap = _create_heatmap_trace(llr_dict['RASK'])
fig.append_trace(rask_heatmap,1,1) # 2

rask_umap  = _create_scatter_trace(embedding_dict['RASK'])
fig.append_trace(rask_umap, 7,2) # 3

In [21]:
template = ['plotly', 'simple_white', 'plotly', 'simple_white']
buttons = list([dict(label="P53", method="update",
                     args=[{ "visible": [True, True, False, False]},
                           {"title": "P53",
                            "annotations": []}]),
                
                dict(label="RASK", method="update",
                     args=[{"visible": [False, False, True, True]},
                           {"title": "RASK",
                            "annotations": []}])])
_update_layout(fig, buttons, template)
# plotly.offline.plot(fig, filename='dms_v0.2.1.html') 

In [22]:
fig.show()