In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from src.datasets import Dataset
from src.utils import complete_mappings
from torch.utils.data import Dataset as Dataset_torch
from utils.megan_utils import FEATURIZER_INITIALIZERS
import torch
from src.feat.megan_graph import MeganTrainingSamplesFeaturizer, get_sample_data_path
from src.split.basic_splits import DefaultSplit

from scipy import sparse
from src.datasets.uspto_hard import UsptoHard
from src.datasets.uspto_50k import Uspto50k
from tqdm import tqdm
import os
import pandas as pd
from rdkit import Chem

In [2]:
# key = 'megan-baseline'
key = 'uspto_50k_motif'
data_dir = '/gaozhangyang/experiments/MotifRetro/data'
dataset = Uspto50k(data_dir)
DEFAULT_SPLIT = DefaultSplit()
featurizer = MeganTrainingSamplesFeaturizer(n_jobs=128, max_n_steps=50,
                                            split=DEFAULT_SPLIT, key=key,
                                            action_order='bfs_randat',
                                            use_motif_action=1,
                                            use_motif_feature=0,
                                            vocab_path='/gaozhangyang/chenxran/MotifRetro/data/uspto_hard/uspto_bpe_300.txt')

action_vocab = featurizer.get_actions_vocabulary(dataset.feat_dir)
sample_data = sparse.load_npz(get_sample_data_path(featurizer.dir(dataset.feat_dir)))
# change datatype to int
sample_data = sample_data.astype(np.int32).toarray()

action_frequency = {}
action_vocab['action_tuples']
for i in tqdm(range(sample_data.shape[0])):
    # pass
    action_ind, atom_map1, atom_map2, n_nodes, is_atom_action = \
    sample_data[i, 0], sample_data[i, 1], sample_data[i, 2], sample_data[i, 3], sample_data[i, 4]
    this_action_ind, a1, a2 = action_ind, atom_map1, atom_map2

    action = str(action_vocab['action_tuples'][this_action_ind])
    
    if action not in action_frequency:
        action_frequency[action] = 1
    action_frequency[action] += 1

action_frequency = sorted(action_frequency.items(), key=lambda x: x[1], reverse=True)

100%|██████████| 800256/800256 [00:01<00:00, 597438.48it/s]


In [5]:
action_frequency[20:]

