In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import json
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

from conformal_amr.models.utils import create_ab_graph

from conformal_amr.models.mole_bert import graphcl, GNN, DiscreteGNN, global_mean_pool
from conformal_amr.data_split.mole_bert_loaders import (
    MoleculeGraphDataset,
    MoleculeDataset,
    DataLoaderMaskingPred,
)
from torch import nn
import torch

from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.patheffects as PathEffects

In [4]:
import warnings

warnings.filterwarnings("ignore")

### Load data

In [None]:
driams_long_table = pd.read_csv(
    "/fs/pool/pool-miranda/Projects/AMR/ConformalAMR/data/Processed/DRIAMS_combined_long_table_multidrug.csv"
)
driams_long_table

In [None]:
# Get number of labels and prevalence per species and antibiotic
driams_a_prev = (
    driams_long_table[driams_long_table["dataset"] == "A"]
    .groupby(["species", "drug"])["response"]
    .agg(["count", "mean"])
)
driams_a_prev.xs("Escherichia coli", axis=0, level=0)

# Get SMILES and graph representation of molecules

In [None]:
drug_representations = {}
for drug in tqdm(driams_long_table.drug.unique()):
    if drug not in drug_representations.keys():
        try:
            drug_representations[drug] = create_ab_graph(drug)
        except:
            print(f"Failed to get representation for {drug}")

print(f"Got representations for {len(drug_representations.keys())} drugs")

### Load Mole-BERT pretrained model

In [6]:
gnn = GNN(num_layer=5, emb_dim=300, JK="last", drop_ratio=0, gnn_type="gin")
model = graphcl(gnn).eval()
model.gnn.from_pretrained(
    "/fs/pool/pool-miranda/Projects/AMR/Mole-BERT/model_gin/Mole-BERT.pth"
)

In [None]:
DRIAMS_ab_dataset = MoleculeGraphDataset(
    mol_list_of_lists=[
        drug_representations[antibiotic][2]
        for antibiotic in drug_representations.keys()
    ],
)
DRIAMS_ab_dataloader = DataLoaderMaskingPred(
    DRIAMS_ab_dataset,
    batch_size=len(DRIAMS_ab_dataset),
    mask_rate=0.0,
    mask_edge=0.0,
    shuffle=False,
)
DRIAMS_ab_batch = next(iter(DRIAMS_ab_dataloader))
amr_drugs = pd.DataFrame(
    model.forward_cl(
        DRIAMS_ab_batch.x,
        DRIAMS_ab_batch.edge_index,
        DRIAMS_ab_batch.edge_attr,
        DRIAMS_ab_batch.batch,
    )[1]
    .detach()
    .numpy()
)
amr_drugs["drug"] = [drug for drug in drug_representations.keys()]
amr_drugs.set_index("drug", inplace=True)
amr_drugs = (amr_drugs - amr_drugs.mean()) / amr_drugs.std()
amr_drugs

In [8]:
# Save drug embeddings in a format compatible with our AMR models
amr_drugs.to_csv(
    "/fs/pool/pool-miranda/Projects/AMR/ConformalAMR/data/Processed/DRIAMS_Mole-BERT_drug_embeddings.csv"
)

