In [1]:
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import math

import torch
from torch import tensor
import torch.optim as optim
from torch import autograd
import torch.utils.data as D
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.distributions import Bernoulli

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve, auc, roc_curve
import rdkit
from rdkit import Chem

import math
from tqdm import tqdm_notebook as tqdm
from itertools import chain

from IPython.display import clear_output
from ipywidgets import interact
from bokeh.models import ColumnDataSource
from bokeh.layouts import column, row
from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure

import deepchem as dc
from deepchem.molnet import load_tox21

import h5py
import os

output_notebook()

use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)
np.random.seed(123)

  from ._conv import register_converters as _register_converters


In [17]:
class Model(nn.Module):
    def __init__(self, feat_size, N_Max_atom, parts, T):
        super(Model, self).__init__()
        self.N_Max_atom = N_Max_atom
        self.parts = parts
        self.T = T
        self.feat_size = feat_size
        self.E_atom = nn.Embedding(1 + len(self.parts[0]), feat_size, padding_idx=0)
        self.E_bond = nn.Embedding(1 + len(self.parts[1]), feat_size, padding_idx=0)
        self.E_pair = nn.Embedding(1 + len(self.parts[2]), feat_size, padding_idx=0)
        
        #f_an
        
        self.f_e_an = nn.ModuleList([nn.Linear(3*feat_size, 3*feat_size*2) for _ in range(self.T)])
        self.f_n_an = nn.ModuleList([nn.GRUCell(3*feat_size*2, feat_size) for _ in range(self.T)])
        
        self.f_m_an = nn.Linear(feat_size, feat_size*2)
        self.g_m_an = nn.Linear(feat_size, 1)
        
        self.f_an = nn.Linear(feat_size*2, 1+len(self.parts[0]))
        
        #hv initialize
        self.f_m_init = nn.Linear(feat_size, feat_size*2)
        self.g_m_init = nn.Linear(feat_size, 1)
        
        self.f_init = nn.Linear(feat_size + feat_size*2, feat_size)
        
        #f_ae
        
        self.f_e_ae = nn.ModuleList([nn.Linear(3*feat_size, 3*feat_size*2) for _ in range(self.T)])
        self.f_n_ae = nn.ModuleList([nn.GRUCell(3*feat_size*2, feat_size) for _ in range(self.T)])
        
        self.f_m_ae = nn.Linear(feat_size, feat_size*2)
        self.g_m_ae = nn.Linear(feat_size, 1)
        
        self.f_ae = nn.Linear(feat_size*2, 1)
        
        #f_s
        
        self.f_e_s = nn.ModuleList([nn.Linear(3*feat_size, 3*feat_size*2) for _ in range(self.T)])
        self.f_n_s = nn.ModuleList([nn.GRUCell(3*feat_size*2, feat_size) for _ in range(self.T)])
        
        self.f_s = nn.Linear(feat_size*2, 1+len(self.parts[1]))
        
        
    def forward(self, hdf, start, end):
        #adj_mat = torch.from_numpy(hdf['adjacent'][start:end])[:,0,:]
        node_index = torch.from_numpy(hdf['initmol'][start:end])[:,0].view(-1,1).long()
        #edge_index = torch.from_numpy(hdf['bond'][start:end])[:,0,:].long()
        
        if use_cuda:
            #adj_mat = adj_mat.cuda()
            node_index = node_index.cuda()
            #edge_index = edge_index.cuda()
            
            
        emb_node = self.E_atom(node_index)
        edge_index = torch.zeros(len(emb_node), 1, self.N_Max_atom)
        emb_edge = torch.zeros(len(emb_node), 1, self.N_Max_atom, self.feat_size)
        adj_mat = torch.zeros(len(emb_node), 1, self.N_Max_atom)
        h_v = torch.randn(emb_node.size())
        
        h_v = self.h_v_init(h_v, emb_node, self.f_m_init, self.g_m_init, self.f_init)

        h_G = self.propagation(h_v, emb_edge, adj_mat, 
                               self.f_e_an, self.f_n_an, self.f_m_an, self.g_m_an)
        p_addnode = F.softmax(self.f_an(h_G),dim=-1)
        
        DC = Categorical(p_addnode)
        v_new_index = DC.sample()
        node_index = torch.log(node_index.float())
        node_index = torch.cat((node_index,DC.log_prob(v_new_index).view(-1,1)),1)
        check_node = (v_new_index.view(-1,1) != 0).float()
        
        position = 1
        while torch.sum(check_node) != 0 and position < self.N_Max_atom -1:
            #print('atom position')
            #print(position)
            emb_new_node = self.E_atom(v_new_index.view(-1,1))
            v_new = self.h_v_init(h_v,emb_new_node,self.f_m_init,self.g_m_init,self.f_init)
            
            h_v = torch.cat((h_v,v_new),1)
            
            emb_edge = torch.cat((emb_edge, torch.zeros(emb_edge.size(0),1,emb_edge.size(2),emb_edge.size(3))),1)
            edge_index = torch.cat((edge_index, torch.zeros(len(h_v), 1, self.N_Max_atom)), 1)
            adj_mat = torch.cat((adj_mat, torch.zeros(adj_mat.size(0), 1, adj_mat.size(2))), 1)
            
            h_G = self.propagation(h_v,emb_edge,adj_mat,self.f_e_ae,self.f_n_ae,self.f_m_ae,self.g_m_ae)
            
            p_addedge = F.softmax(self.f_ae(h_G),dim=-1)
            DB = Bernoulli(p_addedge)
            z_t = DB.sample().view(-1,1)
            check_edge = torch.mul(z_t, check_node)            
            check_edge_s = torch.ones(h_v.size(0),h_v.size(1),1)
            limit = 15
            count = 0
            while torch.sum(check_edge) != 0 and torch.sum(check_edge_s) != 0 and count<limit:
                h_u_T = self.prop_only(h_v,emb_edge,adj_mat,self.f_e_s,self.f_n_s)
                h_v_T = torch.cat([h_u_T[:,-1,:].view(len(h_u_T),1,-1)for _ in range(h_u_T.size(1))],1)
                
                p_nodes = F.softmax(self.f_s(torch.cat((h_u_T,h_v_T),2).view(-1,h_u_T.size(2)*2)),dim=1).view(h_u_T.size(0), h_u_T.size(1), -1)
                DC_edge = Categorical(p_nodes)
                edge_new_index = DC_edge.sample().view(h_u_T.size(0), h_u_T.size(1), -1).float()
                check_edge_s_cdt = (edge_new_index == 0).float().view(h_u_T.size(0), h_u_T.size(1), -1)
                edge_new_index =  torch.mul(edge_new_index, check_edge.view(-1, 1, 1))
                edge_new_index =  torch.mul(edge_new_index, check_edge_s)
                
                edge_index = edge_index.clone()
                
                edge_index[:,:,position] = edge_new_index.view(-1,edge_new_index.size(1))
                edge_index[:,position,:position+1] = edge_new_index.view(-1,edge_new_index.size(1))
                emb_edge = self.E_bond(edge_index.long())

                adj_mat = adj_mat.clone()
                adj_sub = edge_new_index != 0
                adj_mat[:,:,position] = adj_sub.long().view(-1,adj_sub.size(1))
                adj_mat[:,position,:position+1] = adj_sub.long().view(-1,adj_sub.size(1))
                

                h_G = self.propagation(h_v,emb_edge,adj_mat,self.f_e_ae,self.f_n_ae,self.f_m_ae,self.g_m_ae)

                p_addedge = F.softmax(self.f_ae(h_G),dim=-1)
                z_t = Bernoulli(p_addedge).sample().view(-1,1)
                check_edge = torch.mul(z_t, check_edge)
                check_edge_s = torch.mul(check_edge_s, check_edge_s_cdt)
                #print('check_edge')
                #print(torch.sum(check_edge))
                #print('check_edge_s')
                #print(torch.sum(check_edge_s))
                count +=1
            position += 1
            h_G = self.propagation(h_v,emb_edge,adj_mat,self.f_e_ae,self.f_n_ae,self.f_m_ae,self.g_m_ae)
            p_addnode = F.softmax(self.f_an(h_G),dim=-1)
            DC = Categorical(p_addnode)
            v_new_index = DC.sample()
            node_index = torch.cat((node_index,DC.log_prob(v_new_index).view(-1,1)),1)
            check_node = (v_new_index.view(-1,1) != 0).float()
            #print('sum of check_node')
            #print(torch.sum(check_node))
        if node_index.size(1) < self.N_Max_atom:
            adjust =torch.zeros(node_index.size(0),self.N_Max_atom - node_index.size(1))
            node_index = torch.cat((node_index.float(),adjust),1)
        else:
            pass
        return node_index.float(), edge_index
    
    def propagation(self, h_v, edge, adj, f_e, f_n, f_m, g_m):
        for T in range(self.T):
            for t in range(len(h_v[0])):
                if t==0:
                    neighbor = torch.mul(h_v, adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    watch = torch.mul(h_v[:,t,:].view(-1,1,h_v.size(2)), adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    m_u_v = f_e[T](torch.cat((neighbor,watch,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_u_v = torch.sum(m_u_v.view(neighbor.size(0), neighbor.size(1),-1),1,keepdim=True)
                    m_v_u = f_e[T](torch.cat((watch,neighbor,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_v_u = torch.sum(m_v_u.view(neighbor.size(0), neighbor.size(1), -1),1,keepdim=True)
                    a_v = m_u_v + m_v_u
                    
                else:
                    neighbor = torch.mul(h_v, adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    watch = torch.mul(h_v[:,t,:].view(-1,1,h_v.size(2)), adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    m_u_v = f_e[T](torch.cat((neighbor,watch,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_u_v = torch.sum(m_u_v.view(neighbor.size(0), neighbor.size(1),-1),1,keepdim=True)
                    m_v_u = f_e[T](torch.cat((watch,neighbor,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_v_u = torch.sum(m_v_u.view(neighbor.size(0), neighbor.size(1), -1),1,keepdim=True)
                    a_t = m_u_v + m_v_u
                    a_v = torch.cat((a_v,a_t),1) 
            
            h_v = f_n[T](a_v.view(-1,a_v.size(2)), h_v.view(-1, h_v.size(2))).view(h_v.size(0), h_v.size(1), -1)
        
        h_v_g = f_m(h_v.view(h_v.size(0)* h_v.size(1), -1))
        g_v = F.sigmoid(g_m(h_v)).view(-1,1)
        h_G = torch.sum(torch.mul(g_v,h_v_g).view(h_v.size(0),h_v.size(1),-1),1,keepdim=False)
        
        return h_G
    
    def prop_only(self, h_v, edge, adj, f_e, f_n):
        for T in range(self.T):
            for t in range(len(h_v[0])):
                if t==0:
                    neighbor = torch.mul(h_v, adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    watch = torch.mul(h_v[:,t,:].view(-1,1,h_v.size(2)), adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    m_u_v = f_e[T](torch.cat((neighbor,watch,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_u_v = torch.sum(m_u_v.view(neighbor.size(0), neighbor.size(1),-1),1,keepdim=True)
                    m_v_u = f_e[T](torch.cat((watch,neighbor,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_v_u = torch.sum(m_v_u.view(neighbor.size(0), neighbor.size(1), -1),1,keepdim=True)
                    a_v = m_u_v + m_v_u
                    
                else:
                    neighbor = torch.mul(h_v, adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    watch = torch.mul(h_v[:,t,:].view(-1,1,h_v.size(2)), adj[:,t,:h_v.size(1)].view(-1, h_v.size(1), 1))
                    m_u_v = f_e[T](torch.cat((neighbor,watch,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_u_v = torch.sum(m_u_v.view(neighbor.size(0), neighbor.size(1),-1),1,keepdim=True)
                    m_v_u = f_e[T](torch.cat((watch,neighbor,edge[:,t,:h_v.size(1)]),2).view(neighbor.size(0)*neighbor.size(1), -1))
                    m_v_u = torch.sum(m_v_u.view(neighbor.size(0), neighbor.size(1), -1),1,keepdim=True)
                    a_t = m_u_v + m_v_u
                    a_v = torch.cat((a_v,a_t),1) 
            
            h_v = f_n[T](a_v.view(-1,a_v.size(2)), h_v.view(-1, h_v.size(2))).view(h_v.size(0), h_v.size(1), -1)
            
        return h_v
    
    def h_v_init(self, h_v,e, f_m, g_m, f_init):
        h_v_g_init = f_m(h_v.view(-1 , h_v.size(2)))
        g_v_init = F.sigmoid(g_m(h_v)).view(-1, 1)
        h_G_init = torch.sum(torch.mul(g_v_init, h_v_g_init).view(h_v.size(0), h_v.size(1), -1) ,1, keepdim=True)
        
        h_v_new = f_init(torch.cat((e, h_G_init), 2).view(-1, e.size(2) + h_G_init.size(2))).view(h_v.size(0), 1, -1)
        
        return h_v_new
    def he_init(self, m):
        if type(m) == nn.Linear:
            out_features = m.weight.size(0)
            in_features = m.weight.size(1)
            m.weight.data = torch.randn(out_features, in_features)* (2 * torch.rsqrt(tensor([out_features]).float() ))

In [3]:
class Featurizer(object):
    def __init__(self, N_Max_atom, parts):
        super(Featurizer, self).__init__()
        self.N_Max_atom = N_Max_atom
        self.parts = parts
    
    def __call__(self, mols, filepass):
        
        self.size = mol_num(mols)
        
        self.f =  h5py.File(filepass, "w")
        
        self.f.create_dataset('adjacent',shape=(self.size, self.N_Max_atom, self.N_Max_atom))
        self.f.create_dataset('atom_pair',shape=(self.size, self.N_Max_atom, self.N_Max_atom))
        self.f.create_dataset('bond',shape=(self.size, self.N_Max_atom, self.N_Max_atom))
        self.f.create_dataset('initmol',shape=(self.size, self.N_Max_atom))
        
        print('...Featurizing init_mol')
        self.init_mol(mols)
        print('...Featurizing adjlist')
        self.adjlist(mols)
        print('...Featurizing atom_pair')
        self.atom_pair(mols)
        print('...Featurizing bond_type')
        self.bond_type(mols)
        
        self.f.flush()
        self.f.close()
        
        print('Done')
        
    def init_mol(self, input_mols):
        ids = self.f['initmol']
        for index_mol, mol in enumerate(input_mols):
            if mol.GetNumAtoms() <= self.N_Max_atom:
                for index_atom , atom in enumerate(mol.GetAtoms()):
                    ids[index_mol, index_atom] = (self.parts[0])[atom.GetSymbol()]
            else:
                pass

    def adjlist(self, input_mols):
        adj = self.f['adjacent']
        for index_mol, mol in enumerate(input_mols):
            if mol.GetNumAtoms() <= self.N_Max_atom:
                for index_atom , atom in enumerate(mol.GetAtoms()):
                    for atom_neighbor in atom.GetNeighbors():
                        adj[index_mol, index_atom, int(atom_neighbor.GetIdx())] = 1
            else:
                pass
    
    def atom_pair(self, mols):
        ids = self.f['atom_pair']
        for index_mol, mol in enumerate(mols):
            if mol.GetNumAtoms() <= self.N_Max_atom:
                for index_atom, atom in enumerate(mol.GetAtoms()):
                    watch_atom = atom.GetSymbol()
                    for index_n, n in enumerate(atom.GetNeighbors()):
                        neighbor = n.GetSymbol()
                        pair = sorted([watch_atom, neighbor])
                        pair = '-'.join(pair)
                        ids[index_mol, index_atom, int(n.GetIdx())] = (self.parts[2])[pair]
        else:
            pass
        
    def bond_type(self, mols):
        ids = self.f['bond']
        for index_mol, mol in enumerate(mols):
            if mol.GetNumAtoms() <= self.N_Max_atom:
                for index_bond, bond in enumerate(mol.GetBonds()):
                        ids[index_mol, bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] = (self.parts[1])[str(bond.GetBondType())]
        else:
            pass
        
        

In [4]:
def roc_score(label, prob):
    y_true = label.tolist()
    y_scores = prob.tolist()
    auc = roc_auc_score(y_true, y_scores)
    return auc
def isnan(tensor):
    return tensor != tensor
def mol_num(mols):
    num = sum([int(mol.GetNumAtoms()) <= N_Max_atom for mol in mols if mol is not None])
    return num
def isnan(tensor):
    return tensor != tensor
def nv(arr):
    v = Variable(torch.from_numpy(arr).float())
    if use_cuda:
        v = v.cuda()
    return v
def calc_lr(epoch_num):
    return 0.000000001 * 1.3 ** epoch_num

In [5]:
def evaluate(classifier, hdf, target):
    num = len(hdf['initmol']) // batch_size
    for i in range(0, num):
        start = i*batch_size
        end = i*batch_size + batch_size
        if i == 0:
            valid_all_hat ,_ = classifier(hdf, start, end)
            valid_all_hat = valid_all_hat.data
        else:
            valid_y_hat ,_ = classifier(hdf, start, end)
            valid_all_hat = torch.cat((valid_all_hat, valid_y_hat.data), 0) 
    valid_y_hat ,_  = classifier(hdf, end, len(hdf['initmol']))
    valid_all_hat = torch.cat((valid_all_hat, valid_y_hat.data),0)
    
    loss = criterion(valid_all_hat, target)
    if use_cuda:
        loss = loss.cpu()
    
    return loss

In [6]:
def atom_set(mols):
    all_set = set()
    for mol in mols:
        for atom in mol.GetAtoms():
            watch_atom = atom.GetSymbol()
            all_set.add(watch_atom)
    
    alllist = sorted(list(all_set))
    index = [i for i in range(1, len(alllist) + 1)]
    pair_dic = dict(zip(alllist, index))
    return pair_dic

def pair_set(mols):
    all_set = set()
    for mol in mols:
        for atom in mol.GetAtoms():
            watch_atom = atom.GetSymbol()
            all_set.add(watch_atom)
            for n in atom.GetNeighbors():
                neighbor = n.GetSymbol()
                pair = sorted([watch_atom, neighbor])
                pair = '-'.join(pair)
                all_set.add(pair)
    
    alllist = sorted(list(all_set))
    index = [i for i in range(1, len(alllist) + 1)]
    pair_dic = dict(zip(alllist, index))
    return pair_dic

def bond_set(mols):
    all_set = set()
    for mol in mols:
        for bond in mol.GetBonds():
            all_set.add(str(bond.GetBondType()))
            
    alllist = sorted(list(all_set))
    index = [i for i in range(1, len(alllist) + 1)]
    pair_dic = dict(zip(alllist, index))
    return pair_dic
def make_neighborlist(input_mols):
    adj_list = []
    
    for index_mol, mol in enumerate(input_mols):
        mol_adj_list = []
        for index_atom , atom in enumerate(mol.GetAtoms()):            
            atom_adj_list = []
            for atom_neighbor in atom.GetNeighbors():
                atom_adj_list.append(int(atom_neighbor.GetIdx()))
            mol_adj_list.append(atom_adj_list)
        adj_list.append(mol_adj_list)
        
    adj_list = np.array(adj_list)
    
    return adj_list

In [7]:
N_Max_atom = 15

tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer='Raw', split='random')
(train_dataset, valid_dataset, test_dataset) = tox21_datasets

train_x = list(train_dataset.X)
train_x = [Chem.AddHs(mol) for mol in train_x if mol is not None]
train_x = [mol for mol in train_x if mol is not None and mol.GetNumAtoms() <= N_Max_atom]

train_label = nv(train_dataset.y)

valid_x = list(valid_dataset.X)
valid_x = [Chem.AddHs(mol) for mol in valid_x if mol is not None]
valid_x = [mol for mol in valid_x if mol is not None and mol.GetNumAtoms() <= N_Max_atom]

valid_label = nv(valid_dataset.y)

test_x = list(test_dataset.X)
test_x = [Chem.AddHs(mol) for mol in test_x if mol is not None]
test_x = [mol for mol in test_x if mol is not None and mol.GetNumAtoms() <= N_Max_atom]

test_label = nv(test_dataset.y)

Loading dataset from disk.
Loading dataset from disk.
Loading dataset from disk.


In [8]:
all_mols = train_x + valid_x + test_x

atom = atom_set(all_mols)
bond = bond_set(all_mols)
pair = pair_set(all_mols)
parts = (atom, bond, pair)

neighborlist = make_neighborlist(all_mols)
atom_degree = [[len(l) for l in l] for l in neighborlist]
atom_degree_flatten = list(chain.from_iterable(atom_degree))
max_degree = max(atom_degree_flatten)
min_degree = min(atom_degree_flatten)

In [9]:
featurizer= Featurizer(N_Max_atom, parts)
if not os.path.exists('./train.hdf5'):
    print('Train set')
    featurizer(train_x,'./train.hdf5')
if not os.path.exists('./valid.hdf5'):
    print('Valid set')
    featurizer(valid_x,'./valid.hdf5')
if not os.path.exists('./test.hdf5'):
    print('Test set')
    featurizer(test_x,'./test.hdf5')

In [18]:
#Settings

batch_size = 8

#Model

classifier = Model(128, N_Max_atom, parts, 2)
torch.save(classifier.state_dict(), './init.pth') 
finder = Model(128, N_Max_atom, parts, 2)
finder.load_state_dict(torch.load('./init.pth'))

#Loss function
criterion = nn.KLDivLoss()

In [19]:
#Verify finder's weight is the same as classifier's weight
classifier.f_ae.weight == finder.f_ae.weight

tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1, 

Find Learning rate(below cell)

In [20]:
##Optimal learning rates

#Log

find_logs = []

#Plotting

source_find = ColumnDataSource(data=dict(epoch=[], loss=[], lr=[]))
settings = dict(plot_width=480, plot_height=430, min_border=0)
pf = figure(title="Optimal learning rates", x_axis_label="epoch",y_axis_label="Loss", **settings)
pf.line(x='epoch', y='loss', source=source_find)

pf2 = figure(title="Corresponding Learning rate", x_axis_label="epoch", y_axis_label="Learning rate", **settings)
pf2.line(x='epoch', y='lr', source=source_find, color="orange")

tf = show(column(pf,pf2), notebook_handle=True)

#Config
train = h5py.File('./train.hdf5', 'r')
finder.train()
finder_x = train_x[:batch_size]
batch_num = len(finder_x) // batch_size

optimizer_find = optim.Adam(finder.parameters(), lr=calc_lr(50))
scheduler = optim.lr_scheduler.ExponentialLR(optimizer_find, 1.3, last_epoch=-1)

for epoch in tqdm(range(50,80)):
    
    #Training
    
    for batch in tqdm(range(0,batch_num), disable=True):
        
        start = batch*batch_size
        end = batch*batch_size + batch_size
        
        target = torch.from_numpy(train['initmol'][start:end]).view(-1, N_Max_atom).float()
        if use_cuda:
            target = target.cuda()
        target = torch.log(target)
        y_hat, edge = finder(train, start, end)
        
        loss = criterion(y_hat, target)
        loss.backward()
        optimizer_find.step()
        finder.zero_grad()
    
    
    #Logiging 
    if use_cuda:
        loss = loss.cpu()
    new_data_find = {'epoch' : [epoch], 'loss' : [loss.detach().numpy()], 'lr' : [calc_lr(epoch)]}
    new_data_find_df = {'epoch' : epoch, 'loss' : loss.detach().numpy(), 'lr' : calc_lr(epoch)}
    source_find.stream(new_data_find)
    
    find_logs.append(new_data_find_df)
    df_find = pd.DataFrame(find_logs)
    df_find.to_csv("./find_logs.csv", index=False)
    
    scheduler.step()
    
    #Show plots
    
    push_notebook(handle=tf)

clear_output()

find_logs = pd.read_csv('./find_logs.csv')
source_find = ColumnDataSource(find_logs)

tf = show(column(pf,pf2))

print('Done')

Done


Training(below cell)

In [21]:
#Log

train_logs = []
valid_logs = []

#Plotting

source_train = ColumnDataSource(data=dict(epoch=[], loss=[]))
source_test = ColumnDataSource(data=dict(epoch=[], loss=[]))
settings = dict(plot_width=480, plot_height=430, min_border=0)
p = figure(title="KLDiv", x_axis_label="epoch",y_axis_label="Loss", **settings)
p.line(x='epoch', y='loss', source=source_train)
p.line(x='epoch', y='loss', source=source_test, color="orange")

t = show(p, notebook_handle=True)

#Config
train = h5py.File('./train.hdf5', 'r')
valid = h5py.File('./valid.hdf5', 'r')


batch_num = len(train_x) // batch_size

optimizer = optim.Adam(classifier.parameters(), lr=calc_lr(65))

for epoch in tqdm(range(0,100)):
    
    #Training
    
    for batch in tqdm(range(0,batch_num)):
        
        start = batch*batch_size
        end = batch*batch_size + batch_size
        
        target = torch.from_numpy(train['initmol'][start:end]).view(-1, N_Max_atom).float()
        if use_cuda:
            target = target.cuda()
        classifier.train()
        y_hat, edge = classifier(train, start, end)
        
        loss = criterion(y_hat, target)
        loss.backward()
        optimizer.step()
        classifier.zero_grad()
    
    #Train set Evaluation
    
    if use_cuda:
        loss = loss.cpu()
    new_data_train = {'epoch' : [epoch], 'loss' : [loss.detach().numpy()]}
    new_data_train_df = {'epoch' : epoch, 'loss' : loss.detach().numpy()}
    source_train.stream(new_data_train)
    
    train_logs.append(new_data_train_df)
    df_train = pd.DataFrame(train_logs)
    df_train.to_csv("./train_logs.csv", index=False)
    
    #Valid set Evaluation
    
    classifier.eval()
    v_target = torch.from_numpy(valid['initmol'].value).view(-1, N_Max_atom).float()
    if use_cuda:
        v_target = v_target.cuda()
    valid_y_hat, valid_edge = classifier(valid, 0, len(v_target))
    valid_loss = criterion(valid_y_hat, v_target)
    if use_cuda:
        valid_loss = valid_loss.cpu()
        
    valid_new_data = {'epoch' : [epoch], 'loss' : [valid_loss.detach().numpy()]}
    valid_new_data_df = {'epoch' : epoch, 'loss' : valid_loss.detach().numpy()}
    source_test.stream(valid_new_data)
        
    valid_logs.append(valid_new_data_df)
    df_test = pd.DataFrame(valid_logs)
    df_test.to_csv("./valid_logs.csv", index=False)
    
    #Show plots
    
    push_notebook(handle=t)
    
    #Save model
    
    torch.save(classifier.state_dict(), 'epoch{0}.pth' .format(epoch))
        
clear_output()

train_logs = pd.read_csv('./train_logs.csv')
test_logs = pd.read_csv('./valid_logs.csv')

source_train = ColumnDataSource(train_logs)
source_test = ColumnDataSource(test_logs)

t = show(p, notebook_handle=True)

train.flush()
train.close()
valid.flush()
valid.close()
print("Done")

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

HBox(children=(IntProgress(value=0, max=86), HTML(value='')))

RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1524586445097/work/aten/src/THC/generated/../THCReduceAll.cuh:339

Show Plot(below cell)

In [22]:
train_logs = pd.read_csv('./train_logs.csv')
test_logs = pd.read_csv('./valid_logs.csv')

source_train = ColumnDataSource(train_logs)
source_test = ColumnDataSource(test_logs)
settings = dict(plot_width=480, plot_height=430, min_border=0)
p = figure(title="Binary Cross Entropy Loss with Sigmoid Layer", x_axis_label="epoch",y_axis_label="Loss", **settings)
p.line(x='epoch', y='loss', source=source_train, legend="Train")
p.line(x='epoch', y='loss', source=source_test, legend="Valid", color="orange")

draw = show(p, notebook_handle=True)