[("('add_motif', '*C(=O)OC(C)(C)C')", 853),
 ("('add_motif', '*OC(C)(C)C')", 738),
 ("('change_atom', (-1, 0, 0, 0))", 722),
 ("('change_atom', (1, 0, 0, 0))", 588),
 ("('add_motif', '*OCc1ccccc1')", 566),
 ("('change_bond', (12, 0))", 561),
 ("('add_motif', '*c1ccccc1')", 543),
 ("('add_motif', '*S(=O)=O')", 495),
 ("('add_motif', '*C(=O)OCC')", 489),
 ("('change_bond', (3, 0))", 480),
 ("('add_motif', '*NC(=O)OC(C)(C)C')", 454),
 ("('add_atom', ((1, 0), (8, -1, 0, 0, 0)))", 404),
 ("('add_atom', ((1, 0), (12, 1, 0, 0, 0)))", 291),
 ("('add_motif', '*Cc1ccccc1')", 274),
 ("('add_atom', ((1, 0), (50, 0, 0, 0, 0)))", 254),
 ("('add_motif', '*C(F)(F)F')", 243),
 ("('add_motif', '*[Si](C)(C)C')", 241),
 ("('add_motif', '*C(=O)OCc1ccccc1')", 238),
 ("('add_motif', '*OC')", 216),
 ("('add_motif', '*=C(C)OC(C)=O')", 208),
 ("('add_motif', '*CCCC')", 207),
 ("('add_motif', '*[Sn](CCCC)CCCC')", 205),
 ("('change_atom', (0, 0, 0, 1))", 185),
 ("('add_motif', '*C(C)(C)C')", 184),
 ("('add_motif'

In [18]:
count = 0
table = pd.read_csv(os.path.join(featurizer.dir(dataset.feat_dir), 'metadata.csv'))
with open(os.path.join(featurizer.dir(dataset.feat_dir), 'sanity-check.txt'), 'w') as f:
    for i in tqdm(range(len(table))):
        # print(table.loc[i, 'final_smi'])
        # print(table.loc[i, 'target_smi'])
        final_smi = table.loc[i, 'final_smi']
        target_smi = table.loc[i, 'target_smi']

        final_mol = Chem.MolFromSmiles(final_smi)
        if final_mol is None:
            f.write(f"No. {i} exmaples, final_smiles is invalid.\n")
            f.write(f"final_smi: {final_smi}\n")
            f.write(f"\n")
            count += 1
            continue

        for a in final_mol.GetAtoms():
            a.SetAtomMapNum(0)

        target_mol = Chem.MolFromSmiles(target_smi)
        if target_mol is None:
            f.write(f"No. {i} exmaples, target_smiles is invalid.\n")
            f.write(f"target_smi: {target_smi}\n")
            f.write(f"\n")
            count += 1
            continue

        for a in target_mol.GetAtoms():
            a.SetAtomMapNum(0)  

        final_smi_without_mapping = Chem.MolToSmiles(final_mol)
        target_smi_without_mapping = Chem.MolToSmiles(target_mol)
        if final_smi_without_mapping == target_smi_without_mapping:
            pass
        else:
            f.write(f"No. {i} exmaples TARGET and FINAL not same.\n")
            f.write(f"target_smiles: {target_smi}\n")
            f.write(f" final_smiles: {final_smi}\n")
            f.write(f"target_smiles w/o mapping: {target_smi_without_mapping}\n")
            f.write(f" final_smiles w/o mapping: {final_smi_without_mapping}\n")
            f.write("\n")
            count += 1
    print(count)

100%|██████████| 57659/57659 [00:47<00:00, 1212.56it/s]

4905





In [25]:
key = 'truncation_debug-truncate-top-50-without-atombond'
data_dir = '/gaozhangyang/chenxran/MotifRetro/data'
dataset = UsptoHard(data_dir)
DEFAULT_SPLIT = DefaultSplit()
featurizer = MeganTrainingSamplesFeaturizer(n_jobs=128, max_n_steps=50,
                                            split=DEFAULT_SPLIT, key=key,
                                            action_order='bfs_randat',
                                            use_motif_action=1,
                                            use_motif_feature=0,
                                            vocab_path='/gaozhangyang/chenxran/MotifRetro/data/uspto_hard/uspto_bpe_300.txt')

action_vocab = featurizer.get_actions_vocabulary(dataset.feat_dir)
sample_data = sparse.load_npz(get_sample_data_path(featurizer.dir(dataset.feat_dir)))
# change datatype to int
sample_data = sample_data.astype(np.int32).toarray()

action_frequency = {}
action_vocab['action_tuples']
for i in tqdm(range(sample_data.shape[0])):
    # pass
    action_ind, atom_map1, atom_map2, n_nodes, is_atom_action = \
    sample_data[i, 0], sample_data[i, 1], sample_data[i, 2], sample_data[i, 3], sample_data[i, 4]
    this_action_ind, a1, a2 = action_ind, atom_map1, atom_map2

    action = str(action_vocab['action_tuples'][this_action_ind])
    
    if action not in action_frequency:
        action_frequency[action] = 1
    action_frequency[action] += 1

action_frequency = sorted(action_frequency.items(), key=lambda x: x[1], reverse=True)

100%|██████████| 3203300/3203300 [00:05<00:00, 541735.45it/s]


In [None]:
action_frequency

In [26]:
count = 0
table = pd.read_csv(os.path.join(featurizer.dir(dataset.feat_dir), 'metadata.csv'))
with open(os.path.join(featurizer.dir(dataset.feat_dir), 'sanity-check.txt'), 'w') as f:
    for i in tqdm(range(len(table))):
        # print(table.loc[i, 'final_smi'])
        # print(table.loc[i, 'target_smi'])
        final_smi = table.loc[i, 'final_smi']
        target_smi = table.loc[i, 'target_smi']

        final_mol = Chem.MolFromSmiles(final_smi)
        if final_mol is None:
            f.write(f"No. {i} exmaples, final_smiles is invalid.\n")
            f.write(f"final_smi: {final_smi}\n")
            f.write(f"\n")
            count += 1
            continue

        for a in final_mol.GetAtoms():
            a.SetAtomMapNum(0)

        target_mol = Chem.MolFromSmiles(target_smi)
        if target_mol is None:
            f.write(f"No. {i} exmaples, target_smiles is invalid.\n")
            f.write(f"target_smi: {target_smi}\n")
            f.write(f"\n")
            count += 1
            continue

        for a in target_mol.GetAtoms():
            a.SetAtomMapNum(0)  

        final_smi_without_mapping = Chem.MolToSmiles(final_mol)
        target_smi_without_mapping = Chem.MolToSmiles(target_mol)
        if final_smi_without_mapping == target_smi_without_mapping:
            pass
        else:
            f.write(f"No. {i} exmaples TARGET and FINAL not same.\n")
            f.write(f"target_smiles: {target_smi}\n")
            f.write(f" final_smiles: {final_smi}\n")
            f.write(f"target_smiles w/o mapping: {target_smi_without_mapping}\n")
            f.write(f" final_smiles w/o mapping: {final_smi_without_mapping}\n")
            f.write("\n")
            count += 1
    print(count)

100%|██████████| 57579/57579 [00:46<00:00, 1229.08it/s]

6729





In [1]:
import torch


ckpt = torch.load("/gaozhangyang/experiments/MotifRetro/results/uspto_50k_frag_motif_gate_num_embed_linear2023-01-11 09:20:58.942730/checkpoint.pth")

In [5]:
print(ckpt.keys())

odict_keys(['module.atom_embedding.weight', 'module.atom_embedding.bias', 'module.bond_embedding.weight', 'module.bond_embedding.bias', 'module.degree_embedding.weight', 'module.synthon_embedding.weight', 'module.encoder.MultiHeadGraphConv_1.atoms_att.weight', 'module.encoder.MultiHeadGraphConv_1.atoms_att.bias', 'module.encoder.MultiHeadGraphConv_1.v_layer.weight', 'module.encoder.MultiHeadGraphConv_1.v_layer.bias', 'module.encoder.MultiHeadGraphConv_1.final_att.weight', 'module.encoder.MultiHeadGraphConv_1.final_att.bias', 'module.encoder.MultiHeadGraphConv_1.conv_layer.weight', 'module.encoder.MultiHeadGraphConv_1.conv_layer.bias', 'module.encoder.MultiHeadGraphConv_1.motif_gate.weight', 'module.encoder.MultiHeadGraphConv_1.motif_gate.bias', 'module.encoder.MultiHeadGraphConv_1.num_embed.weight', 'module.encoder.MultiHeadGraphConv_1.num_embed.bias', 'module.encoder.MultiHeadGraphConv_2.atoms_att.weight', 'module.encoder.MultiHeadGraphConv_2.atoms_att.bias', 'module.encoder.MultiHead

In [4]:
new_ckpt = {}
for k,v in ckpt.items():
    new_ckpt[k.replace('module.', '')] = v
torch.save(new_ckpt, 'checkpoint.pth')

In [1]:
# open json /gaozhangyang/experiments/MotifRetro/data/uspto_50k/decompose_edit_path.json
import json

with open("/gaozhangyang/experiments/MotifRetro/data/uspto_50k/decompose_edit_path.json", "r") as f:
    lg_action_paths = json.load(f)

with open("/gaozhangyang/experiments/MotifRetro/data/uspto_50k/feat/uspto_50k_frag/all_actions.json", "r") as f:
    actions = json.load(f)


In [4]:
for k, v in lg_action_paths.items():
    for i, action in v['path'].items():
        action
    print(k, v['path'])
    break

*C(=O)OC(C)(C)C {'0': [{'*C(=O)OC(C)(C)C': {'attach': 0, 'am_smi': '*[C:1](=[O:2])[O:3][C:4]([CH3:5])([CH3:6])[CH3:7]', 'freq': 2416}}]}


In [5]:
actions

[['add_motif', '*C(=O)OC(C)(C)C'],
 ['stop', None],
 ['change_bond', [None, None]],
 ['add_motif', '*Cl'],
 ['add_motif', '*Br'],
 ['add_motif', '*B(O)O'],
 ['change_atom', [0, 0, 1, 1]],
 ['change_atom', [0, 1, 1, 0]],
 ['add_motif', '*O'],
 ['add_motif', '*OS(=O)(=O)C(F)(F)F'],
 ['add_motif', '*=O'],
 ['change_bond', [3, 0]],
 ['change_bond', [1, 0]],
 ['add_motif', '*B1OC(C)(C)C(C)(C)O1'],
 ['change_atom', [0, 2, 1, 0]],
 ['add_motif', '*[Cu]'],
 ['add_motif', '*C(C)(C)C'],
 ['add_motif', '*CC'],
 ['add_motif', '*I'],
 ['change_atom', [1, 0, 0, 0]],
 ['add_motif', 'O=[*+][O-]'],
 ['add_motif', '*C'],
 ['change_bond', [2, 3]],
 ['change_bond', [2, 0]],
 ['add_motif', '*OS(C)(=O)=O'],
 ['add_motif', '*=[N+]=[N-]'],
 ['add_motif', '*Cc1ccccc1'],
 ['add_motif', '*[Sn](CCCC)(CCCC)CCCC'],
 ['change_atom', [0, 0, 0, 0]],
 ['add_motif', '*C(=O)OCc1ccccc1'],
 ['add_motif', '*F'],
 ['add_motif', '*OC(=O)OC(C)(C)C'],
 ['change_atom', [0, 0, 0, 1]],
 ['add_motif', '*[Si](C)(C)C(C)(C)C'],
 ['cha

In [2]:
import torch

for i in range(2):
    bs_atom = torch.load(f"beam-search_atom_{i}.pt")
    bs_bond = torch.load(f"beam-search_bond_{i}.pt")

    valid_atom = torch.load(f"valid_atom_{i}.pt")
    valid_bond = torch.load(f"valid_bond_{i}.pt")

    print(f"step {i}")
    print(bs_atom.shape, valid_atom.shape)
    print(bs_bond.shape, valid_bond.shape)
    print((bs_atom.cpu() != valid_atom.cpu()).sum())
    print((bs_bond.cpu() != valid_bond.cpu()).sum())

step 0
torch.Size([12, 39]) torch.Size([12, 39])
torch.Size([56, 14]) torch.Size([56, 14])
tensor(0)
tensor(0)
step 1
torch.Size([15, 39]) torch.Size([15, 39])
torch.Size([71, 14]) torch.Size([71, 14])
tensor(6)
tensor(8)


In [23]:
id = 60
(bs_bond.cpu() != valid_bond.cpu()).sum(1)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 4, 0, 0, 4, 0])

In [16]:
valid_bond

tensor([[0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0

In [8]:
from rdkit import Chem


mol = Chem.MolFromSmiles("[CH3:1][C:2]([CH3:3])([CH3:4])[O:5][C:6](=[O:7])[O:35][C:33]([O:32][C:29]([CH3:28])([CH3:30])[CH3:31])=[O:34].[CH3:8][C@@H:9]1[CH2:10][C@@H:11]([CH2:12][c:13]2[cH:14][cH:15][c:16](-[c:17]3[cH:18][cH:19][cH:20][cH:21][cH:22]3)[cH:23][cH:24]2)[NH:25][C:26]1=[O:27]")
for atom in mol.GetAtoms():
    atom.SetAtomMapNum(0)
print(Chem.MolToSmilses(mol))
print(Chem.CanonSmiles(Chem.MolToSmiles(mol)))
print(Chem.CanonSmiles("CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.C[C@H]1C[C@H](Cc2ccc(-c3ccccc3)cc2)NC1=O"))

CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.C[C@@H]1C[C@@H](Cc2ccc(-c3ccccc3)cc2)NC1=O
CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.C[C@@H]1C[C@@H](Cc2ccc(-c3ccccc3)cc2)NC1=O
CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.C[C@H]1C[C@H](Cc2ccc(-c3ccccc3)cc2)NC1=O
