<a href="https://colab.research.google.com/github/ajia90/smilestransformer/blob/main/interpolation_notgt_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports/ setup

mask + target = source -> molecules w/ some errors |||
mask + only start tokens -> only start toekn 


---



In [None]:
# Install RDKit. Takes 2-3 minutes
!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge python=3.7 
!time conda install -q -y -c conda-forge rdkit 

In [None]:
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

# training the model


In [None]:
!conda install pytorch torchvision -c pytorch


In [None]:
#!conda install pytorch torchvision -c pytorch

In [None]:
from google.colab import drive
drive.mount("/content/gdrive/")

Mounted at /content/gdrive/


In [None]:
 !python pretrain_trfm_target_zero.py

# methods/model


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import DrawingOptions
from torch.nn import functional as F


In [None]:
def plot_mols(mols, unit=200, w=120, h=200, fontsize=1.0):
    drawer = Draw.MolDraw2DSVG(4*unit, 3*unit, w, h)

    # optをとり出しておく
    opt = drawer.drawOptions()
    opt.padding = 0.1
    opt.legendFontSize = 20
    #opt.atomfontSize = 20

    xs = np.array([0,1,2,3,0,1,2,3,0,1,2,3])*unit
    ys = np.array([0,0,0,0,1,1,1,1,2,2,2,2])*unit
    for i, (mol, x, y) in enumerate(zip(mols,xs,ys)):
        # SetOffsetで左上の座標を指定できる
        drawer.SetOffset(int(x), int(y))
        drawer.SetFontSize(fontsize)

        AllChem.Compute2DCoords(mol)
        Chem.Kekulize(mol)
        # 分子をSVGに書く
        drawer.DrawMolecule(mol, legend=str(i))


    # </svg> 書く
    drawer.FinishDrawing()
    return drawer

In [None]:
def get_inputs(sm):
    seq_len = 220
    sm = sm.split()
    if len(sm)>218:
        print('SMILES is too long ({:d})'.format(len(sm)))
        sm = sm[:109]+sm[-109:]
    ids = [vocab.stoi.get(token, unk_index) for token in sm]
    ids = [sos_index] + ids + [eos_index]
    seg = [1]*len(ids)
    padding = [pad_index]*(seq_len - len(ids))
    ids.extend(padding), seg.extend(padding)
    return ids, seg

def get_array(smiles):
    x_id, x_seg = [], []
    for sm in smiles:
        a,b = get_inputs(sm)
        x_id.append(a)
        x_seg.append(b)
    return torch.tensor(x_id), torch.tensor(x_seg)

In [None]:
import torch
from pretrain_trfm_target_zero import TrfmSeq2seq
from build_vocab import WordVocab
from utils import split

pad_index = 0
unk_index = 1
eos_index = 2
sos_index = 3
mask_index = 4

vocab = WordVocab.load_vocab('vocab.pkl')

# trfm_c= TrfmSeq2seq(len(vocab), 256, len(vocab), 4).cuda()
# trfm_c.load_state_dict(torch.load('trfm_new_12_80000.pkl'))
# trfm_c.eval()

trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
trfm.load_state_dict(torch.load('trfm_notgt_12_90000.pkl'))
trfm.eval()
print('Total parameters:', sum(p.numel() for p in trfm.parameters()))

Total parameters: 4244013


In [None]:
smiles_dict = vocab.stoi

In [None]:
smiles_dict


# data


In [None]:
#read in BBBp data
# df = pd.read_csv('BBBP.csv')
# print(df.shape)
# df.head()

In [None]:
#sample of chembl25 data
df2 = pd.read_csv('smiles_sample2.csv')
print(df2.shape)

(1052, 1)


In [None]:
df2.head()

Unnamed: 0,canonical_smiles
0,Cc1cc(cn1C)c2csc(N=C(N)N)n2
1,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN4CCOCC4
2,COc1c(O)cc(O)c(C(=N)Cc2ccc(O)cc2)c1O
3,CCOC(=O)c1cc2cc(ccc2[nH]1)C(=O)O
4,C[C@H](NC(=O)OCc1ccccc1)C(=O)N[C@@H](C)C(=O)NN...


