In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import pandas as pd
from mol_gen_docking.data.pdb_uniprot.target_naming import fetch_uniprot_id_from_pdbid
from tqdm import tqdm
from multiprocessing import Pool
import plotly.express as px
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

tqdm.pandas()

# Targets

In [None]:
def plot_all_cols(df, N_COLS, fig, escape=["smiles", "pdb_id"]):
    outer_grid = fig.add_gridspec(N_COLS, 1, wspace=0.1, hspace=0.5)


    i = 0
    j = 0
    n_cols = (df[[c for c in df.columns if not c in escape]].shape[1]-1) // N_COLS +1

    grid = outer_grid[0].subgridspec(1, ncols=n_cols)

    for col in df.columns:
        if col not in escape:
            ax = plt.subplot(grid[i])
            sns.histplot(df[col], bins=100, ax=ax)
            ax.set_title(col)
            i += 1
            if i == n_cols and not j == N_COLS - 1:
                i = 0
                j += 1
                if j == N_COLS - 1:
                    grid = outer_grid[j].subgridspec(1, ncols=n_cols + (df[[c for c in df.columns if not c in escape]].shape[1]-1) % N_COLS)
                else:
                    grid = outer_grid[j].subgridspec(1, ncols=n_cols)

# df = pd.read_csv("data/properties.csv", index_col=0)
# fig = plt.figure(figsize=(20, 10))
# plot_all_cols(df, 4, fig, escape=["smiles", "pdb_id"])

In [None]:
DATA_PATH = "data/molgendata"
with open(os.path.join(DATA_PATH, "pockets_info.json")) as f:
    data = json.load(f)
data = {k: data[k]["metadata"] for k in data}
target_info = pd.DataFrame.from_dict(data).transpose()
target_info["origin"] = target_info["origin"].fillna("sair")

target_info["volume (nm$^3$)"] = target_info["size"].apply(lambda x: np.prod(x)/1000)

def get_activity_val(row: pd.Series):
    if not np.isnan(row.avg_pIC50):
        return row.avg_pIC50
    else:
        return row.avg_pKd

def get_activity_label(row: pd.Series):
    if not np.isnan(row.avg_pIC50):
        return "pIC50"
    else:
        return "pKd"

target_info["Average labeled activity"] = target_info.apply(get_activity_val, axis=1)
target_info["activity unit"] = target_info.apply(get_activity_label, axis=1)
target_info = target_info.sort_values("prot_id", ascending=False)
target_info

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(6, 2), )

plt.subplots_adjust(wspace=0.4)
axes = axes.flatten()
cols = [("volume (nm$^3$)","origin"), ("Average labeled activity", "activity unit")]
palette = sns.color_palette("Paired", n_colors=2)[::-1]

for ax, col in zip(axes, cols):
    if len(col) == 1:
        sns.histplot(target_info, x=col[0], bins=20, ax=ax, palette=palette)
    else:
        sns.histplot(target_info, x=col[0], bins=20, ax=ax, hue=col[1], multiple="stack", palette=palette)


In [None]:
import requests

def get_info(uniprot_id):
    try:
        url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json"
        response = requests.get(url, timeout=10)
        if response.status_code != 200:
            raise ValueError(f"UniProt ID {uniprot_id} not found.")
        data = response.json()
        data["organism"] = data.get("organism", {})
        prt_ex = None
        if "proteinExistence" in data:
            prt_ex = data["proteinExistence"]
        elif "inactiveReason" in data:
            prt_ex = data["inactiveReason"]["deletedReason"]
        else:
            prt_ex = "5: Uncertain"
        out = dict(
            molecular_func = "unknown",
            uniprot_score = data["annotationScore"],
            organism= data["organism"].get("scientificName", "unk"),
            proteinExistence= prt_ex,
            organism_path=data["organism"].get("lineage", np.nan),
            len_lineage=len(data["organism"].get("lineage", [])),
        )
        mol_func = []
        if "keywords" in data:
            for kyw in data["keywords"]:
                if kyw["category"] == "Molecular function":
                    mol_func.append(kyw["name"])
            out["molecular_func"] = mol_func
        else:
            out["molecular_func"] = "unk"
    except Exception as e:
        return None
    return out

In [None]:
get_info(target_info.prot_id[0])

In [None]:
pool = Pool(64)
uniprot_ids = target_info["prot_id"].unique()
infos = {
    uniprot_id: inf for uniprot_id, inf in zip(
        uniprot_ids,
        list(
            tqdm(
                pool.imap(get_info, uniprot_ids),
                total=len(uniprot_ids)
            )
        )
    )
}
del pool

