# 学習プログラム

## path通し

In [1]:
import sys
import os

sys.path.append(os.path.split(os.getcwd())[0])


## データプレプロセス

## vocablary作成
すでにvocab.txtが作成済みである場合不用

In [4]:
from fast_jtnn.mol_tree import MolTree
from getpass import getpass,getuser

import mysql
from mysql import connector
import warnings

# SQL server profile
host = "localhost"
user = None
passwd = None
port = 3306
database="chemoinfo"

# 
VOCAB_FILE = "./MS_vocab.txt"

# get massbank data from SQL server
try:
    if not isinstance(user,str):
        user = raw_input("user")
    if not isinstance(passwd,str):
        passwd = getpass()
    connect = connector.connect(host=host,user=user,password=passwd,port=port,database=database)
    cursor = connect.cursor()
    cursor.execute("""select smiles from massbank where ms_type="MS" and instrument_type="EI-B" and smiles<>'N/A'; """)
    smiles_list = cursor.fetchall()
except mysql.connector.Error as e:
    print("Something went wrong: {}".format(e))
    sys.exit(1)
finally:
    if passwd : del passwd
    if connect: connect.close()
    if cursor: cursor.close()

# create vocablary
succes = 0
fault = 0
cset = set()
for one in smiles_list:
    try:
        mol = MolTree(one[0])
    except AttributeError as e:
        warnings.warn("Entered An SMILES that does not meet the rules")
        continue
    for c in mol.nodes:
        cset.add(c.smiles)

# write vocab
with open(VOCAB_FILE,"w") as f:
    for one in cset:
        f.write(one+"\n")

user aisiars
 ··········


! mkdir vae_model/
%run ../fast_molvae/vae_train.py --train processed --vocab ./MS_vocab.txt --save_dir vae_model/

# Vocab,datasetのロード

In [None]:
from fast_jtnn import *
from MS_PredictModel import MS_Dataset

VOCAB_FILE = "./MS_vocab.txt"

vocab = [x.strip("\r\n ") for x in open(VOCAB_FILE,"r")]
vocab = Vocab(vocab)

MS_Dataset.QUERY = """select smiles,file_path from massbank where ms_type="MS" and instrument_type="EI-B" and smiles<>'N/A'; """
dataset = MS_Dataset(vocab=vocab,host="localhost",database="chemoinfo",batch_size=20)

user aisiars
 ··········


 76%|███████▌  | 8481/11200 [23:04<05:35,  8.10it/s]

## モデルの作成

In [3]:
from ms_encoder import ms_peak_encoder
import torch.nn as nn
import torch
hidden_size = 100
latent_size = 56
depthT = 20
depthG = 3

dec_model = JTNNVAE(vocab, hidden_size, latent_size, depthT, depthG).cuda()
print dec_model
enc_model = ms_peak_encoder(dataset.max_spectrum_size,latent_size).cuda()
print enc_model

for param in dec_model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)
load_model = "./vae_model/model.iter-70000"
dec_model.load_state_dict(torch.load(load_model))
print "Model #Params: %dK" % (sum([x.nelement() for x in dec_model.parameters()]) / 1000,)



JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(365, 100)
    (outputNN): Sequential(
      (0): Linear(in_features=200, out_features=100, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=200, out_features=100, bias=True)
      (W_r): Linear(in_features=100, out_features=100, bias=False)
      (U_r): Linear(in_features=100, out_features=100, bias=True)
      (W_h): Linear(in_features=200, out_features=100, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(365, 100)
    (W_z): Linear(in_features=200, out_features=100, bias=True)
    (U_r): Linear(in_features=100, out_features=100, bias=False)
    (W_r): Linear(in_features=100, out_features=100, bias=True)
    (W_h): Linear(in_features=200, out_features=100, bias=True)
    (W): Linear(in_features=128, out_features=100, bias=True)
    (U): Linear(in_features=128, out_features=100, bias=True)
    (U_i): Linear(in_features=200, out_features=100, bias=True)
    (W_o): 

## オプティマイザの設定

In [4]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

optimizer = optim.Adam(enc_model.parameters(), lr=1e-3)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

In [5]:
from MS_PredictModel import ms_peak_encoder,MS_Dataset
from tqdm import tqdm

if pbar: pbar = tqdm()
def training(max_epoch = 100):
    global pbar
    total_step = 0
    meters = np.zeros(3)
    for epoch in range(max_epoch):
        print("epoch : ",epoch)
        for batch in dataset:
            x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder,x,y = batch
            total_step+=1
            pbar.update(1)
            enc_model.zero_grad()
            h = enc_model(x,y)
            tree_vec = h[:,:h.shape[1]/2]
            mol_vec  = h[:,h.shape[1]/2:]
            _, x_tree_mess = dec_model.jtnn(*x_jtenc_holder)
            word_loss, topo_loss, word_acc, topo_acc = dec_model.decoder(x_batch,tree_vec)
            assm_loss, assm_acc = dec_model.assm(x_batch, x_jtmpn_holder, mol_vec , x_tree_mess)
            total_loss = word_loss+topo_loss+assm_loss
            total_loss.backward()
            optimizer.step()
            
            meters = meters + np.array([word_acc * 100, topo_acc * 100, assm_acc * 100])
            if total_step % 100 == 0:
                meters /= 100
                print "[%d] , Word: %.2f, Topo: %.2f, Assm: %.2f" % (total_step,meters[0], meters[1], meters[2])
                sys.stdout.flush()
                meters *= 0
            if total_step % 100 == 0:
                torch.save(enc_model.state_dict(), "./enc_model" + "/model.iter-" + str(total_step))

#import pdb; pdb.set_trace()
try:
    training(10)
except RuntimeError as e:
    import pdb; pdb.set_trace()
    print(e)



0it [00:00, ?it/s]

('epoch : ', 0)


99it [00:25,  3.66it/s]

('epoch : ', 1)


100it [00:26,  3.44it/s]

[100] , Word: 32.14, Topo: 85.34, Assm: 80.25


198it [00:52,  6.99it/s]

('epoch : ', 2)


200it [00:56,  1.11s/it]

[200] , Word: 34.05, Topo: 86.37, Assm: 83.96


297it [01:19,  1.40it/s]

('epoch : ', 3)


300it [01:23,  1.09s/it]

[300] , Word: 35.90, Topo: 86.84, Assm: 84.67


396it [01:48,  5.93it/s]

('epoch : ', 4)


400it [01:50,  1.70it/s]

[400] , Word: 36.68, Topo: 87.27, Assm: 85.23


495it [02:14,  5.79it/s]

('epoch : ', 5)


500it [02:17,  1.75it/s]

[500] , Word: 36.60, Topo: 87.12, Assm: 84.36


594it [02:39,  6.55it/s]

('epoch : ', 6)


600it [02:41,  4.13it/s]

[600] , Word: 35.90, Topo: 87.16, Assm: 85.57


693it [03:03,  5.22it/s]

('epoch : ', 7)


700it [03:04,  5.68it/s]

[700] , Word: 35.80, Topo: 87.15, Assm: 83.76


792it [03:29,  6.53it/s]

('epoch : ', 8)


800it [03:34,  2.08it/s]

[800] , Word: 35.64, Topo: 87.10, Assm: 85.23


891it [03:56,  4.17it/s]

('epoch : ', 9)


900it [04:00,  4.16it/s]

[900] , Word: 34.55, Topo: 87.87, Assm: 85.00


990it [04:22,  4.73it/s]

In [None]:
from fast_jtnn.datautils import tensorize
import random

random.shuffle(dataset.dataset)
test = [[one[2]] for one in dataset.dataset]
for one in test:
    print(len(one),one[0].smiles)
    tensorize(one,vocab,assm=True)

In [4]:
from rdkit import Chem
from rdkit.Chem import Draw
from fast_jtnn import MolTree

smiles = "CC(=O)c1ccccc1"
#smiles = "CCCCOC(=O)c(c1)c(ccc1)C(=O)OCCCC"
mol = Chem.MolFromSmiles(smiles)
Draw.MolToImage(mol)

mol_tree = MolTree(smiles) # convert smiles to MolTree
mol_tree.recover()
assm = True
if assm:
    mol_tree.assemble()
    for node in mol_tree.nodes:
        print("test1",node.label,node.cands)
        if node.label not in node.cands:
            print("test2",node.label)
            node.cands.append(node.label)

del mol_tree.mol
for node in mol_tree.nodes:
    del node.mol

def set_batch_nodeID(mol_batch, vocab):
    tot = 0
    for mol_tree in mol_batch:
        for node in mol_tree.nodes:
            node.idx = tot
            node.wid = vocab.get_index(node.smiles)
            tot += 1
    
tree_batch = [mol_tree]

#ret = tensorize(mol_batch,vocab,assm)
set_batch_nodeID(tree_batch, vocab)
smiles_batch = [tree.smiles for tree in tree_batch]
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
jtenc_holder = jtenc_holder
mpn_holder = MPN.tensorize(smiles_batch)

cands = []
batch_idx = []
for i,mol_tree in enumerate(tree_batch):
    for node in mol_tree.nodes:
        print(node.smiles,node.is_leaf,len(node.cands))
        #Leaf node's attachment is determined by neighboring node's attachment
        if node.is_leaf or len(node.cands) == 1:
            #print("test")
            continue
        cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
        batch_idx.extend([i] * len(node.cands))
print([one[0] for one in cands])
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
batch_idx = torch.LongTensor(batch_idx)
print(len(ret))

('test1', 'C[CH3:5]', ['C[CH3:5]'])
('test1', 'O=[CH2:5]', ['O=[CH2:5]'])
('test1', 'c1cc[c:3]([CH3:5])cc1', ['c1cc[c:3]([CH3:5])cc1'])
('test1', 'c1ccc([CH3:3])cc1', ['c1ccc([CH3:3])cc1'])
('test1', 'C[C:5](=O)[CH3:3]', ['C[C:5](=O)[CH3:3]'])
('CC', True, 1)
('C=O', True, 1)
('CC', False, 1)
('C1=CC=CC=C1', True, 1)
('C', False, 1)
[]


RuntimeError: expected a non-empty list of Tensors

In [None]:
smiles = "O=C(C)c(c1)cccc1"
smiles = "CCCCOC(=O)c(c1)c(ccc1)C(=O)OCCCC"
mol = Chem.MolFromSmiles(smiles)
#mol = Chem.AddHs(mol)
print(Chem.MolToSmiles(mol))
Draw.MolsToGridImage([mol])

In [8]:
del tqdm