In [1]:
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tqdm
import ray
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from scipy import sparse
import logging
from functools import lru_cache
from mol2vec.features import mol2alt_sentence, sentences2vec
from gensim.models.word2vec import Word2Vec

In [20]:
NUM_CPUS = 16

In [3]:
ray.init(num_cpus=NUM_CPUS)

2021-05-27 14:56:35,118	INFO services.py:1267 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '172.31.43.236',
 'raylet_ip_address': '172.31.43.236',
 'redis_address': '172.31.43.236:6379',
 'object_store_address': '/tmp/ray/session_2021-05-27_14-56-34_232188_4550/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-05-27_14-56-34_232188_4550/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-05-27_14-56-34_232188_4550',
 'metrics_export_port': 44101,
 'node_id': 'eff34b67bf2822056e7733e3df51736e25b64ecd08b5fa27ff2d6f73'}

In [4]:
NUM_CHUNKS = 10

In [5]:
USE_EMBEDDINGS = False

In [6]:
!ls /mnt/efs/AmpC_data/

AmpC_embeddings_0.npy	 AmpC_fingerprints_1.npz  AmpC_scores_2.npy
AmpC_embeddings_1.npy	 AmpC_fingerprints_2.npz  AmpC_scores_3.npy
AmpC_embeddings_2.npy	 AmpC_fingerprints_3.npz  AmpC_scores_4.npy
AmpC_embeddings_3.npy	 AmpC_fingerprints_4.npz  AmpC_scores_5.npy
AmpC_embeddings_4.npy	 AmpC_fingerprints_5.npz  AmpC_scores_6.npy
AmpC_embeddings_5.npy	 AmpC_fingerprints_6.npz  AmpC_scores_7.npy
AmpC_embeddings_6.npy	 AmpC_fingerprints_7.npz  AmpC_scores_8.npy
AmpC_embeddings_7.npy	 AmpC_fingerprints_8.npz  AmpC_scores_9.npy
AmpC_embeddings_8.npy	 AmpC_fingerprints_9.npz  AmpC_screen_table.csv
AmpC_embeddings_9.npy	 AmpC_scores_0.npy	  AmpC_screen_table.csv.zip
AmpC_fingerprints_0.npz  AmpC_scores_1.npy	  AmpC_screen_table_test.csv


In [7]:
RECEPTOR = "AmpC"
DATA_DIR = "/mnt/efs/AmpC_data"
INPUT_DATA = f"{DATA_DIR}/{RECEPTOR}_screen_table.csv"

MODEL_PATH = "/mnt/efs/mol2vec/examples/models/model_300dim.pkl"
UNCOMMON = "UNK"

In [8]:
@lru_cache(maxsize=2)
def get_data():
    ligands_df = pd.read_csv(INPUT_DATA)
    ligands_df = ligands_df[ligands_df["dockscore"] != "no_score"]
    
    return ligands_df

In [9]:
@lru_cache(maxsize=2)
def get_w2v_model():
    word2vec_model = Word2Vec.load(MODEL_PATH)
    word2vec_model.wv.init_sims()
    return word2vec_model

In [10]:
def create_fingerprint(smiles, score, i, radius=2, n_bits=8192):
    if i % 10000 == 0:
        logging.basicConfig(level=logging.INFO)
        logging.info(i)
        
    mol = Chem.MolFromSmiles(smiles)
    pars = { 
        "radius": radius,
        "nBits": n_bits,
        "invariants": [],
        "fromAtoms": [],
        "useChirality": False,
        "useBondTypes": True,
        "useFeatures": True,
    }
    fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, **pars)
    onbits = list(fp.GetOnBits())
    
    alt_sentence = mol2alt_sentence(mol, radius=radius)

    return onbits, alt_sentence, float(score)

In [11]:
@ray.remote
def create_fingerprint_batched(batches, radius=2, n_bits=8192):
    bits_list = []
    sentence_list = []
    score_list = []
    for i, batch in enumerate(batches):
        smiles, score = batch
        
        onbits, alt_sentence, score = create_fingerprint(smiles, score, i)
        
        bits_list.append(onbits)
        sentence_list.append(alt_sentence)
        score_list.append(score)

    return bits_list, sentence_list, score_list

