In [29]:
# 写一个函数，入参是reactants SMILES
# 输出是predicted SMILES, mol_graph, diff graph


In [30]:
# 把入参的SMILES处理成model可以接收的形式：（batch size = 1）的dataloader
# 同时，输出src的element, mask, bond, aroma, charge
from dataset import TransformerDataset
from preprocess import molecule
from torch.utils.data import DataLoader
from rdkit import Chem

def map_atoms_in_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        return None

    atom_map = {}
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx() + 1)
        atom_map[atom.GetIdx()] = atom.GetIdx() + 1

    mapped_smiles = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
    return mapped_smiles

def process_smiles(smiles, map_atoms=False):
    if map_atoms:
        smiles = map_atoms_in_smiles(smiles)
    reactant_mols = [Chem.MolFromSmiles(item) for item in smiles.split(".")]
    reactant_len = Chem.MolFromSmiles(smiles).GetNumAtoms()

    reactant_features = molecule(reactant_mols, reactant_len)

    element = reactant_features['element']
    mask = reactant_features['mask']
    bond = reactant_features['bond']
    aroma = reactant_features['aroma']
    charge = reactant_features['charge']

    input_data = {}
    for key in reactant_features:
        if key in ["element", "reactant"]:
            input_data[key] = reactant_features[key]
        else:
            input_data['src_'+key] = reactant_features[key]

    print('reactant')
    print(reactant_features['reactant'])
    print('src mask')
    print(reactant_features['mask'])

    data = [input_data]
    full_dataset = TransformerDataset(False, data)

    data_loader = DataLoader(full_dataset,
                             batch_size=1,
                             num_workers=4, collate_fn=TransformerDataset.collate_fn)

    return data_loader, element, mask, bond, aroma, charge


In [31]:
# SM = "[CH2:23]1[O:24][CH2:25][CH2:26][CH2:27]1.[F:1][c:2]1[c:3]([N+:10](=[O:11])[O-:12])[cH:4][c:5]([F:9])[c:6]([F:8])[cH:7]1.[H-:22].[NH2:13][c:14]1[s:15][cH:16][cH:17][c:18]1[C:19]#[N:20].[Na+:21]"
# data_loader, element, mask, bond, aroma, charge = process_smiles(SM)

In [32]:
# 利用result = map(result2mol, arg_list) ； 输出为： mol, smile, check(smile), mol_graph
from utils import result2mol
from model import *

def init_model(save_path, checkpoint):
    state_dict = {}
    map_location = {'cuda:%d' % 0: 'cuda:%d' % 0}
    checkpoint = torch.load(os.path.join(save_path, checkpoint), map_location=map_location)
    for key in checkpoint['model_state_dict']:
        if key in state_dict:
            state_dict[key] += checkpoint['model_state_dict'][key]
        else:
            state_dict[key] = checkpoint['model_state_dict'][key]

    model = MoleculeVAE(None, 100, 192, 6).to(0) # TODO
    # model.load_state_dict(state_dict)

    return model

def predict(data_loader,
            save_path='./CKPT/no_reactant_mask/', checkpoint="epoch-7-loss-2.3548229463048114", temperature=1):

    model = init_model(save_path, checkpoint)

    for data in data_loader: # 只有1个
        data_gpu = {}
        for key in data:
            data_gpu[key] = data[key].to(0)

        predicted_dict = model('sample', data_gpu, temperature)

        element = data['element']
        src_mask = data['src_mask']
        pred_bond = predicted_dict['bond'].cpu()
        pred_aroma, pred_charge = predicted_dict['aroma'].cpu(), predicted_dict['charge'].cpu()

        arg_list = [(element[j], src_mask[j], pred_bond[j], pred_aroma[j], pred_charge[j], None) for j in
                range(1)]

        res = map(result2mol, arg_list)
        res = list(res)

        for item in res:
            mol, smile, valid, mol_graph = item[0], item[1], item[2], item[3]

        return mol, smile, valid, mol_graph, pred_bond, pred_aroma, pred_charge

In [33]:
# pred_mol, pred_smile, pred_valid, pred_mol_graph, pred_bond, pred_aroma, pred_charge = predict(data_loader)

In [34]:
# 根据result的element, mask, bond, aroma, charge 以及 src的element, mask, bond, aroma, charge
# 调用get_diff_adj，获得diff graph
from utils import get_diff_adj

