In [1]:
# get embeddings from uniprotid
import pickle
with open('data/uniprotid_to_emb.pkl', 'rb') as f:
    embeddings = pickle.load(f)
embeddings['P00533'].shape

torch.Size([1, 768])

In [2]:
import umap
from sklearn.preprocessing import StandardScaler

embeddings_list = []
uniprot_ids = []
for uniprot_id, embedding in embeddings.items():
    embeddings_list.append(embedding.numpy().flatten())
    uniprot_ids.append(uniprot_id)

scaler = StandardScaler()
scaled_embeddings = scaler.fit_transform(embeddings_list)

reducer = umap.UMAP(n_components=3)
embeddings_umap = reducer.fit_transform(scaled_embeddings)

umap_dict = {
    uniprot_id: embedding for uniprot_id, embedding in \
    zip(uniprot_ids, embeddings_umap)
}
umap_dict['P00533']

array([5.416897 , 5.5928226, 4.584672 ], dtype=float32)

In [13]:
import pandas as pd
df = pd.read_parquet('data/data_human_reduced.parquet')
df.head()

Unnamed: 0,uniprot_id,ligand_smiles,pKi,pKd,pIC50,pEC50
0,P00533,*#Cc1cnc2c(Cl)cc(NC(C3=CN(C4CC4)NN3)c3ccc(F)nc...,,,7.021819,
1,P52333,*.C=CC(=O)N1CC(Nc2ncnc3[nH]cc(Cl)c23)CCC1C.S,,,7.811072,
2,Q02750,*.CC(O)CONC(=O)c1c2c(c(=O)n(C)c1Nc1ccc(I)cc1F)...,,,5.999957,
3,P34896,*.CCC(C)C(NC(=O)Cc1csc(-c2ccc(Cl)cc2Cl)n1)C(=O...,,,3.69897,
4,P34897,*.CCC(C)C(NC(=O)Cc1csc(-c2ccc(Cl)cc2Cl)n1)C(=O...,,,3.69897,


In [14]:
# for all ligands, create a dictionary of ligand to list of uniprot ids
ligand_to_uniprot_ids = {}
for ligand, uniprot_id in zip(df['ligand_smiles'], df['uniprot_id']):
    if ligand in ligand_to_uniprot_ids:
        ligand_to_uniprot_ids[ligand].append(uniprot_id)
    else:
        ligand_to_uniprot_ids[ligand] = [uniprot_id]

In [15]:
# select ligands with more than 50 uniprot ids
ligand_to_uniprot_ids = {
    ligand: uniprot_ids for ligand, uniprot_ids in \
    ligand_to_uniprot_ids.items() if len(uniprot_ids) > 300
}

In [40]:
import numpy as np
import plotly.graph_objects as go

def plot_and_save(ligand, df, umap_dict, param):
    
    df_ligand = df[df['ligand_smiles'] == ligand]
    uniprot_ids = ligand_to_uniprot_ids[ligand]
    embeddings = []
    colors = []
    hover_texts = []
    for id in uniprot_ids:
        row = df_ligand[df_ligand['uniprot_id'] == id]
        if not row.empty and not pd.isnull(row[param].values[0]):
            embeddings.append(umap_dict[id])
            colors.append(row[param].values[0])
            hover_texts.append(id)

    if len(embeddings) < 6:
        return None
    
    embeddings = np.array(embeddings)
    fig = go.Figure(data=[go.Scatter3d(
        x=embeddings[:, 0],
        y=embeddings[:, 1],
        z=embeddings[:, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=colors,
            colorscale='Viridis',
            opacity=0.8,
            colorbar=dict(title='Colorbar')
        ),
        text=hover_texts,
        hoverinfo='text'
    )])
    fig.update_layout(scene=dict(
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        zaxis_title='UMAP 3'
    ))
    fig_json = fig.to_json()
    return fig_json

In [41]:
import os
from tqdm import tqdm

for ligand in tqdm(ligand_to_uniprot_ids):
    for param in ['pKi', 'pKd', 'pIC50', 'pEC50']:
        os.makedirs(f'plots/{param}', exist_ok=True)
        fig_json = plot_and_save(ligand, df, umap_dict, param)
        if fig_json:
            with open(f'plots/{param}/{ligand}.json', 'w') as f:
                f.write(fig_json)

100%|██████████| 75/75 [01:02<00:00,  1.20it/s]
