这是 SPMM 的官方 GitHub 仓库。SPMM 是一个多模态分子预训练模型,用于协同理解分子结构与性质。 详细内容可参考以下论文: Bidirectional Generation of Structure and Properties Through a Single Molecular Foundation Model. (Nature Communications 2024)
分子结构以 SMILES 表示,我们使用 53 个简单化学性质构建分子的性质向量(PV)。
模型检查点和数据体积较大,未直接包含在本仓库中,可从这里下载。
data/:保存论文实验所使用的数据。(需要你自行创建该目录,并放入上方链接下载的数据。)Pretrain/:保存预训练好的 SPMM 检查点。(需要你自行创建该目录,并放入上方链接下载的检查点。)vocab_bpe_300.txt:SMILES 分词器使用的 SMILES token 表。property_name.txt:53 个化学性质的名称。normalize.pkl:构建 PV 时使用的 53 个化学性质的均值与标准差。calc_property.py:用于计算 53 个化学性质,并根据给定 SMILES 构建 PV。如果你要将 SPMM 预训练用于自定义 PV,请按需修改此文件。SPMM_models.py:SPMM 模型及其预训练代码。SPMM_pretrain.py:用于运行 SPMM 预训练。d_*.py:下游任务脚本。
运行 pip install -r requirements.txt 安装所需依赖。
参数既可以通过命令行传入,也可以直接在脚本中手动修改。
-
预训练
python SPMM_pretrain.py --data_path './data/pretrain.txt' -
PV 到 SMILES 生成
- batched:模型会读取
input_file中分子的 PV,并使用 k-beam search 生成具有这些 PV 的分子。生成结果会写入generated_molecules.txt。python d_pv2smiles_batched.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --input_file './data/pubchem_1k_unseen.txt' --k 2 - single:模型接收一个查询 PV,并使用 k-beam search 生成
n_generate个满足该 PV 的分子。生成结果会写入generated_molecules.txt。你需要先在p2s_input.csv中构建输入 PV,可参考仓库提供的四个示例。python d_pv2smiles_single.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --n_generate 1000 --stochastic True --k 2
- batched:模型会读取
-
SMILES 到 PV 生成
模型会读取
input_file中的查询分子,并生成对应的 PV。python d_smiles2pv.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --input_file './data/pubchem_1k_unseen.txt' -
MoleculeNet + DILI 预测任务
d_regression.py、d_classification.py和d_classification_multilabel.py分别对应回归、二分类和多标签分类任务。python d_regression.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --name 'bace' python d_classification.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --name 'bbbp' python d_classification_multilabel.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --name 'clintox' -
正向/逆向反应预测任务
d_rxn_prediction.py可在 USPTO-480k 和 USPTO-50k 数据集上执行正向反应预测与逆合成预测。例如:正向反应预测,不使用 beam search
python d_rxn_prediction.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --mode 'forward' --n_beam 1例如:逆反应预测,使用 k=3 的 beam search
python d_rxn_prediction.py --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --mode 'retro' --n_beam 3 -
使用分子/溶剂双输入的 FluoDB 荧光性质回归
项目已经可以扩展为支持本地 FluoDB 风格的数据划分,其中每个样本包含
smiles、solvent以及一个目标列。当前仓库已兼容本地FlourDB/目录结构:abs_train.csv,abs_valid.csv,abs_test.csvemi_train.csv,emi_valid.csv,emi_test.csvplqy_train.csv,plqy_valid.csv,plqy_test.csve_train.csv,e_valid.csv,e_test.csv
示例:
python d_fluodb_regression.py --task emi --data_dir './FlourDB' --checkpoint './Pretrain/checkpoint_SPMM.ckpt' --device cuda可选任务包括
abs、emi、plqy和e。脚本会分别编码荧光团与溶剂,再融合两者嵌入,并将最佳检查点保存到./output/FluoDB/。 -
面向已训练 SPMM 回归器的 FluoDB 原子级解释
在训练好目标相关模型后,可以通过原子级 token masking 估计哪些荧光团原子会推动预测性质升高或降低。下面给出发射模型的解释示例:
python d_fluodb_explain.py --targets emi --checkpoint './output/FluoDB/emi_best.pth' --fluorophore 'Cc1ccc(C(=O)c2cc(C(=O)O)cc3c2CCN3c2c(Cl)cccc2Cl)cc1' --solvent 'O' --device cpu若要同时解释四个目标,请先训练四个检查点,然后运行:
python d_fluodb_explain.py --targets abs emi plqy e --model_dir './output/FluoDB' --fluorophore 'Cc1ccc(C(=O)c2cc(C(=O)O)cc3c2CCN3c2c(Cl)cccc2Cl)cc1' --solvent 'O'脚本会将 CSV 归因表和 PNG 原子高亮图写入
./pred/spmm_fluodb_explain/。 -
优先级最高的优化我建议按这个顺序做:
多任务学习:现在 abs/emi/plqy/e 是四个单独模型。FluoDB 里这些性质有相关性,改成一个模型同时预测 4 个目标,通常会比单任务更稳,尤其是数据不大时。
更强的分子-溶剂融合:现在只是把两个 CLS embedding 拼起来,交互很浅。可以加入 cross-attention、bilinear pooling,或者至少加一个小的交互层,让分子和溶剂 token 级别互相注意,而不是只在最后拼向量。
使用 solvent 物化性质:溶剂只用 SMILES 编码可能不够。可以额外加入介电常数、极性、折射率、H-bond donor/acceptor、ET(30) 等溶剂描述符,再和 solvent embedding 拼接,荧光性质预测会更合理。
SMILES 增广:训练时对 fluorophore 做 randomized SMILES augmentation,可以提升泛化。项目里反应任务已经用过类似思路,FluoDB 这边目前没有做。
训练策略优化:可以加 layer-wise learning rate、先 freeze encoder 再 unfreeze、梯度裁剪、AMP、early stopping、Huber loss。当前只有保存 best checkpoint,没有真正 early stopping。
数据划分更严格:如果当前 train/valid/test 是随机切分,指标可能偏乐观。建议做 scaffold split,甚至 solvent split,用来测试模型对新骨架、新溶剂的泛化能力。
集成模型/不确定性:训练 3-5 个不同 seed 的模型做 ensemble,筛选生成分子时会比单模型可靠很多,也能给出预测方差。
针对目标做变换:plqy 是有界性质,直接 MSE 可能不理想;可以尝试 logit 变换或 beta-like 处理。abs/emi 也可以检查异常值,用 Huber loss 比 MSE 更抗离群点。
xbert.py与scheduler中带交叉注意力层的 BERT 代码修改自 ALBEF。- SMILES 增广代码来自 pysmilesutils。
