In [1]:
from multiprocessing import Pool
import math, random, sys
import pickle
import argparse
from functools import partial
import torch
import numpy
from sklearn.utils import shuffle
from tqdm import tqdm

from hgraph import MolGraph, common_atom_vocab, PairVocab
import rdkit

def to_numpy(tensors):
    convert = lambda x : x.numpy() if type(x) is torch.Tensor else x
    a,b,c = tensors
    b = [convert(x) for x in b[0]], [convert(x) for x in b[1]]
    return a, b, c

def tensorize(mol_batch, vocab):    
#     TypeError
    try:
        x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
        return to_numpy(x)
    except KeyError:
        return None

def tensorize_pair(mol_batch, vocab):
    x, y = zip(*mol_batch)
    x = MolGraph.tensorize(x, vocab, common_atom_vocab)
    y = MolGraph.tensorize(y, vocab, common_atom_vocab)
    return to_numpy(x)[:-1] + to_numpy(y) #no need of order for x

def tensorize_cond(mol_batch, vocab):
    x, y, cond = zip(*mol_batch)
    cond = [map(int, c.split(',')) for c in cond]
    cond = numpy.array(cond)
    x = MolGraph.tensorize(x, vocab, common_atom_vocab)
    y = MolGraph.tensorize(y, vocab, common_atom_vocab)
    return to_numpy(x)[:-1] + to_numpy(y) + (cond,) #no need of order for x

In [2]:
class ARGS:
    def __init__(self):
        self.train = 'abc'
        self.vocab = 'data/chembl/vocab.txt'
        self.mode = 'single'
        self.ncpu = 8        
        self.batch_size = 32

args = ARGS()

In [3]:
with open(args.vocab) as f:
    vocab = [x.strip("\r\n ").split() for x in f]
args_vocab = PairVocab(vocab, cuda=False)

pool = Pool(args.ncpu) 
random.seed(1)

#dataset contains single molecules
with open('/home/quang/working/Theory_of_ML/train_JTNN_full.txt') as f:
    data = [line.strip("\r\n ").split()[0] for line in f]

# with open('./data/chembl/all.txt') as f:
#     data = [line.strip("\r\n ").split()[0] for line in f]

with open('/home/quang/working/Theory_of_ML/ids_train_JTNN_full.txt') as f:
    ids_data = [line.strip("\r\n ").split()[0] for line in f]

In [4]:
len(data), len(ids_data)

(664079, 664079)

In [5]:
data, ids_data = shuffle(data, ids_data)
data = data[:10_000]
ids_data = ids_data[:10_000]
#         random.shuffle(data)
# data, ids_data = shuffle(data, ids_data)

batches = [data[i : i + args.batch_size] for i in range(0, len(data), args.batch_size)]
batches_ids = [ids_data[i : i + args.batch_size] for i in range(0, len(ids_data), args.batch_size)]
# func = partial(tensorize, vocab = args_vocab)
# all_data = pool.map(func, batches)

In [6]:
# x = MolGraph.tensorize(data[:20], args_vocab, common_atom_vocab)
# len(all_data)
# all_data[0]
all_data = []
all_data_ids = []

for b, b_ids in tqdm(zip(batches, batches_ids), total=len(batches)):
    t = tensorize(b, args_vocab)
    if t is not None:
        all_data.append(t)
        all_data_ids.append(b_ids)

100%|██████████| 313/313 [02:15<00:00,  2.31it/s]


In [8]:
# num_splits = len(all_data) // 1000
num_splits = len(all_data) // 50

le = (len(all_data) + num_splits - 1) // num_splits

for split_id in range(num_splits):
    st = split_id * le
    sub_data = all_data[st : st + le]

    with open('./train_processed_bms/mol/tensors-%d.pkl' % split_id, 'wb') as f:
        pickle.dump(sub_data, f, pickle.HIGHEST_PROTOCOL)
        
    sub_data_ids = all_data_ids[st : st + le]

    with open('./train_processed_bms/ids/tensors-%d.pkl' % split_id, 'wb') as f:
        pickle.dump(sub_data_ids, f, pickle.HIGHEST_PROTOCOL)
