-
Notifications
You must be signed in to change notification settings - Fork 8
/
eval_rdkit_pkl.py
150 lines (125 loc) · 6.31 KB
/
eval_rdkit_pkl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Evaluate for rdkit mols generated by other methods
import torch
import argparse
import pickle
from datasets.qm9_dataset import QM9Dataset
from datasets.geom_dataset import GeomDrugDataset
from datasets.datasets_config import get_dataset_info
from evaluation import *
def rdmol_process(mols, dataset_info, only_2D=False):
from rdkit.Chem.rdchem import BondType as BT
bond_encoder = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
# if only_2D: return None, atom_type, edge_type, zero; else return pos, atom_type
processed_list = []
atom_encoder = dataset_info['atom_encoder']
for mol in mols:
N = mol.GetNumAtoms()
atom_type = torch.tensor([atom_encoder[atom.GetSymbol()] for atom in mol.GetAtoms()])
if not only_2D:
pos = mol.GetConformer().GetPositions()
else:
pos = None
if only_2D:
edge_types = torch.zeros((N, N))
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
bond_type = bond.GetBondType()
order = bond_encoder[bond_type]
edge_types[start, end] = order
edge_types[end, start] = order
fc = torch.tensor([atom.GetFormalCharge() for atom in mol.GetAtoms()])
processed_list.append((pos, atom_type, edge_types, fc))
else:
processed_list.append((pos, atom_type))
return processed_list
if __name__ == "__main__":
from rdkit import RDLogger
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# random.seed(seed)
# Ignore info output by RDKit
RDLogger.DisableLog('rdApp.error')
RDLogger.DisableLog('rdApp.warning')
parser = argparse.ArgumentParser()
parser.add_argument('--pkl_path', type=str, default='../generated_samples/qm9_gschnet.pkl')
parser.add_argument('--dataset_name', type=str, default='qm9', help="'qm9', 'Geom_Drugs'")
parser.add_argument('--type', type=str, default='3D', help="'3D', '2D', 'both'")
parser.add_argument('--sub_geometry', type=eval, default=False, help='Substructure Geometry Evaluation.')
parser.add_argument('--root_path', type=str, default='data/', help='Data path')
args, unparsed_args = parser.parse_known_args()
root_path = args.root_path
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# Dataset
if args.dataset_name == 'qm9':
dataset_root_path = root_path + 'QM9'
dataset = QM9Dataset(dataset_root_path)
dataset_info = get_dataset_info('qm9_with_h')
elif args.dataset_name == 'Geom_Drugs':
dataset_root_path = root_path + 'geom'
dataset = GeomDrugDataset(dataset_root_path, 'data_geom_drug_1.pt')
dataset_info = get_dataset_info('geom_with_h_1')
else:
raise ValueError("Invalid dataset name!")
# Split dataset
split_idx = dataset.get_idx_split()
train_ds = dataset.index_select(split_idx['train'])
test_ds = dataset.index_select(split_idx['test'])
train_mols = [train_ds[i].rdmol for i in range(len(train_ds))]
test_mols = [test_ds[i].rdmol for i in range(len(test_ds))]
# Build Evaluation metrics
EDM_metric = get_edm_metric(dataset_info, train_mols)
EDM_metric_2D = get_2D_edm_metric(dataset_info, train_mols)
mose_metric = get_moses_metrics(test_mols, n_jobs=32, device=device)
if args.sub_geometry:
sub_geo_mmd_metric = get_sub_geometry_metric(test_mols, dataset_info, dataset_root_path)
# Read pickles
with open(args.pkl_path, 'rb') as f:
samples = pickle.load(f)
print(args)
if args.type == '3D' or args.type == 'both':
# convert samples to processed mols
processed_mols = rdmol_process(samples, dataset_info, False)
# EDM stability evaluation metrics
stability_res, rdkit_res, sample_rdmols = EDM_metric(processed_mols)
print('Number of molecules: %d' % len(sample_rdmols))
print("Metric-3D || atom stability: %.4f, mol stability: %.4f, validity: %.4f, complete: %.4f," % (
stability_res['atom_stable'], stability_res['mol_stable'], rdkit_res['Validity'], rdkit_res['Complete']))
# Mose evaluation metrics
mose_res = mose_metric(sample_rdmols)
print("Metric-3D || FCD: %.4f" % (mose_res['FCD']))
# 3D geometry
if args.sub_geometry:
if args.type == 'both':
sub_geo_mmd_res = sub_geo_mmd_metric(samples)
else:
sub_geo_mmd_res = sub_geo_mmd_metric(sample_rdmols)
print("Metric-Align || Bond Length MMD: %.4f, Bond Angle MMD: %.4f, Dihedral Angle MMD: %.6f" % (
sub_geo_mmd_res['bond_length_mean'], sub_geo_mmd_res['bond_angle_mean'],
sub_geo_mmd_res['dihedral_angle_mean']))
# ## bond length
# bond_length_str = ''
# for sym in dataset_info['top_bond_sym']:
# bond_length_str += f"{sym}: %.4f " % sub_geo_mmd_res[sym]
# print(bond_length_str)
# ## bond angle
# bond_angle_str = ''
# for sym in dataset_info['top_angle_sym']:
# bond_angle_str += f'{sym}: %.4f ' % sub_geo_mmd_res[sym]
# print(bond_angle_str)
# ## dihedral angle
# dihedral_angle_str = ''
# for sym in dataset_info['top_dihedral_sym']:
# dihedral_angle_str += f'{sym}: %.6f ' % sub_geo_mmd_res[sym]
# print(dihedral_angle_str)
if args.type == '2D' or args.type == 'both':
# convert samples to processed mols
processed_mols = rdmol_process(samples, dataset_info, True)
stability_res, rdkit_res, complete_rdmols = EDM_metric_2D(processed_mols)
print("Metric-2D || atom stability: %.4f, mol stability: %.4f, validity: %.4f, complete: %.4f,"
" valid & unique: %.4f, valid & unique & novelty: %.4f" % (stability_res['atom_stable'], stability_res['mol_stable'],
rdkit_res['Validity'], rdkit_res['Complete'], rdkit_res['Unique'], rdkit_res['Novelty']))
mose_res = mose_metric(complete_rdmols)
print("Metric-2D || FCD: %.4f, SNN: %.4f, Frag: %.4f, Scaf: %.4f, IntDiv: %.4f" % (mose_res['FCD'],
mose_res['SNN'], mose_res['Frag'], mose_res['Scaf'], mose_res['IntDiv']))