# Link to GITHUB Repo

In [None]:
#https://github.com/zengkaipeng/UAlign

# Modified Graph Util Script

In [11]:
from ogb.utils.features import (
    allowable_features, atom_to_feature_vector, bond_feature_vector_to_dict,
    bond_to_feature_vector, atom_feature_vector_to_dict
)
import torch
import rdkit
from rdkit import Chem
import numpy as np

def smiles2graph(smiles_string, with_amap=False):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return None 

    if with_amap:
        if len(mol.GetAtoms()) > 0:
            max_amap = max([atom.GetAtomMapNum() for atom in mol.GetAtoms()])
            for atom in mol.GetAtoms():
                if atom.GetAtomMapNum() == 0:
                    atom.SetAtomMapNum(max_amap + 1)
                    max_amap = max_amap + 1

            amap_idx = {
                atom.GetAtomMapNum(): atom.GetIdx()
                for atom in mol.GetAtoms()
            }
        else:
            amap_idx = dict()


    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))

    num_atom_features = 9
    if len(atom_features_list) > 0:
        x = np.array(atom_features_list, dtype=np.int64)
    else:
        x = np.empty((0, num_atom_features), dtype=np.int64)


    num_bond_features = 3 
    if len(mol.GetBonds()) > 0: 
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        edge_index = np.array(edges_list, dtype=np.int64).T

        edge_attr = np.array(edge_features_list, dtype=np.int64)

    else:
        edge_index = np.empty((2, 0), dtype=np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype=np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['num_nodes'] = len(x)

    if with_amap:
        return graph, amap_idx
    else:
        return graph


# Modified Inference Script

In [None]:

from tqdm import tqdm
import torch
import argparse
import json
import pickle
from model import PretrainModel, PositionalEncoding
from data_utils import fix_seed
from torch.nn import TransformerDecoderLayer, TransformerDecoder
from sparse_backBone import GATBase
from utils.chemistry_parse import clear_map_number, canonical_smiles
from utils.graph_utils import smiles2graph
import pandas as pd
import torch_geometric
from inference_tools import beam_search_one
import time
import os
import rdkit
from rdkit import RDLogger

def make_graph_batch(smi, rxn=None):
    graph = smiles2graph(smi, with_amap=False)
    if graph is None:
        return None

    num_nodes = graph['node_feat'].shape[0]
    num_edges = graph['edge_index'].shape[1]

    data = {
        'x': torch.from_numpy(graph['node_feat']),
        'num_nodes': num_nodes,
        'edge_attr': torch.from_numpy(graph['edge_feat']),
        'edge_index': torch.from_numpy(graph['edge_index']),
        'ptr': torch.LongTensor([0, num_nodes]),
        'e_ptr': torch.LongTensor([0, num_edges]),
        'batch': torch.zeros(num_nodes).long(),
        'e_batch': torch.zeros(num_edges).long(),
        'batch_mask': torch.ones(1, num_nodes).bool()
    }

    if rxn is not None:
        data['node_rxn'] = torch.ones(num_nodes).long() * rxn
        data['edge_rxn'] = torch.ones(num_edges).long() * rxn
    return torch_geometric.data.Data(**data)

args = argparse.Namespace(
    dim=768,
    n_layer=8,
    heads=12,
    negative_slope=0.2,
    seed=2023,
    device=0,
    checkpoint='model.pth',
    token_ckpt='token.pkl',
    use_class=False,
    max_len=100,
    beams=1,
    product_smiles='product_smiles_test.csv',
    input_class=-1,
    org_output=False,
    output_file='results.csv'
)

print(args)

if not torch.cuda.is_available() or args.device < 0:
    device = torch.device('cpu')
else:
    device = torch.device(f'cuda:{args.device}')

fix_seed(args.seed)
with open(args.token_ckpt, 'rb') as Fin:
    tokenizer = pickle.load(Fin)

GNN = GATBase(
    num_layers=args.n_layer, dropout=0.1, embedding_dim=args.dim,
    num_heads=args.heads, negative_slope=args.negative_slope,
    n_class=11 if args.use_class else None
)

decode_layer = TransformerDecoderLayer(
    d_model=args.dim, nhead=args.heads, batch_first=True,
    dim_feedforward=args.dim * 2, dropout=0.1
)
Decoder = TransformerDecoder(decode_layer, args.n_layer)
Pos_env = PositionalEncoding(args.dim, 0.1, maxlen=2000)

model = PretrainModel(
    token_size=tokenizer.get_token_size(), encoder=GNN,
    decoder=Decoder, d_model=args.dim, pos_enc=Pos_env
).to(device)

if args.checkpoint != '':
    assert args.token_ckpt != '', 'Missing Tokenizer Information'
    print(f'[INFO] Loading model weight in {args.checkpoint}')
    weight = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(weight, strict=False)

print('[INFO] padding index', tokenizer.token2idx['<PAD>'])
if args.use_class:
    assert args.input_class != -1, 'require reaction class!'
    start_token, rxn_class = f'<RXN>_{args.input_class}', args.input_class
else:
    start_token, rxn_class = '<CLS>', None

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

smiles_df = pd.read_csv(args.product_smiles, header=None)
product_smiles_list = smiles_df[0].tolist()

results = []
for product_smiles in tqdm(product_smiles_list, desc="Processing SMILES"):
    prd = canonical_smiles(product_smiles)
    g_ip = make_graph_batch(prd, rxn_class)
    if g_ip is None:
        print(f"Invalid SMILES string: {product_smiles}")
        results.append([product_smiles, "", "", args.input_class])
        continue
    g_ip = g_ip.to(device)

    preds, probs = beam_search_one(
        model, tokenizer, g_ip, device, max_len=args.max_len,
        size=args.beams, begin_token=start_token, end_token='<END>',
        pen_para=0, validate=not args.org_output
    )

    results.append([product_smiles, preds, probs, args.input_class])

results_df = pd.DataFrame(results)
results_df.to_csv(args.output_file, header=False, index=False)

print('[RESULT]')
print(json.dumps(results, indent=4))


Namespace(dim=768, n_layer=8, heads=12, negative_slope=0.2, seed=2023, device=0, checkpoint='model.pth', token_ckpt='token.pkl', use_class=False, max_len=100, beams=1, product_smiles='product_smiles_test.csv', input_class=-1, org_output=False, output_file='results.csv')
[INFO] Loading model weight in model.pth
[INFO] padding index 265


Processing SMILES:  99%|████████████████████████████████████████████████████████▋| 7648/7690 [1:38:43<00:36,  1.16it/s]

# Command

In [None]:
#!python inference_one.py --dim 768 --n_layer 8 --heads 12 --device 0 --checkpoint model.pth --token_ckpt token.pkl --negative_slope 0.2 --max_len 100 --beams 1 --product_smiles product_smiles_test.csv

# Code for Processing Result CSV

In [None]:
import pandas as pd

df = pd.read_csv('results.csv', header=None)

second_column = df[1]

second_column = second_column.str.strip("[]").str.replace("'", "")

second_column.to_csv('processed_results.csv', header=False, index=False)