In [None]:
x_split = [split(sm) for sm in df2['canonical_smiles'].values]
xid, xseg = get_array(x_split)

In [None]:
xid.shape

torch.Size([1052, 220])

In [None]:
smile_lengths = {}
#for i in range(df2.shape[0]):
for k in range(df2.shape[0]):
  sl1 = len(df2['canonical_smiles'][k])
  sl2 = len(df2['canonical_smiles'][0])
  if (sl1 == sl2):
    if sl1 in smile_lengths.keys():
        smile_lengths[sl1].append((k))
    else:
        smile_lengths[sl1] = [(k)]

In [None]:
index = smile_lengths[27]
same_length_smiles = df2.loc[index]
same_length_smiles

Unnamed: 0,canonical_smiles
0,Cc1cc(cn1C)c2csc(N=C(N)N)n2
72,NC(=O)c1ccccc1n2cnc3ccccc23
241,COc1ccc2c(O)cc(OC)c(O)c2c1O
292,Nc1ccc2nc(oc2c1)c3ccc(F)cc3
302,CCOc1cc2c(cn1)[nH]c3ccccc23
304,O=C(c1ccccc1)c2nccc3ccccc23
312,CC(C)(C)C(N)C(=O)N1CCCC1C#N
445,[O-][S+]1N(Sc2ccccc12)C3CC3
547,COCC(C)(CS(=O)(=O)O)N(Cl)Cl
597,CC(C)(C)C(=O)Cn1c[n+](N)cn1


In [None]:
x_split = [split(sm) for sm in same_length_smiles['canonical_smiles'].values]
xid, xseg = get_array(x_split)

# encode


In [None]:
from dataset import Seq2seqDataset
from torch.utils.data import DataLoader
from tqdm import tqdm


In [None]:
#encode function
X = trfm.encode(torch.t(xid))
print(X.shape)

(220, 15, 256)


In [None]:
# #trfm.encoder
# dataset = Seq2seqDataset(df2['canonical_smiles'].values, vocab)
# data_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
# output_total = torch.empty(size=(xid.shape[1], xid.shape[0], 256))
# print(output_total.size)
# for b, sm in tqdm(enumerate(data_loader)):
#   # sm = torch.t(sm.cuda()) # (T,B)
#   # output1 = trfm_c(sm) # (T,,V)
#   embedded = trfm.embed(torch.t(sm.cuda()))  # (T,B,H)
#   embedded = trfm.pe(embedded) # (T,B,H)
#   output_total[:,b:b+4,:] = trfm.encoder(embedded)
# #output = output.detach().numpy()

# New Section

# mask


In [None]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


# decode

tgt = xid //source


In [None]:
mask1 = generate_square_subsequent_mask(torch.t(xid).shape[0])

In [None]:
mask1.shape

torch.Size([220, 220])

In [None]:
# #trfm.decoder
# decoded = trfm.decoder(output, output, mask1)
# out = trfm.out(decoded) # (T,B,V)
# out = F.log_softmax(out, dim=2)
# out = out.detach().numpy()

In [None]:
#decode function
hidden = torch.from_numpy(X).float()
decoded_output = trfm.decode(hidden)

In [None]:
decoded.shape

(220, 15, 45)

# get smiles from decoded output


In [None]:
def get_smiles(decoded):
  _, next_word = torch.max(torch.from_numpy(decoded), dim = 2)
  decoded_smiles = torch.t(next_word).detach().numpy()
  #y =torch.t(xid).detach().numpy()
  #value -> smiles
  smiles_molecules = np.empty([decoded_smiles.shape[0],decoded_smiles.shape[1]], dtype=object)
  print(smiles_molecules)
  for i in range(decoded_smiles.shape[0]):   
    smiless = [list(smiles_dict.keys())[list(smiles_dict.values()).index(elem)] for elem in decoded_smiles[i]]
    smiles_molecules[i] = smiless
  #put all characters into one continuous string
  smiles_formatted = np.empty(decoded_smiles.shape[0], dtype=object)
  for i in range(smiles_molecules.shape[0]):
    smile = smiles_molecules[i]
    end = np.where(smile == '<eos>')[0][0]
    smiles_formatted[i] = "".join(smile[1:end])
  return smiles_formatted


