In [1]:
import pickle
import numpy as np
import tmap as tm
import pandas as pd
import scipy.stats as ss
from rdkit.Chem import AllChem
from mhfp.encoder import MHFPEncoder
from faerun import Faerun
from collections import Counter
from matplotlib.colors import ListedColormap
from matplotlib import pyplot as plt
from pathlib import Path
from loguru import logger

In [2]:
data_dir = Path("../data/processed")
dyes_path = data_dir / "dyes.csv"
coconut_path = data_dir / "coconut.csv"

dyes_df = pd.read_csv(dyes_path)
dyes_df["label"] = "dyes"
coconut_df = pd.read_csv(coconut_path)
coconut_df["label"] = "non-dyes"

In [3]:
print(f"dyes: {len(dyes_df)}, coconut: {len(coconut_df)}")
coconut_df = coconut_df.sample(n=len(dyes_df), random_state=42)
df = pd.concat([dyes_df, coconut_df], ignore_index=True).reset_index(drop=True)
print(f"total after sampling: {len(df)}")

dyes: 26255, coconut: 583991
total after sampling: 52510


In [4]:
df = df[df["MW"] > 100]
print(f"total after filtering: {len(df)}")

total after filtering: 52437


In [5]:
enc = MHFPEncoder(1024)
lf = tm.LSHForest(1024, 64)

fps = []
hac = []
c_frac = []
ring_atom_frac = []
largest_ring_size = []

In [6]:
for i, row in df.iterrows():
    if i != 0 and i % 1000 == 0:
        print(100 * i / len(df))
    mol = AllChem.MolFromSmiles(row["smiles"])
    atoms = mol.GetAtoms()
    size = mol.GetNumHeavyAtoms()
    n_c = 0
    n_ring_atoms = 0
    for atom in atoms:
        if atom.IsInRing():
            n_ring_atoms += 1
        if atom.GetSymbol().lower() == "c":
            n_c += 1

    c_frac.append(n_c / size)
    ring_atom_frac.append(n_ring_atoms / size)
    sssr = AllChem.GetSymmSSSR(mol)
    if len(sssr) > 0:
        largest_ring_size.append(max([len(s) for s in sssr]))
    else:
        largest_ring_size.append(0)
    hac.append(size)
    fps.append(tm.VectorUint(enc.encode_mol(mol)))

1.907050365200145
3.81410073040029
5.721151095600435
7.62820146080058
9.535251826000724
11.44230219120087
13.349352556401014
15.25640292160116
17.163453286801303
19.07050365200145
20.977554017201594
22.88460438240174
24.791654747601886
26.698705112802028
28.605755478002173
30.51280584320232
32.419856208402464
34.326906573602606
36.233956938802756
38.1410073040029
40.04805766920305
41.95510803440319
43.86215839960333
45.76920876480348
47.67625913000362
49.58330949520377
51.49035986040391
53.397410225604055
55.304460590804204
57.211510956004346
59.118561321204496
61.02561168640464
62.93266205160478
64.83971241680493




66.74676278200508
68.65381314720521
70.56086351240536
72.46791387760551
74.37496424280565
76.2820146080058
78.18906497320594
80.0961153384061
82.00316570360623
83.91021606880638
85.81726643400653
87.72431679920666
89.63136716440681
91.53841752960696
93.4454678948071
95.35251826000724
97.2595686252074
99.16661899040754


In [7]:
# 添加指纹并构建索引
lf.batch_add(fps)
lf.index()

# 设置缓存路径
tmp_dir = Path("../tmp")
tmp_dir.mkdir(parents=True, exist_ok=True)  # 确保目录存在
cache_path = tmp_dir / "props.pickle"
lf_path = tmp_dir / "lf.dat"
force_write = True

if not cache_path.exists() or force_write:
    # 保存属性数据
    with open(cache_path, "wb") as f:
        pickle.dump(
            (hac, c_frac, ring_atom_frac, largest_ring_size),
            f,
            protocol=pickle.HIGHEST_PROTOCOL,
        )
    lf.store(str(lf_path))
else:
    # 恢复索引和属性数据
    lf.restore(str(lf_path))
    with open(cache_path, "rb") as f:
        hac, c_frac, ring_atom_frac, largest_ring_size = pickle.load(f)

In [8]:
c_frak_ranked = ss.rankdata(np.array(c_frac) / max(c_frac)) / len(c_frac)
cfg = tm.LayoutConfiguration()
cfg.node_size = 1 / 26
cfg.mmm_repeats = 2
cfg.sl_extra_scaling_steps = 5
cfg.k = 20
cfg.sl_scaling_type = tm.RelativeToAvgLength
x, y, s, t, _ = tm.layout_from_lsh_forest(lf, cfg)

In [9]:
type_labels, type_data = Faerun.create_categories(df["label"])

In [10]:
tab_10 = plt.cm.get_cmap("tab10")
colors = [i for i in tab_10.colors]
colors[7] = (0.17, 0.24, 0.31)
tab_10.colors = tuple(colors)

  tab_10 = plt.cm.get_cmap("tab10")


In [11]:
f = Faerun(view="front", coords=False)
f.add_scatter(
    "dye_atlas",
    {
        "x": x,
        "y": y,
        "c": [
            type_data,
            hac,
            c_frak_ranked,
            ring_atom_frac,
            largest_ring_size,
        ],
        "labels": df["smiles"],
    },
    shader="smoothCircle",
    point_scale=2.0,
    max_point_size=20,
    legend_labels=[type_labels],
    categorical=[True, False, False, False, False],
    colormap=[tab_10, "rainbow", "rainbow", "rainbow", "Blues"],
    series_title=[
        "Type",
        "HAC",
        "C Frac",
        "Ring Atom Frac",
        "Largest Ring Size",
    ],
    has_legend=True,
)

In [12]:
f.add_tree("dye_atlas_tree", {"from": s, "to": t}, point_helper="dye_atlas")
f.plot(template="smiles")