<a href="https://colab.research.google.com/github/ajia90/smilestransformer/blob/main/interpolation_notgt_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports/ setup

mask + target = source -> molecules w/ some errors |||
mask + only start tokens -> only start toekn 


---



In [None]:
# Install RDKit. Takes 2-3 minutes
!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge python=3.7 
!time conda install -q -y -c conda-forge rdkit 

In [2]:
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

# training the model


In [None]:
!conda install pytorch torchvision -c pytorch


In [None]:
#!conda install pytorch torchvision -c pytorch

In [None]:
from google.colab import drive
drive.mount("/content/gdrive/")

Mounted at /content/gdrive/


In [None]:
 !python pretrain_trfm_target_zero.py

# methods/model


In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import DrawingOptions
from torch.nn import functional as F


In [4]:
def plot_mols(mols, unit=200, w=120, h=200, fontsize=1.0):
    drawer = Draw.MolDraw2DSVG(4*unit, 3*unit, w, h)

    # optをとり出しておく
    opt = drawer.drawOptions()
    opt.padding = 0.1
    opt.legendFontSize = 20
    #opt.atomfontSize = 20

    xs = np.array([0,1,2,3,0,1,2,3,0,1,2,3])*unit
    ys = np.array([0,0,0,0,1,1,1,1,2,2,2,2])*unit
    for i, (mol, x, y) in enumerate(zip(mols,xs,ys)):
        # SetOffsetで左上の座標を指定できる
        drawer.SetOffset(int(x), int(y))
        drawer.SetFontSize(fontsize)

        AllChem.Compute2DCoords(mol)
        Chem.Kekulize(mol)
        # 分子をSVGに書く
        drawer.DrawMolecule(mol, legend=str(i))


    # </svg> 書く
    drawer.FinishDrawing()
    return drawer

In [9]:
def get_inputs(sm):
    seq_len = 220
    sm = sm.split()
    if len(sm)>218:
        print('SMILES is too long ({:d})'.format(len(sm)))
        sm = sm[:109]+sm[-109:]
    ids = [vocab.stoi.get(token, unk_index) for token in sm]
    ids = [sos_index] + ids + [eos_index]
    seg = [1]*len(ids)
    padding = [pad_index]*(seq_len - len(ids))
    ids.extend(padding), seg.extend(padding)
    return ids, seg

def get_array(smiles):
    x_id, x_seg = [], []
    for sm in smiles:
        a,b = get_inputs(sm)
        x_id.append(a)
        x_seg.append(b)
    return torch.tensor(x_id), torch.tensor(x_seg)

In [5]:
import torch
from pretrain_trfm_target_zero import TrfmSeq2seq
from build_vocab import WordVocab
from utils import split

pad_index = 0
unk_index = 1
eos_index = 2
sos_index = 3
mask_index = 4

vocab = WordVocab.load_vocab('vocab.pkl')

# trfm_c= TrfmSeq2seq(len(vocab), 256, len(vocab), 4).cuda()
# trfm_c.load_state_dict(torch.load('trfm_new_12_80000.pkl'))
# trfm_c.eval()

trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
trfm.load_state_dict(torch.load('trfm_notgt_12_90000.pkl'))
trfm.eval()
print('Total parameters:', sum(p.numel() for p in trfm.parameters()))

Total parameters: 4244013


In [6]:
smiles_dict = vocab.stoi

In [None]:
smiles_dict


{'#': 31,
 '(': 7,
 ')': 8,
 '+': 29,
 '-': 30,
 '.': 33,
 '/': 32,
 '1': 11,
 '2': 13,
 '3': 14,
 '4': 16,
 '5': 22,
 '6': 27,
 '7': 35,
 '8': 38,
 '9': 43,
 '<eos>': 2,
 '<mask>': 4,
 '<pad>': 0,
 '<sos>': 3,
 '<unk>': 1,
 '=': 10,
 '@': 17,
 'B': 40,
 'Br': 34,
 'C': 6,
 'Cl': 25,
 'F': 21,
 'H': 20,
 'I': 37,
 'K': 44,
 'N': 12,
 'Na': 39,
 'O': 9,
 'P': 36,
 'S': 23,
 'Se': 42,
 'Si': 41,
 '[': 18,
 '\\': 24,
 ']': 19,
 'c': 5,
 'n': 15,
 'o': 28,
 's': 26}

