In [1]:
from collections import defaultdict
import pandas as pd
import tmap as tm
from faerun import Faerun

from matplotlib.colors import ListedColormap

from tqdm import tqdm

tqdm.pandas()

# Load data

In [2]:
all_data = pd.read_csv("../data/fingerprints/combined_mhfp6.tsv", sep="\t")
all_data.head()

Unnamed: 0,bit0,bit1,bit2,bit3,bit4,bit5,bit6,bit7,bit8,bit9,...,bit2040,bit2041,bit2042,bit2043,bit2044,bit2045,bit2046,bit2047,cmp_id,label
0,53109374,13294028,17313015,13405020,159565048,166773519,112685388,23281340,74257363,41227994,...,86693380,23470166,13223091,13893646,72383491,138619944,10981154,123604138,OOYGSFOGFJDDHP-KMCOLRRFSA-N,acid-fast
1,2376200,75861701,8411880,265132626,17183886,48771247,26537035,264241068,149295683,164407546,...,21431025,1416250,104553653,280401855,71835899,176069654,91289868,137970302,XIPHLJFTBFXVBS-UHFFFAOYSA-N,fungi
2,4718263,19386904,104326033,8518591,15497897,69171867,65532846,11446788,98580395,35586042,...,18682934,1416250,11522298,41181622,96396185,6500986,106638228,43510588,OEFUWWDPRNNMRS-WDPNLORNSA-N,gram-negative
3,4718263,19386904,6003371,8518591,15497897,69171867,101889093,11446788,98580395,21639682,...,18682934,1416250,11522298,41181622,66023973,78748076,106638228,36539738,LBRXTHSVIORIGU-OLROFJLRSA-N,gram-positive
4,26679775,33988699,63689098,10196723,73651484,9407713,94168771,13298204,10766915,141650786,...,7031338,1416250,14178885,41181622,12309677,4412870,10981154,214764498,PHYLUFIYANLQSE-UHFFFAOYSA-N,gram-positive


In [3]:
amr_df = pd.read_csv(
    "../data/processed/combined_bioassay_data.tsv",
    sep="\t",
    usecols=["compound_smiles", "compound_inchikey", "compound_source"],
)

In [4]:
combined_df = pd.merge(all_data, amr_df, left_on="cmp_id", right_on="compound_inchikey")
combined_df.drop(columns=["compound_inchikey"], inplace=True)
combined_df.head(2)

Unnamed: 0,bit0,bit1,bit2,bit3,bit4,bit5,bit6,bit7,bit8,bit9,...,bit2042,bit2043,bit2044,bit2045,bit2046,bit2047,cmp_id,label,compound_smiles,compound_source
0,53109374,13294028,17313015,13405020,159565048,166773519,112685388,23281340,74257363,41227994,...,13223091,13893646,72383491,138619944,10981154,123604138,OOYGSFOGFJDDHP-KMCOLRRFSA-N,acid-fast,NC[C@H]1O[C@H](O[C@H]2[C@H](O)[C@@H](O[C@H]3O[...,chembl_34
1,2376200,75861701,8411880,265132626,17183886,48771247,26537035,264241068,149295683,164407546,...,104553653,280401855,71835899,176069654,91289868,137970302,XIPHLJFTBFXVBS-UHFFFAOYSA-N,fungi,C=C(C(=O)c1ccc(F)cc1)c1ccc(Cl)cc1Cl,chembl_34


# Generate the map

In [5]:
fingerprint_dict = defaultdict(list)

mol_smiles = combined_df["compound_smiles"]
mol_labels = combined_df["label"]
mol_source = combined_df["compound_source"]
mol_fingerprint_df = combined_df.drop(
    columns=["compound_smiles", "cmp_id", "label", "compound_source"]
)

for idx, row in tqdm(mol_fingerprint_df.iterrows(), total=mol_fingerprint_df.shape[0]):
    fingerprint_dict[mol_smiles[idx]].append(
        (row.to_numpy(), mol_labels[idx], mol_source[idx])
    )

100%|██████████| 74202/74202 [00:01<00:00, 53958.90it/s]


# Generate TMAP

In [6]:
lf = tm.LSHForest(2048, 128)
fps = []
labels = []
activity = []
ids = []

for smile, data_info in tqdm(fingerprint_dict.items()):
    fp, class_label, cmp_source = data_info[0]
    fps.append(tm.VectorUint(fp))

    labels.append(f"{smile}__{cmp_source}__{smile}".replace("'", ""))

    activity.append(class_label)

100%|██████████| 74202/74202 [00:17<00:00, 4276.08it/s]


In [7]:
lf.batch_add(fps)
lf.index()

In [8]:
cfg = tm.LayoutConfiguration()
x, y, s, t, _ = tm.layout_from_lsh_forest(lf, config=cfg)

In [9]:
custom_cmap = ListedColormap(
    ["#2ecc71", "#9b59b6", "#ecf0f1", "#e74c3c", "#e67e22", "#f1c40f", "#95a5a6"],
    name="custom",
)

# Visualize TMAP

In [10]:
f = Faerun(
    # clear_color="#FFFFFF",
    coords=False,
    view="front",
)

In [11]:
activity_mapper = {"acid-fast": 0, "fungi": 1, "gram-negative": 2, "gram-positive": 3}

activity_labels, activity_data = Faerun.create_categories(
    [activity_mapper[i] for i in activity]
)

In [12]:
f.add_scatter(
    "amr",
    {
        "x": x,
        "y": y,
        "c": [activity_data],
        "labels": labels,
    },
    selected_labels=["SMILES", "Source", "Smile"],
    shader="smoothCircle",
    colormap=[custom_cmap],
    categorical=[True],
    has_legend=True,
    legend_labels=[
        (0, "Acid-fast"),
        (1, "Fungi"),
        (2, "Gram-negative"),
        (3, "Gram-positive"),
    ],
    legend_title="Bacterial strain",
)

In [13]:
f.add_tree("amrtree", {"from": s, "to": t}, point_helper="amr")

In [None]:
labels[0]

In [14]:
f.plot(
    "amrkg_chemspace",
    template="smiles",
    path="../figures/",
    notebook_height=200,
)