In [11]:
valid_atomic_nums = list(range(1, 119)) + ['ukn']
valid_bond_types = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC', 'ukn']

valid_features_dic = {'atomomic_nums': valid_atomic_nums, 'bond_types': valid_bond_types}

def safe_index(l, e):
    """
    Return index of element e in list l. If e is not present, return the last index
    """
    try:
        return l.index(e)
    except:
        return len(l) - 1

In [12]:
from rdkit import Chem
import numpy as np

def smiles2graph(smiles):

    mol = Chem.MolFromSmiles(smiles)
    atom_atomic_nums = []
    for atom in mol.GetAtoms():
        atom_atomic_nums.append(safe_index(valid_features_dic['atomomic_nums'], atom.GetAtomicNum()))
    atom_atomic_nums = np.array(atom_atomic_nums, dtype=np.int64)

    if len(mol.GetBonds()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = str(bond.GetBondType())
            edge_feature = safe_index(valid_features_dic['bond_types'], edge_feature)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype=np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype=np.int64)

    else:  # mol has no bonds
        edge_index = np.empty((2, 0), dtype=np.int64)
        edge_attr = np.empty((0, 1), dtype=np.int64)

    return atom_atomic_nums, edge_index, edge_attr

In [13]:
import pandas as pd 

smiles_list = pd.read_csv('data/QM9.csv')['smiles'].values
smiles_list = smiles_list[:10]

for smiles in smiles_list:
    graph = smiles2graph(smiles)
    

In [None]:
import os
import lmdb
from tqdm import tqdm
from multiprocessing import Pool

def write_lmdb(smiles_list, outpath='lmdb_file', nthreads=8):
    os.makedirs(outpath, exist_ok=True)
    output_name = os.path.join(outpath,'.lmdb')
    try:
        os.remove(output_name)
    except:
        pass
    env_new = lmdb.open(
        output_name,
        subdir=False,
        readonly=False,
        lock=False,
        readahead=False,
        meminit=False,
        max_readers=1,
        map_size=int(100e9),
    )
    txn_write = env_new.begin(write=True)
    with Pool(nthreads) as pool:
        i = 0
        for inner_output in tqdm(pool.imap(smiles2graph, smiles_list)):
            if inner_output is not None:
                txn_write.put(f'{i}'.encode("ascii"), inner_output)
                i += 1
        print('{} lines'.format(i))
        txn_write.commit()
        env_new.close()

write_lmdb(['CCO', 'CCN', 'CCF', 'CCCl'], outpath='lmdb_file', nthreads=8)

0it [00:00, ?it/s]