In [None]:
molecules = get_smiles(decoded_output)

In [None]:
decoded_final = pd.DataFrame(molecules)
decoded_final.columns = ['smiles']
decoded_final

Unnamed: 0,smiles
0,Cc1cc(cn1C]c2csc(N=C(N]N]n2
1,NC(=O]c1ccccc1n2cnc3ccccc23
2,COc1ccc2c(O]cc(OC]c(O]c2c1O
3,Nc1ccc2nc(oc2c1]c3ccc(F]cc3
4,CCOc1cc2c(cn1][nH]c3ccccc23
5,O=C(c1ccccc1]c2nccc3ccccc23
6,CC(C](C]C(N]C(=O]N1CCCC1CsN
7,[O-][S+]1N(Sc2ccccc12]C3CC3
8,COCC(C](CS(=O](=O]O]N(Cl]Cl
9,CC(C](C]C(=O]Cn1c[n+](N]cn1


In [None]:
same_length_smiles

Unnamed: 0,canonical_smiles
0,Cc1cc(cn1C)c2csc(N=C(N)N)n2
72,NC(=O)c1ccccc1n2cnc3ccccc23
241,COc1ccc2c(O)cc(OC)c(O)c2c1O
292,Nc1ccc2nc(oc2c1)c3ccc(F)cc3
302,CCOc1cc2c(cn1)[nH]c3ccccc23
304,O=C(c1ccccc1)c2nccc3ccccc23
312,CC(C)(C)C(N)C(=O)N1CCCC1C#N
445,[O-][S+]1N(Sc2ccccc12)C3CC3
547,COCC(C)(CS(=O)(=O)O)N(Cl)Cl
597,CC(C)(C)C(=O)Cn1c[n+](N)cn1


# New Section

In [None]:
mol1 = X[:,2,:]
mol2 = X[:,4,:]

