In [None]:
# default_exp spe2vec

# SPE2Vec

>  

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import os
import gensim


class Corpus(object):
    '''
    *filename*: A file that stores SMILES line-by-line.
    
    *tokenizer*: SPE tokenizer
    
    *dropout*: SPE dropout, default = 0
    '''
    def __init__(self, infile, tokenizer, isdir=False, dropout=0):
        self.infile = infile
        self.tokenizer = tokenizer
        self.dropout = dropout
        self.isdir = isdir
 
    def __iter__(self):
        if self.isdir:
            for fname in os.listdir(self.infile):
                for smi in open(os.path.join(self.infile, fname)):
                    yield self.tokenizer.tokenize(smi, dropout=self.dropout).split(' ')
        else:
            for smi in open(self.infile):
                yield self.tokenizer.tokenize(smi, dropout=self.dropout).split(' ')
    
def learn_spe2vec(corpus, outfile=None, 
                  vector_size=100, window=10, min_count=10, n_jobs = 1, method = 'skip-gram', 
                  **kwargs):
    '''
    Train a spe2vec model.
    
    *corpus*: an instance of `Class Corpus()`
    
    *outfile*: str, name of the spe2vec model file.
    
    *vector_size*: dimensions of embedding.
    
    *window*: number of tokens considered as context
    
    *min_count*: number of occurrences a token should have to be considered in training
    
    *n_jobs*: number of cpu cores used for training.
    
    *method*: modeling method, choose from ['cbow', 'skip-gram']
    
    More training parameter can be found https://radimrehurek.com/gensim/models/word2vec.html#gensim.models.word2vec.Word2Vec
    '''
    
    if method.lower() == 'skip-gram':
        sg = 1
    elif method.lower() == 'cbow':
        sg = 0
    else:
        raise ValueError("Invalid option,  choose from ['cbow', 'skip-gram']")
    
    model = gensim.models.Word2Vec(corpus, size=vector_size, window=window, min_count=min_count, workers=n_jobs, sg=sg,
                              **kwargs)
    
    if outfile:
        model.save(outfile)
        
    return model

def load_spe2vec(model_path):
    return gensim.models.Word2Vec.load(model_path)

In [None]:
#export

import numpy as np

