In [None]:
%matplotlib inline 

In [None]:
import numpy as np
import pandas as pd
import polars as pl
import requests
import streamlit as st
from ensmallen import HyperSketchingPy
from grape import Graph
from src.models import LightGBM

In [None]:
# Constants
URL_CLASSYFIRE = "https://structure.gnps2.org/classyfire?smiles="
URL_NP_CLASSIFIER = "https://npclassifier.gnps2.org/classify?smiles="

In [None]:
def classify_with_np_classifier(compound: str) -> int:
    """Submit a compound information to the NP Classifier service for evaluation
    and receive the classification of the compound.

    :param compound: The SMILES of the compound of interest
    :type compound: str
    :return: A dictionary with the results of the classification
    :rtype: dict

    >>> classify_with_np_classifier('CCC')

    """
    r = requests.get(
        URL_NP_CLASSIFIER + compound,
    )
    r.raise_for_status()
    return r.json()

In [None]:
model = LightGBM.load_model("lightgbm_model_new.pkl")
graph = Graph.from_csv(
    name="full_graph_clean",
    node_path="./data/full_graph_clean_nodes.csv",
    edge_path="./data/full_graph_clean_edges.csv",
    node_list_separator="\t",
    node_list_header=True,
    nodes_column_number=0,
    node_list_node_types_column_number=1,
    edge_list_separator="\t",
    edge_list_header=True,
    sources_column_number=0,
    destinations_column_number=1,
    edge_list_edge_types_column_number=2,
    directed=False,
    load_edge_list_in_parallel=False,
    load_node_list_in_parallel=False,
)

species_phylo = pd.read_csv("./data/species/full_wikidata_taxonomy_nodes.csv")

lotus = pl.read_csv(
    "data/molecules/230106_frozen_metadata.csv.gz",
    dtypes={
        "structure_xlogp": pl.Float32,
        "structure_cid": pl.UInt32,
        "organism_taxonomy_ncbiid": pl.UInt32,
        "organism_taxonomy_ottid": pl.UInt32,
        "structure_stereocenters_total": pl.UInt32,
        "structure_stereocenters_unspecified": pl.UInt32,
    },
    infer_schema_length=50000,
    null_values=["", "NA"],
)

lotus = lotus.with_columns(
    pl.col("organism_taxonomy_gbifid")
    .map_elements(lambda x: np.nan if x.startswith("c(") else x, return_dtype=pl.UInt32)
    .alias("organism_taxonomy_gbifid")
)
lotus = lotus.with_columns(
    ("wd:" + pl.col("organism_wikidata").str.extract(r"(Q\d+)")).alias("wd_species")
)

lotus = lotus.with_columns(
    ("wd:" + pl.col("structure_wikidata").str.extract(r"(Q\d+)")).alias("wd_molecule")
)

In [None]:
compound = "CC(C)CCC(C)C(=O)NCCCNC(=N)N"
dct = classify_with_np_classifier(compound)
_ = dct.pop("isglycoside")

# We first create the edges dataframe
edges_np_classifier = (
    pd.concat(
        [
            pd.DataFrame([compound]),
            pd.DataFrame.from_dict(dct, orient="index"),
        ]
    )
    .dropna()
    .reset_index(drop=True)
)

edges_np_classifier[1] = edges_np_classifier.iloc[:, 0].shift(-1)
edges_np_classifier.dropna(inplace=True)
edges_np_classifier.rename(columns={0: "child", 1: "parent"}, inplace=True)
edges_np_classifier["type"] = "biolink:subclass_of"

# then the nodes dataframe
nodes_np_classifier = (
    pd.DataFrame(
        {
            "node": pd.concat([edges_np_classifier.child, edges_np_classifier.parent]),
            "type": "biolink:ChemicalEntity",
        }
    )
    .drop_duplicates()
    .reset_index(drop=True)
)

In [None]:
edges_np_classifier

# Now we add the edges of similarity between molecules using FAISS

In [None]:
import faiss
from src.utils import calculate_fingerprint_parallel
from src.utils import calculate_fingerprint

lotus_pd = pd.read_csv(
    "./data/molecules/230106_frozen_metadata.csv.gz", low_memory=False
)

lotus_pd["wd_molecule"] = "wd:" + lotus_pd.structure_wikidata.str.extract(r"(Q\d+)")

wd_pd = (
    lotus_pd[
        [
            "wd_molecule",
            "structure_smiles_2D",
        ]
    ]
    .drop_duplicates(subset=["wd_molecule"])
    .reset_index(drop=True)
)
index = wd_pd.wd_molecule.str.extract(r"wd:Q(\d+)").astype("int64").sort_values(0).index
wd_pd = wd_pd.reindex(index).reset_index(drop=True)