# data


In [None]:
#read in BBBp data
# df = pd.read_csv('BBBP.csv')
# print(df.shape)
# df.head()

In [7]:
#sample of chembl25 data
df2 = pd.read_csv('smiles_sample2.csv')
print(df2.shape)

(1052, 1)


In [None]:
df2.head()

Unnamed: 0,canonical_smiles
0,Cc1cc(cn1C)c2csc(N=C(N)N)n2
1,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN4CCOCC4
2,COc1c(O)cc(O)c(C(=N)Cc2ccc(O)cc2)c1O
3,CCOC(=O)c1cc2cc(ccc2[nH]1)C(=O)O
4,C[C@H](NC(=O)OCc1ccccc1)C(=O)N[C@@H](C)C(=O)NN...


In [10]:
x_split = [split(sm) for sm in df2['canonical_smiles'].values]
xid, xseg = get_array(x_split)

In [11]:
xid.shape

torch.Size([1052, 220])

# encode


In [12]:
from dataset import Seq2seqDataset
from torch.utils.data import DataLoader
from tqdm import tqdm


In [13]:
#encode function
X = trfm.encode(torch.t(xid))
print(X.shape)

There are 1052 molecules. It will take a little time.
(220, 1052, 256)


In [None]:
output_total.size


<function Tensor.size>

In [None]:
torch.cuda.empty_cache()

In [None]:
# #trfm.encoder
# dataset = Seq2seqDataset(df2['canonical_smiles'].values, vocab)
# data_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
# output_total = torch.empty(size=(xid.shape[1], xid.shape[0], 256))
# print(output_total.size)
# for b, sm in tqdm(enumerate(data_loader)):
#   # sm = torch.t(sm.cuda()) # (T,B)
#   # output1 = trfm_c(sm) # (T,,V)
#   embedded = trfm.embed(torch.t(sm.cuda()))  # (T,B,H)
#   embedded = trfm.pe(embedded) # (T,B,H)
#   output_total[:,b:b+4,:] = trfm.encoder(embedded)
# #output = output.detach().numpy()

# mask


In [16]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


# decode

tgt = xid //source


In [17]:
mask1 = generate_square_subsequent_mask(torch.t(xid).shape[0])

In [18]:
mask1.shape

torch.Size([220, 220])

In [None]:
# #trfm.decoder
# decoded = trfm.decoder(output, output, mask1)
# out = trfm.out(decoded) # (T,B,V)
# out = F.log_softmax(out, dim=2)
# out = out.detach().numpy()

In [23]:
#decode function
hidden = torch.from_numpy(X).float()
decoded = trfm.decode(hidden, hidden)

There are 1052 molecules. It will take a little time.


In [None]:
decoded.shape

(220, 15, 45)

# get smiles from decoded output


In [24]:
_, next_word = torch.max(torch.from_numpy(decoded), dim = 2)
decoded_smiles = torch.t(next_word).detach().numpy()
decoded_smiles.shape

(1052, 220)

In [25]:
y =torch.t(xid).detach().numpy()


In [26]:
#value -> smiles
smiles_molecules = np.empty([decoded_smiles.shape[0],decoded_smiles.shape[1]], dtype=object)
for i in range(decoded_smiles.shape[0]):   
   smiless = [list(smiles_dict.keys())[list(smiles_dict.values()).index(elem)] for elem in decoded_smiles[i]]
   smiles_molecules[i] = smiless