class SPE2Vec(object):
    def __init__(self, model_path, tokenizer):
        self.model = gensim.models.Word2Vec.load(model_path)
        self.tokenizer = tokenizer
        self.token_keys = set(self.model.wv.vocab.keys())
        
        #get the vector for unknown tokens. simply averge the vectors of all known tokens.
        
        vectors = []
        for word in self.model.wv.vocab:
            vectors.append(self.model.wv[word])        
        self.unknown = np.mean(vectors, axis=0)
    
    def tokenize(self, smi, dropout=0):
        '''
        tokenize SMILES into substructure tokens.
        '''
        return self.tokenizer.tokenize(smi, dropout)
    
    def smiles2vec(self, smi, dropout=0, mode = 'average'):
        '''
        Generate a vector for a SMILES. The vector is construc in four modes: ['average', 'sum', 'avg_pool', 'sum_pool']
        
        `average`: average the embedding of all tokens
        
        `sum`: sum the embedding of all tokens
        
        `avg_pool`: concatenation of average, max pooling and min pooling 
        
        `sum_pool`: concatenation of sum, max pooling and min pooling 
        
        The Unknown token will be skipped
        '''
        
        if mode not in ['average', 'sum', 'avg_pool', 'sum_pool']:
            raise ValueError("Invalid option,  choose from ['average', 'sum', 'avg_pool', 'sum_pool']")
        
        tokens = self.tokenizer.tokenize(smi, dropout).split(' ')
        
        if mode == 'average':
        
            return np.mean([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
        
        if mode == 'sum':
            return np.sum([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
        
        if mode == 'avg_pool':
            tok_mean = np.mean([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_max = np.amax([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_min = np.amin([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_concate = np.concatenate((tok_mean, tok_max, tok_min))
            return tok_concate
        
        if mode == 'sum_pool':
            tok_sum = np.sum([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_max = np.amax([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_min = np.amin([self.model.wv[tok] for tok in tokens if tok in self.token_keys], axis=0)
            tok_concate = np.concatenate((tok_sum, tok_max, tok_min))
            return tok_concate
        
    def spe2vec(self, smi, dropout=0, skip_unknown=False):
        '''
        Generate a list of vectors (np.array). Each vector is spe vector of each token.
        
        The unknown token will be represented by the mean of all token vectors from the model.
        '''
        
        token_keys = set(self.model.wv.vocab.keys())
        tokens = self.tokenizer.tokenize(smi, dropout).split(' ')
        
        if skip_unknown:
            vec = [self.model.wv[tok] for tok in tokens if tok in self.token_keys]
        else:
            vec = [self.model.wv[tok] if tok in self.token_keys else self.unknown for tok in tokens]
        
        return vec

In [None]:
#hide
file = '../experiments/data/smiles_toy.smi'
filedir = '../experiments/data/'

In [None]:
#hide
import codecs
from SmilesPE.tokenizer import *
spe_vob= codecs.open('../SPE_ChEMBL.txt')
spe = SPE_Tokenizer(spe_vob)

In [None]:
#hide
%%time
corpus = Corpus(file, tokenizer = spe) # a memory-friendly iterator
model = learn_spe2vec(corpus)

CPU times: user 2.81 s, sys: 19.6 ms, total: 2.83 s
Wall time: 1.05 s


In [None]:
#hide
model = load_spe2vec('../experiments/results/spe_model.bin')
print(model)

Word2Vec(vocab=3114, size=100, alpha=0.025)


In [None]:
#hide
s = SPE2Vec('../experiments/results/spe_model.bin', spe)

In [None]:
#hide
s.tokenize('c1ccccc1C')

'c1ccccc1 C'

In [None]:
#hide
s.smiles2vec('c1ccccc1') == model.wv['c1ccccc1']

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])

In [None]:
s.tokenize('c1ccccc1C')

'c1ccccc1 C'

In [None]:
s.spe2vec('c1ccccc1C')

[array([ 0.00324177, -0.18124679,  0.1894573 ,  0.29736474, -0.14143717,
        -0.03290153, -0.31891045,  0.16373567, -0.12413523, -0.08658446,
        -0.23956653,  0.05335753,  0.18146366, -0.17212407, -0.17879114,
        -0.01039552, -0.00274071,  0.01653983,  0.08432296, -0.15634526,
         0.29629305, -0.16786121,  0.06479991,  0.34462902, -0.11052489,
        -0.13513446,  0.16418819, -0.21508686, -0.01842665, -0.15818536,
        -0.05421342,  0.2041645 ,  0.14783993, -0.00653112, -0.19034739,
        -0.11876111,  0.12208337, -0.0743893 ,  0.03400969,  0.04422404,
        -0.10224582,  0.34490895,  0.12326851, -0.08695894, -0.08150315,
         0.09907438,  0.28797793,  0.15912676,  0.15228626, -0.164707  ,
         0.33839643, -0.04265443, -0.11858924,  0.10059267, -0.24335982,
        -0.02948368,  0.53029126,  0.2448303 ,  0.11335112,  0.01153868,
        -0.01010862, -0.06406022, -0.01338368, -0.18424016,  0.03580371,
         0.18463984,  0.15326728, -0.15144381, -0.0

In [None]:
s.smiles2vec('c1ccccc1C', mode = 'average')

array([-1.34022804e-02,  1.61780491e-02,  1.71152934e-01,  2.72531897e-01,
       -1.82708904e-01, -1.00375563e-01, -3.14150989e-01,  1.54515356e-03,
       -5.33701554e-02,  3.77645455e-02, -8.03763866e-02,  3.61803323e-02,
        2.03489780e-01, -1.41125008e-01, -6.87211454e-02,  3.95707041e-02,
        1.86435529e-04,  2.40851752e-03,  1.79549366e-01, -2.34853595e-01,
        2.54714727e-01, -1.01204760e-01,  1.46192573e-02,  3.20846856e-01,
       -1.26948640e-01, -3.60341370e-02, -1.99805945e-02, -6.57920390e-02,
        1.12212896e-01, -9.79810804e-02, -6.81295246e-02,  3.15792769e-01,
        9.70366374e-02,  2.58648433e-02, -1.16440386e-01, -2.44033992e-01,
        1.81054398e-02, -7.37853795e-02,  2.92290300e-02, -6.44483045e-02,
        4.18638885e-02,  1.02528289e-01,  9.31825042e-02, -7.51460642e-02,
       -1.67078704e-01, -5.22097796e-02,  2.24865556e-01, -4.42353934e-02,
        1.09328270e-01, -1.35819301e-01,  2.80320883e-01,  2.93976273e-02,
       -1.31033853e-01, -

In [None]:
s.smiles2vec('c1ccccc1C', mode = 'sum')

array([-2.6804561e-02,  3.2356098e-02,  3.4230587e-01,  5.4506379e-01,
       -3.6541781e-01, -2.0075113e-01, -6.2830198e-01,  3.0903071e-03,
       -1.0674031e-01,  7.5529091e-02, -1.6075277e-01,  7.2360665e-02,
        4.0697956e-01, -2.8225002e-01, -1.3744229e-01,  7.9141408e-02,
        3.7287106e-04,  4.8170350e-03,  3.5909873e-01, -4.6970719e-01,
        5.0942945e-01, -2.0240952e-01,  2.9238515e-02,  6.4169371e-01,
       -2.5389728e-01, -7.2068274e-02, -3.9961189e-02, -1.3158408e-01,
        2.2442579e-01, -1.9596216e-01, -1.3625905e-01,  6.3158554e-01,
        1.9407327e-01,  5.1729687e-02, -2.3288077e-01, -4.8806798e-01,
        3.6210880e-02, -1.4757076e-01,  5.8458060e-02, -1.2889661e-01,
        8.3727777e-02,  2.0505658e-01,  1.8636501e-01, -1.5029213e-01,
       -3.3415741e-01, -1.0441956e-01,  4.4973111e-01, -8.8470787e-02,
        2.1865654e-01, -2.7163860e-01,  5.6064177e-01,  5.8795255e-02,
       -2.6206771e-01, -1.0030853e-01,  6.5598294e-02, -1.1914876e-01,
      

In [None]:
s.smiles2vec('c1ccccc1C', mode = 'avg_pool').shape

(300,)

In [None]:
#hide
s.tokenize('c1ccccc1[dum]')

'c1ccccc1 [dum]'

In [None]:
#hide
s.spe2vec('c1ccccc1[dum]')

[array([ 0.00324177, -0.18124679,  0.1894573 ,  0.29736474, -0.14143717,
        -0.03290153, -0.31891045,  0.16373567, -0.12413523, -0.08658446,
        -0.23956653,  0.05335753,  0.18146366, -0.17212407, -0.17879114,
        -0.01039552, -0.00274071,  0.01653983,  0.08432296, -0.15634526,
         0.29629305, -0.16786121,  0.06479991,  0.34462902, -0.11052489,
        -0.13513446,  0.16418819, -0.21508686, -0.01842665, -0.15818536,
        -0.05421342,  0.2041645 ,  0.14783993, -0.00653112, -0.19034739,
        -0.11876111,  0.12208337, -0.0743893 ,  0.03400969,  0.04422404,
        -0.10224582,  0.34490895,  0.12326851, -0.08695894, -0.08150315,
         0.09907438,  0.28797793,  0.15912676,  0.15228626, -0.164707  ,
         0.33839643, -0.04265443, -0.11858924,  0.10059267, -0.24335982,
        -0.02948368,  0.53029126,  0.2448303 ,  0.11335112,  0.01153868,
        -0.01010862, -0.06406022, -0.01338368, -0.18424016,  0.03580371,
         0.18463984,  0.15326728, -0.15144381, -0.0

In [None]:
#hide
s.spe2vec('c1ccccc1[dum]',skip_unknown=True)

[array([ 0.00324177, -0.18124679,  0.1894573 ,  0.29736474, -0.14143717,
        -0.03290153, -0.31891045,  0.16373567, -0.12413523, -0.08658446,
        -0.23956653,  0.05335753,  0.18146366, -0.17212407, -0.17879114,
        -0.01039552, -0.00274071,  0.01653983,  0.08432296, -0.15634526,
         0.29629305, -0.16786121,  0.06479991,  0.34462902, -0.11052489,
        -0.13513446,  0.16418819, -0.21508686, -0.01842665, -0.15818536,
        -0.05421342,  0.2041645 ,  0.14783993, -0.00653112, -0.19034739,
        -0.11876111,  0.12208337, -0.0743893 ,  0.03400969,  0.04422404,
        -0.10224582,  0.34490895,  0.12326851, -0.08695894, -0.08150315,
         0.09907438,  0.28797793,  0.15912676,  0.15228626, -0.164707  ,
         0.33839643, -0.04265443, -0.11858924,  0.10059267, -0.24335982,
        -0.02948368,  0.53029126,  0.2448303 ,  0.11335112,  0.01153868,
        -0.01010862, -0.06406022, -0.01338368, -0.18424016,  0.03580371,
         0.18463984,  0.15326728, -0.15144381, -0.0