In [None]:
pd.DataFrame(mol1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,-12.803356,6.788694,-10.489622,12.800719,11.396008,-16.653835,-10.381500,15.116764,16.915638,13.419433,-10.767895,17.192476,6.702651,-10.065237,11.506454,-10.859803,7.387299,6.052010,7.680329,-8.606855,-5.248012,16.344646,20.032137,-10.601822,8.502974,-11.174334,8.914734,10.782484,-9.883665,14.092573,-11.691484,13.981385,-15.001251,-11.642458,-10.048145,-14.281707,-19.514595,-8.681421,10.998167,-13.071776,...,11.958583,7.086503,13.198001,11.144255,-16.295580,2.736654,-21.444077,7.731144,9.608370,6.606700,-6.972273,-13.686513,5.498168,-7.370670,8.776616,-7.374544,13.985360,20.543451,-8.946177,10.945930,-14.920374,-12.754510,17.935326,8.527763,14.728496,8.610201,10.785258,-12.887712,-9.423468,11.900187,16.516445,15.008853,-12.285117,-10.342769,11.530216,8.959604,-8.799786,-19.456743,10.484785,-11.954103
1,-7.171934,12.894155,-6.956240,13.445880,10.280861,-18.052626,-9.009795,13.896951,10.553954,12.122952,-11.747062,14.757994,7.955058,-5.426693,11.354470,-11.005962,9.548174,17.575621,13.425718,-4.693926,-9.172297,14.858472,14.905234,-8.985991,17.287115,-16.719383,11.464545,17.324808,-12.554738,6.478510,-12.891435,13.546384,-10.408455,-13.729339,-10.306595,-16.611689,-14.422675,-10.022675,10.065098,-9.065392,...,14.028344,12.822985,11.130803,15.694044,-10.451122,7.146226,-15.941079,20.725443,8.187132,10.426474,-7.089782,-11.465517,8.652777,-14.210089,13.807182,-10.710464,11.051229,11.723392,-15.604832,7.924263,-18.417179,-11.880120,15.846006,12.132806,16.762739,8.001766,10.492572,-14.532620,-10.379836,14.415862,11.199519,6.282033,-2.507975,-12.077810,14.125642,10.766519,-10.315540,-11.110530,8.790992,-13.761786
2,-3.733560,3.552344,-3.153018,10.228532,11.000668,-22.695541,-8.748476,-0.993121,5.356827,12.752072,0.923853,11.871262,7.504683,-3.072999,21.432720,-13.965893,11.593856,10.241639,-2.543969,-5.210904,-2.770575,11.419711,14.857699,-10.589686,18.010054,-9.416499,6.754814,7.372495,-15.803809,2.658587,-10.706658,15.024889,-8.881316,-5.217547,-4.376481,-15.760055,-15.450557,-6.683535,4.823577,-7.091381,...,10.934430,9.243549,10.325013,12.932347,-12.203330,1.482648,-10.394934,11.563151,2.451394,10.669425,-6.380447,-15.834220,5.581871,-11.168127,14.762830,-4.767666,18.347437,18.986738,-14.449496,7.022230,-8.700366,-11.327778,22.037586,4.257614,12.840907,6.140222,7.802189,-3.340954,-11.017865,18.163439,9.747244,18.909496,-21.263847,-11.465522,16.230421,10.453148,-11.355357,-13.052875,2.451615,-9.004667
3,-12.993629,9.920752,-10.073252,11.693156,11.925569,-15.973969,-14.395805,11.260496,9.712104,9.080920,-7.415376,18.453592,10.381047,-11.305900,18.434790,-9.911259,8.591755,10.196766,3.539748,-10.358308,-10.343792,10.155920,16.897808,-10.718766,13.568964,-16.216614,14.179276,8.830261,-11.757533,10.777471,-12.280403,13.168190,-21.757603,-12.878099,-9.086560,-16.427111,-12.873577,-10.346504,9.619030,-11.457694,...,8.473235,13.196728,9.989080,16.064129,-13.858126,10.243100,-12.403658,12.002872,8.761887,10.781141,-8.463038,-13.653571,10.961841,-12.988787,8.173595,-13.073828,16.325399,21.101540,-8.439468,8.182730,-12.058236,-11.824684,19.251747,7.742739,14.400626,12.213207,11.357842,-14.650983,-12.382051,20.587200,12.111123,10.046788,-10.130484,-9.326023,9.567728,9.109931,-12.911180,-15.172270,4.914798,-14.727250
4,-2.052634,7.614048,-8.316536,7.232498,10.274465,-16.210451,-11.984109,7.944471,6.866225,6.320078,-7.578903,16.291029,6.315854,-17.337070,7.556752,-13.337531,8.805182,21.527012,4.447323,-17.209784,-11.968374,10.084806,11.644914,-15.237505,14.985333,-3.359427,10.955883,9.009671,-5.977899,15.582176,-13.125157,3.190409,3.013995,-10.861408,-13.020073,-10.423079,-9.198493,-9.878391,13.703934,-3.373550,...,6.628816,10.604772,13.552889,14.679960,-12.714017,4.202743,-14.330726,4.553133,4.833189,18.172934,-10.023796,-9.464222,14.485874,-6.969497,16.491737,-7.841741,13.925351,22.341173,-11.064845,19.749674,-8.044341,-10.156638,13.681525,9.902145,13.295949,8.097250,12.290469,-9.309617,-6.822398,18.089333,14.334571,21.217852,-8.506188,-10.930363,10.943995,11.109279,-17.145422,-13.341661,10.672636,-4.747087
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-18.022541,9.226926,-8.685872,10.563972,6.266322,-13.745155,-1.346036,5.251029,10.593820,9.954638,-5.256190,14.102864,13.777992,-6.672281,10.794866,-9.430584,6.901839,15.570704,9.360172,-13.093697,-10.610111,13.203014,11.131260,-8.162650,10.315321,-15.557427,9.085590,16.325455,-10.918001,11.522655,-15.329686,12.713851,-12.958396,-8.278859,-11.089897,-13.806515,-12.628786,-4.353388,5.570451,-6.950884,...,12.469085,13.550295,11.566945,16.862892,-17.016634,10.757463,-10.848796,25.147917,11.985706,5.138355,-5.744225,-1.033285,7.704782,-7.123579,20.856144,-11.000607,7.140047,1.333593,-4.933283,11.205585,-11.290542,-14.877491,16.698662,11.098980,20.250498,6.949159,10.536427,-15.032090,-10.773134,14.402801,14.288861,9.936523,0.936095,-8.345315,12.497780,15.603442,-8.559468,-9.343353,8.905419,-17.291924
216,-18.022541,9.226926,-8.685872,10.563972,6.266322,-13.745155,-1.346036,5.251029,10.593820,9.954638,-5.256190,14.102864,13.777992,-6.672281,10.794866,-9.430584,6.901839,15.570704,9.360172,-13.093697,-10.610111,13.203014,11.131260,-8.162650,10.315321,-15.557427,9.085590,16.325455,-10.918001,11.522655,-15.329686,12.713851,-12.958396,-8.278859,-11.089897,-13.806515,-12.628786,-4.353388,5.570451,-6.950884,...,12.469085,13.550295,11.566945,16.862892,-17.016634,10.757463,-10.848796,25.147917,11.985706,5.138355,-5.744225,-1.033285,7.704782,-7.123579,20.856144,-11.000607,7.140047,1.333593,-4.933283,11.205585,-11.290542,-14.877491,16.698662,11.098980,20.250498,6.949159,10.536427,-15.032090,-10.773134,14.402801,14.288861,9.936523,0.936095,-8.345315,12.497780,15.603442,-8.559468,-9.343353,8.905419,-17.291924
217,-18.022541,9.226926,-8.685872,10.563972,6.266322,-13.745155,-1.346036,5.251029,10.593820,9.954638,-5.256190,14.102864,13.777992,-6.672281,10.794866,-9.430584,6.901839,15.570704,9.360172,-13.093697,-10.610111,13.203014,11.131260,-8.162650,10.315321,-15.557427,9.085590,16.325455,-10.918001,11.522655,-15.329686,12.713851,-12.958396,-8.278859,-11.089897,-13.806515,-12.628786,-4.353388,5.570451,-6.950884,...,12.469085,13.550295,11.566945,16.862892,-17.016634,10.757463,-10.848796,25.147917,11.985706,5.138355,-5.744225,-1.033285,7.704782,-7.123579,20.856144,-11.000607,7.140047,1.333593,-4.933283,11.205585,-11.290542,-14.877491,16.698662,11.098980,20.250498,6.949159,10.536427,-15.032090,-10.773134,14.402801,14.288861,9.936523,0.936095,-8.345315,12.497780,15.603442,-8.559468,-9.343353,8.905419,-17.291924
218,-18.022541,9.226926,-8.685872,10.563972,6.266322,-13.745155,-1.346036,5.251029,10.593820,9.954638,-5.256190,14.102864,13.777992,-6.672281,10.794866,-9.430584,6.901839,15.570704,9.360172,-13.093697,-10.610111,13.203014,11.131260,-8.162650,10.315321,-15.557427,9.085590,16.325455,-10.918001,11.522655,-15.329686,12.713851,-12.958396,-8.278859,-11.089897,-13.806515,-12.628786,-4.353388,5.570451,-6.950884,...,12.469085,13.550295,11.566945,16.862892,-17.016634,10.757463,-10.848796,25.147917,11.985706,5.138355,-5.744225,-1.033285,7.704782,-7.123579,20.856144,-11.000607,7.140047,1.333593,-4.933283,11.205585,-11.290542,-14.877491,16.698662,11.098980,20.250498,6.949159,10.536427,-15.032090,-10.773134,14.402801,14.288861,9.936523,0.936095,-8.345315,12.497780,15.603442,-8.559468,-9.343353,8.905419,-17.291924


In [None]:
mol2

array([[-12.42792  ,   6.432371 ,  -9.693922 , ..., -18.254427 ,
         11.178756 , -11.443842 ],
       [ -6.91101  ,  12.991957 ,  -7.1016583, ..., -11.750151 ,
          8.857696 , -13.414981 ],
       [ -6.91101  ,  12.991957 ,  -7.1016583, ..., -11.750151 ,
          8.857696 , -13.414981 ],
       ...,
       [-17.960728 ,   8.825605 ,  -8.451591 , ...,  -9.556772 ,
          8.408287 , -18.088741 ],
       [-17.960728 ,   8.825605 ,  -8.451591 , ...,  -9.556772 ,
          8.408287 , -18.088741 ],
       [-17.960728 ,   8.825605 ,  -8.451591 , ...,  -9.556772 ,
          8.408287 , -18.088741 ]], dtype=float32)

In [None]:
def linear_interpolation(mol_from, mol_to, steps):
    n = steps + 1
    diff = mol_to - mol_from
    inter = mol_from + (1 / steps) * diff
    inter = np.reshape(inter, (220,1,256))
    for i in range(2,n):
      add = mol_from + (i / steps) * diff
      add = np.reshape(add, (220,1,256))
      inter = np.hstack((add, inter))
    return inter

In [None]:
molecule_morph = linear_interpolation(mol1, mol2, 15)


In [None]:
molecule_morph.shape

(220, 15, 256)

In [None]:
pd.DataFrame(molecule_morph[:,0,:])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,-12.427920,6.432371,-9.693922,12.578071,10.420417,-17.103424,-10.762917,15.542607,16.088902,13.000470,-10.402634,17.765070,7.182338,-9.473957,12.575037,-10.837489,7.515432,5.679412,7.010657,-9.304878,-6.119547,16.853880,18.325703,-10.664460,8.062894,-10.471768,9.198726,11.070743,-10.197888,13.423920,-12.388330,14.586823,-14.702550,-11.658175,-10.057137,-14.722287,-18.257900,-7.406022,10.667712,-12.702645,...,11.978279,7.403729,13.160180,12.168163,-16.148537,2.155151,-20.920147,7.445014,8.961248,6.627284,-6.768361,-13.680278,4.551894,-8.815456,8.572799,-7.477885,13.741497,19.850870,-9.167773,10.732714,-14.797009,-12.827584,18.623983,8.725233,14.787131,8.476263,11.078612,-13.108059,-9.881663,12.143747,16.843666,16.628878,-13.203073,-10.731888,11.985181,9.490840,-7.670404,-18.254427,11.178756,-11.443842
1,-6.911010,12.991957,-7.101658,13.296404,10.591055,-18.216427,-8.699280,13.920757,10.483831,12.048428,-11.380270,15.271732,8.185200,-5.515406,11.261868,-11.359967,9.849767,17.511894,13.280849,-4.329025,-9.132667,14.749804,14.935214,-8.907741,17.118822,-16.321247,11.099267,17.165821,-12.365690,6.182334,-13.162038,13.392587,-10.149905,-13.615577,-10.721411,-16.930084,-14.601992,-9.987946,10.197196,-8.750405,...,14.146240,12.593237,11.519047,15.484993,-10.360149,6.273999,-15.862475,21.459740,8.719436,9.956595,-6.675954,-11.548562,8.886376,-14.609131,13.933117,-10.304875,11.014430,12.185185,-16.001669,7.804715,-18.505932,-12.066866,15.454466,12.598481,16.850451,8.026176,10.165191,-14.245400,-10.389374,14.214512,10.806207,6.589170,-2.158624,-12.064212,13.793745,10.844497,-10.182081,-11.750151,8.857696,-13.414981
2,-6.911010,12.991957,-7.101658,13.296404,10.591055,-18.216427,-8.699280,13.920757,10.483831,12.048428,-11.380270,15.271732,8.185200,-5.515406,11.261868,-11.359967,9.849767,17.511894,13.280849,-4.329025,-9.132667,14.749804,14.935214,-8.907741,17.118822,-16.321247,11.099267,17.165821,-12.365690,6.182334,-13.162038,13.392587,-10.149905,-13.615576,-10.721411,-16.930084,-14.601992,-9.987946,10.197196,-8.750405,...,14.146240,12.593237,11.519047,15.484993,-10.360149,6.273999,-15.862475,21.459740,8.719436,9.956595,-6.675954,-11.548562,8.886376,-14.609131,13.933117,-10.304875,11.014430,12.185185,-16.001669,7.804715,-18.505932,-12.066866,15.454466,12.598481,16.850451,8.026176,10.165191,-14.245400,-10.389374,14.214512,10.806207,6.589170,-2.158625,-12.064212,13.793745,10.844497,-10.182081,-11.750151,8.857696,-13.414981
3,-4.566346,3.984156,-3.076006,10.503679,11.754541,-22.753874,-10.068980,-1.938288,4.456532,9.681273,0.931535,12.668999,8.286695,-3.741024,19.480713,-12.848519,11.897975,12.649598,-0.958904,-5.657074,-5.011619,10.185978,14.490394,-9.513254,16.739712,-11.655971,6.830252,7.905119,-17.254978,3.704541,-12.014209,15.307922,-7.880445,-4.506474,-5.556416,-16.867044,-14.696795,-5.888762,3.724631,-5.973329,...,10.063710,9.687161,11.063781,13.162334,-12.638499,0.176764,-8.689936,11.891586,5.411841,9.416130,-6.778010,-15.830599,5.428782,-11.426514,16.438061,-5.244295,18.627645,18.437132,-13.724153,7.002278,-8.789222,-12.852459,19.198105,3.076798,14.331836,7.561663,7.804978,-4.738629,-10.226971,19.365215,9.610475,18.307184,-20.426558,-13.395200,16.684660,13.279683,-11.307617,-12.941139,3.774168,-10.210317
4,-13.289097,9.757739,-9.777774,11.837010,12.354814,-16.664152,-13.896273,10.598447,8.821618,7.547210,-6.848146,18.393154,10.814766,-11.680892,18.615292,-9.909272,8.400716,10.584118,4.701674,-10.550056,-10.402273,9.060270,16.912054,-10.182013,13.073063,-16.532406,14.251570,10.076437,-12.229506,11.174375,-12.448223,12.911856,-21.927326,-13.167303,-9.828756,-16.788433,-12.652192,-10.309620,10.011616,-11.512445,...,8.294155,12.367299,10.075584,16.636560,-14.741016,10.268962,-11.453566,12.078461,9.679300,11.092757,-7.508410,-13.510096,10.498106,-13.934389,9.727357,-13.340666,16.315681,21.093719,-8.756892,7.937389,-11.812800,-11.859784,18.243290,8.339580,13.699884,12.179472,10.942401,-14.223453,-11.998565,21.545065,11.878815,10.383742,-9.410940,-9.072204,9.619243,9.969243,-12.770279,-15.420297,5.681416,-14.790644
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,-17.960728,8.825605,-8.451591,10.011883,7.580081,-13.836320,-1.768318,5.644078,10.726879,9.712966,-2.787933,14.166553,14.691217,-5.927428,9.523998,-8.367174,6.386056,16.390455,8.981689,-11.366660,-11.389125,11.818145,12.527308,-8.134904,8.993979,-15.484956,7.520059,17.776957,-10.794101,11.728086,-15.670280,12.288410,-13.856984,-8.197856,-10.940022,-14.216541,-12.694590,-2.724284,6.278480,-7.919415,...,11.229818,13.356666,10.620025,17.040356,-16.974243,11.471745,-10.028069,27.158773,12.810237,5.099986,-3.838223,-1.287942,7.179879,-7.596642,21.878201,-10.727372,7.357180,0.752633,-5.128458,11.319994,-12.747209,-14.177285,15.967305,11.464613,20.219870,8.364977,10.787473,-15.103135,-10.738522,14.913680,14.296676,9.505918,2.035478,-8.029973,12.981397,16.991152,-7.894284,-9.556772,8.408287,-18.088741
216,-17.960728,8.825605,-8.451591,10.011883,7.580081,-13.836320,-1.768318,5.644078,10.726879,9.712966,-2.787933,14.166553,14.691217,-5.927428,9.523998,-8.367174,6.386056,16.390455,8.981689,-11.366660,-11.389125,11.818145,12.527308,-8.134904,8.993979,-15.484956,7.520059,17.776957,-10.794101,11.728086,-15.670280,12.288410,-13.856984,-8.197856,-10.940022,-14.216541,-12.694590,-2.724284,6.278480,-7.919415,...,11.229818,13.356666,10.620025,17.040356,-16.974243,11.471745,-10.028069,27.158773,12.810237,5.099986,-3.838223,-1.287942,7.179879,-7.596642,21.878201,-10.727372,7.357180,0.752633,-5.128458,11.319994,-12.747209,-14.177285,15.967305,11.464613,20.219870,8.364977,10.787473,-15.103135,-10.738522,14.913680,14.296676,9.505918,2.035478,-8.029973,12.981397,16.991152,-7.894284,-9.556772,8.408287,-18.088741
217,-17.960728,8.825605,-8.451591,10.011883,7.580081,-13.836320,-1.768318,5.644078,10.726879,9.712966,-2.787933,14.166553,14.691217,-5.927428,9.523998,-8.367174,6.386056,16.390455,8.981689,-11.366660,-11.389125,11.818145,12.527308,-8.134904,8.993979,-15.484956,7.520059,17.776957,-10.794101,11.728086,-15.670280,12.288410,-13.856984,-8.197856,-10.940022,-14.216541,-12.694590,-2.724284,6.278480,-7.919415,...,11.229818,13.356666,10.620025,17.040356,-16.974243,11.471745,-10.028069,27.158773,12.810237,5.099986,-3.838223,-1.287942,7.179879,-7.596642,21.878201,-10.727372,7.357180,0.752633,-5.128458,11.319994,-12.747209,-14.177285,15.967305,11.464613,20.219870,8.364977,10.787473,-15.103135,-10.738522,14.913680,14.296676,9.505918,2.035478,-8.029973,12.981397,16.991152,-7.894284,-9.556772,8.408287,-18.088741
218,-17.960728,8.825605,-8.451591,10.011883,7.580081,-13.836320,-1.768318,5.644078,10.726879,9.712966,-2.787933,14.166553,14.691217,-5.927428,9.523998,-8.367174,6.386056,16.390455,8.981689,-11.366660,-11.389125,11.818145,12.527308,-8.134904,8.993979,-15.484956,7.520059,17.776957,-10.794101,11.728086,-15.670280,12.288410,-13.856984,-8.197856,-10.940022,-14.216541,-12.694590,-2.724284,6.278480,-7.919415,...,11.229818,13.356666,10.620025,17.040356,-16.974243,11.471745,-10.028069,27.158773,12.810237,5.099986,-3.838223,-1.287942,7.179879,-7.596642,21.878201,-10.727372,7.357180,0.752633,-5.128458,11.319994,-12.747209,-14.177285,15.967305,11.464613,20.219870,8.364977,10.787473,-15.103135,-10.738522,14.913680,14.296676,9.505918,2.035478,-8.029973,12.981397,16.991152,-7.894284,-9.556772,8.408287,-18.088741


In [None]:
decoded_interpolation = trfm.decode(torch.from_numpy(molecule_morph).float())


In [None]:
decoded_interpolation.shape

(220, 15, 45)

In [None]:
interpolation_smiles = get_smiles(decoded_interpolation)
interpolations = pd.DataFrame(interpolation_smiles)
interpolations.columns = ['canonical_smiles']
interpolations

[[None None None ... None None None]
 [None None None ... None None None]
 [None None None ... None None None]
 ...
 [None None None ... None None None]
 [None None None ... None None None]
 [None None None ... None None None]]


Unnamed: 0,canonical_smiles
0,CCOc1cc2c(cn1][nH]c3ccccc23
1,CCOc1cc2c(cn1][nH]c3ccccc23
2,CCOc1cc2c(cn1][nH]c3ccccc23
3,CCOc1cc2c(cn1][nH]c3ccccc23
4,CCOc1cc2c(cn1][nH]c3ccccc23
5,COOc1cc2c(On1][O]]c3Occcc23
6,COOcscc2c(O]s][O]]c3Occcc]O
7,COOcscc2c(O]s][O]]c(Occcc]O
8,COOsscc2c(O]sc(O]]c(Occ]c]O
9,COO1ccc2c(O]cc(O]]c(O]c]c]O


# plot molecules

In [None]:
from IPython.display import SVG


In [None]:

mols = [Chem.MolFromSmiles(sm) for sm in interpolations['canonical_smiles'].values]
dr = plot_mols(mols, 250, 175, 250, 1.1)
with open('bbbp_mol.svg', 'w') as f:
    f.write(dr.GetDrawingText())
SVG(dr.GetDrawingText())
#Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(300,150))

ArgumentError: ignored

In [None]:
dr

<rdkit.Chem.Draw.rdMolDraw2D.MolDraw2DSVG at 0x7f3138f4df80>