In [1]:
import os
import pickle

import numpy as np

import warnings
warnings.filterwarnings('ignore')

import torch
torch.set_default_dtype(torch.float64)
import geoopt

from utils import data
from utils.hyperbolic_utils import poincare_translation

import plotly.graph_objects as go

from jupyter_dash import JupyterDash
from dash import html, dcc, Input, Output

# Load data

In [None]:
data_path = "./data"
result_path = "./results"

tissue_hierarchy, ppi_networks, labels = data.read_data(data_path)

config = {
    "manifold": geoopt.manifolds.Lorentz(k = 1.0, learnable = False),
    #"manifold": geoopt.manifolds.PoincareBallExact(c = 1.0, learnable = False),
    #"manifold": None,
    "embedding_dim": 128,

    "p": 1,
    "q": 4,                   
    "walks_per_node": 1,
    "walk_length": 10,
    "context_size": 10,
    "num_negative_samples": 1,
 
    "num_epochs": 100,
    "test_epochs": 5,
    "batch_size": 64,          
    "learning_rate": 0.025,       
    "lambda": 0.2,              
  
    "device": torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0"),
    "num_workers": 16 if not torch.cuda.is_available() else 0
}

config_name = '_'.join(key + str(val).replace(".", "").replace(':', '').split(" ")[0] for key, val in config.items())

## Load 2D embeddings for visualization

In [3]:
with open(os.path.join(result_path, config_name, "reduced_embeddings.pickle"), 'rb') as f:
    embedding_dict = pickle.load(f)

## Create and save 2D embeddings for visualization

In [12]:
from utils import model
from utils.hyperbolic_utils import lorentz_to_poincare, lorentzian_distance_matrix, poincare_distance

import multiprocessing
from tqdm.notebook import tqdm

import umap

import sys
sys.path.append('...') #TODO: set path for Poincaré Maps https://github.com/facebookresearch/PoincareMaps/tree/main
from poincare_maps import PoincareMaps as PMAPS

In [None]:
OhmNet = model.TissueSpecificProteinEmbeddings(tissue_hierarchy, ppi_networks, labels, config, eval = False)
OhmNet.load(result_path)

embedding_dict = OhmNet.get_embeddings()

for tissue, tissue_label_dict in labels.items():
    embedding_dict[tissue]["labels"] = dict()
    for function, label in tissue_label_dict.items():
        embedding_dict[tissue]["labels"][function] = np.array([label[gene] for gene in embedding_dict[tissue]["entrez_ids"]]).astype(int)

for tissue, tissue_embedding_dict in embedding_dict.items():
    embedding_dict[tissue]["is_leaf"] = tissue_hierarchy.out_degree()[tissue]==0

for tissue, ppi_network in ppi_networks.items():
    embedding_dict[tissue]["ppi_network"] = ppi_network

def dimred(embeds):
    if (str(config["manifold"]).split(" ")[0] == "Lorentz") and (embeds.shape[1] == 3):
        return lorentz_to_poincare(torch.Tensor(embeds), config["manifold"].k.detach().cpu()).detach().cpu().numpy()
    elif embedding_dict[tissue]["embeds"].shape[1] == 2:
        return embeds
    else:
        if config["manifold"] == None:
            reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, metric="euclidean", verbose=True, output_metric="euclidean")
            reduced_embeds = reducer.fit_transform(embeds)
        else:
            if str(config["manifold"]).split(" ")[0] == "PoincareBallExact(exact)":
                distances = poincare_distance(embeds)
            elif str(config["manifold"]).split(" ")[0] == "Lorentz":
                distances = lorentzian_distance_matrix(torch.Tensor(embeds), torch.Tensor(embeds), config["manifold"].k.detach().cpu()).detach().cpu().numpy()
            
            reduced_embeds = PMAPS.compute_poincare_maps(distances, 2, mode='features', normalize=False, n_pca=0,
                                distlocal='precomputed', k_neighbours=3, sigma=1.0, gamma=2.0,
                                epochs = 4000, batchsize=-1, lr=0.1, burnin=200, lrm=5.00, earlystop=0.0001, cuda=1)
        return reduced_embeds

"""model_embeddings = [embedding_dict[tissue]["embeds"] for tissue in embedding_dict.keys()]
with multiprocessing.Pool(30) as p:
   reduced_embeddings = p.map(dimred, model_embeddings)
for i, tissue in enumerate(embedding_dict.keys()):
    embedding_dict[tissue]["2D_embeds"] = reduced_embeddings[i]"""
for tissue in tqdm(embedding_dict.keys()):
    embedding_dict[tissue]["2D_embeds"] = dimred(embedding_dict[tissue]["embeds"])

with open(os.path.join(result_path, config_name, "reduced_embeddings.pickle"), 'wb') as f:
    pickle.dump(embedding_dict, f, pickle.HIGHEST_PROTOCOL)

# Plotly Dash

In [4]:
app = JupyterDash(__name__) 

In [5]:
c = "#2b2d42"

@app.callback(
    Output("ppi_embedding", "figure"), 
    Input('tissue', 'value'),
    Input('GO_function', 'value'),
    Input('ppi_embedding', 'clickData'),
    prevent_initial_call=True)
