In [158]:
import numpy as np
import torch
from torch import nn

class Smiles_To_Emmbedding(nn.Module):
    
    def __init__(self):
        nn.Module.__init__(self)
        
        #tokenization:
        vocab_file=open('drug_codes_chembl_freq_1500.txt','r')
        vocab_data=vocab_file.read()
        vocab_data=vocab_data.replace(' ','')
        self.vocab=vocab_data.split('\n')
        self.vocab.pop(0)
        self.max_lenght= len(max(vocab, key=len))
        self.zeta=50 #max lengt of drug representation
        
        SMILES_CHARS = [' ',
                '#', '%', '(', ')', '+', '-', '.', '/',
                '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                '=', '@',
                'A', 'B', 'C', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P',
                'R', 'S', 'T', 'V', 'X', 'Z',
                '[', '\\', ']',
                'a', 'b', 'c', 'e', 'g', 'i', 'l', 'n', 'o', 'p', 'r', 's',
                't', 'u']
        self.vocab=self.vocab+SMILES_CHARS
        
        self.smi2index = dict( (c,i) for i,c in enumerate( vocab ) )
        self.index2smi = dict( (i,c) for i,c in enumerate( vocab ) )
        
        #encoder layer:
        self.gamma=50
        dropout_rate=0.05
        self.chem_embedding=torch.nn.Linear(len(self.vocab),self.gamma,bias=False)
        self.pos_embedding=torch.nn.Linear(self.zeta,self.gamma,bias=False)      
        self.dropout = nn.Dropout(dropout_rate)
        
        
    def smiles_encoder(self,smiles ):
        X = np.zeros( (len(self.vocab), self.zeta ) )
        n_lower=0
        n_upper=self.max_lenght
        j=0
        while(n_lower<n_upper):
            try:
                i=self.smi2index[smiles[n_lower:n_upper]]
                #print(smiles[n_lower:n_upper],'found, position',(i,j))
                X[i,j] = 1
                n_lower=n_upper
                n_upper=min(n_upper+max_lenght,len(smiles)+1)
                j+=1
            except: 
                n_upper-=1
                #print('fail, new',smiles[n_lower:n_upper])
        return X

    def smiles_decoder(self,X ):
        smi = ''
        X = X.argmax( axis=0 )
        for i in X:
            if(i==0): break
            smi += index2smi[ i ]
        return smi
    
    def forward(self,smiles):
        M=torch.from_numpy(self.smiles_encoder(smiles)).float()
        C=torch.zeros([self.gamma,self.zeta])
        P=torch.zeros([self.gamma,self.zeta])
        for j in range(self.zeta):
            C[:,j]=self.chem_embedding(M[:,j])
            I=torch.zeros(self.zeta)
            I[j]=1
            P[:,j]=self.pos_embedding(I) 
        E=C+P
        E=self.dropout(E)
        return (E)
    
Encoder=Smiles_To_Emmbedding()
X=Encoder.forward('C1CC2=C(C=C(C=C2)C3=NC(=C(S3)CCCOC4=CC=C(C=C4)CN)C(=O)O)/C(=N/NC5=NC6=CC=CC=C6S5)/C1')
print(X)


tensor([[-0.0721,  0.0248,  0.1143,  ..., -0.0746,  0.0595,  0.0000],
        [ 0.0828,  0.1275, -0.1284,  ...,  0.1092,  0.0576,  0.0800],
        [-0.0469, -0.0000, -0.0524,  ..., -0.0528, -0.0266,  0.0613],
        ...,
        [ 0.1255, -0.0769, -0.1521,  ..., -0.1169, -0.0284,  0.0841],
        [ 0.1447,  0.0871,  0.1154,  ..., -0.0131, -0.1329, -0.0494],
        [-0.1311,  0.0324, -0.0182,  ..., -0.1411,  0.1170, -0.0777]],
       grad_fn=<MulBackward0>)
