# Training program for Experiment

## setting path

In [1]:
import sys
import os

sys.path.append(os.path.split(os.getcwd())[0])


# Vocab,datasetのロード

In [2]:
from fast_jtnn import *
from MS_PredictModel import MS_Dataset,MS_Dataset_pickle,dataset_load
import pickle
import torch

VOCAB_FILE = "./MS_vocab.txt"

vocab = [x.strip("\r\n ") for x in open(VOCAB_FILE,"r")]
vocab = Vocab(vocab)

'''
MS_Dataset.QUERY = """select smiles,file_path from massbank where ms_type="MS" and instrument_type="EI-B" and smiles<>'N/A';"""
dataset = MS_Dataset(vocab=vocab,host="localhost",database="chemoinfo",batch_size=20)
'''
train_vali_rate = 0.9

train_dataset, vali_dataset = dataset_load("./massbank.pkl",vocab,20,train_vali_rate)
print("number of train dataset :",len(train_dataset))
print("number of validation dataset :",len(vali_dataset))

('number of train dataset :', 6716)
('number of validation dataset :', 747)


## モデルの作成

In [3]:
from ms_encoder import ms_peak_encoder,ms_peak_encoder_lstm
import torch.nn as nn
import torch
hidden_size = 100
latent_size = 56
depthT = 20
depthG = 3

dec_model = JTNNVAE(vocab, hidden_size, latent_size, depthT, depthG).to('cuda')
print dec_model
enc_model = ms_peak_encoder_lstm(train_dataset.max_spectrum_size,output_size=latent_size,\
                                 hidden_size=50,embedding_size=10,num_rnn_layers=2,bidirectional=True,dropout_rate=0.5).to('cuda')
print enc_model

for param in dec_model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)
load_model = "./vae_model/model.iter-90000"
dec_model.load_state_dict(torch.load(load_model,map_location='cuda'))

print "Model #Params: %dK" % (sum([x.nelement() for x in dec_model.parameters()]) / 1000,)
print "Model #Params: %dK" % (sum([x.nelement() for x in enc_model.parameters()]) / 1000,)



JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(904, 100)
    (outputNN): Sequential(
      (0): Linear(in_features=200, out_features=100, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=200, out_features=100, bias=True)
      (W_r): Linear(in_features=100, out_features=100, bias=False)
      (U_r): Linear(in_features=100, out_features=100, bias=True)
      (W_h): Linear(in_features=200, out_features=100, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(904, 100)
    (W_z): Linear(in_features=200, out_features=100, bias=True)
    (U_r): Linear(in_features=100, out_features=100, bias=False)
    (W_r): Linear(in_features=100, out_features=100, bias=True)
    (W_h): Linear(in_features=200, out_features=100, bias=True)
    (W): Linear(in_features=128, out_features=100, bias=True)
    (U): Linear(in_features=128, out_features=100, bias=True)
    (U_i): Linear(in_features=200, out_features=100, bias=True)
    (W_o): 

RuntimeError: Error(s) in loading state_dict for JTNNVAE:
	size mismatch for jtnn.embedding.weight: copying a param with shape torch.Size([878, 100]) from checkpoint, the shape in current model is torch.Size([904, 100]).
	size mismatch for decoder.embedding.weight: copying a param with shape torch.Size([878, 100]) from checkpoint, the shape in current model is torch.Size([904, 100]).
	size mismatch for decoder.W_o.bias: copying a param with shape torch.Size([878]) from checkpoint, the shape in current model is torch.Size([904]).
	size mismatch for decoder.W_o.weight: copying a param with shape torch.Size([878, 100]) from checkpoint, the shape in current model is torch.Size([904, 100]).

## setting optimizer

In [None]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

optimizer = optim.Adam(enc_model.parameters(), lr=1e-3)
#optimizer = optim.SGD(enc_model.parameters(),lr=100)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

In [None]:
from MS_PredictModel import ms_peak_encoder,MS_Dataset
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np

pbar = None
train_dataset.batch_size = 20
vali_dataset.batch_size = 10

anneal_iter = 7400

beta = 0
step_beta = 0.002
kl_anneal_iter = 14800
max_beta = 1.0
warmup = 7400

def training(max_epoch = 100):
    global pbar
    global beta
    total_step = 0
    meters = np.zeros(4)
    vali_meters = np.zeros(3)
    with open("log2.csv","w") as f:
        f.write("epoch,iter.,word,topo,assm,vali word,vali topo,vali assm\n")
    for epoch in range(max_epoch):
        print("epoch : ",epoch)
        for batch in train_dataset:
            x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder,x,y = batch
            total_step+=1
            #pbar.update(1)
            x = x.to('cuda')
            y = y.to('cuda')
            
            enc_model.zero_grad()
            dec_model.zero_grad()
            optimizer.zero_grad()
            
            h,kl_loss = enc_model(x,y,sample=True)
            tree_vec = h[:,:h.shape[1]/2]
            mol_vec  = h[:,h.shape[1]/2:]
            _, x_tree_mess = dec_model.jtnn(*x_jtenc_holder)
            word_loss, topo_loss, word_acc, topo_acc = dec_model.decoder(x_batch,tree_vec)
            assm_loss, assm_acc = dec_model.assm(x_batch, x_jtmpn_holder, mol_vec , x_tree_mess)
            total_loss = word_loss+topo_loss+assm_loss+beta*kl_loss
            total_loss.backward()
            optimizer.step()
            del x,y,h
            
            meters = meters + np.array([kl_loss.item(),word_acc * 100, topo_acc * 100, assm_acc * 100])
            if total_step % 200 == 0:
                vali_total = 0
                for batch in vali_dataset:
                    x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder,x,y = batch
                    x = x.to('cuda')
                    y = y.to('cuda')
                    with torch.no_grad():
                        h,_ = enc_model(x,y,training=False)
                        tree_vec = h[:,:h.shape[1]/2]
                        mol_vec  = h[:,h.shape[1]/2:]
                        _, x_tree_mess = dec_model.jtnn(*x_jtenc_holder)
                        word_loss, topo_loss, word_acc, topo_acc = dec_model.decoder(x_batch,tree_vec)
                        assm_loss, assm_acc = dec_model.assm(x_batch, x_jtmpn_holder, mol_vec , x_tree_mess)
                        vali_meters = vali_meters + np.array([word_acc * 100, topo_acc * 100, assm_acc * 100])
                        vali_total += 1    
                    del x,y,h
                    
                meters /= 200
                vali_meters /= vali_total
                print "[%d] , kl_loss %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f vali_Word: %.2f, vali_Topo: %.2f, vali_assm: %.2f, learning rate: %.4f" % \
                    (total_step, meters[0], meters[1], meters[2],meters[3], vali_meters[0],vali_meters[1],vali_meters[2],scheduler.get_lr()[0])
                with open("log2.csv","a") as f:
                    f.write("%d,%d,%.2f,%.2f,%.2f,%.2f,%.2f,%.2f\n" % (epoch,total_step,meters[0], meters[1], meters[2],vali_meters[0],vali_meters[1],vali_meters[2]))
                sys.stdout.flush()
                meters *= 0
                vali_meters *= 0
            if total_step % 200 == 0:
                torch.save(enc_model.state_dict(), "./enc_model" + "/model.iter-" + str(total_step))
            #if total_step % anneal_iter == 0:
                #scheduler.step()
                
            if total_step % kl_anneal_iter == 0 and total_step >= warmup:
                beta = min(max_beta, beta + step_beta)

#import pdb; pdb.set_trace()
try:
    #if pbar is None:
        #pbar = tqdm()
    training(100)
except RuntimeError as e:
    #if pbar is not None:
        #del pbar
    import traceback
    print(traceback.format_exc())
    #import pdb; pdb.set_trace()
    print(e)




In [6]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs

model_path = None
if model_path is not None:
    enc_model.load_state_dict(torch.load(model_path,map_location='cuda'))

train_dataset.batch_size = 20
vali_dataset.batch_size = 10

def evaluation():
    ret = []
    with torch.no_grad():
        for batch in vali_dataset:
            x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder,x,y = batch
            x = x.to('cuda')
            y = y.to('cuda')
            
            h,_ = enc_model(x,y,training=False)
            tree_vec = h[:,:h.shape[1]/2]
            mol_vec  = h[:,h.shape[1]/2:]
            for num in range(h.size()[0]):
                
                true_smiles=x_batch[num].smiles
                predict_smiles = dec_model.decode(tree_vec[num].view(1,latent_size/2),mol_vec[num].view(1,latent_size/2),False)
                
                #smilesの正規化
                true_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(true_smiles),True)
                predict_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(predict_smiles),True)
                
                ret.append((true_smiles,predict_smiles))
    return ret
result = evaluation()
print(len(result))

740


In [7]:
import os
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import ImageDraw,ImageFont
from rdkit.Chem import AllChem
from rdkit import DataStructs

def _re_smiles(smiles1,smiles2):
    #print(smiles1,smiles2)
    smiles1 = Chem.MolToSmiles(Chem.MolFromSmiles(smiles1),True)
    smiles2 = Chem.MolToSmiles(Chem.MolFromSmiles(smiles2),True)
    return smiles1 == smiles2

def is_structural_isomer(smiles1,smiles2):
    def Molecular_formula(smiles):
        atoms = {}
        mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
        for atom in mol.GetAtoms():
            if not atom.GetSymbol() in atoms:
                atoms[atom.GetSymbol()] = 1
            else:
                atoms[atom.GetSymbol()] += 1
        return atoms
    atoms1 = Molecular_formula(smiles1)
    atoms2 = Molecular_formula(smiles2)
    return atoms1 == atoms2
    