def plot_2D(tissue, GO_function, clickData):
    embeds = embedding_dict[tissue]["2D_embeds"]
    entrez_ids = embedding_dict[tissue]["entrez_ids"]

    if clickData and 'points' in clickData and len(clickData['points']) > 0:
        selected_entrez_id = clickData['points'][0]["text"]
        center = embeds[entrez_ids.index(selected_entrez_id)]
        if config["manifold"] != None:
            embeds = poincare_translation(-center, embeds)
        else:
            embeds = embeds - center
        clickData = None

    title = tissue + " embeddings" 
    if GO_function is not None:
        label = embedding_dict[tissue]["labels"][GO_function]
        title += " colored by " + GO_function
    else:
        label = np.zeros(embeds.shape[0],)
    
    if "ppi_network" in embedding_dict[tissue]:
        ppi_network = embedding_dict[tissue]["ppi_network"]
        edge_x = []
        edge_y = []
        for edge in ppi_network.edges:
            id0 = edge[0].split("__")[1]
            id1 = edge[1].split("__")[1]
            if (id0 in entrez_ids) and (id1 in entrez_ids):
                edge_x.append(embeds[entrez_ids.index(id0),0])
                edge_x.append(embeds[entrez_ids.index(id1),0])
                edge_x.append(None)
                edge_y.append(embeds[entrez_ids.index(id0),1])
                edge_y.append(embeds[entrez_ids.index(id1),1])
                edge_y.append(None)
        data = [go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=0.03, color=c),
            hoverinfo='none',
            mode='lines')]
        node_size = [5 + 0.025*ppi_network.degree[tissue + "__" + entrez_id] for entrez_id in entrez_ids]
    else:
        data = []
        node_size = 5 * np.ones(embeds.shape[0],)

    data += [go.Scatter(
        x=embeds[:,0], y=embeds[:,1],
        mode='markers',
        hoverinfo='text',
        text = entrez_ids,
        marker=dict(
            color = [*map({0: "#8d99ae", 1: "#d90429"}.get, label)],
            opacity = 1.0,
            size = node_size,
            line=dict(width=1, color=c),
        ))]

    fig = go.Figure(data=data,
                    layout=go.Layout(
                        yaxis_scaleanchor="x",
                        title={'text' : title, 'x':0.5, 'xanchor': 'center'},
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        height=700,
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
                        )
                    )
    
    if config["manifold"] != None:
        fig.add_shape(type="circle", x0=-1, y0=-1, x1=1, y1=1, line_color=c, line_width = 1)
    
    fig.update_layout(title={'text' : title, 'x':0.5, 'xanchor': 'center'})
    fig.update_layout(yaxis={'visible': False, 'showticklabels': False}, xaxis={'visible': False, 'showticklabels': False})
    fig.update_layout(plot_bgcolor='rgba(0, 0, 0, 0)',paper_bgcolor='rgba(0, 0, 0, 0)')
    fig.update_layout(height = 750)
    fig.update_yaxes(scaleanchor="x",scaleratio=1,)
    
    return fig

In [6]:
@app.callback(
    Output('tissue', 'options'),
    Input('hierarchy_level', 'value'),
    Input('with_go', 'value'),
    Input('with_ppi', 'value'),
    prevent_initial_call=True)
def set_tissue_options(hierarchy_level, with_go, with_ppi):
    if hierarchy_level == 'Leaf tissues':
        tissues = [t for t in embedding_dict.keys() if tissue_hierarchy.out_degree()[t]==0]
    else:
        tissues = [t for t in embedding_dict.keys() if tissue_hierarchy.out_degree()[t]!=0]
    if with_go:
        tissues = [t for t in tissues if "labels" in embedding_dict[t]]
    if with_ppi:
        tissues = [t for t in tissues if "ppi_network" in embedding_dict[t]]
    return tissues

@app.callback(
    Output('tissue', 'value'),
    Input('tissue', 'options'),
    prevent_initial_call=True)
def set_tissue_value(available_options):
    return available_options[0]

In [7]:
@app.callback(
    Output('GO_function', 'options'),
    Input('tissue', 'value'),
    prevent_initial_call=True)
def set_GO_options(tissue):
    if (tissue is None) or ("labels" not in embedding_dict[tissue]):
        return []
    else:
        return [k for k in embedding_dict[tissue]["labels"].keys()]
        
@app.callback(
    Output('GO_function', 'value'),
    Input('GO_function', 'options'),
    prevent_initial_call=True)
def set_GO_value(available_options):
    return None

In [8]:
app.layout = html.Div(style={'display': 'flex', 'flex-direction': 'column'}, children=[

    html.Div(style={'display': 'flex', 'width': '100%'}, children=[
        html.Div(style={'display': 'inline-block', 'width': '20%'}, children=[
            html.H2(children='', style={'textAlign': 'center'}),
            dcc.Checklist(id="with_go", options=['Tissues with GO function'], style={"textAlign":"center"}),
            html.Br(),
            dcc.Checklist(id="with_ppi", options=['Tissues with PPI network'], style={"textAlign":"center"}),
            html.Br(),
            dcc.Dropdown(id="hierarchy_level", placeholder="Select the hierarchy level", options=['Leaf tissues', 'Internal tissues']),
            html.Br(),
            dcc.Dropdown(id="tissue", placeholder="Select a tissue"),
            html.Br(),
            dcc.Dropdown(id="GO_function", placeholder="Select a GO function"),
            ]),
        html.Div(style={'display': 'inline-block', 'width': '80%'}, children=[
            html.H2(children='Multi-layer tissue network embeddings', style={'textAlign': 'center'}),
            dcc.Graph(id = "ppi_embedding")
            ]) 
        ])
])

In [None]:
my_host = 'XXX.XXX.XXX.XXX' #TODO: set host number
my_port = '50000'
app.run_server(host = my_host, port = my_port) 