In [None]:
failed = [k for k in infos if infos[k] is None]
print(len(failed))
pool = Pool(16)
uniprot_ids = failed
new_infos = {
    uniprot_id: inf for uniprot_id, inf in zip(
        uniprot_ids,
        list(
            tqdm(
                pool.imap(get_info, uniprot_ids),
                total=len(uniprot_ids)
            )
        )
    )
}
for unip, v in new_infos.items():
    if v is not None:
        infos[unip] = v

del pool


In [None]:
for k in list(infos[list(infos.keys())[0]].keys()):
    target_info[k] = target_info["prot_id"].apply(lambda id: infos[id][k])


In [None]:
target_info.to_csv("data/tmp.csv")

In [None]:
target_info_df = pd.read_csv("data/tmp.csv")
def decode_org(org_path):
    if isinstance(org_path, str) and not org_path=="unk":
        return json.loads(org_path.replace("'",'''"'''))
    return org_path

target_info_df.organism_path = target_info_df.organism_path.apply(decode_org)
target_info_df.molecular_func = target_info_df.molecular_func.apply(decode_org)


In [None]:
target_info_df

In [None]:
# Re arrange path for lineage

lineage_max = target_info_df["len_lineage"].max()

def get_lineage(df, idx_forb):
    lineage_df = df[["prot_id", "organism", "organism_path", "len_lineage"]]
    lineage_df["organism_path"] = lineage_df["organism_path"].apply(lambda L: [L[i] for i in range(len(L)) if not i in idx_forb or i == len(L) - 1])


    leaf_list = [L[-1] for L in lineage_df["organism_path"]]
    lineage_df["leaf_reached"] = False
    for k in range(lineage_max+1):
        lineage_df["lineage_{}".format(k)] = lineage_df.apply(
            lambda row: None if len(row["organism_path"]) <= k or row["leaf_reached"] else row["organism_path"][k],
            axis=1
        )
        lineage_df["leaf_reached"] = lineage_df.apply(lambda row: row["leaf_reached"] or row["lineage_{}".format(k)] in leaf_list, axis=1)

    lineage_df["org_path_hash"] = lineage_df.apply(lambda row: "".join([row["lineage_{}".format(k)] for k in range(lineage_max) if not row["lineage_{}".format(k)] is None]), axis=1)

    count_organism = lineage_df.groupby("org_path_hash").prot_id.nunique().to_frame().rename(columns={"prot_id": "organism_count"})

    lineage_df = lineage_df.join(count_organism, on="org_path_hash")
    lineage_df = lineage_df.drop(["prot_id", "organism_path", "organism", "len_lineage"], axis=1).drop_duplicates()
    return lineage_df

In [None]:
import warnings
warnings.filterwarnings("ignore")

lineage_df = get_lineage(target_info_df[~target_info_df.organism_path.isna()], [1,2,3,4,5,7,8] + list(range(10,lineage_max+1)))
fig = px.sunburst(lineage_df, path=[f"lineage_{k}" for k in range(15)], values='organism_count', width=600, height=600,)
fig.update_layout(
    font=dict(size=20),
    )
fig.show()

In [None]:
fn_count = {}
target_info_df.molecular_func = target_info_df.molecular_func.apply(lambda x: [x] if not isinstance(x, list) else x)
for mol_fns in target_info_df.molecular_func:
    for fn in mol_fns:
        fn_count[fn] = fn_count.get(fn, 0) + 1

max_len_func = target_info_df.molecular_func.apply(len).max()
taxonomy: dict[str,str] = {} # Parent -> Child
for iteration in tqdm(range(max_len_func)):
    for mol_fns in target_info_df.molecular_func:
        if len(mol_fns) < iteration + 1:
            continue
        # Get largest parent not already in taxonomy
        ordered_mol_fns = sorted(mol_fns, key=fn_count.get)[::-1]
        for i in range(len(ordered_mol_fns)-1):
            if not ordered_mol_fns[i] in taxonomy:
                taxonomy[ordered_mol_fns[i]] = ordered_mol_fns[i+1]
                break

len(taxonomy), len(fn_count)

In [None]:
counts = fn_count

In [None]:
child_to_parent = {child: parent for parent, child in taxonomy.items()}

def build_path(leaf):
    path = [leaf]
    while path[-1] in child_to_parent:
        path.append(child_to_parent[path[-1]])
    return path[::-1]  # reverse to have root -> leaf

# Build DataFrame for all leaves (all nodes in counts)
import pandas as pd

rows = []
for node in counts.keys():
    path = build_path(node)
    rows.append(path + [counts[node]])  # add count as last column

# Determine the max depth
max_depth = max(len(r) for r in rows)