def analyze_result(smiles_list,log_path="evaluation.csv",path="image_list"):
    if not os.path.exists(path):
        os.mkdir(path)
        
    with open(log_path,"w") as f:
        print "Number of data: %d"% (len(smiles_list))
        f.write("Number of data,%d\n" % (len(smiles_list)))
        
        true_list = [[i,one[0]] for i,one in enumerate(smiles_list) if _re_smiles(one[0],one[1])]
        print "Number of matching: %d" % (len(true_list))
        f.write("Number of matching: %d\n" % (len(true_list)))
        print(true_list)
    
        true_list = [[i,one[0]] for i,one in enumerate(smiles_list) if is_structural_isomer(one[0],one[1]) and [i,one[0]] not in true_list]
        print "Number of matching: %d" % (len(true_list))
        f.write("Number of matching: %d\n" % (len(true_list)))
        print(true_list)
        
        f.write("true,predict,ECFP-Tanimoto score,MACCS-Tanimoto score\n")
        
    for i,smiles in enumerate(smiles_list):
        true_mol = Chem.MolFromSmiles(smiles[0])
        predict_mol = Chem.MolFromSmiles(smiles[1])
        
        true_fingerprint = AllChem.GetMorganFingerprint(true_mol,2)
        predict_fingerprint = AllChem.GetMorganFingerprint(predict_mol,2)
        ECFP_score = DataStructs.TanimotoSimilarity(true_fingerprint,predict_fingerprint)
        
        true_fingerprint = AllChem.GetMACCSKeysFingerprint(true_mol)
        predict_fingerprint = AllChem.GetMACCSKeysFingerprint(predict_mol)
        MACCS_score = DataStructs.TanimotoSimilarity(true_fingerprint,predict_fingerprint)
        
        with open(log_path,"a") as f:
            f.write(smiles[0]+","+smiles[1]+","+str(ECFP_score)+","+str(MACCS_score)+"\n")
        
        image = Draw.MolsToImage([true_mol,predict_mol])
        draw = ImageDraw.Draw(image)
        font = ImageFont.load_default()
        font.size=40
        draw.text((0,0),str(ECFP_score)+","+str(MACCS_score),(0, 0, 0),font=font)
        image.save(os.path.join(path,"%d.png" % i))
    
analyze_result(result)

Number of data: 740
Number of matching: 11
[[45, 'CCC(C)OC(C)=O'], [234, 'CC(=O)Oc1ccc(C)cc1'], [252, 'CCCCCCCCCCCCC'], [262, 'CCC(=O)CC'], [351, 'CC(=O)c1cccc2ccccc12'], [368, 'CCCC(=O)OCCC(C)C'], [471, 'Cc1ccccc1O'], [480, 'OCc1ccccc1'], [542, 'OC1CC2CCC1C2'], [582, 'CCCCCCCCCCCCCCC(=O)O[Si](C)(C)C'], [628, 'CCOC(C)=O']]
Number of matching: 31
[[20, 'Cc1cccc(O)c1C'], [54, 'Cc1ccc(O)c(C)c1'], [137, 'CCCCCCCOC(=O)CC'], [144, 'CCC(C)(C)c1ccccc1'], [163, 'CC(Cl)Cl'], [196, 'CCCCCC(=O)OCC(C)C'], [200, 'CCCCOC(=O)CCC'], [228, 'CC(=O)OC(C)(C)C'], [277, 'CC12CCC(C1)C(C)(C)C2=O'], [300, 'CC(C)=CCCC1(C)OC1CO'], [326, 'COc1cccc(O)c1'], [340, 'CCc1cccc(O)c1'], [355, 'C=Cc1ccc(O)c(OC)c1'], [370, 'CCCC(C)C(=O)O'], [396, 'CC(C)CCCC(C)C1CCC2(C)C3=C(CCC12C)C1(C)CCC(O)C(C)(C)C1CC3'], [418, 'CCCC=CCOC(CCCCC)OCc1ccccc1'], [429, 'Cc1cc(C)cc(O)c1'], [435, 'Cc1ccc(N)cc1'], [447, 'CCC(C)C(=O)OCCc1ccccc1'], [455, 'C=C(C)COC(OCC(C)(C)OCC(=C)C)C(C)C'], [478, 'CC(=O)c1ccc(C)cc1'], [485, 'COC(=O)c1ccccc1'], [495

In [11]:
!cp -f evaluation.csv ./image_list
!zip -r image.zip ./image_list

updating: image_list/0.png (deflated 10%)
updating: image_list/1.png (deflated 11%)
updating: image_list/10.png (deflated 7%)
updating: image_list/100.png (deflated 12%)
updating: image_list/101.png (deflated 10%)
updating: image_list/102.png (deflated 8%)
updating: image_list/103.png (deflated 11%)
updating: image_list/104.png (deflated 10%)
updating: image_list/105.png (deflated 6%)
updating: image_list/106.png (deflated 6%)
updating: image_list/107.png (deflated 11%)
updating: image_list/108.png (deflated 7%)
updating: image_list/109.png (deflated 12%)
updating: image_list/11.png (deflated 8%)
updating: image_list/110.png (deflated 12%)
updating: image_list/111.png (deflated 7%)
updating: image_list/112.png (deflated 7%)
updating: image_list/113.png (deflated 9%)
updating: image_list/114.png (deflated 11%)
updating: image_list/115.png (deflated 7%)
updating: image_list/116.png (deflated 5%)
updating: image_list/117.png (deflated 23%)
updating: image_list/118.png (deflated 10%)
updat