In [12]:
@ray.remote
def create_mol_sentence(smiles, score, r, i):
    if i % 10000 == 0:
        logging.basicConfig(level=logging.INFO)
        logging.info(i)
        
    mol = Chem.MolFromSmiles(smiles)
    # smiles = Chem.MolToSmiles(mol)
    
    if not mol:
        return
    
    alt_sentence = mol2alt_sentence(mol, radius=r)
    
    return alt_sentence, score

In [13]:
def flatten(lst):
    return [item for batch in lst for item in batch]

In [14]:
def get_fingerprints(ligands_df, fp_size=8192, smiles_col="smiles", score_col="score"):
    future_values = []
    for df_chunk in np.array_split(ligands_df, NUM_CPUS):
        future_values.append(create_fingerprint_batched.remote(zip(df_chunk[smiles_col], df_chunk[score_col])))

    values = ray.get(future_values)
    
    all_bits, alt_sentences, scores = zip(*values)
    
    all_bits = flatten(all_bits)
    alt_sentences = flatten(alt_sentences)
    scores = flatten(scores)
    
    row_idx = []
    col_idx = []
    for i, bits in enumerate(all_bits):
        # these bits all have the same row:
        row_idx += [i] * len(bits)
        #and the column indices of those bits:
        col_idx += bits
    
    # generate a sparse matrix out of the row,col indices:
    fingerprint_matrix = sparse.coo_matrix((np.ones(len(row_idx)).astype(bool), 
                                           (row_idx, col_idx)), 
                                           shape=(max(row_idx)+1, fp_size))
    
    # convert to csr matrix, it is better:
    fingerprint_matrix =  sparse.csr_matrix(fingerprint_matrix)

    return alt_sentences, fingerprint_matrix, scores

In [15]:
def get_embeddings(ligands_df, model, radius=1):
    future_values = [create_mol_sentence.remote(smiles=smiles, score=score, r=radius, i=i) for (i, (smiles, score)) in enumerate(zip(ligands_df["smiles"], ligands_df["score"]))]
    
    values = [v for v in ray.get(future_values) if v]
    mol_sentences, scores = zip(*values)

#     vectors = sentences2vec(sentences=mol_sentences, model=model, unseen=UNCOMMON)

    return mol_sentences, scores

In [16]:
ligands_df = get_data()

  if (await self.run_code(code, result,  async_=asy)):


In [17]:
word2vec_model = get_w2v_model()

  word2vec_model.wv.init_sims()


In [18]:
# NUM_CHUNKS = 1

In [None]:
start = time.time()
for i, df_chunk in enumerate(np.array_split(ligands_df, NUM_CHUNKS)):
    if USE_EMBEDDINGS:
        print("Generating mol2vec embeddings...")
        embeddings, scores = get_embeddings(ligands_df=df_chunk, model=word2vec_model, radius=1)
        vectors = sentences2vec(sentences=embeddings, model=word2vec_model, unseen=UNCOMMON)
        
        np.save(f"{DATA_DIR}/{RECEPTOR}_embeddings_{i}.npy", vectors)
        np.save(f"{DATA_DIR}/{RECEPTOR}_embedding_scores_{i}.npy", np.array(scores))
    else:
        print("Generating Morgan Fingerprints...")
        embeddings, fingerprint_matrix, scores = get_fingerprints(ligands_df=df_chunk, score_col="dockscore")
        
        print("Saving fingerprint matrix...")
        sparse.save_npz(f"{DATA_DIR}/{RECEPTOR}_fingerprints_{i}.npz", fingerprint_matrix)
        np.save(f"{DATA_DIR}/{RECEPTOR}_scores_{i}.npy", np.array(scores))
        
        print("Saving embeddings...")
        vectors = sentences2vec(sentences=embeddings, model=word2vec_model, unseen=UNCOMMON)
        np.save(f"{DATA_DIR}/{RECEPTOR}_embeddings_{i}.npy", vectors)
    
    print(f"Chunk {i} took: {(time.time() - start)/60} mins")
    
