In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
from notebooks.utils import *

import pandas as pd
from mol_gen_docking.data.target_naming import fetch_uniprot_id_from_pdbid
from tqdm import tqdm
from multiprocessing import Pool
import plotly.express as px

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots


tqdm.pandas()

In [None]:
def plot_all_cols(df, N_COLS, fig, escape=["smiles", "pdb_id"]):
    outer_grid = fig.add_gridspec(N_COLS, 1, wspace=0.1, hspace=0.5)


    i = 0
    j = 0
    n_cols = (df[[c for c in df.columns if not c in escape]].shape[1]-1) // N_COLS +1

    grid = outer_grid[0].subgridspec(1, ncols=n_cols)

    for col in df.columns:
        if col not in escape:
            ax = plt.subplot(grid[i])
            sns.histplot(df[col], bins=100, ax=ax)
            ax.set_title(col)
            i += 1
            if i == n_cols and not j == N_COLS - 1:
                i = 0
                j += 1
                if j == N_COLS - 1:
                    grid = outer_grid[j].subgridspec(1, ncols=n_cols + (df[[c for c in df.columns if not c in escape]].shape[1]-1) % N_COLS)
                else:
                    grid = outer_grid[j].subgridspec(1, ncols=n_cols)

# df = pd.read_csv("data/properties.csv", index_col=0)
# fig = plt.figure(figsize=(20, 10))
# plot_all_cols(df, 4, fig, escape=["smiles", "pdb_id"])

In [None]:
DATA_PATH = "data/mol_orz"

target_info_df = pd.read_csv(f"{DATA_PATH}/pockets_info.csv").drop(columns=["center", "size"])
target_info_df["pocket score"] = target_info_df["pocket score"].clip(0,1)

fig, axes = plt.subplots(3, 4, figsize=(10, 5))
axes = axes.flatten()

cols = [c for c in target_info_df.columns if not c in ["pocket_id", "pdb_id", "mean b-factor of pocket residues", "pocket score"]]
for ax, col in zip(axes, cols):
    sns.histplot(target_info_df[col], bins=20, ax=ax)

fig.tight_layout()

In [None]:
sns.histplot(target_info_df["pocket_id"], bins=20,)

In [None]:
pool = Pool(16)
pdb_ids = target_info_df["pdb_id"].unique()
infos = {
    pdb_id: uniprot_id for pdb_id, uniprot_id in zip(
        pdb_ids,
        list(
            tqdm(
                pool.imap(fetch_uniprot_id_from_pdbid, pdb_ids),
                total=len(pdb_ids)
            )
        )
    )
}
del pool


In [None]:
target_info_df["uniprot_id"] = target_info_df["pdb_id"].apply(lambda x: infos[x])

In [None]:
import requests

def get_info(uniprot_id):
    url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json"
    response = requests.get(url, timeout=10)
    if response.status_code != 200:
        raise ValueError(f"UniProt ID {uniprot_id} not found.")
    data = response.json()
    out = dict(
        molecular_func = "unknown",
        uniprot_score = data["annotationScore"],
        organism= data["organism"]["scientificName"],
        proteinExistence= data["proteinExistence"],
        organism_path=data["organism"]["lineage"],
        len_lineage=len(data["organism"]["lineage"]),
    )
    for kyw in data["keywords"]:
        if kyw["category"] == "Molecular function":
            out["molecular_func"] = kyw["name"]
    return out

In [None]:
get_info(uniprot_ids[0])

In [None]:
pool = Pool(16)
uniprot_ids = target_info_df["uniprot_id"].unique()
infos = {
    uniprot_id: inf for uniprot_id, inf in zip(
        uniprot_ids,
        list(
            tqdm(
                pool.imap(get_info, uniprot_ids),
                total=len(uniprot_ids)
            )
        )
    )
}
del pool


In [None]:
for k in list(infos[list(infos.keys())[0]].keys()):
    target_info_df[k] = target_info_df["uniprot_id"].apply(lambda id: infos[id][k])

In [None]:
# Re arrange path for lineage
import numpy as np

lineage_max = target_info_df["len_lineage"].max()

def get_lineage(df, idx_forb):
    lineage_df = df[["pdb_id", "organism", "organism_path", "len_lineage"]]

    lineage_df["organism_path"] = lineage_df["organism_path"].apply(lambda L: [L[i] for i in range(len(L)) if not i in idx_forb or i == len(L) - 1])


    leaf_list = [L[-1] for L in lineage_df["organism_path"]]
    lineage_df["leaf_reached"] = False
    for k in range(lineage_max+1):
        lineage_df["lineage_{}".format(k)] = lineage_df.apply(
            lambda row: None if len(row["organism_path"]) <= k or row["leaf_reached"] else row["organism_path"][k],
            axis=1
        )
        lineage_df["leaf_reached"] = lineage_df.apply(lambda row: row["leaf_reached"] or row["lineage_{}".format(k)] in leaf_list, axis=1)

    lineage_df["org_path_hash"] = lineage_df.apply(lambda row: "".join([row["lineage_{}".format(k)] for k in range(lineage_max) if not row["lineage_{}".format(k)] is None]), axis=1)

    count_organism = lineage_df.groupby("org_path_hash").pdb_id.nunique().to_frame().rename(columns={"pdb_id": "organism_count"})

    lineage_df = lineage_df.join(count_organism, on="org_path_hash")
    lineage_df = lineage_df.drop(["pdb_id", "organism_path", "organism", "len_lineage"], axis=1).drop_duplicates()
    return lineage_df

In [None]:
lineage_df = get_lineage(target_info_df, [1,2,3,5,6,7,8,] + list(range(10,lineage_max+1)))
fig = px.sunburst(lineage_df, path=[f"lineage_{k}" for k in range(15)], values='organism_count', width=600, height=600)
fig.update_layout(uniformtext=dict(minsize=20, ))
fig.show()

In [None]:
fig,ax = plt.subplots(1,1,figsize=(4,10))
order = target_info_df.molecular_func.value_counts().sort_values()[::-1].index.tolist()
target_info_df["molecular_func"] = pd.Categorical(target_info_df["molecular_func"], order)
target_info_df.rename(columns={"uniprot_score": "Annotation Score"}, inplace=True)

sns.histplot(data = target_info_df, y = "molecular_func", hue="Annotation Score", multiple="stack", ax=ax, palette=sns.color_palette("coolwarm_r", as_cmap=True))
# ax.tick_params(axis="x", rotation=90)
ax.set_ylabel("")
plt.savefig("target_mol_fn.png", dpi=300, bbox_inches="tight")