# Pad paths to same length for DataFrame
for r in rows:
    while len(r) < max_depth + 1:
        r.insert(-1, np.nan)  # insert empty strings before count

# Column names: Level1, Level2, ..., Count
col_names = [f"level{i+1}" for i in range(max_depth)] + ["count"]

df = pd.DataFrame(rows, columns=col_names)

def get_longest_path(df, value):
    return (~df[df["level1"] == value].isna()).sum(1).max() - 2

for value in df.level1.unique():
    length = get_longest_path(df, value)
    for i in range(df.shape[0]):
        for depth in range(2,length+2):
            if df.level1[i] == value and df[f"level{depth}"].isna()[i] and df[f"level{depth-1}"][i] != "" and ~df[f"level{depth-1}"].isna()[i]:
                df[f"level{depth}"].iloc[i] = ""


df


In [None]:
import plotly.graph_objects as go
parents = []
children = []
labels = []

for val in df.level1.unique():
    parents.append("")
    children.append(val)
    labels.append(counts[val])

for lvl in range(2,5):
    for val in df[f"level{lvl}"].unique():
        if not val in [np.nan, ""]:
            parent = df[df[f"level{lvl}"] == val][f"level{lvl-1}"].iloc[0]
            parents.append(parent)
            children.append(val)
            labels.append(counts[val])


fig =px.sunburst(
        names=children,
        parents=parents,
        values=labels, width=800, height=800
)
fig.update_layout(
    font=dict(size=20),
    )

fig.show()

In [None]:
fig = t(
    df,
    path=[f"level{i+1}" for i in range(2)],
    values="count",
    width=800,
    height=800,
)
fig.show()

In [None]:
fig,ax = plt.subplots(1,1,figsize=(4,3))
target_info_df["Protein existence"] = target_info_df.proteinExistence.apply(lambda x: x if ":" in x else "5: Uncertain")
order = target_info_df.molecular_func.value_counts().sort_values()[::-1].index.tolist()[:10]

target_info_df.rename(columns={"uniprot_score": "Annotation Score"}, inplace=True)
target_info_df["Annotation Score"] = pd.Categorical(target_info_df["Annotation Score"], [1.,2.,3.,4.,5.])


palette = {
    k: sns.color_palette("coolwarm", as_cmap=True)((int(k.split(':')[0])-1)/4) for k in target_info_df["Protein existence"].unique()
}

sns.histplot(data = target_info_df.sort_values("Protein existence", ascending=True), x = "Annotation Score", hue="Protein existence", multiple="stack", ax=ax, palette=palette, binwidth=.2)
ax.set_xlabel("")
sns.move_legend(
    ax, "upper left",
 ncol=1,fontsize=9
)
ax.tick_params(axis="x", rotation=90)
# plt.yscale("log")

plt.savefig("target_mol_fn.png", dpi=300, bbox_inches="tight")

In [None]:
target_info_df.molecular_func.unique().tolist()

# Prompts

In [None]:
from mol_gen_docking.data.pydantic_dataset import read_jsonl
from pathlib import Path

def load(path:str):
    data = read_jsonl(Path(path))
    return [line.conversations[0].meta for line in data]

data_dir = "data/molgendata"
with open(os.path.join(data_dir, "pockets_info.json")) as f:
    pockets = json.load(f)
with open(os.path.join(data_dir, "docking_targets.json")) as f:
    docking_targets = json.load(f)

data = {}
for d in os.listdir(os.path.join(data_dir)):
    directory = os.path.join(data_dir, d)
    if d.endswith(".jsonl"):
        data[d.split(".")[0]] = load(Path(directory))
    if os.path.isdir(directory):
        for f in os.listdir(directory):
            if f.endswith(".jsonl"):
                data[f.split(".")[0]] = load(Path(os.path.join(directory, f)))


In [None]:
def get_df(data_d):
    df = pd.DataFrame(data_d)
    df = df.drop(columns=["docking_metadata"])
    print(df)
    df= df.explode(["properties", "objectives", "target"]).reset_index(drop=True)
    df["is_docking"] = df["properties"].apply(lambda x: x in docking_targets)

    return df


In [None]:
df = get_df(data["train_prompts"])
def transf(x):
    if x in pockets:
        return "docking"
    return x

df["reward_type"] = df.properties.apply(transf)

df

In [None]:
sns.histplot(df, x="reward_type",  palette=sns.color_palette("colorblind"))
plt.xticks(rotation=90)

In [None]:
sns.histplot(df, hue="reward_type", x="n_props", alpha=1, multiple="fill", palette=sns.color_palette("colorblind"))
plt.xticks(rotation=90)