-
Notifications
You must be signed in to change notification settings - Fork 4
/
canonicalize_prod.py
151 lines (121 loc) · 5.11 KB
/
canonicalize_prod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Canonicalize the product SMILES, and then use substructure matching to infer
the correspondence to the original atom-mapped order. This correspondence is then
used to renumber the reactant atoms.
"""
from rdkit import Chem
import os
import argparse
import pandas as pd
def canonicalize_prod(p):
import copy
p = copy.deepcopy(p)
p = canonicalize(p)
p_mol = Chem.MolFromSmiles(p)
for atom in p_mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx() + 1)
p = Chem.MolToSmiles(p_mol)
return p
def canonicalize(smiles):
try:
tmp = Chem.MolFromSmiles(smiles)
except:
print('no mol', flush=True)
return smiles
if tmp is None:
return smiles
tmp = Chem.RemoveHs(tmp)
[a.ClearProp('molAtomMapNumber') for a in tmp.GetAtoms()]
return Chem.MolToSmiles(tmp)
def fix_charge(mol):
# fix simple atomic charge, eg. 'COO-', 'CH3O-', '(S=O)O-', '-NH3+', 'NH4+', 'NH2+', 'S-'
for atom in mol.GetAtoms():
explicit_hs = atom.GetNumExplicitHs()
charge = atom.GetFormalCharge()
bond_vals = int(sum([b.GetBondTypeAsDouble()
for b in atom.GetBonds()]))
if atom.GetSymbol() == 'O' and bond_vals == 1 and charge == -1 and explicit_hs == 0:
if atom.GetNeighbors()[0].GetSymbol() != 'N':
atom.SetFormalCharge(0)
atom.SetNumExplicitHs(1)
if atom.GetSymbol() == 'N' and bond_vals == 1 and charge == 1 and explicit_hs == 3:
atom.SetFormalCharge(0)
atom.SetNumExplicitHs(2)
if atom.GetSymbol() == 'N' and bond_vals == 0 and charge == 1 and explicit_hs == 4:
atom.SetFormalCharge(0)
atom.SetNumExplicitHs(3)
if atom.GetSymbol() == 'N' and bond_vals == 2 and charge == 1 and explicit_hs == 2:
atom.SetFormalCharge(0)
atom.SetNumExplicitHs(1)
if atom.GetSymbol() == 'S' and charge == -1 and explicit_hs == 0 and bond_vals == 1:
atom.SetNumExplicitHs(1)
atom.SetFormalCharge(0)
return mol
def infer_correspondence(p):
orig_mol = Chem.MolFromSmiles(p)
canon_mol = Chem.MolFromSmiles(canonicalize_prod(p))
matches = list(canon_mol.GetSubstructMatches(orig_mol))
idx_amap = {atom.GetIdx(): atom.GetAtomMapNum()
for atom in orig_mol.GetAtoms()}
correspondence = {}
if matches:
for idx, match_idx in enumerate(matches[0]):
match_anum = canon_mol.GetAtomWithIdx(match_idx).GetAtomMapNum()
old_anum = idx_amap[idx]
correspondence[old_anum] = match_anum
return correspondence
def remap_rxn_smi(rxn_smi):
r, p = rxn_smi.split(">>")
canon_mol = Chem.MolFromSmiles(canonicalize_prod(p))
correspondence = infer_correspondence(p)
rmol = Chem.MolFromSmiles(r)
if rmol is None or rmol.GetNumAtoms() <= 1:
return rxn_smi, None
for atom in rmol.GetAtoms():
atomnum = atom.GetAtomMapNum()
if atomnum in correspondence:
newatomnum = correspondence[atomnum]
atom.SetAtomMapNum(newatomnum)
max_amap = max([atom.GetAtomMapNum() for atom in rmol.GetAtoms()])
for atom in rmol.GetAtoms():
if atom.GetAtomMapNum() == 0:
atom.SetAtomMapNum(max_amap + 1)
max_amap += 1
# fix simple atomic charge, eg. 'COO-', 'CH3O-', '(S=O)O-', '-NH3+', 'NH4+', 'NH2+', 'S-'
rmol = fix_charge(rmol)
canon_mol = fix_charge(canon_mol)
rmol = Chem.MolFromSmiles(Chem.MolToSmiles(rmol))
rxn_smi_new = Chem.MolToSmiles(rmol) + ">>" + Chem.MolToSmiles(canon_mol)
return rxn_smi_new, correspondence
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='USPTO_50k',
help='dataset: USPTO_50k or USPTO_full')
parser.add_argument('--mode', type=str, default='train',
help='Type of dataset being prepared: train or valid or test')
args = parser.parse_args()
args.dataset = args.dataset.lower()
datadir = f'data/{args.dataset}/'
new_file = f'canonicalized_{args.mode}.csv'
filename = f'raw_{args.mode}.csv'
df = pd.read_csv(os.path.join(datadir, filename))
print(f"Processing file of size: {len(df)}")
if args.dataset == 'uspto_50k':
new_dict = {'id': [], 'class': [], 'reactants>reagents>production': []}
else:
new_dict = {'id': [], 'reactants>reagents>production': []}
for idx in range(len(df)):
element = df.loc[idx]
if args.dataset == 'uspto_50k':
uspto_id, class_id, rxn_smi = element['id'], element['class'], element['reactants>reagents>production']
else:
uspto_id, rxn_smi = element['id'], element['reactants>reagents>production']
rxn_smi_new, _ = remap_rxn_smi(rxn_smi)
new_dict['id'].append(uspto_id)
if args.dataset == 'uspto_50k':
new_dict['class'].append(class_id)
new_dict['reactants>reagents>production'].append(rxn_smi_new)
new_df = pd.DataFrame.from_dict(new_dict)
new_df.to_csv(os.path.join(datadir, new_file), index=False)
if __name__ == "__main__":
main()