<a href="https://colab.research.google.com/github/ajia90/smilestransformer/blob/main/interpolationtraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports/ setup

In [None]:
# Install RDKit. Takes 2-3 minutes
!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge python=3.7 
!time conda install -q -y -c conda-forge rdkit 

In [2]:
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

In [None]:
#!conda install pytorch torchvision -c pytorch

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import DrawingOptions
from torch.nn import functional as F


# methods/model


In [4]:
def plot_mols(mols, unit=200, w=120, h=200, fontsize=1.0):
    drawer = Draw.MolDraw2DSVG(4*unit, 3*unit, w, h)

    # optをとり出しておく
    opt = drawer.drawOptions()
    opt.padding = 0.1
    opt.legendFontSize = 20
    #opt.atomfontSize = 20

    xs = np.array([0,1,2,3,0,1,2,3,0,1,2,3])*unit
    ys = np.array([0,0,0,0,1,1,1,1,2,2,2,2])*unit
    for i, (mol, x, y) in enumerate(zip(mols,xs,ys)):
        # SetOffsetで左上の座標を指定できる
        drawer.SetOffset(int(x), int(y))
        drawer.SetFontSize(fontsize)

        AllChem.Compute2DCoords(mol)
        Chem.Kekulize(mol)
        # 分子をSVGに書く
        drawer.DrawMolecule(mol, legend=str(i))


    # </svg> 書く
    drawer.FinishDrawing()
    return drawer

In [5]:
def get_inputs(sm):
    seq_len = 220
    sm = sm.split()
    if len(sm)>218:
        print('SMILES is too long ({:d})'.format(len(sm)))
        sm = sm[:109]+sm[-109:]
    ids = [vocab.stoi.get(token, unk_index) for token in sm]
    ids = [sos_index] + ids + [eos_index]
    seg = [1]*len(ids)
    padding = [pad_index]*(seq_len - len(ids))
    ids.extend(padding), seg.extend(padding)
    return ids, seg

def get_array(smiles):
    x_id, x_seg = [], []
    for sm in smiles:
        a,b = get_inputs(sm)
        x_id.append(a)
        x_seg.append(b)
    return torch.tensor(x_id), torch.tensor(x_seg)

In [6]:
import torch
from pretrain_trfm import TrfmSeq2seq
from pretrain_rnn import RNNSeq2Seq
#from bert import BERT
from build_vocab import WordVocab
from utils import split

pad_index = 0
unk_index = 1
eos_index = 2
sos_index = 3
mask_index = 4

vocab = WordVocab.load_vocab('vocab.pkl')

trfm_c= TrfmSeq2seq(len(vocab), 256, len(vocab), 4).cuda()
trfm_c.load_state_dict(torch.load('trfm.pkl'))
trfm_c.eval()

trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
trfm.load_state_dict(torch.load('trfm.pkl'))
trfm.eval()
print('Total parameters:', sum(p.numel() for p in trfm.parameters()))

Total parameters: 4245037


In [8]:
smiles_dict = vocab.stoi

In [None]:
smiles_dict


# data


In [None]:
#read in BBBp data
df = pd.read_csv('BBBP.csv')
print(df.shape)
df.head()

(2050, 4)


Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...


In [7]:
#sample of chembl25 data
df2 = pd.read_csv('smiles_sample2.csv')
print(df2.shape)

(1052, 1)


In [8]:
x_split = [split(sm) for sm in df2.iloc[162:165]['canonical_smiles'].values]
xid, xseg = get_array(x_split)

In [12]:
len(df2.iloc[162]['canonical_smiles'])

51

In [13]:
len(df2.iloc[164]['canonical_smiles'])

51

# full transformer feedforward


In [9]:
from dataset import Seq2seqDataset
from torch.utils.data import DataLoader
from tqdm import tqdm


In [None]:
# dataset = Seq2seqDataset(df['smiles'].values, vocab)
# data_loader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=16)



In [4]:
# for b, sm in tqdm(enumerate(data_loader)):
#   sm = torch.t(sm.cuda()) # (T,B)
#   output1 = trfm_c(sm) # (T,,V)



In [10]:
total_out = trfm(torch.t(xid))

In [15]:
total_out.shape

torch.Size([220, 3, 45])

