<a href="https://colab.research.google.com/github/ajia90/smilestransformer/blob/main/interpolationtraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports/ setup

In [1]:
# Install RDKit. Takes 2-3 minutes
!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge python=3.7 
!time conda install -q -y -c conda-forge rdkit 

--2020-10-07 21:37:22--  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.continuum.io (repo.continuum.io)... 104.18.201.79, 104.18.200.79, 2606:4700::6812:c94f, ...
Connecting to repo.continuum.io (repo.continuum.io)|104.18.201.79|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh [following]
--2020-10-07 21:37:22--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.131.3, 104.16.130.3, 2606:4700::6810:8303, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.131.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 93052469 (89M) [application/x-sh]
Saving to: ‘Miniconda3-latest-Linux-x86_64.sh’


2020-10-07 21:37:22 (237 MB/s) - ‘Miniconda3-latest-Linux-x86_64.sh’ saved [93052469/93052469]

PREFIX=/usr/local
Unpacking payload ...
C

In [2]:
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

In [None]:
#!conda install pytorch torchvision -c pytorch

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import DrawingOptions

# methods/model


In [4]:
def plot_mols(mols, unit=200, w=120, h=200, fontsize=1.0):
    drawer = Draw.MolDraw2DSVG(4*unit, 3*unit, w, h)

    # optをとり出しておく
    opt = drawer.drawOptions()
    opt.padding = 0.1
    opt.legendFontSize = 20
    #opt.atomfontSize = 20

    xs = np.array([0,1,2,3,0,1,2,3,0,1,2,3])*unit
    ys = np.array([0,0,0,0,1,1,1,1,2,2,2,2])*unit
    for i, (mol, x, y) in enumerate(zip(mols,xs,ys)):
        # SetOffsetで左上の座標を指定できる
        drawer.SetOffset(int(x), int(y))
        drawer.SetFontSize(fontsize)

        AllChem.Compute2DCoords(mol)
        Chem.Kekulize(mol)
        # 分子をSVGに書く
        drawer.DrawMolecule(mol, legend=str(i))


    # </svg> 書く
    drawer.FinishDrawing()
    return drawer

In [5]:
def get_inputs(sm):
    seq_len = 220
    sm = sm.split()
    if len(sm)>218:
        print('SMILES is too long ({:d})'.format(len(sm)))
        sm = sm[:109]+sm[-109:]
    ids = [vocab.stoi.get(token, unk_index) for token in sm]
    ids = [sos_index] + ids + [eos_index]
    seg = [1]*len(ids)
    padding = [pad_index]*(seq_len - len(ids))
    ids.extend(padding), seg.extend(padding)
    return ids, seg

def get_array(smiles):
    x_id, x_seg = [], []
    for sm in smiles:
        a,b = get_inputs(sm)
        x_id.append(a)
        x_seg.append(b)
    return torch.tensor(x_id), torch.tensor(x_seg)

In [6]:
import torch
from pretrain_trfm import TrfmSeq2seq
from pretrain_rnn import RNNSeq2Seq
#from bert import BERT
from build_vocab import WordVocab
from utils import split

pad_index = 0
unk_index = 1
eos_index = 2
sos_index = 3
mask_index = 4

vocab = WordVocab.load_vocab('vocab.pkl')

trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
trfm.load_state_dict(torch.load('trfm.pkl'))
trfm.eval()
print('Total parameters:', sum(p.numel() for p in trfm.parameters()))

Total parameters: 4245037


# data/interpolation

In [7]:
#read in BBBp data
df = pd.read_csv('BBBP.csv')
print(df.shape)
df.head()

(2050, 4)


Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...


In [8]:
x_split = [split(sm) for sm in df['smiles'].values]
xid, xseg = get_array(x_split)
X = trfm.encode(torch.t(xid))
print(X.shape)

SMILES is too long (256)
SMILES is too long (239)
SMILES is too long (258)
SMILES is too long (380)
SMILES is too long (332)
There are 2050 molecules. It will take a little time.
(220, 2050, 256)


In [None]:
#crashes, using trfm.encoder instead of trfm.encoder.layers[i]
embedded = trfm.embed(torch.t(xid))  # (T,B,H)
embedded = trfm.pe(embedded) # (T,B,H)
output = embedded
output = trfm.trfm.encoder(output)
if trfm.trfm.encoder.norm:
    output = trfm.trfm.encoder.norm(output) # (T,B,H)
output = output.detach().numpy()

In [10]:
t = torch.t(xid)
t.shape

torch.Size([220, 2050])

In [11]:
target = trfm.embed(t)
target = trfm.pe(target)


In [12]:
decoded = trfm.decode(target, torch.from_numpy(X).float())

There are 2050 molecules. It will take a little time.


In [None]:
s = decoded[:,0,:]

