In [None]:
import json

import matplotlib.pyplot as plt
import pandas as pd
import umap as umap
from matplotlib.colors import ListedColormap

In [None]:
def create_df(fp_json_name: str):
    smiles_to_remove = [
        "CCCCCCCCCCCCCC(OC(COC(CCCCCCCCCCCCC)=O)COCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOC)=O",
        "CC(C)CCC[C@@H](C)[C@H]1CC[C@@]2([H])[C@]3([H])CC=C4C[C@@H](O)CC[C@]4(C)[C@@]3([H])CC[C@]12C",
        "CCCCCCCCCCCCCCCCCCN(C)CCCCCCCCCCCCCCCCCC",
    ]

    fp_dict = json.load(open(fp_json_name))
    classes = pd.read_csv("../data/iphos_multiclass.csv").set_index("m1")

    df = pd.DataFrame.from_dict(fp_dict, orient="index")

    try:
        df = df.drop(smiles_to_remove)
    except:
        pass

    df_classes = df.join(classes[["family", "y1"]])

    return df_classes

In [None]:
embeddings = [
    [
        "mol2fp_cfp",
        "mol2fp_expert",
        "mol2fp_gcn",
    ],
    ["mol2fp_grover", "mol2fp_grover_large"],
    ["mol2fp_MegaMB_base_iphos", "mol2fp_MegaMB_finetuned_iphos"],
]
naming = [
    [
        "CFP",
        "Expert",
        "GCN",
    ],
    ["Grover", "GroverLarge"],
    ["MMB", "MMB-FT"],
]

In [None]:
import matplotlib.pyplot as plt
import umap.plot
from matplotlib.lines import Line2D

fig, ax = plt.subplots(3, 3, figsize=(12, 12))
ax[1, 2].axis("off")
ax[2, 2].axis("off")

labels = [0, 1, 2, 3, 4, 5, 6]

for i, group in enumerate(embeddings):
    for j, name in enumerate(group):
        df_classes = create_df(f"{name}.json")

        reducer = umap.UMAP()
        embedding_data = df_classes[df_classes.drop(["family", "y1_text"], axis=1).columns]
        embedding = reducer.fit_transform(embedding_data)

        # Scatter plot with family information and store legend handles
        scatter = ax[i][j].scatter(
            x=embedding[:, 0], y=embedding[:, 1], c=df_classes.family, label=labels
        )
        ax[i][j].set_title(naming[i][j])

plt.legend(
    *scatter.legend_elements(), loc="upper right", title="Family", bbox_to_anchor=(1.5, 3.5)
)
plt.show()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 12))
ax[1, 2].axis("off")
ax[2, 2].axis("off")

cmap = ListedColormap(["#fff9c9", "#efda6d", "#b64a47", "#754242"])
labels = [0, 1, 2, 3]

for i, group in enumerate(embeddings):
    for j, name in enumerate(group):
        df_classes = create_df(f"{name}.json")

        reducer = umap.UMAP()
        embedding_data = df_classes[df_classes.drop(["family", "y1_text", "y1"], axis=1).columns]
        embedding = reducer.fit_transform(embedding_data)

        # Scatter plot with family information and store legend handles
        scatter = ax[i][j].scatter(
            x=embedding[:, 0], y=embedding[:, 1], c=df_classes.y1, label=labels, cmap=cmap
        )
        ax[i][j].set_title(naming[i][j])

L = plt.legend(
    *scatter.legend_elements(), loc="upper right", title="RLU_activity", bbox_to_anchor=(1.8, 3.5)
)

L.get_texts()[0].set_text("<1,000")
L.get_texts()[1].set_text("1,000-10,000")
L.get_texts()[2].set_text("10,000-100,000")
L.get_texts()[3].set_text(">100,000")

plt.show()