embedding_full = np.array(
    calculate_fingerprint_parallel(wd_pd["structure_smiles_2D"].values, radi=2)
)

In [None]:
SIZE = len(embedding_full)

embedding = embedding_full[:SIZE].astype(np.float32)

d = embedding.shape[1]

index = faiss.IndexFlatIP(d)
faiss.normalize_L2(embedding)
index.add(embedding)


embedding_compound = (
    calculate_fingerprint(compound, radi=2).astype(np.float32).reshape(1, -1)
)
faiss.normalize_L2(embedding_compound)
D, I = index.search(embedding_compound.reshape(1, -1), SIZE)
out_array = D.reshape(-1)[np.argsort(I.reshape(-1))].astype("float16")

In [None]:
edges = pd.DataFrame(
    out_array.reshape(1, -1),
    columns=[
        wd_pd["wd_molecule"].values[i] for i in range(len(wd_pd["wd_molecule"].values))
    ],
    index=[compound],
)

In [None]:
# pivot table to get index in one columns, the column names in an other columns and the intersection values in the last column
edges = edges.stack().reset_index()

In [None]:
edges.rename(
    columns={"level_0": "child", "level_1": "parent", 0: "similarity"}, inplace=True
)

In [None]:
edges_filtered = edges[edges.similarity > 0.96]

In [None]:
edges_filtered.drop(columns=["similarity"], inplace=True)
edges_filtered["type"] = "biolink:similar_to"

In [None]:
edges_to_add = pd.concat([edges_filtered, edges_np_classifier])
nodes_to_add = pd.DataFrame(
    {
        "node": pd.concat([edges_to_add.child, edges_to_add.parent]),
        "type": "biolink:ChemicalEntity",
    }
).drop_duplicates()

In [None]:
graph_np_classifier = Graph.from_pd(
    directed=False,
    edges_df=edges_to_add,
    nodes_df=nodes_to_add,
    node_name_column="node",
    node_type_column="type",
    edge_src_column="child",
    edge_dst_column="parent",
    edge_type_column="type",
)

In [None]:
graph_merged = graph | graph_np_classifier

In [None]:
lotus_filtered = lotus.select(
    [
        "wd_species",
        "organism_wikidata",
        "organism_name",
        "organism_taxonomy_01domain",
        "organism_taxonomy_02kingdom",
        "organism_taxonomy_03phylum",
        "organism_taxonomy_04class",
        "organism_taxonomy_05order",
        "organism_taxonomy_06family",
        "organism_taxonomy_08genus",
        "organism_taxonomy_09species",
        "organism_taxonomy_gbifid",
        "organism_taxonomy_ncbiid",
        "organism_taxonomy_ottid",
    ]
)

lotus_filtered = lotus_filtered.unique().to_pandas()
lotus_filtered = lotus_filtered.sample(3000)
lotus_filtered["molecule"] = compound
species_to_remove = list(set(lotus_filtered.wd_species) - set(species_phylo.node))
lotus_filtered = lotus_filtered[~lotus_filtered.wd_species.isin(species_to_remove)]
molecules_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.molecule)
species_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.wd_species)

In [None]:
lotus_filtered.molecule

In [None]:
sketching_features = HyperSketchingPy(
    hops=2,
    normalize=False,
    graph=graph_merged,
)
sketching_features.fit()

In [None]:
pair_sketching_unknown = sketching_features.unknown(
    sources=molecules_id.astype("uint32"),
    destinations=species_id.astype("uint32"),
    feature_combination="addition",
)

In [None]:
lotus_filtered = lotus.select(
    [
        "wd_species",
        "wd_molecule",
    ]
)
lotus_filtered = lotus_filtered.unique().to_pandas()

In [None]:
lotus_filtered = lotus_filtered.sample(3000)
species_to_remove = list(set(lotus_filtered.wd_species) - set(species_phylo.node))
lotus_filtered = lotus_filtered[~lotus_filtered.wd_species.isin(species_to_remove)]
molecules_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.wd_molecule)
species_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.wd_species)

In [None]:
# create the sketching features
pair_sketching_positive = sketching_features.unknown(
    sources=molecules_id.astype("uint32"),
    destinations=species_id.astype("uint32"),
    feature_combination="addition",
)