# diff, adj_src, adj_pred = get_diff_adj(element, mask, bond, aroma, charge,
#              pred_bond.squeeze(), pred_aroma.squeeze(), pred_charge.squeeze())



In [35]:
# diff.sum()

In [36]:
# bond

In [37]:
# pred_bond.numpy()

In [38]:
# adj_src

In [39]:
# adj_pred

# 代碼

In [40]:
def pred_from_smiles(smiles):
    dl, element, mask, bond, aroma, charge = process_smiles(smiles, False)

    pred_mol, pred_smile, pred_valid, pred_mol_graph, pred_bond, pred_aroma, pred_charge = predict(dl)

    diff, adj_src, adj_pred = get_diff_adj(element, mask, bond, aroma, charge,
             pred_bond.squeeze(), pred_aroma.squeeze(), pred_charge.squeeze())

    return pred_smile, diff, adj_src, adj_pred


In [51]:
pred_smile, diff, adj_src, adj_pred = pred_from_smiles("[CH2:1]([CH3:2])[n:3]1[cH:4][c:5]([C:22](=[O:23])[OH:24])[c:6](=[O:21])[c:7]2[cH:8][c:9]([F:20])[c:10](-[c:13]3[cH:14][cH:15][c:16]([NH2:19])[cH:17][cH:18]3)[cH:11][c:12]12")

# [OH:1][c:2]1[n:3][cH:4][c:5]([C:6](=[O:7])[CH2:15][CH:16]([CH3:17])[CH3:18])[cH:12][cH:13]1

reactant
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
src mask
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
reactant ************************************************************
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')
src_mask ************************************************************
tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False]], device='cuda:0')
smile ********************
[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12


[12:41:48] SMILES Parse Error: syntax error while parsing: [C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12
[12:41:48] SMILES Parse Error: Failed parsing SMILES '[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12' for input: '[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12'


smile ********************
[CH2:1]([CH3:2])[n:3]1[cH:4][c:5]([C:22](=[O:23])[OH:24])[c:6](=[O:21])[c:7]2[cH:8][c:9]([F:20])[c:10]([c:13]3[cH:14][cH:15][c:16]([NH2:19])[cH:17][cH:18]3)[cH:11][c:12]12
smile ********************
[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12


[12:41:48] SMILES Parse Error: syntax error while parsing: [C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12
[12:41:48] SMILES Parse Error: Failed parsing SMILES '[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12' for input: '[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12'


In [52]:
pred_smile
# [CH3:14][NH2:15].[N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[Cl:13].[OH2:16]
# [N+:1](=[O:2])([O-:3])[c:4]1[cH:5][c:6]([C:7](=[O:8])[OH:9])[cH:10][cH:11][c:12]1[NH:15][CH3:14]

'[C-2:1]([CH-2:2])[n-2:3]1[c-2:4][c-2:5]([C+5:22](=[O-2:23])[O-2:24])[c-2:6](=[O-2:21])[c-2:7]2[c-2:8][c-2:9]([f-2:20])[c:10]([c-2:13]3[c-2:14][C-2:15][c+5:16]([n-2:19])[c-6:17][c+5:18]3)[c-2:11][c+5:12]12'

In [53]:
diff

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0

## 转化为文字解释

In [43]:
def get_reaction_info(elements, diff):
    element_symbols = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Ni", "Co", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"]

    reaction_info = []

    for i in range(len(diff)):
        for j in range(i+1, len(diff)):
            if diff[i][j] == 1:
                reaction_info.append(f"Formation of bond between atom {i+1} ({element_symbols[elements[i]]} element) and atom {j+1} ({element_symbols[elements[j]]} element)")
            elif diff[i][j] == -1:
                reaction_info.append(f"Breaking of bond between atom {i+1} ({element_symbols[elements[i]]} element) and atom {j+1} ({element_symbols[elements[j]]} element)")

    return reaction_info

# Example inputs
elements = [6, 8, 1, 1, 6, 1]  # C, O, H, H, C, H
diff = [
    [0, 1, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0]
]

reaction_info = get_reaction_info(elements, diff)
for info in reaction_info:
    print(info)


Formation of bond between atom 1 (N element) and atom 2 (F element)
Formation of bond between atom 1 (N element) and atom 5 (N element)
