In [9]:
from src.huggingmolecules import GroverConfig, GroverFeaturizer, GroverModel
import torch

# Build the model and load the pre-trained weights in one line:
model = GroverModel.from_pretrained('grover_base')
model.eval()

# Encode (featurize) the batch of two SMILES strings: 
featurizer = GroverFeaturizer.from_pretrained('grover_base')

def get_fingerprint(smiles):
    batch = featurizer([smiles])
    
    emb = model(batch)

    emb = torch.cat([emb[0], emb[1]], dim=-1).squeeze(0)
    
    return emb.detach()

def get_fingerprint_batch(batch):
    batch = featurizer(batch)
    
    emb = model(batch)

    emb = torch.cat([emb[0], emb[1]], dim=-1).squeeze(0)
    
    return np.array(emb.detach())

get_fingerprint_batch(["C[S+]([O-])c1ccc(-c2nc(-c3ccc(F)cc3)c(-c3ccncc3)[nH]2)cc1", "CCCCCC=CCC=CC=CC=CC(SCC(NC(=O)CCC(N)C(=O)O)C(=O)NCC(=O)O)C(O)CCCC(=O)O"])

array([[-0.6496998 , -0.57952565, -0.5719348 , ..., -1.6774462 ,
        -0.27131468,  0.8308551 ],
       [-0.37015235, -0.27983865, -0.49142447, ..., -1.3654205 ,
        -0.02482126,  1.2173467 ]], dtype=float32)

In [None]:
import csv, tqdm
import numpy as np

chem_dict = {}

with open("stitch_molecules.csv", "r") as f:
    reader = csv.reader(f, delimiter=";")
    next(reader)
    lines = list(reader)
    for line in tqdm.tqdm(lines):
        #print('line[{}] = {}'.format(i, line))
        chem_dict[line[0]] = np.array(get_fingerprint(line[0]))
        


In [None]:
import csv, tqdm
import numpy as np

chem_dict = {}

seqs = []
with open("stitch_molecules.csv", "r") as f:
    reader = csv.reader(f, delimiter=";")
    next(reader)
    lines = list(reader)
    for line in lines:
        smiles = line[0]
        seqs.append(smiles)
        
def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

for batch in tqdm.tqdm(chunker(seqs, 4), total=len(seqs)/5):
    fingerprints = get_fingerprint_batch(batch)
    
    for i in range(len(batch)):
        chem_dict[batch[i]] = fingerprints[i]

  0%|▏                                                                                     | 121/76106.2 [00:47<7:42:41,  2.74it/s]

In [30]:
np.savez_compressed("graph_dict_stitch.npz", gdict=chem_dict)