In [None]:
pos = graph_merged.filter_from_names(
    edge_type_names_to_keep=["biolink:in_taxon"],
)
neg = pos.sample_negative_graph(
    number_of_negative_samples=3000,
    sample_edge_types=False,
    only_from_same_component=False,
    use_scale_free_distribution=True,
    random_state=23391 * (3 + 1),
)

neg_sources = neg.get_directed_source_node_ids()
neg_destinations = neg.get_directed_destination_node_ids()
sk_negative_features = sketching_features.unknown(
    sources=neg_sources,
    destinations=neg_destinations,
    feature_combination="addition",
)

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(
    n_components=2,
    random_state=42,
    verbose=2,
)
X = np.concatenate(
    [
        pair_sketching_unknown,
        sk_negative_features,
        pair_sketching_positive,
    ]
)
X_embedded = tsne.fit_transform(X)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.DataFrame(X_embedded, columns=["x", "y"])
df["label"] = (
    ["unknown"] * pair_sketching_unknown.shape[0]
    + ["negative"] * sk_negative_features.shape[0]
    + ["positive"] * pair_sketching_positive.shape[0]
)

fig, ax = plt.subplots(figsize=(10, 10))
sns.scatterplot(data=df, x="x", y="y", hue="label", ax=ax)

In [None]:
graph_merged.get_neighbour_node_names_from_node_name(compound)

In [None]:
plt.hist(
    model.predict_proba(pair_sketching_positive)[:, 1],
    bins=50,
)

In [None]:
plt.hist(
    model.predict_proba(pair_sketching_unknown)[:, 1],
    bins=50,
    range=(0, 1),
)

In [None]:
plt.hist(
    model.predict_proba(sk_negative_features)[:, 1],
    bins=50,
    range=(0, 1),
)

### Let's try to add an edge from the compound "CC(C)CCC(C)C(=O)NCCCNC(=N)N" to a species.

In [None]:
edges_np_classifier_new = pd.concat(
    [
        edges_np_classifier,
        pd.DataFrame(
            {
                "child": ["CC(C)CCC(C)C(=O)NCCCNC(=N)N"],
                "parent": ["wd:Q25834677"],
                "type": ["biolink:in_taxon"],
            }
        ),
    ]
)

In [None]:
nodes_np_classifier_new = (
    pd.DataFrame(
        {
            "node": pd.concat(
                [edges_np_classifier_new.child, edges_np_classifier_new.parent]
            ),
            "type": "biolink:ChemicalEntity",
        }
    )
    .drop_duplicates()
    .reset_index(drop=True)
)
nodes_np_classifier_new.iloc[4, 1] = "biolink:OrganismTaxon"

In [None]:
graph_np_classifier_new = Graph.from_pd(
    directed=False,
    edges_df=edges_np_classifier_new,
    nodes_df=nodes_np_classifier_new,
    node_name_column="node",
    node_type_column="type",
    edge_src_column="child",
    edge_dst_column="parent",
    edge_type_column="type",
)

In [None]:
graph_merged_new = graph | graph_np_classifier_new

In [None]:
lotus_filtered = lotus.select(
    [
        "wd_species",
        "organism_wikidata",
        "organism_name",
        "organism_taxonomy_01domain",
        "organism_taxonomy_02kingdom",
        "organism_taxonomy_03phylum",
        "organism_taxonomy_04class",
        "organism_taxonomy_05order",
        "organism_taxonomy_06family",
        "organism_taxonomy_08genus",
        "organism_taxonomy_09species",
        "organism_taxonomy_gbifid",
        "organism_taxonomy_ncbiid",
        "organism_taxonomy_ottid",
    ]
)

lotus_filtered = lotus_filtered.unique().to_pandas()
lotus_filtered = lotus_filtered.sample(3000)
lotus_filtered["molecule"] = compound
species_to_remove = list(set(lotus_filtered.wd_species) - set(species_phylo.node))
lotus_filtered = lotus_filtered[~lotus_filtered.wd_species.isin(species_to_remove)]
molecules_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.molecule)
species_id = graph_merged.get_node_ids_from_node_names(lotus_filtered.wd_species)

In [None]:
sketching_features = HyperSketchingPy(
    hops=2,
    normalize=False,
    graph=graph_merged_new,
)
sketching_features.fit()

In [None]:
# create the sketching features
pair_sketching_unknown = sketching_features.unknown(
    sources=molecules_id.astype("uint32"),
    destinations=species_id.astype("uint32"),
    feature_combination="addition",
)

In [None]:
plt.hist(
    model.predict_proba(pair_sketching_unknown)[:, 1],
    bins=50,
)