In [None]:
f = pd.DataFrame(s)
f

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44
0,-45.470173,-37.101559,-29.256117,0.000000,-45.291073,-28.516005,-29.121420,-29.097557,-29.547464,-30.180914,-32.877251,-28.971113,-31.556475,-30.298185,-35.527153,-32.507462,-35.761383,-33.346615,-29.148830,-33.430061,-37.748947,-37.402824,-38.879234,-36.305046,-34.487492,-34.085678,-36.536816,-39.494839,-34.125221,-36.803223,-36.887764,-34.481155,-35.181801,-28.138823,-38.456833,-39.239799,-34.882401,-37.805706,-37.666882,-34.525105,-37.918995,-34.325474,-34.888958,-38.189602,-34.892010
1,-29.301823,-13.799294,-10.297996,-20.810789,-29.578262,-1.241233,-1.390985,-2.580844,-2.887537,-2.597407,-11.433615,-2.689757,-3.348817,-3.133340,-10.602304,-12.134955,-13.403796,-8.030526,-2.693455,-3.214077,-9.602731,-15.141812,-15.247645,-12.991599,-14.733134,-7.043643,-13.489715,-16.591805,-12.896339,-9.764510,-12.891412,-11.201367,-14.309287,-7.874004,-14.161911,-16.818937,-12.525611,-14.127555,-14.955525,-13.179366,-14.709735,-13.407912,-12.651450,-18.167086,-14.560893
2,-28.002546,-14.300258,-7.515660,-16.648066,-28.313601,-2.135618,-2.449620,-5.779906,-3.049522,-3.823579,-13.503217,-1.281487,-4.810882,-4.067439,-9.224874,-13.857028,-12.634272,-8.779547,-6.181548,-4.913281,-9.196618,-12.490766,-14.382434,-13.531489,-14.927424,-0.901370,-14.270227,-15.679343,-14.062472,-12.178649,-12.561431,-12.056256,-14.757603,-5.541435,-11.511449,-17.378422,-15.678811,-13.392316,-15.731251,-12.258845,-15.973686,-15.799025,-14.738399,-18.552027,-12.960588
3,-29.321672,-14.742002,-10.494133,-21.474413,-29.464661,-1.240499,-1.439253,-2.557791,-2.673108,-2.642319,-11.714100,-2.930491,-3.241757,-2.945319,-9.949409,-11.911226,-12.674129,-8.125805,-3.134292,-2.742188,-9.502390,-15.150566,-14.610735,-13.142111,-14.498672,-6.220497,-13.782190,-16.134712,-12.854937,-10.645266,-14.442843,-10.791265,-14.294293,-7.785919,-13.682887,-16.397455,-12.737179,-13.719917,-14.867741,-13.119571,-14.403181,-13.089798,-12.391588,-17.883230,-14.669279
4,-29.810993,-13.486595,-2.099242,-18.828798,-29.946140,-0.935535,-3.040370,-4.980090,-2.823648,-4.027109,-12.270425,-3.031455,-5.842985,-4.155580,-10.499365,-12.039428,-13.849604,-9.992680,-6.797335,-5.321773,-10.742392,-13.902041,-17.262602,-14.406287,-16.202913,-6.778595,-13.991510,-17.975685,-12.098310,-11.173542,-15.073608,-10.648915,-17.584824,-1.276396,-15.358470,-17.951000,-15.336847,-14.650130,-16.492916,-11.419200,-16.456060,-15.114813,-15.330233,-18.195072,-14.035368
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
216,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
217,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
218,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639


In [None]:
torch.from_numpy(s).view(-1, len(vocab)).shape
                    

torch.Size([220, 45])

In [None]:
xid[1].contiguous().view(-1).shape

torch.Size([220])

In [None]:
mol1 = X[:,2,:]
mol2 = X[:,40,:]

In [None]:
mol1.shape

(220, 256)

In [None]:
def linear_interpolation(mol_from, mol_to, steps):
    n = steps + 1
    diff = mol_to - mol_from
    inter = mol_from + 1 / steps * diff
    print(inter.shape)
    for i in range(2,n):
      add = mol_from + i / steps * diff
      inter = np.dstack((inter, add))
     return inter

In [None]:
molecule_morph = linear_interpolation(mol1, mol2, 20)
molecule_morph2 = molecule_morph.reshape((220,20,256))

(220, 256)


