# 学習プログラム

## path通し

In [1]:
import sys
import os

sys.path.append(os.path.split(os.getcwd())[0])


## vocablary作成
すでにvocab.txtが作成済みである場合不用

In [2]:
from fast_jtnn.mol_tree import MolTree
from getpass import getpass,getuser

import mysql
from mysql import connector
import warnings

# SQL server profile
host = "localhost"
user = None
passwd = None
port = 3306
database="chemoinfo"

# 
VOCAB_FILE = "./MS_vocab.txt"

# get massbank data from SQL server
try:
    if not isinstance(user,str):
        user = raw_input("user")
    if not isinstance(passwd,str):
        passwd = getpass()
    connect = connector.connect(host=host,user=user,password=passwd,port=port,database=database)
    cursor = connect.cursor()
    cursor.execute("""select smiles from massbank where ms_type="MS" and instrument_type="EI-B"; """)
    smiles_list = cursor.fetchall()
except mysql.connector.Error as e:
    print("Something went wrong: {}".format(e))
    sys.exit(1)
finally:
    if connect: connect.close()
    if cursor: cursor.close()

# create vocablary
succes = 0
fault = 0
cset = set()
for one in smiles_list:
    try:
        mol = MolTree(one[0])
    except AttributeError as e:
        warnings.warn("Entered An SMILES that does not meet the rules")
        continue
    for c in mol.nodes:
        cset.add(c.smiles)

# write vocab
with open(VOCAB_FILE,"w") as f:
    for one in cset:
        f.write(one+"\n")

user aisiars
 ··········




# Vocabのロード

In [3]:
from fast_jtnn import *

vocab = [x.strip("\r\n ") for x in open(VOCAB_FILE,"r")]
vocab = Vocab(vocab)

## モデルの作成

In [4]:
from ms_encoder import ms_peak_encoder
import torch.nn as nn
hidden_size = 45
latent_size = 56
depthT = 20
depthG = 3

dec_model = JTNNVAE(vocab, hidden_size, latent_size, depthT, depthG).cuda()
print dec_model
enc_model = ms_peak_encoder(174,latent_size).cuda()
print enc_model

for param in dec_model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)
        
print "Model #Params: %dK" % (sum([x.nelement() for x in dec_model.parameters()]) / 1000,)



JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(365, 45)
    (outputNN): Sequential(
      (0): Linear(in_features=90, out_features=45, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=90, out_features=45, bias=True)
      (W_r): Linear(in_features=45, out_features=45, bias=False)
      (U_r): Linear(in_features=45, out_features=45, bias=True)
      (W_h): Linear(in_features=90, out_features=45, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(365, 45)
    (W_z): Linear(in_features=90, out_features=45, bias=True)
    (U_r): Linear(in_features=45, out_features=45, bias=False)
    (W_r): Linear(in_features=45, out_features=45, bias=True)
    (W_h): Linear(in_features=90, out_features=45, bias=True)
    (W): Linear(in_features=73, out_features=45, bias=True)
    (U): Linear(in_features=73, out_features=45, bias=True)
    (U_i): Linear(in_features=90, out_features=45, bias=True)
    (W_o): Linear(in_features=45, out

## オプティマイザの設定

In [5]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

optimizer = optim.Adam(enc_model.parameters(), lr=1e-3)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

In [6]:
from MS_PredictModel import ms_peak_encoder,MS_Dataset

total_step = 0
MS_Dataset.QUERY = """select smiles,file_path from massbank where ms_type="MS" and instrument_type="EI-B" limit 40; """
dataset = MS_Dataset(vocab=vocab,host="localhost",database="chemoinfo",batch_size=4)
for epoch in range(10):
    print("epoch : ",epoch)
    for batch in dataset:
        x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder,x,y = batch
        total_step+=1
        enc_model.zero_grad()
        h = enc_model(x,y)
        tree_vec = h[:,:h.shape[1]/2]
        mol_vec  = h[:,h.shape[1]/2:]
        _, x_tree_mess = dec_model.jtnn(*x_jtenc_holder)
        word_loss, topo_loss, word_acc, topo_acc = dec_model.decoder(x_batch,tree_vec)
        assm_loss, assm_acc = dec_model.assm(x_batch, x_jtmpn_holder, mol_vec , x_tree_mess)
        total_loss = word_loss+topo_loss+assm_loss
        total_loss.backward()
        optimizer.step()

user aisiars
 ··········


100%|██████████| 40/40 [00:06<00:00,  5.75it/s]

success 36,fault 4
('epoch : ', 0)





('epoch : ', 1)
('epoch : ', 2)
('epoch : ', 3)
('epoch : ', 4)
('epoch : ', 5)
('epoch : ', 6)
('epoch : ', 7)
('epoch : ', 8)
('epoch : ', 9)


In [7]:
h.shape

torch.Size([4, 56])