In [9]:
drug_data = """
>>>Penicillins
Piperacillin-Tazobactam
Amoxicillin-Clavulanic acid
Amoxicillin
Penicillin
Ampicillin-Amoxicillin
Oxacillin
Ticarcillin-Clavulanic acid
Ampicillin-Sulbactam
Ampicillin
Piperacillin
Ticarcillin
>>>Cephalosporins
Cefepime
Ceftazidime
Ceftriaxone
Cefpodoxime
Cefuroxime
Cefazolin
Cefixime
Ceftarolin
Ceftobiprole
Cefoxitin
Cefotaxime
>>>Carbapenems
Meropenem
Imipenem
Ertapenem
>>>Monobactams
Aztreonam
>>>Flouroquinolones
Ciprofloxacin
Levofloxacin
Norfloxacin
Moxifloxacin
Ofloxacin
>>>Macrolides and lincosamides
Clindamycin
>>>Aminoglycosides
Amikacin
Tobramycin
Erythromycin
Gentamicin
Clarithromycin
Azithromycin
Telithromycin
>>>Glycopeptides
Vancomycin
Teicoplanin
>>>Tetracyclines
Tetracycline
Tigecycline
Doxycycline
Minocin
>>>Azoles
Fluconazole
Itraconazole
Voriconazole
Posaconazole
Isavuconazole
>>>Echinocandins
Caspofungin
Micafungin
Anidulafungin
>>>Miscallaneaous
Cotrimoxazol
Colistin
Metronidazole
Amphotericin B
5-Fluorocytosine
Fosfomycin-Trometamol
Nitrofurantoin
Linezolid
Daptomycin
Chloramphenicol
Rifamdin
Fusidic acid
Mupirocin
Fosfomycin
Bacitracin
Polymyxin
Novobiocin
"""

# Initialize empty lists to hold drug names and classes
drug_names = []
drug_classes = []

# Split the data into lines and iterate
current_class = None
for line in drug_data.splitlines():
    if line.startswith(">>>"):
        # New class detected, update current_class
        current_class = line.replace(">>>", "").strip()
    elif line.strip():
        # It's a drug name, append to the list along with the current class
        drug_names.append(line.strip())
        drug_classes.append(current_class)

# Create the dataframe from parsed data
drug_classes = pd.DataFrame({"drug": drug_names, "class": drug_classes}).set_index(
    "drug"
)

In [10]:
def get_closest_matches(query, embeddings, drug_structures):

    # Get the embeddings of the query
    dbase = embeddings.loc[embeddings.index != query]
    query = embeddings.loc[query]

    # Compute cosine similarity between query and all embeddings
    cosine_sim = cosine_similarity(query.values.reshape(1, -1), dbase.values)[0]

    # Get the name of the drug with the highest cosine similarity
    closest_drug = dbase.index[cosine_sim.argmax()]

    # Get the structure of the closest drug
    closest_structure = drug_structures[closest_drug][2]

    return closest_drug, closest_structure

In [None]:
print("Query:")
query = "Amoxicillin-Clavulanic acid"
drug_representations[query][2][1]

In [None]:
name, struct = get_closest_matches(query, amr_drugs, drug_representations)
print("Closest match:", name)
struct[0]

In [13]:
import umap
import seaborn as sns
import matplotlib.pyplot as plt

dim_red = umap.UMAP(n_components=2, n_neighbors=15, metric="cosine")
amr_drugs_umap = dim_red.fit_transform(amr_drugs)

amr_drugs_red_df = pd.DataFrame(
    amr_drugs_umap, columns=["UMAP-1", "UMAP-2"], index=amr_drugs.index
)
amr_drugs_red_df["class"] = drug_classes.loc[amr_drugs_red_df.index].values.flatten()
amr_drugs_red_df["multi_drug"] = amr_drugs_red_df.index.str.split("-").str.len() > 1

In [None]:
sns.set_context("paper")
plt.figure(figsize=(10, 10))
sns.scatterplot(
    data=amr_drugs_red_df,
    x="UMAP-1",
    y="UMAP-2",
    hue="class",
    palette="tab20",
    alpha=0.8,
    s=200,
)

# Add the drug structures
for drug in amr_drugs_red_df.index:
    txt = plt.text(
        amr_drugs_red_df.loc[drug, "UMAP-1"],
        amr_drugs_red_df.loc[drug, "UMAP-2"],
        drug,
        fontsize=10,
        ha="center",
        va="center",
    )
    txt.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="w")])

# Despine the plot
sns.despine()

# Add grid
plt.grid(True)

plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")

plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.savefig(
    "/fs/pool/pool-miranda/Projects/AMR/ConformalAMR/figures/DRIAMS_Mole-BERT_drug_embeddings.pdf",
    dpi=400,
    bbox_inches="tight",
)
plt.show()