print(f"Dataset took: {(time.time() - start)/60} mins")

Generating Morgan Fingerprints...


[2m[36m(pid=4776)[0m INFO:root:0
[2m[36m(pid=4811)[0m INFO:root:0
[2m[36m(pid=4835)[0m INFO:root:0
[2m[36m(pid=4777)[0m INFO:root:0
[2m[36m(pid=4778)[0m INFO:root:0
[2m[36m(pid=4772)[0m INFO:root:0
[2m[36m(pid=4774)[0m INFO:root:0
[2m[36m(pid=4775)[0m INFO:root:0
[2m[36m(pid=4812)[0m INFO:root:0
[2m[36m(pid=4810)[0m INFO:root:0
[2m[36m(pid=4806)[0m INFO:root:0
[2m[36m(pid=4790)[0m INFO:root:0
[2m[36m(pid=4829)[0m INFO:root:0
[2m[36m(pid=4828)[0m INFO:root:0
[2m[36m(pid=4781)[0m INFO:root:0
[2m[36m(pid=4787)[0m INFO:root:0
[2m[36m(pid=4776)[0m INFO:root:10000
[2m[36m(pid=4811)[0m INFO:root:10000
[2m[36m(pid=4835)[0m INFO:root:10000
[2m[36m(pid=4777)[0m INFO:root:10000
[2m[36m(pid=4778)[0m INFO:root:10000
[2m[36m(pid=4772)[0m INFO:root:10000
[2m[36m(pid=4775)[0m INFO:root:10000
[2m[36m(pid=4774)[0m INFO:root:10000
[2m[36m(pid=4812)[0m INFO:root:10000
[2m[36m(pid=4810)[0m INFO:root:10000
[2m[36m(pid=4806)[0m 

[2m[36m(pid=4781)[0m INFO:root:120000
[2m[36m(pid=4787)[0m INFO:root:120000
[2m[36m(pid=4775)[0m INFO:root:130000
[2m[36m(pid=4776)[0m INFO:root:130000
[2m[36m(pid=4778)[0m INFO:root:130000
[2m[36m(pid=4811)[0m INFO:root:130000
[2m[36m(pid=4772)[0m INFO:root:130000
[2m[36m(pid=4829)[0m INFO:root:130000
[2m[36m(pid=4777)[0m INFO:root:130000
[2m[36m(pid=4810)[0m INFO:root:130000
[2m[36m(pid=4790)[0m INFO:root:130000
[2m[36m(pid=4774)[0m INFO:root:130000
[2m[36m(pid=4835)[0m INFO:root:130000
[2m[36m(pid=4828)[0m INFO:root:130000
[2m[36m(pid=4812)[0m INFO:root:130000
[2m[36m(pid=4806)[0m INFO:root:130000
[2m[36m(pid=4781)[0m INFO:root:130000
[2m[36m(pid=4787)[0m INFO:root:130000
[2m[36m(pid=4776)[0m INFO:root:140000
[2m[36m(pid=4775)[0m INFO:root:140000
[2m[36m(pid=4778)[0m INFO:root:140000
[2m[36m(pid=4811)[0m INFO:root:140000
[2m[36m(pid=4772)[0m INFO:root:140000
[2m[36m(pid=4829)[0m INFO:root:140000
[2m[36m(pid=47

[2m[36m(pid=4774)[0m INFO:root:250000
[2m[36m(pid=4777)[0m INFO:root:250000
[2m[36m(pid=4810)[0m INFO:root:250000
[2m[36m(pid=4811)[0m INFO:root:250000
[2m[36m(pid=4828)[0m INFO:root:250000
[2m[36m(pid=4812)[0m INFO:root:250000
[2m[36m(pid=4787)[0m INFO:root:250000
[2m[36m(pid=4781)[0m INFO:root:250000
[2m[36m(pid=4835)[0m INFO:root:250000
[2m[36m(pid=4806)[0m INFO:root:250000
[2m[36m(pid=4778)[0m INFO:root:260000
[2m[36m(pid=4775)[0m INFO:root:260000
[2m[36m(pid=4776)[0m INFO:root:260000
[2m[36m(pid=4790)[0m INFO:root:260000
[2m[36m(pid=4772)[0m INFO:root:260000
[2m[36m(pid=4774)[0m INFO:root:260000
[2m[36m(pid=4829)[0m INFO:root:260000
[2m[36m(pid=4777)[0m INFO:root:260000
[2m[36m(pid=4810)[0m INFO:root:260000
[2m[36m(pid=4811)[0m INFO:root:260000
[2m[36m(pid=4828)[0m INFO:root:260000
[2m[36m(pid=4812)[0m INFO:root:260000
[2m[36m(pid=4787)[0m INFO:root:260000
[2m[36m(pid=4781)[0m INFO:root:260000
[2m[36m(pid=48

[2m[36m(pid=4835)[0m INFO:root:370000
[2m[36m(pid=4806)[0m INFO:root:370000
[2m[36m(pid=4775)[0m INFO:root:380000
[2m[36m(pid=4790)[0m INFO:root:380000
[2m[36m(pid=4772)[0m INFO:root:380000
[2m[36m(pid=4776)[0m INFO:root:380000
[2m[36m(pid=4774)[0m INFO:root:380000
[2m[36m(pid=4829)[0m INFO:root:380000
[2m[36m(pid=4777)[0m INFO:root:380000
[2m[36m(pid=4811)[0m INFO:root:380000
[2m[36m(pid=4778)[0m INFO:root:380000
[2m[36m(pid=4828)[0m INFO:root:380000
[2m[36m(pid=4810)[0m INFO:root:380000
[2m[36m(pid=4787)[0m INFO:root:380000
[2m[36m(pid=4812)[0m INFO:root:380000
[2m[36m(pid=4781)[0m INFO:root:380000
[2m[36m(pid=4835)[0m INFO:root:380000
[2m[36m(pid=4806)[0m INFO:root:380000
[2m[36m(pid=4775)[0m INFO:root:390000
[2m[36m(pid=4790)[0m INFO:root:390000
[2m[36m(pid=4772)[0m INFO:root:390000
[2m[36m(pid=4776)[0m INFO:root:390000
[2m[36m(pid=4829)[0m INFO:root:390000
[2m[36m(pid=4774)[0m INFO:root:390000
[2m[36m(pid=48

[2m[36m(pid=4778)[0m INFO:root:500000
[2m[36m(pid=4777)[0m INFO:root:500000
[2m[36m(pid=4812)[0m INFO:root:500000
[2m[36m(pid=4787)[0m INFO:root:500000
[2m[36m(pid=4828)[0m INFO:root:500000
[2m[36m(pid=4810)[0m INFO:root:500000
[2m[36m(pid=4781)[0m INFO:root:500000
[2m[36m(pid=4835)[0m INFO:root:500000
[2m[36m(pid=4806)[0m INFO:root:500000
[2m[36m(pid=4790)[0m INFO:root:510000
[2m[36m(pid=4775)[0m INFO:root:510000
[2m[36m(pid=4774)[0m INFO:root:510000
[2m[36m(pid=4811)[0m INFO:root:510000
[2m[36m(pid=4772)[0m INFO:root:510000
[2m[36m(pid=4829)[0m INFO:root:510000
[2m[36m(pid=4776)[0m INFO:root:510000
[2m[36m(pid=4777)[0m INFO:root:510000
[2m[36m(pid=4778)[0m INFO:root:510000
[2m[36m(pid=4787)[0m INFO:root:510000
[2m[36m(pid=4812)[0m INFO:root:510000
[2m[36m(pid=4828)[0m INFO:root:510000
[2m[36m(pid=4810)[0m INFO:root:510000
[2m[36m(pid=4781)[0m INFO:root:510000
[2m[36m(pid=4835)[0m INFO:root:510000
[2m[36m(pid=48