In [16]:
pd.DataFrame(total_out.detach().numpy()[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44
0,-44.175671,-35.747681,-31.292555,0.000000,-44.260025,-28.635603,-27.625763,-27.809988,-28.519175,-28.752401,-29.250795,-28.165989,-30.510796,-29.545069,-30.112347,-36.913124,-30.936882,-33.528564,-33.910561,-37.071022,-35.486027,-32.370548,-33.820328,-27.984274,-36.689098,-36.346581,-36.033531,-35.381283,-37.492249,-39.098457,-38.080399,-36.582424,-37.780548,-29.739103,-38.661522,-36.839684,-34.670715,-38.427593,-36.948112,-35.530842,-37.118153,-36.615639,-37.581856,-39.887928,-35.878986
1,-26.323692,-13.302274,-7.293172,-19.012615,-26.589888,-1.471507,-1.359549,-2.515324,-2.585663,-2.875618,-3.296388,-3.311698,-2.544843,-2.919650,-3.226220,-10.517247,-3.679514,-11.390262,-12.023293,-12.094478,-12.414701,-12.701138,-4.783476,-3.860903,-14.093339,-13.104815,-11.625642,-10.066242,-12.079480,-16.725693,-12.065263,-13.024291,-13.863189,-12.979401,-16.238556,-13.675254,-13.637882,-16.865936,-17.264095,-16.858646,-16.097912,-19.203560,-17.509773,-19.802118,-18.621206
2,-26.443110,-13.190294,-7.166296,-19.076300,-26.692677,-1.437387,-1.409159,-2.741089,-2.418569,-2.664074,-3.333209,-2.945692,-2.768120,-2.960833,-3.286055,-10.643085,-3.662803,-11.407366,-11.825718,-11.924688,-12.261599,-13.015224,-4.669723,-3.977208,-14.344002,-12.985704,-11.321815,-9.884907,-11.794058,-16.802887,-12.270813,-13.028002,-14.110308,-12.847872,-16.060928,-13.464298,-13.658043,-16.585810,-17.179873,-16.712868,-15.931524,-19.095335,-17.293543,-19.711664,-18.669382
3,-26.020302,-13.246327,-7.860716,-18.837280,-26.236397,-1.394491,-1.447143,-2.436429,-2.588636,-2.943488,-3.322718,-3.214290,-2.608548,-2.942546,-3.267622,-10.746580,-3.552000,-11.660261,-12.198205,-12.307819,-12.542967,-12.218102,-4.574267,-3.849207,-13.934746,-12.512344,-11.337343,-9.734991,-11.933450,-16.836124,-11.647294,-12.769719,-13.639711,-12.731513,-16.087334,-13.386945,-13.528852,-16.680016,-17.045216,-16.693459,-15.938904,-19.177122,-17.486834,-19.564457,-18.387510
4,-28.262392,-13.449371,-7.368211,-19.396257,-28.552042,-2.002810,-1.905679,-2.529511,-2.453239,-3.022075,-3.146697,-1.036299,-3.846676,-2.931501,-4.268562,-10.834663,-5.534667,-12.020775,-11.973896,-11.905186,-11.829760,-13.425762,-7.567640,-4.574835,-14.525647,-13.413580,-11.965914,-12.140483,-12.955153,-16.853287,-12.491102,-14.111275,-14.485508,-12.011520,-14.653417,-15.343867,-14.097477,-16.393097,-17.412333,-16.047474,-14.896681,-19.095831,-16.953480,-21.015871,-18.256979
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
216,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
217,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
218,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465


# encode


In [17]:
#encode function
X = trfm.encode(torch.t(xid))
print(X.shape)

(220, 3, 256)


In [18]:
#trfm.encoder
embedded = trfm.embed(torch.t(xid))  # (T,B,H)
embedded = trfm.pe(embedded) # (T,B,H)
output = embedded

output = trfm.trfm.encoder(output)
output = output.detach().numpy()

In [21]:
pd.DataFrame(X[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,0.454708,0.080336,-0.294366,-0.401369,0.541710,-0.200889,-0.120335,-0.339083,-0.101097,0.229046,-0.445155,0.211485,-0.205020,0.664492,0.185645,0.124312,-0.342530,0.481596,-0.401595,0.014376,-0.054348,0.089179,-0.282574,0.331085,-0.352750,-0.113576,-0.100597,0.064156,-0.499652,0.099948,-0.229785,0.066797,0.098899,0.519214,-0.268636,0.056397,0.265923,0.103134,-0.056445,-0.672909,...,-0.113477,0.296749,0.091493,0.069218,0.420627,-0.165765,-0.225396,-0.282655,0.151135,0.404711,0.158811,-0.216490,0.163378,0.818042,0.486945,-0.419192,-1.384934,0.327678,0.130529,0.120827,0.691080,-0.539669,0.176293,0.265167,0.047067,-0.310800,0.179857,0.472129,-0.207352,1.645416,-0.911975,-0.592123,0.113224,0.239112,-0.312228,-0.003359,0.727808,0.533852,-0.238266,0.362855
1,0.096572,-0.001399,-0.232750,-0.283242,0.021962,-0.218613,-0.264428,-0.028616,-0.034521,-0.319740,0.305263,-0.816265,-0.025122,0.711028,0.244499,-0.129888,-0.127948,1.198038,-0.405273,0.161967,0.332475,-0.063572,0.022592,-0.272858,-0.161523,-0.104455,-0.216542,-0.123286,-0.167681,0.080028,-0.038022,0.402875,-0.376032,-0.135759,-0.039112,0.117376,0.142446,-0.088147,0.126099,-0.279504,...,0.098807,0.126360,-0.193077,-0.771844,0.716417,-0.108638,-0.015464,-0.251625,0.073495,0.079394,0.098538,-0.198694,0.115210,0.105099,0.478798,-0.179246,-0.827921,-0.047870,-0.150012,0.229416,0.403532,-0.294955,0.391178,0.202953,0.164371,0.157130,0.139870,0.296469,-0.395560,0.206402,-0.063842,-0.341909,-0.115843,-0.436179,-0.151443,-0.169579,0.763682,0.237843,-0.163331,0.149988
2,0.117096,0.402802,-0.449836,-1.033805,0.141712,-0.429342,-0.210812,0.092066,0.030922,-0.494427,-0.194613,-0.962170,0.532097,0.497838,0.464924,0.430694,-0.519575,1.004808,-0.223635,0.191375,0.379482,0.059067,-0.208432,-0.311391,-0.048323,-0.691355,0.362867,0.221461,-0.134918,-0.084183,-0.267636,-0.506540,-0.445274,0.004556,-0.130335,0.296921,0.682433,-0.023733,0.222944,-0.106098,...,0.210007,-0.452477,-0.569182,0.196491,0.207027,-0.149190,0.215945,-0.337185,0.269829,0.191722,0.071451,0.518645,0.037460,1.434995,0.210121,0.002721,-1.948874,0.446429,0.234299,-0.216470,0.712723,-0.390150,0.351521,0.362909,0.075845,0.150345,-0.316661,0.166582,0.113761,0.486705,-0.725282,-0.683606,-0.000458,-0.592681,0.008754,0.490660,0.705642,0.129619,-0.426470,1.047946
3,0.029140,-0.173007,-0.035160,-0.253242,-0.220172,-0.618837,-0.222439,-0.138516,0.254560,-0.002792,-0.028122,-1.108267,0.247467,0.479325,0.552668,0.180563,-0.321407,0.698443,0.229832,0.119159,0.009356,-0.025854,-0.358923,-0.043755,0.202075,-0.116351,-0.042463,-0.148999,-0.328305,0.169367,-0.125129,0.359951,-0.215271,-0.244864,0.147397,0.004064,-0.136148,-0.413386,-0.032845,0.079964,...,-0.118971,-0.145324,0.046045,0.262360,0.599527,0.141921,0.136837,-0.180848,0.158193,-0.039137,-0.004125,-0.569593,0.065075,0.372793,0.388905,-0.272986,-1.629066,0.253706,-0.129571,0.460231,-0.103492,-0.438113,0.415432,0.500873,0.056621,-0.060449,0.105304,-0.336540,0.185622,0.791139,-0.396957,-0.408002,-0.150851,-0.661160,-0.203430,0.175949,0.287032,0.128584,-0.277056,0.458057
4,-0.041524,0.086308,-0.118732,-0.367404,-0.090567,-0.214584,-0.500232,0.134418,0.158793,-0.141809,0.021213,-1.180993,0.028826,0.643279,0.413703,0.084214,-0.289736,0.601917,0.252522,-0.038898,0.239144,0.058708,-0.392334,-0.219020,-0.063629,-0.201278,-0.090049,0.271942,-0.386986,0.222466,-0.368982,0.250737,0.078833,-0.249764,-0.346872,0.433003,0.330628,-0.396010,-0.177195,-0.057273,...,0.417420,-0.092456,-0.293656,-0.069984,0.142354,-0.337069,0.168362,0.015887,0.266458,0.341685,0.336643,-0.354652,0.230625,0.644350,0.491924,-0.525562,-0.753347,0.495169,-0.035378,0.199521,1.168490,-0.349137,-0.029784,0.438653,0.065312,0.125178,0.156485,0.150127,0.354485,1.445903,-0.369432,-0.173093,0.153058,0.235338,-0.211106,0.016361,0.811819,-0.050305,-0.641572,0.190057
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
216,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
217,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
218,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409


In [20]:
pd.DataFrame(output[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,0.454708,0.080336,-0.294366,-0.401369,0.541710,-0.200889,-0.120335,-0.339083,-0.101097,0.229046,-0.445155,0.211485,-0.205020,0.664492,0.185645,0.124312,-0.342530,0.481596,-0.401595,0.014376,-0.054348,0.089179,-0.282574,0.331085,-0.352750,-0.113576,-0.100597,0.064156,-0.499652,0.099948,-0.229785,0.066797,0.098899,0.519214,-0.268636,0.056397,0.265923,0.103134,-0.056445,-0.672909,...,-0.113477,0.296749,0.091493,0.069218,0.420627,-0.165765,-0.225396,-0.282655,0.151135,0.404711,0.158811,-0.216490,0.163378,0.818042,0.486945,-0.419192,-1.384934,0.327678,0.130529,0.120827,0.691080,-0.539669,0.176293,0.265167,0.047067,-0.310800,0.179857,0.472129,-0.207352,1.645416,-0.911975,-0.592123,0.113224,0.239112,-0.312228,-0.003359,0.727808,0.533852,-0.238266,0.362855
1,0.096572,-0.001399,-0.232750,-0.283242,0.021962,-0.218613,-0.264428,-0.028616,-0.034521,-0.319740,0.305263,-0.816265,-0.025122,0.711028,0.244499,-0.129888,-0.127948,1.198038,-0.405273,0.161967,0.332475,-0.063572,0.022592,-0.272858,-0.161523,-0.104455,-0.216542,-0.123286,-0.167681,0.080028,-0.038022,0.402875,-0.376032,-0.135759,-0.039112,0.117376,0.142446,-0.088147,0.126099,-0.279504,...,0.098807,0.126360,-0.193077,-0.771844,0.716417,-0.108638,-0.015464,-0.251625,0.073495,0.079394,0.098538,-0.198694,0.115210,0.105099,0.478798,-0.179246,-0.827921,-0.047870,-0.150012,0.229416,0.403532,-0.294955,0.391178,0.202953,0.164371,0.157130,0.139870,0.296469,-0.395560,0.206402,-0.063842,-0.341909,-0.115843,-0.436179,-0.151443,-0.169579,0.763682,0.237843,-0.163331,0.149988
2,0.117096,0.402802,-0.449836,-1.033805,0.141712,-0.429342,-0.210812,0.092066,0.030922,-0.494427,-0.194613,-0.962170,0.532097,0.497838,0.464924,0.430694,-0.519575,1.004808,-0.223635,0.191375,0.379482,0.059067,-0.208432,-0.311391,-0.048323,-0.691355,0.362867,0.221461,-0.134918,-0.084183,-0.267636,-0.506540,-0.445274,0.004556,-0.130335,0.296921,0.682433,-0.023733,0.222944,-0.106098,...,0.210007,-0.452477,-0.569182,0.196491,0.207027,-0.149190,0.215945,-0.337185,0.269829,0.191722,0.071451,0.518645,0.037460,1.434995,0.210121,0.002721,-1.948874,0.446429,0.234299,-0.216470,0.712723,-0.390150,0.351521,0.362909,0.075845,0.150345,-0.316661,0.166582,0.113761,0.486705,-0.725282,-0.683606,-0.000458,-0.592681,0.008754,0.490660,0.705642,0.129619,-0.426470,1.047946
3,0.029140,-0.173007,-0.035160,-0.253242,-0.220172,-0.618837,-0.222439,-0.138516,0.254560,-0.002792,-0.028122,-1.108267,0.247467,0.479325,0.552668,0.180563,-0.321407,0.698443,0.229832,0.119159,0.009356,-0.025854,-0.358923,-0.043755,0.202075,-0.116351,-0.042463,-0.148999,-0.328305,0.169367,-0.125129,0.359951,-0.215271,-0.244864,0.147397,0.004064,-0.136148,-0.413386,-0.032845,0.079964,...,-0.118971,-0.145324,0.046045,0.262360,0.599527,0.141921,0.136837,-0.180848,0.158193,-0.039137,-0.004125,-0.569593,0.065075,0.372793,0.388905,-0.272986,-1.629066,0.253706,-0.129571,0.460231,-0.103492,-0.438113,0.415432,0.500873,0.056621,-0.060449,0.105304,-0.336540,0.185622,0.791139,-0.396957,-0.408002,-0.150851,-0.661160,-0.203430,0.175949,0.287032,0.128584,-0.277056,0.458057
4,-0.041524,0.086308,-0.118732,-0.367404,-0.090567,-0.214584,-0.500232,0.134418,0.158793,-0.141809,0.021213,-1.180993,0.028826,0.643279,0.413703,0.084214,-0.289736,0.601917,0.252522,-0.038898,0.239144,0.058708,-0.392334,-0.219020,-0.063629,-0.201278,-0.090049,0.271942,-0.386986,0.222466,-0.368982,0.250737,0.078833,-0.249764,-0.346872,0.433003,0.330628,-0.396010,-0.177195,-0.057273,...,0.417420,-0.092456,-0.293656,-0.069984,0.142354,-0.337069,0.168362,0.015887,0.266458,0.341685,0.336643,-0.354652,0.230625,0.644350,0.491924,-0.525562,-0.753347,0.495169,-0.035378,0.199521,1.168490,-0.349137,-0.029784,0.438653,0.065312,0.125178,0.156485,0.150127,0.354485,1.445903,-0.369432,-0.173093,0.153058,0.235338,-0.211106,0.016361,0.811819,-0.050305,-0.641572,0.190057
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
216,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
217,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409
218,0.140701,0.509276,0.050154,-0.100147,0.028786,-0.170915,0.029878,-0.045099,-0.107430,0.201016,-0.183036,1.054294,0.152804,0.227688,0.049680,0.094654,-0.193559,0.233261,-0.055860,-0.096291,-0.042198,-0.079879,-0.262906,-0.237771,-0.157021,0.032711,-0.103149,0.190604,-0.207629,-0.105669,-0.090645,0.200079,-2.918440,-0.034615,-0.373975,0.228271,0.055789,0.100242,0.030708,-0.088818,...,0.126009,0.404953,-0.316168,0.195166,0.210035,-0.079260,-0.116465,0.038719,0.509658,0.142061,-0.122091,0.001951,0.101457,-0.038431,-0.017664,-0.230765,-0.866622,0.132721,0.100462,0.137835,0.280374,-0.371303,0.002539,-0.045281,0.083558,-0.071286,-0.077155,0.297261,-0.057134,0.583708,-0.007375,0.185255,-0.110206,-0.086687,0.104460,0.020271,0.463251,0.441999,-0.022306,0.280409


In [22]:

# for b, sm in tqdm(enumerate(data_loader)):
#   embedded = trfm_c.embed(torch.t(sm.cuda()))  # (T,B,H)
#   embedded = trfm_c.pe(embedded) # (T,B,H)
#   output = embedded
#   output = trfm_c.trfm.encoder(output)
#   if trfm_c.trfm.encoder.norm:
#       output = trfm_c.trfm.encoder.norm(output) # (T,B,H)
#   output = output.cpu().detach().numpy()


# decode

In [23]:
t = torch.t(xid)
t.shape


torch.Size([220, 3])

In [24]:
target = trfm.embed(t)
target = trfm.pe(target)


In [25]:
#trfm.decoder
decoded5 = trfm.trfm.decoder(target, torch.from_numpy(X).float())
decoded1 = decoded5.detach().numpy()

In [29]:
out = torch.from_numpy(decoded1).float()
out = trfm.out(out) # (T,B,V)
out = F.log_softmax(out, dim=2)
out = out.detach().numpy()

In [30]:
pd.DataFrame(out[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44
0,-44.175671,-35.747681,-31.292555,0.000000,-44.260025,-28.635603,-27.625763,-27.809988,-28.519175,-28.752401,-29.250795,-28.165989,-30.510796,-29.545069,-30.112347,-36.913124,-30.936882,-33.528564,-33.910561,-37.071022,-35.486027,-32.370548,-33.820328,-27.984274,-36.689098,-36.346581,-36.033531,-35.381283,-37.492249,-39.098457,-38.080399,-36.582424,-37.780548,-29.739103,-38.661522,-36.839684,-34.670715,-38.427593,-36.948112,-35.530842,-37.118153,-36.615639,-37.581856,-39.887928,-35.878986
1,-26.323692,-13.302274,-7.293172,-19.012615,-26.589888,-1.471507,-1.359549,-2.515324,-2.585663,-2.875618,-3.296388,-3.311698,-2.544843,-2.919650,-3.226220,-10.517247,-3.679514,-11.390262,-12.023293,-12.094478,-12.414701,-12.701138,-4.783476,-3.860903,-14.093339,-13.104815,-11.625642,-10.066242,-12.079480,-16.725693,-12.065263,-13.024291,-13.863189,-12.979401,-16.238556,-13.675254,-13.637882,-16.865936,-17.264095,-16.858646,-16.097912,-19.203560,-17.509773,-19.802118,-18.621206
2,-26.443110,-13.190294,-7.166296,-19.076300,-26.692677,-1.437387,-1.409159,-2.741089,-2.418569,-2.664074,-3.333209,-2.945692,-2.768120,-2.960833,-3.286055,-10.643085,-3.662803,-11.407366,-11.825718,-11.924688,-12.261599,-13.015224,-4.669723,-3.977208,-14.344002,-12.985704,-11.321815,-9.884907,-11.794058,-16.802887,-12.270813,-13.028002,-14.110308,-12.847872,-16.060928,-13.464298,-13.658043,-16.585810,-17.179873,-16.712868,-15.931524,-19.095335,-17.293543,-19.711664,-18.669382
3,-26.020302,-13.246327,-7.860716,-18.837280,-26.236397,-1.394491,-1.447143,-2.436429,-2.588636,-2.943488,-3.322718,-3.214290,-2.608548,-2.942546,-3.267622,-10.746580,-3.552000,-11.660261,-12.198205,-12.307819,-12.542967,-12.218102,-4.574267,-3.849207,-13.934746,-12.512344,-11.337343,-9.734991,-11.933450,-16.836124,-11.647294,-12.769719,-13.639711,-12.731513,-16.087334,-13.386945,-13.528852,-16.680016,-17.045216,-16.693459,-15.938904,-19.177122,-17.486834,-19.564457,-18.387510
4,-28.262392,-13.449371,-7.368211,-19.396257,-28.552042,-2.002810,-1.905679,-2.529511,-2.453239,-3.022075,-3.146697,-1.036299,-3.846676,-2.931501,-4.268562,-10.834663,-5.534667,-12.020775,-11.973896,-11.905186,-11.829760,-13.425762,-7.567640,-4.574835,-14.525647,-13.413580,-11.965914,-12.140483,-12.955153,-16.853287,-12.491102,-14.111275,-14.485508,-12.011520,-14.653417,-15.343867,-14.097477,-16.393097,-17.412333,-16.047474,-14.896681,-19.095831,-16.953480,-21.015871,-18.256979
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
216,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
217,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465
218,-26.232956,-11.488969,-1.134130,-19.844969,-26.488903,-2.460382,-1.946802,-6.698808,-2.872375,-3.096838,-4.274427,-1.552228,-4.789965,-2.375774,-4.601269,-10.634704,-6.965468,-11.971305,-13.711863,-10.445302,-12.029331,-13.400041,-9.318451,-4.707634,-13.675148,-10.889783,-12.075483,-12.415662,-12.305228,-17.010628,-14.783117,-13.518417,-14.937438,-11.410779,-13.976059,-14.822598,-15.958387,-15.960261,-16.742914,-15.558978,-15.172908,-19.216215,-16.665447,-19.531206,-17.109465


In [31]:
#decode function
decoded = trfm.decode(torch.from_numpy(X).float(), target)

In [None]:
decoded.shape

(220, 3, 45)

In [None]:
loss = F.nll_loss(out.view(-1, len(vocab)), torch.t(xid).contiguous().view(-1), ignore_index=0)

In [None]:
loss

tensor(2.0111, grad_fn=<NllLossBackward>)

# get smiles from decoded output


In [32]:
_, next_word = torch.max(torch.from_numpy(decoded), dim = 2)
decoded_smiles = torch.t(next_word).detach().numpy()
decoded_smiles.shape

(3, 220)

In [33]:
y =torch.t(xid).detach().numpy()


In [35]:
#value -> smiles
smiles_molecules = np.empty([decoded_smiles.shape[0],decoded_smiles.shape[1]], dtype=object)
for i in range(decoded_smiles.shape[0]):   
   smiless = [list(smiles_dict.keys())[list(smiles_dict.values()).index(elem)] for elem in decoded_smiles[i]]
   smiles_molecules[i] = smiless

In [36]:
pd.DataFrame(smiles_molecules)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219
0,<sos>,C,C,c,1,c,c,c,c,c,1,C,C,C,C,C,C,C,C,C,C,C,C,C,C,C,C,C,C,C,c,C,c,c,c,c,c,C,C,C,...,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>
1,<sos>,c,c,1,c,c,c,c,c,c,1,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,...,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>
2,<sos>,c,c,1,c,c,c,c,c,C,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,c,C,c,c,c,c,c,1,c,c,C,c,c,c,c,...,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>,<eos>


In [37]:
#put all characters into one continuous string
smiles_formatted = np.empty(decoded_smiles.shape[0], dtype=object)
for i in range(smiles_molecules.shape[0]):
  smile = smiles_molecules[i]
  end = np.where(smile == '<eos>')[0][0]
  smiles_formatted[i] = "".join(smile[1:end])

In [38]:
decoded_final = pd.DataFrame(smiles_formatted)
decoded_final

Unnamed: 0,0
0,CCc1ccccc1CCCCCCCCCCCCCCCCCCCcCcccccCCCCCCCCCC...
1,cc1cccccc1cccccccccccccccccccccccccccccccccccc...
2,cc1cccccCccccccccccccccccCccccc1ccCcccccCCCccC...


In [40]:
for i in range(3):
  print(i,len(decoded_final.iloc[i][0]))

0 52
1 65
2 51


In [None]:
plot_mols(smiles_formatted, unit=200, w=120, h=200, fontsize=1.0)


In [None]:
s = decoded[:,0,:]

In [None]:
f = pd.DataFrame(s)
f

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44
0,-45.470173,-37.101559,-29.256117,0.000000,-45.291073,-28.516005,-29.121420,-29.097557,-29.547464,-30.180914,-32.877251,-28.971113,-31.556475,-30.298185,-35.527153,-32.507462,-35.761383,-33.346615,-29.148830,-33.430061,-37.748947,-37.402824,-38.879234,-36.305046,-34.487492,-34.085678,-36.536816,-39.494839,-34.125221,-36.803223,-36.887764,-34.481155,-35.181801,-28.138823,-38.456833,-39.239799,-34.882401,-37.805706,-37.666882,-34.525105,-37.918995,-34.325474,-34.888958,-38.189602,-34.892010
1,-29.301823,-13.799294,-10.297996,-20.810789,-29.578262,-1.241233,-1.390985,-2.580844,-2.887537,-2.597407,-11.433615,-2.689757,-3.348817,-3.133340,-10.602304,-12.134955,-13.403796,-8.030526,-2.693455,-3.214077,-9.602731,-15.141812,-15.247645,-12.991599,-14.733134,-7.043643,-13.489715,-16.591805,-12.896339,-9.764510,-12.891412,-11.201367,-14.309287,-7.874004,-14.161911,-16.818937,-12.525611,-14.127555,-14.955525,-13.179366,-14.709735,-13.407912,-12.651450,-18.167086,-14.560893
2,-28.002546,-14.300258,-7.515660,-16.648066,-28.313601,-2.135618,-2.449620,-5.779906,-3.049522,-3.823579,-13.503217,-1.281487,-4.810882,-4.067439,-9.224874,-13.857028,-12.634272,-8.779547,-6.181548,-4.913281,-9.196618,-12.490766,-14.382434,-13.531489,-14.927424,-0.901370,-14.270227,-15.679343,-14.062472,-12.178649,-12.561431,-12.056256,-14.757603,-5.541435,-11.511449,-17.378422,-15.678811,-13.392316,-15.731251,-12.258845,-15.973686,-15.799025,-14.738399,-18.552027,-12.960588
3,-29.321672,-14.742002,-10.494133,-21.474413,-29.464661,-1.240499,-1.439253,-2.557791,-2.673108,-2.642319,-11.714100,-2.930491,-3.241757,-2.945319,-9.949409,-11.911226,-12.674129,-8.125805,-3.134292,-2.742188,-9.502390,-15.150566,-14.610735,-13.142111,-14.498672,-6.220497,-13.782190,-16.134712,-12.854937,-10.645266,-14.442843,-10.791265,-14.294293,-7.785919,-13.682887,-16.397455,-12.737179,-13.719917,-14.867741,-13.119571,-14.403181,-13.089798,-12.391588,-17.883230,-14.669279
4,-29.810993,-13.486595,-2.099242,-18.828798,-29.946140,-0.935535,-3.040370,-4.980090,-2.823648,-4.027109,-12.270425,-3.031455,-5.842985,-4.155580,-10.499365,-12.039428,-13.849604,-9.992680,-6.797335,-5.321773,-10.742392,-13.902041,-17.262602,-14.406287,-16.202913,-6.778595,-13.991510,-17.975685,-12.098310,-11.173542,-15.073608,-10.648915,-17.584824,-1.276396,-15.358470,-17.951000,-15.336847,-14.650130,-16.492916,-11.419200,-16.456060,-15.114813,-15.330233,-18.195072,-14.035368
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
216,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
217,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639
218,-24.232878,-12.585644,-0.830986,-14.553202,-24.596067,-4.898775,-4.376544,-8.502015,-5.587920,-5.466045,-13.074730,-2.987265,-6.579686,-5.436175,-8.467040,-12.829607,-11.095387,-11.870573,-8.540427,-5.627752,-14.002762,-15.134090,-13.715795,-14.223989,-14.374601,-0.892552,-13.978781,-14.728750,-14.302713,-12.033865,-10.474875,-11.713814,-15.120043,-2.716974,-8.644224,-15.116562,-15.661121,-11.041181,-15.388418,-10.108124,-14.954815,-15.181656,-14.064907,-15.308775,-11.246639


# interpolate


In [None]:
mol1 = X[:,0,:]
mol2 = X[:,2,:]

In [None]:
mol1

array([[ 0.4547085 ,  0.08033638, -0.29436618, ...,  0.53385216,
        -0.23826611,  0.36285454],
       [ 0.09657217, -0.00139888, -0.23275046, ...,  0.23784275,
        -0.16333099,  0.14998849],
       [ 0.1170961 ,  0.40280172, -0.44983634, ...,  0.12961872,
        -0.42646986,  1.0479465 ],
       ...,
       [ 0.14070055,  0.5092763 ,  0.05015361, ...,  0.4419992 ,
        -0.02230594,  0.2804095 ],
       [ 0.14070055,  0.5092763 ,  0.05015361, ...,  0.4419992 ,
        -0.02230594,  0.2804095 ],
       [ 0.14070055,  0.5092763 ,  0.05015361, ...,  0.4419992 ,
        -0.02230594,  0.2804095 ]], dtype=float32)

In [None]:
def linear_interpolation(mol_from, mol_to, steps):
    n = steps + 1
    diff = mol_to - mol_from
    inter = mol_from + 1 / steps * diff
    print(inter.shape)
    for i in range(2,n):
      add = mol_from + i / steps * diff
      inter = np.dstack((inter, add))
    return inter

In [None]:
molecule_morph = linear_interpolation(mol1, mol2, 20)
molecule_morph2 = molecule_morph.reshape((220,20,256))

(220, 256)


In [None]:
molecule_morph2.shape

(220, 20, 256)

In [None]:
pd.DataFrame(molecule_morph2[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,0.077286,0.085127,0.092968,0.100809,0.108650,0.116491,0.124332,0.132173,0.140014,0.147856,0.155697,0.163538,0.171379,0.179220,0.187061,0.194902,0.202743,0.210584,0.218425,0.226266,-0.014650,-0.011636,-0.008622,-0.005608,-0.002594,0.000420,0.003434,0.006448,0.009462,0.012476,0.015491,0.018505,0.021519,0.024533,0.027547,0.030561,0.033575,0.036589,0.039603,0.042617,...,-0.211053,-0.200482,-0.189911,-0.179341,1.620298,1.587653,1.555008,1.522363,1.489718,1.457073,1.424428,1.391783,1.359138,1.326493,1.293848,1.261203,1.228558,1.195913,1.163268,1.130623,1.097978,1.065333,1.032688,1.000043,-0.036918,-0.036767,-0.036616,-0.036465,-0.036314,-0.036164,-0.036013,-0.035862,-0.035711,-0.035560,-0.035410,-0.035259,-0.035108,-0.034957,-0.034806,-0.034656
1,-0.319582,-0.304199,-0.288815,-0.273431,-0.258047,-0.242663,-0.227279,-0.211895,-0.196512,-0.181128,-0.165744,-0.150360,-0.134976,-0.119592,-0.104208,-0.088825,-0.073441,-0.058057,-0.042673,-0.027289,0.002542,0.010421,0.018300,0.026179,0.034058,0.041938,0.049817,0.057696,0.065575,0.073454,0.081334,0.089213,0.097092,0.104971,0.112850,0.120730,0.128609,0.136488,0.144367,0.152246,...,0.391709,0.419254,0.446799,0.474344,-0.188792,-0.172265,-0.155739,-0.139213,-0.122687,-0.106160,-0.089634,-0.073108,-0.056582,-0.040055,-0.023529,-0.007003,0.009524,0.026050,0.042576,0.059102,0.075629,0.092155,0.108681,0.125207,0.008961,0.000886,-0.007190,-0.015266,-0.023342,-0.031418,-0.039494,-0.047570,-0.055646,-0.063722,-0.071798,-0.079874,-0.087950,-0.096026,-0.104102,-0.112178
2,-0.336954,-0.314595,-0.292235,-0.269876,-0.247517,-0.225158,-0.202799,-0.180440,-0.158081,-0.135722,-0.113363,-0.091003,-0.068644,-0.046285,-0.023926,-0.001567,0.020792,0.043151,0.065510,0.087869,0.220183,0.222027,0.223872,0.225716,0.227560,0.229405,0.231249,0.233093,0.234938,0.236782,0.238627,0.240471,0.242315,0.244160,0.246004,0.247848,0.249693,0.251537,0.253381,0.255226,...,-0.726708,-0.778144,-0.829579,-0.881014,-0.186539,-0.233490,-0.280440,-0.327391,-0.374342,-0.421292,-0.468243,-0.515194,-0.562144,-0.609095,-0.656045,-0.702996,-0.749947,-0.796897,-0.843848,-0.890799,-0.937749,-0.984700,-1.031651,-1.078601,0.118115,0.119859,0.121603,0.123347,0.125090,0.126834,0.128578,0.130322,0.132066,0.133809,0.135553,0.137297,0.139041,0.140785,0.142528,0.144272
3,-0.124127,-0.122439,-0.120751,-0.119062,-0.117374,-0.115686,-0.113997,-0.112309,-0.110621,-0.108932,-0.107244,-0.105556,-0.103867,-0.102179,-0.100491,-0.098802,-0.097114,-0.095426,-0.093737,-0.092049,-0.386870,-0.348360,-0.309850,-0.271341,-0.232831,-0.194321,-0.155812,-0.117302,-0.078792,-0.040283,-0.001773,0.036737,0.075246,0.113756,0.152266,0.190775,0.229285,0.267795,0.306304,0.344814,...,0.030888,0.017653,0.004419,-0.008815,-0.190559,-0.200069,-0.209578,-0.219087,-0.228597,-0.238106,-0.247615,-0.257125,-0.266634,-0.276143,-0.285653,-0.295162,-0.304671,-0.314181,-0.323690,-0.333199,-0.342709,-0.352218,-0.361727,-0.371237,-0.001295,-0.000961,-0.000626,-0.000292,0.000043,0.000377,0.000712,0.001046,0.001381,0.001715,0.002050,0.002384,0.002719,0.003053,0.003388,0.003722
4,-0.295057,-0.255147,-0.215237,-0.175328,-0.135418,-0.095508,-0.055599,-0.015689,0.024221,0.064130,0.104040,0.143950,0.183859,0.223769,0.263679,0.303588,0.343498,0.383408,0.423317,0.463227,-0.007891,-0.010445,-0.012998,-0.015552,-0.018106,-0.020659,-0.023213,-0.025766,-0.028320,-0.030874,-0.033427,-0.035981,-0.038534,-0.041088,-0.043642,-0.046195,-0.048749,-0.051302,-0.053856,-0.056410,...,-0.124705,-0.127538,-0.130370,-0.133202,-0.193347,-0.181375,-0.169404,-0.157433,-0.145461,-0.133490,-0.121519,-0.109547,-0.097576,-0.085604,-0.073633,-0.061662,-0.049690,-0.037719,-0.025748,-0.013776,-0.001805,0.010166,0.022138,0.034109,0.025847,0.034657,0.043466,0.052276,0.061085,0.069895,0.078704,0.087514,0.096323,0.105133,0.113943,0.122752,0.131562,0.140371,0.149181,0.157990
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
216,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
217,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632
218,0.112164,0.111278,0.110393,0.109507,0.108622,0.107737,0.106851,0.105966,0.105080,0.104195,0.103309,0.102424,0.101539,0.100653,0.099768,0.098882,0.097997,0.097111,0.096226,0.095341,0.426668,0.405090,0.383512,0.361934,0.340356,0.318777,0.297199,0.275621,0.254043,0.232465,0.210887,0.189309,0.167731,0.146152,0.124574,0.102996,0.081418,0.059840,0.038262,0.016684,...,0.001500,0.000638,-0.000224,-0.001085,1.636176,1.599683,1.563190,1.526697,1.490204,1.453711,1.417218,1.380725,1.344232,1.307739,1.271246,1.234753,1.198260,1.161767,1.125274,1.088782,1.052289,1.015796,0.979303,0.942810,0.053620,0.054621,0.055621,0.056622,0.057623,0.058624,0.059625,0.060625,0.061626,0.062627,0.063628,0.064628,0.065629,0.066630,0.067631,0.068632


In [None]:
X.shape

(220, 3, 256)

In [None]:
tgt_part = X[:,0,:]
target = np.zeros([220,20,256])
for i in range(20):
  target[:,i,:] = tgt_part

In [None]:
decoded_interpolations = trfm.decode(torch.from_numpy(molecule_morph2), torch.from_numpy(target).float())


In [None]:
_, next_word = torch.max(torch.from_numpy(decoded_interpolations), dim = 2)
decoded_interpolations = torch.t(next_word).detach().numpy()

In [None]:
smiles_molecules = np.empty([decoded_interpolations.shape[0],decoded_interpolations.shape[1]], dtype=object)

for i in range(decoded_interpolations.shape[0]):
#  smiles_molecules[i] = map(lambda elem: list(smiles_dict.keys())[list(smiles_dict.values()).index(elem)], decoded_smiles[i])
   
   smiless = [list(smiles_dict.keys())[list(smiles_dict.values()).index(elem)] for elem in decoded_interpolations[i]]
   smiles_molecules[i] = smiless

In [None]:
#put all characters into one continuous string
smiles_formatted = np.empty(decoded_interpolations.shape[0], dtype=object)
for i in range(smiles_molecules.shape[0]):
  print(i)
  smile = smiles_molecules[i]
  print(smile)
#  end = np.where(smile == '<eos>')[0][0]
  smiles_formatted[i] = "".join(smile)

In [None]:
nn = NearestNeighbors(metric='euclidean').fit(X_reduced)
xs = np.linspace(-24, 18, 12)
ys = np.linspace(25, -40, 12)
ids = []
pts = []
for x,y in zip(xs, ys):
    _, result = nn.kneighbors([[x, y]], n_neighbors=1)
    ids.append(result[0, 0])
    pts.append(X_reduced[result[0, 0]])
pts = np.array(pts)
mols = [Chem.MolFromSmiles(sm) for sm in df['smiles'].values[ids]]
dr = plot_mols(mols, 250, 175, 250, 1.1)
with open('bbbp_mol.svg', 'w') as f:
    f.write(dr.GetDrawingText())
SVG(dr.GetDrawingText())

NameError: ignored