In [27]:
pd.DataFrame(smiles_molecules)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219
0,<sos>,C,c,1,c,c,(,c,n,1,C,],c,2,c,s,c,(,N,=,C,(,N,],N,],n,2,<eos>,],],],],],],],],],],],...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
1,<sos>,Br,c,1,c,c,c,c,(,N,c,2,n,c,n,c,3,c,c,n,c,c,2,3,],c,1,N,C,C,N,4,C,C,O,C,C,4,<eos>,F,...,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F,F
2,<sos>,C,O,c,1,c,(,O,],c,c,(,O,],c,(,C,(,=,N,],C,c,2,c,c,c,(,O,],c,c,2,],c,1,O,<eos>,],],...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
3,<sos>,C,C,O,C,(,=,O,],c,1,c,c,2,c,c,(,c,c,c,2,[,n,H,],1,],C,(,=,O,],O,<eos>,],],],],],],...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
4,<sos>,C,[,C,@,H,],(,N,C,(,=,O,],O,C,c,1,c,c,c,c,c,1,],C,(,=,O,],N,[,C,@,@,H,],(,C,],...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1047,<sos>,C,[,C,@,H,],(,C,C,C,S,(,=,O,],(,=,O,],C,],C,1,=,C,C,[,C,@,H,],2,\,C,(,=,C,\,C,...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
1048,<sos>,C,C,1,=,C,(,C,#,N,],c,2,n,c,(,N,],c,(,C,#,N,],c,(,C,],c,2,/,C,/,1,=,C,/,c,3,o,...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
1049,<sos>,C,c,1,c,c,c,c,c,1,c,2,n,c,n,c,3,c,2,n,c,n,3,C,4,O,[,C,@,H,],(,C,O,],[,C,@,@,H,...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]
1050,<sos>,C,C,(,C,],c,1,c,c,c,(,c,c,1,],c,2,c,c,(,n,c,(,Cl,],c,2,C,#,N,],c,3,c,c,c,4,C,C,...,],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],],]


In [29]:
#put all characters into one continuous string
smiles_formatted = np.empty(decoded_smiles.shape[0], dtype=object)
for i in range(smiles_molecules.shape[0]):
  smile = smiles_molecules[i]
  end = np.where(smile == '<eos>')[0][0]
  smiles_formatted[i] = "".join(smile[1:end])

In [30]:
decoded_final = pd.DataFrame(smiles_formatted)
decoded_final.columns = ['smiles']
decoded_final

Unnamed: 0,smiles
0,Cc1cc(cn1C]c2csc(N=C(N]N]n2
1,Brc1cccc(Nc2ncnc3ccncc23]c1NCCN4CCOCC4
2,COc1c(O]cc(O]c(C(=N]Cc2ccc(O]cc2]c1O
3,CCOC(=O]c1cc2cc(ccc2[nH]1]C(=O]O
4,C[C@H](NC(=O]OCc1ccccc1]C(=O]N[C@@H](C]C(=O]NN...
...,...
1047,C[C@H](CCCS(=O](=O]C]C1=CC[C@H]2\C(=C\C=C/3\C[...
1048,CC1=C(C#N]c2nc(N]c(C#N]c(C]c2/C/1=C/c3oc(cc3]c...
1049,Cc1ccccc1c2ncnc3c2ncn3C4O[C@H](CO][C@@H](O][C@...
1050,CC(C]c1ccc(cc1]c2cc(nc(Cl]c2C#N]c3ccc4CCCCc4c3


In [31]:
df2

Unnamed: 0,canonical_smiles
0,Cc1cc(cn1C)c2csc(N=C(N)N)n2
1,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN4CCOCC4
2,COc1c(O)cc(O)c(C(=N)Cc2ccc(O)cc2)c1O
3,CCOC(=O)c1cc2cc(ccc2[nH]1)C(=O)O
4,C[C@H](NC(=O)OCc1ccccc1)C(=O)N[C@@H](C)C(=O)NN...
...,...
1047,C[C@H](CCCS(=O)(=O)C)C1=CC[C@H]2\C(=C\C=C/3\C[...
1048,CC1=C(C#N)c2nc(N)c(C#N)c(C)c2/C/1=C/c3oc(cc3)c...
1049,Cc1ccccc1c2ncnc3c2ncn3C4O[C@H](CO)[C@@H](O)[C@...
1050,CC(C)c1ccc(cc1)c2cc(nc(Cl)c2C#N)c3ccc4CCCCc4c3