In [None]:
pd.DataFrame(molecule_morph2[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,0.077286,0.085127,0.092968,0.100809,0.108650,0.116491,0.124332,0.132173,0.140014,0.147856,0.155697,0.163538,0.171379,0.179220,0.187061,0.194902,0.202743,0.210584,0.218425,0.226266,-0.014650,-0.011636,-0.008622,-0.005608,-0.002594,0.000420,0.003434,0.006448,0.009462,0.012476,0.015491,0.018505,0.021519,0.024533,0.027547,0.030561,0.033575,0.036589,0.039603,0.042617,...,-0.211053,-0.200482,-0.189911,-0.179341,1.620298,1.587653,1.555008,1.522363,1.489718,1.457073,1.424428,1.391783,1.359138,1.326493,1.293848,1.261203,1.228558,1.195913,1.163268,1.130623,1.097978,1.065333,1.032688,1.000043,-0.036918,-0.036767,-0.036616,-0.036465,-0.036314,-0.036164,-0.036013,-0.035862,-0.035711,-0.035560,-0.035410,-0.035259,-0.035108,-0.034957,-0.034806,-0.034656
1,-0.319582,-0.304199,-0.288815,-0.273431,-0.258047,-0.242663,-0.227279,-0.211895,-0.196512,-0.181128,-0.165744,-0.150360,-0.134976,-0.119592,-0.104208,-0.088825,-0.073441,-0.058057,-0.042673,-0.027289,0.002542,0.010421,0.018300,0.026179,0.034058,0.041938,0.049817,0.057696,0.065575,0.073454,0.081334,0.089213,0.097092,0.104971,0.112850,0.120730,0.128609,0.136488,0.144367,0.152246,...,0.391709,0.419254,0.446799,0.474344,-0.188792,-0.172265,-0.155739,-0.139213,-0.122687,-0.106160,-0.089634,-0.073108,-0.056582,-0.040055,-0.023529,-0.007003,0.009524,0.026050,0.042576,0.059102,0.075629,0.092155,0.108681,0.125207,0.008961,0.000886,-0.007190,-0.015266,-0.023342,-0.031418,-0.039494,-0.047570,-0.055646,-0.063722,-0.071798,-0.079874,-0.087950,-0.096026,-0.104102,-0.112178
2,-0.336954,-0.314595,-0.292235,-0.269876,-0.247517,-0.225158,-0.202799,-0.180440,-0.158081,-0.135722,-0.113363,-0.091003,-0.068644,-0.046285,-0.023926,-0.001567,0.020792,0.043151,0.065510,0.087869,0.220183,0.222027,0.223872,0.225716,0.227560,0.229405,0.231249,0.233093,0.234938,0.236782,0.238627,0.240471,0.242315,0.244160,0.246004,0.247848,0.249693,0.251537,0.253381,0.255226,...,-0.726708,-0.778144,-0.829579,-0.881014,-0.186539,-0.233490,-0.280440,-0.327391,-0.374342,-0.421292,-0.468243,-0.515194,-0.562144,-0.609095,-0.656045,-0.702996,-0.749947,-0.796897,-0.843848,-0.890799,-0.937749,-0.984700,-1.031651,-1.078601,0.118115,0.119859,0.121603,0.123347,0.125090,0.126834,0.128578,0.130322,0.132066,0.133809,0.135553,0.137297,0.139041,0.140785,0.142528,0.144272
3,-0.124127,-0.122439,-0.120751,-0.119062,-0.117374,-0.115686,-0.113997,-0.112309,-0.110621,-0.108932,-0.107244,-0.105556,-0.103867,-0.102179,-0.100491,-0.098802,-0.097114,-0.095426,-0.093737,-0.092049,-0.386870,-0.348360,-0.309850,-0.271341,-0.232831,-0.194321,-0.155812,-0.117302,-0.078792,-0.040283,-0.001773,0.036737,0.075246,0.113756,0.152266,0.190775,0.229285,0.267795,0.306304,0.344814,...,0.030888,0.017653,0.004419,-0.008815,-0.190559,-0.200069,-0.209578,-0.219087,-0.228597,-0.238106,-0.247615,-0.257125,-0.266634,-0.276143,-0.285653,-0.295162,-0.304671,-0.314181,-0.323690,-0.333199,-0.342709,-0.352218,-0.361727,-0.371237,-0.001295,-0.000961,-0.000626,-0.000292,0.000043,0.000377,0.000712,0.001046,0.001381,0.001715,0.002050,0.002384,0.002719,0.003053,0.003388,0.003722
4,-0.295057,-0.255147,-0.215237,-0.175328,-0.135418,-0.095508,-0.055599,-0.015689,0.024221,0.064130,0.104040,0.143950,0.183859,0.223769,0.263679,0.303588,0.343498,0.383408,0.423317,0.463227,-0.007891,-0.010445,-0.012998,-0.015552,-0.018106,-0.020659,-0.023213,-0.025766,-0.028320,-0.030874,-0.033427,-0.035981,-0.038534,-0.041088,-0.043642,-0.046195,-0.048749,-0.051302,-0.053856,-0.056410,...,-0.124705,-0.127538,-0.130370,-0.133202,-0.193347,-0.181375,-0.169404,-0.157433,-0.145461,-0.133490,-0.121519,-0.109547,-0.097576,-0.085604,-0.073633,-0.061662,-0.049690,-0.037719,-0.025748,-0.013776,-0.001805,0.010166,0.022138,0.034109,0.025847,0.034657,0.043466,0.052276,0.061085,0.069895,0.078704,0.087514,0.096323,0.105133,0.113943,0.122752,0.131562,0.140371,0.149181,0.157990
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
216,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
217,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
218,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632


In [None]:
# decoded_interpolations = trfm.decode(torch.t(molecule_morph))
# mols = [Chem.MolFromSmiles(sm) for sm in df['smiles'].values[ids]]
# dr = plot_mols(mols, 250, 175, 250, 1.1)
# with open('bbbp_mol.svg', 'w') as f:
#     f.write(dr.GetDrawingText())
# SVG(dr.GetDrawingText())