In [None]:
%cd ..

In [None]:
from tests.test_canonicalization import test_invariance
from tucan.io import graph_from_smiles
from networkx.algorithms.components import is_connected
from time import ctime
from collections import namedtuple
import multiprocessing

In [None]:
NUMBER_OF_PROCESSES = multiprocessing.cpu_count() - 1
MAX_QUEUE_SIZE = 10

Molecule = namedtuple("Molecule", ["chembl_id", "graph"])
TestResult = namedtuple("TestResult", ["chembl_id", "status", "info"])

molecule_queue = multiprocessing.Queue(MAX_QUEUE_SIZE)    # cap queue size to limit memory consumption
result_queue = multiprocessing.Queue()


def produce_molecules(molecules, results):
    chembl = open("tests/chembl_30_smiles.txt", "r+")
    chembl.readline()    # skip header
    
    for m in chembl:
        chembl_id, smiles = m.split("\t")
        try:
            graph = graph_from_smiles(smiles)
        except Exception as e:
            print(f"Cannot process {chembl_id} due to unexpected exception: {e}")
            results.put(TestResult(chembl_id, "skipped", e))
            continue
        if not is_connected(graph):    # only test connected graphs for now
            # print(f"Skipping {chembl_id} due to subgraphs")
            results.put(TestResult(chembl_id, "skipped", "disconnected graph"))
            continue
        molecules.put(Molecule(chembl_id, graph))
    chembl.close()
    molecules.put("DONE")


def consume_molecules(molecules, results, pid):
    n_tested = 0
    for m in iter(molecule_queue.get, "DONE"):
        try:
            test_invariance(m.graph)
        except AssertionError as e:
            print(f"Invariance test failed with {m.chembl_id}: {e}")
            results.put(TestResult(m.chembl_id, "failed", e))
        n_tested += 1
        if not n_tested % 1000:
            print(f"{ctime()}: process {pid} tested {n_tested} molecules.")
    molecules.put("DONE")    # tell other processes we're done

print(f"{ctime()}: distributing tests of {NUMBER_OF_PROCESSES} processes.")
processes = [multiprocessing.Process(target=consume_molecules,
                                     args=(molecule_queue, result_queue, pid))
             for pid in range(NUMBER_OF_PROCESSES)]
for p in processes:
    p.start()

produce_molecules(molecule_queue, result_queue)

for p in processes:
    p.join()    # wait for process to finish
    p.close()
molecule_queue.close()

with open("tests/ChEMBL_result.tsv", "w") as f:
    f.write(f"ChEMBL_ID\tstatus\tinfo\n")
    while not result_queue.empty():
        r = result_queue.get()
        f.write(f"{r.chembl_id}\t{r.status}\t{r.info}\n")
result_queue.close()
