# Lab 11. logP value prediction from molecular graph with GCN

이번 실습에서는 molecular graph로부터 분자의 특성 중 하나인 logP value를 GCN을 통해 예측해보겠습니다. 이번 실습은 다음과 같은 내용을 포함합니다.

- Custom pytorch dataset을 정의하여 여러 input 또는 여러 output이 존재하는 경우를 다뤄보기
- Graph convolution을 pytorch로 구현하기
- (Gated) skip connection을 pytorch로 구현하기

# 0. Install Rdkit

분자를 molecular graph 형태로 만들어 주고, 분자의 logP 값을 알려주는 rdkit을 설치합니다.

In [4]:
!curl -LO  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local

import sys
sys.path.append('/usr/local/lib/python3.6/site-packages/')

!conda install -y -c rdkit rdkit 

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 89.8M  100 89.8M    0     0   119M      0 --:--:-- --:--:-- --:--:--  119M
PREFIX=/usr/local
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ | / done
Solving environment: \ | done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - _libgcc_mutex==0.1=main
    - brotlipy==0.7.0=py38h27cfd23_1003
    - ca-certificates==2020.10.14=0
    - certifi==2020.6.20=pyhd3eb1b0_3
    - cffi==1.14.3=py38h261ae71_2
    - chardet==3.0.4=py38h06a4308_1003
    - conda-package-handling==1.7.2=py38h03888b9_0
    - conda==4.9.2=py38h06a4308_0
    - cryptography==3.2.1=py38h3c74f83_1
    - idna==2.10=py_0
    - ld_impl_linux-64==2.33.1=h53a641e_7
    - libedit==3.1.20191231=h14c3975_1
    - li

In [5]:
import sys
import os
import requests
import subprocess
import shutil
from logging import getLogger, StreamHandler, INFO


logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel(INFO)


def install(
        chunk_size=4096,
        file_name="Miniconda3-latest-Linux-x86_64.sh",
        url_base="https://repo.continuum.io/miniconda/",
        conda_path=os.path.expanduser(os.path.join("~", "miniconda")),
        rdkit_version=None,
        add_python_path=True,
        force=False):
    """install rdkit from miniconda
    ```
    import rdkit_installer
    rdkit_installer.install()
    ```
    """

    python_path = os.path.join(
        conda_path,
        "lib",
        "python{0}.{1}".format(*sys.version_info),
        "site-packages",
    )

    if add_python_path and python_path not in sys.path:
        logger.info("add {} to PYTHONPATH".format(python_path))
        sys.path.append(python_path)

    if os.path.isdir(os.path.join(python_path, "rdkit")):
        logger.info("rdkit is already installed")
        if not force:
            return

        logger.info("force re-install")

    url = url_base + file_name
    python_version = "{0}.{1}.{2}".format(*sys.version_info)

    logger.info("python version: {}".format(python_version))

    if os.path.isdir(conda_path):
        logger.warning("remove current miniconda")
        shutil.rmtree(conda_path)
    elif os.path.isfile(conda_path):
        logger.warning("remove {}".format(conda_path))
        os.remove(conda_path)

    logger.info('fetching installer from {}'.format(url))
    res = requests.get(url, stream=True)
    res.raise_for_status()
    with open(file_name, 'wb') as f:
        for chunk in res.iter_content(chunk_size):
            f.write(chunk)
    logger.info('done')

    logger.info('installing miniconda to {}'.format(conda_path))
    subprocess.check_call(["bash", file_name, "-b", "-p", conda_path])
    logger.info('done')

    logger.info("installing rdkit")
    subprocess.check_call([
        os.path.join(conda_path, "bin", "conda"),
        "install",
        "--yes",
        "-c", "rdkit",
        "python=={}".format(python_version),
        "rdkit" if rdkit_version is None else "rdkit=={}".format(rdkit_version)])
    logger.info("done")

    import rdkit
    logger.info("rdkit-{} installation finished!".format(rdkit.__version__))


if __name__ == "__main__":
  install()

add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
python version: 3.6.9
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing rdkit
done
rdkit-2020.09.1 installation finished!


분자를 텍스트 형태로 표현한 smiles 파일과 molecular graph를 생성하는데 필요한 vocab.npy 파일을 받습니다.

In [213]:
!curl -o ZINC.smiles https://raw.githubusercontent.com/heartcored98/Standalone-DeepLearning/master/Lec9/ZINC.smiles
!curl -o vocab.npy https://raw.githubusercontent.com/heartcored98/Standalone-DeepLearning/master/Lec9/vocab.npy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 5374k  100 5374k    0     0  8988k      0 --:--:-- --:--:-- --:--:-- 8988k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   256  100   256    0     0   1340      0 --:--:-- --:--:-- --:--:--  1333


In [258]:
import argparse
import sys
import time
import copy

import numpy as np

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tqdm import tnrange, tqdm_notebook
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [259]:
paser = argparse.ArgumentParser()
args = paser.parse_args("")
args.seed = 123
args.val_size = 0.1
args.test_size = 0.1
args.shuffle = True

In [260]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [261]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# 1. Pre-Processing

ZINC.smiles 파일에 text로 표현되어 있는 분자들을 molecular graph 형태로 바꿔줍시다. 이때 node feature matrix는 아래 그림과 같이 각 원자의 symbol, degree 등 화학적 특성을 one-hot vector로 나타낸 형태입니다.
![node feature matrix](https://github.com/SeungsuKim/CH485--AI-and-Chemistry/raw/c85ce8716ac2e351d730543a2d45fd7054014d4f/Assignments/5.%20GCN/Graph_Generating_Process.png)

`read_ZINC_smiles` 함수는 smiles 파일 내의 분자 텍스트의 list와 각 분자들의 실제 logP value list를 return합니다.
`convert_to_graph` 함수는 분자 텍스트의 list를 받아 각 분자들의 `node feature matrix list`와 `adjacency matrix list`를 return 합니다.

In [285]:
def read_ZINC_smiles(file_name, num_mol):
    # f = open(file_name, 'r')
    # contents = f.readlines()

    id_target = np.array(pd.read_excel(file_name))

    smi_list = []
    logP_list = []
    values_list = []

    
    # for i in tqdm_notebook(range(num_mol), desc='Reading Data'):
        # smi = contents[i].strip()
        # m = Chem.MolFromSmiles(smi)
        # smi_list.append(smi)
        # values_list.append(MolLogP(m))

    for i in tqdm_notebook(range(id_target.shape[0]), desc='Reading Data'):
    

        smi = id_target[i, 0]
        smi_list.append(smi)

        value = id_target[i, 1]
        values_list.append(value)

    values_list = np.asarray(values_list).astype(float)

    return smi_list, values_list


In [288]:
DataPath = "/content/drive/MyDrive/0_한동생활/2020-2.5 겨울방학/2_GNN/data/LOHC_MLproject_POSTECH/Dehydrogenation3.xlsx"
# DataPath = "ZINC.smiles"
list_smi, list_logP = read_ZINC_smiles(DataPath, 1465)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Reading Data', max=1465.0, style=ProgressStyle(descriptio…




In [264]:
vocab = np.load('./vocab.npy')
print("shape", vocab.shape)
print(vocab)
LOHC_vocab = np.array(['O', 'C', '3', '(', '2', '#', 'N', ')', 'c', '1', 'o', '=', '4', '5', 's', 'n', 'l', 'P', 'S', '[', 'H', ']', 'F', 'I', 'B', 'r', '/', '\\', '6', '7', '8'])
print(LOHC_vocab)
Sumup_vocab = np.unique(np.concatenate((vocab, LOHC_vocab), axis=0))
print(Sumup_vocab)
print(list(Sumup_vocab))
# [' ' '#' '(' ')' '+' '-' '/' '1' '2' '3' '4' '5' '6' '7' '8' '=' '@' 'B'
#  'C' 'F' 'H' 'I' 'N' 'O' 'P' 'S' '[' '\\' ']' 'c' 'l' 'n' 'o' 'r' 's']

shape (32,)
[' ' 'c' 'C' '(' ')' 'O' '1' '2' 'N' '=' '[' ']' '@' '3' 'H' 'n' '4' 'F'
 '+' 'S' 'l' 's' '/' 'o' '-' '5' '#' 'B' 'r' '\\' '6' 'I']
['O' 'C' '3' '(' '2' '#' 'N' ')' 'c' '1' 'o' '=' '4' '5' 's' 'n' 'l' 'P'
 'S' '[' 'H' ']' 'F' 'I' 'B' 'r' '/' '\\' '6' '7' '8']
[' ' '#' '(' ')' '+' '-' '/' '1' '2' '3' '4' '5' '6' '7' '8' '=' '@' 'B'
 'C' 'F' 'H' 'I' 'N' 'O' 'P' 'S' '[' '\\' ']' 'c' 'l' 'n' 'o' 'r' 's']
[' ', '#', '(', ')', '+', '-', '/', '1', '2', '3', '4', '5', '6', '7', '8', '=', '@', 'B', 'C', 'F', 'H', 'I', 'N', 'O', 'P', 'S', '[', '\\', ']', 'c', 'l', 'n', 'o', 'r', 's']


In [289]:
def smiles_to_onehot(smi_list):
    def smiles_to_vector(smiles, vocab, max_length):
        while len(smiles) < max_length:
            smiles += " "
        vector = [vocab.index(str(x)) for x in smiles]
        one_hot = np.zeros((len(vocab), max_length), dtype=int)
        for i, elm in enumerate(vector):
            one_hot[elm][i] = 1
        return one_hot

    # vocab = np.load('./vocab.npy')
    vocab = [' ', '#', '(', ')', '+', '-', '/', '1', '2', '3', '4', '5', '6', '7', '8', '=', '@', 'B', 'C', 'F', 'H', 'I', 'N', 'O', 'P', 'S', '[', '\\', ']', 'c', 'l', 'n', 'o', 'r', 's']
    
    smi_total = []

    for i, smi in tqdm_notebook(enumerate(smi_list), desc='Converting to One Hot'):
        smi_onehot = smiles_to_vector(smi, list(vocab), 120)
        smi_total.append(smi_onehot)

    return np.asarray(smi_total)

def convert_to_graph(smiles_list):
    adj = []
    adj_norm = []
    features = []
    maxNumAtoms = 55 # 원래 30인가인데 55로 바꿈, 우리 가지고 있는 데이터에 맞게끔 
    for i in tqdm_notebook(smiles_list, desc='Converting to Graph'):
        # Mol
        iMol = Chem.MolFromSmiles(i.strip())
        #Adj
        iAdjTmp = Chem.rdmolops.GetAdjacencyMatrix(iMol)
        # Feature
        if( iAdjTmp.shape[0] <= maxNumAtoms):
            # Feature-preprocessing
            iFeature = np.zeros((maxNumAtoms, 58))
            iFeatureTmp = []
            for atom in iMol.GetAtoms():
                iFeatureTmp.append( atom_feature(atom) ) ### atom features only
            iFeature[0:len(iFeatureTmp), 0:58] = iFeatureTmp ### 0 padding for feature-set
            features.append(iFeature)

            # Adj-preprocessing
            iAdj = np.zeros((maxNumAtoms, maxNumAtoms))
            iAdj[0:len(iFeatureTmp), 0:len(iFeatureTmp)] = iAdjTmp + np.eye(len(iFeatureTmp))
            adj.append(np.asarray(iAdj))
    features = np.asarray(features)

    return features, adj
    
def atom_feature(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                      ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                                       'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                                       'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                                       'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])    # (40, 6, 5, 6, 1)

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1] # 초과하는 경우를 대비 | 초과하는 경우 범위 내에서 제일 큰 값을 지정함 
    return list(map(lambda s: x == s, allowable_set))

In [290]:
# DataPath = "/content/drive/MyDrive/0_한동생활/2020-2.5 겨울방학/2_GNN/data/LOHC_MLproject_POSTECH/Dehydrogenation3.xlsx"

# list_smi, list_logP = read_ZINC_smiles(DataPath, 1465)
list_feature, list_adj = convert_to_graph(list_smi)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Converting to Graph', max=1465.0, style=ProgressStyle(des…




위 코드를 통해 원하는 수 만큼의 분자들의 node feature matrix list인 `list_feature`, adjacency matrix list인 `list_adj`, 그리고 logP value list인 `list_logP`를 얻었습니다.

그 동안의 실습에서는 이미지로부터 이미지의 label을 얻는, 즉 하나의 input에서 하나의 output을 얻는 형태였지만, 이번 실습에서는 두 개의 input, `list_feature`와 `list_adj`로부터 logP value라는 하나의 output을 얻어내야합니다. 이를 위해 custom pytorch dataset을 정의하고 사용해봅시다.

In [291]:
class GCNDataset(Dataset):
    def __init__(self, list_feature, list_adj, list_logP):
        self.list_feature = list_feature
        self.list_adj = list_adj
        self.list_logP = list_logP

    def __len__(self):
        return len(self.list_feature)

    def __getitem__(self, index):
        return self.list_feature[index], self.list_adj[index], self.list_logP[index]


def partition(list_feature, list_adj, list_logP, args):
    num_total = list_feature.shape[0]
    num_train = int(num_total * (1 - args.test_size - args.val_size))
    num_val = int(num_total * args.val_size)
    num_test = int(num_total * args.test_size)

    feature_train = list_feature[:num_train]
    adj_train = list_adj[:num_train]
    logP_train = list_logP[:num_train]
    feature_val = list_feature[num_train:num_train + num_val]
    adj_val = list_adj[num_train:num_train + num_val]
    logP_val = list_logP[num_train:num_train + num_val]
    feature_test = list_feature[num_total - num_test:]
    adj_test = list_adj[num_total - num_test:]
    logP_test = list_logP[num_total - num_test:]
        
    train_set = GCNDataset(feature_train, adj_train, logP_train)
    val_set = GCNDataset(feature_val, adj_val, logP_val)
    test_set = GCNDataset(feature_test, adj_test, logP_test)

    partition = {
        'train': train_set,
        'val': val_set,
        'test': test_set
    }

    return partition

In [292]:
dict_partition = partition(list_feature, list_adj, list_logP, args)

# 2. Model Construction

Graph Convolution Network, 즉 GCN을 pytorch를 이용하여 구현하여봅시다. 이를 위해 다음과 같은 sub module들을 구현하고 사용합니다.

- **GCNLayer**: node feature matrix와 adjacency matrix의 list를 받아 graph convolution 연산을 수행하는 module 입니다.
- **(Gated)SkipConnection**: ResNet에서 사용되었던 skip connection technique을 구현한 module 입니다.
- **GCNBlock**: node feature matrix와 adjacency matrix의 list를 받아 원하는 갯수의 GCNLayer를 통과시킨 후, (gated)skip connection을 적용하는 module 입니다.
- **ReadOut**: graph structrure에 permutation invariance를 주기 위하여 linear layer를 거친 뒤 batch 별로 summation하는 module 입니다.
- **Predictor**: ReadOut layer로부터의 graph feature vector로부터 logP value를 예측하기 위한 linear layer module 입니다.

위 모듈들을 사용하여 **GCNNet**을 구현해봅시다.

In [293]:
class GCNLayer(nn.Module):
    
    def __init__(self, in_dim, out_dim, n_atom, act=None, bn=False):
        super(GCNLayer, self).__init__()
        
        self.use_bn = bn
        self.linear = nn.Linear(in_dim, out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.bn = nn.BatchNorm1d(n_atom)
        self.activation = act
        
    def forward(self, x, adj):
        out = self.linear(x)
        out = torch.matmul(adj, out)
        if self.use_bn:
            out = self.bn(out)
        if self.activation != None:
            out = self.activation(out)
        return out, adj

In [294]:
class SkipConnection(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super(SkipConnection, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        
    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        out = in_x + out_x
        return out

In [295]:
class GatedSkipConnection(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super(GatedSkipConnection, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        self.linear_coef_in = nn.Linear(out_dim, out_dim)
        self.linear_coef_out = nn.Linear(out_dim, out_dim)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        z = self.gate_coefficient(in_x, out_x)
        out = torch.mul(z, out_x) + torch.mul(1.0-z, in_x)
        return out
            
    def gate_coefficient(self, in_x, out_x):
        x1 = self.linear_coef_in(in_x)
        x2 = self.linear_coef_out(out_x)
        return self.sigmoid(x1+x2)

In [296]:
class Attention(nn.Module):
  def __init(self, in_dim, out_dim, num_head):
    super(Attention, self).__init__()

    self.num_head = num_head
    self.atn_dim = out_dim // num_head 

    self.linears = nn.ModuleList()
    self.corelations = nn.ParameterList()

    for i in range(self.num_head):
      self.linears.append(nnLinear(in_dim, self.atn_dim))
      corelation = torch.LoatTensor(self.atn_dim, self.atn_dim)
      nn.init.xavier_uniform_(corelation)
      self.corelations.append(nn.Parameter(corelation))

    self.tanh = nn.Tanh()

  def forwar(self, x, adj):
    heads = list()

    for i in range(self.num_head):
      x_transformed = self.linears[i](x)
      alpha = self.attention_matrix(x_transformed, self.corelations[i], adj)
      x_head = torch.matmul(alpha, x_transformed)
      heads.append(x_head)
    
    ouput = torch.cat(heads, dim=2)

    return output 
  
  def attention_matrix(self, x_transformed, corelation, adj):
    x = torch.einsum('akj,ij->aki', (x_transformed, corelation))
    alpha = torch.matmul(x, torch.transpose(x_transformed, 1, 2))
    alpha = torch.mul(alpha, adj)
    alpha = self.tanh(alpha)
    return alpha 


In [297]:
class GCNBlock(nn.Module):
    
    def __init__(self, n_layer, in_dim, hidden_dim, out_dim, n_atom, bn=True, sc='gsc'):
        super(GCNBlock, self).__init__()
        
        self.layers = nn.ModuleList()
        for i in range(n_layer):
            self.layers.append(GCNLayer(in_dim if i==0 else hidden_dim,
                                        out_dim if i==n_layer-1 else hidden_dim,
                                        n_atom,
                                        nn.ReLU() if i!=n_layer-1 else None,
                                        bn))
        self.relu = nn.ReLU()
        if sc=='gsc':
            self.sc = GatedSkipConnection(in_dim, out_dim)
        elif sc=='sc':
            self.sc = SkipConnection(in_dim, out_dim)
        elif sc=='no':
            self.sc = None
        else:
            assert False, "Wrong sc type."
        
    def forward(self, x, adj):
        residual = x
        for i, layer in enumerate(self.layers):
            out, adj = layer((x if i==0 else out), adj)
        if self.sc != None:
            out = self.sc(residual, out)
        out = self.relu(out)
        return out, adj

In [298]:
class ReadOut(nn.Module):
    
    def __init__(self, in_dim, out_dim, act=None):
        super(ReadOut, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim= out_dim
        
        self.linear = nn.Linear(self.in_dim, 
                                self.out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.activation = act

    def forward(self, x):
        out = self.linear(x)
        out = torch.sum(out, 1)
        if self.activation != None:
            out = self.activation(out)
        return out

In [299]:
class Predictor(nn.Module):
    
    def __init__(self, in_dim, out_dim, act=None):
        super(Predictor, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(self.in_dim,
                                self.out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.activation = act
        
    def forward(self, x):
        out = self.linear(x)
        if self.activation != None:
            out = self.activation(out)
        return out

In [300]:
class GCNNet(nn.Module):
    
    def __init__(self, args):
        super(GCNNet, self).__init__()
        
        self.blocks = nn.ModuleList()
        for i in range(args.n_block):
            self.blocks.append(GCNBlock(args.n_layer,
                                        args.in_dim if i==0 else args.hidden_dim,
                                        args.hidden_dim,
                                        args.hidden_dim,
                                        args.n_atom,
                                        args.bn,
                                        args.sc))
        self.readout = ReadOut(args.hidden_dim, 
                               args.pred_dim1,
                               act=nn.ReLU())
        self.pred1 = Predictor(args.pred_dim1,
                               args.pred_dim2,
                               act=nn.ReLU())
        self.pred2 = Predictor(args.pred_dim2,
                               args.pred_dim3,
                               act=nn.Tanh())
        self.pred3 = Predictor(args.pred_dim3,
                               args.out_dim)
        
    def forward(self, x, adj):
        for i, block in enumerate(self.blocks):
            out, adj = block((x if i==0 else out), adj)
        out = self.readout(out)
        out = self.pred1(out)
        out = self.pred2(out)
        out = self.pred3(out)
        return out

# 3. Train, Validate, Test and Experiment

In [301]:
def train(net, partition, optimizer, criterion, args):
    trainloader = torch.utils.data.DataLoader(partition['train'], 
                                              batch_size=args.batch_size, 
                                              shuffle=True)
    net.train()

    train_loss = 0.0
    for i, data in enumerate(trainloader):
        optimizer.zero_grad() # [21.01.05 오류 수정] 매 Epoch 마다 .zero_grad()가 실행되는 것을 매 iteration 마다 실행되도록 수정했습니다. 

        # get the inputs
        list_feature, list_adj, list_logP = data
        list_feature = list_feature.cuda().float()
        list_adj = list_adj.cuda().float()
        list_logP = list_logP.cuda().float().view(-1, 1)

        outputs = net(list_feature, list_adj)
        # print("train outputs type : {} | location : {} | shape : {}".format(type(outputs),outputs.device, outputs.shape))
        # print("train list_logP type : {} | location : {} | shape : {}".format(type(list_logP),list_logP.device, list_logP.shape))
        acc = r2_score(list_logP.detach().cpu().numpy(), outputs.detach().cpu().numpy()) # torch.cat(list_preds, dim=0).cpu().numpy()
        loss = criterion(outputs, list_logP)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(trainloader)
    return net, train_loss, acc

def validate(net, partition, criterion, args):
    valloader = torch.utils.data.DataLoader(partition['val'], 
                                            batch_size=args.batch_size, 
                                            shuffle=False)
    net.eval()
    val_loss = 0 
    with torch.no_grad():
        for data in valloader:
            list_feature, list_adj, list_logP = data
            list_feature = list_feature.cuda().float()
            list_adj = list_adj.cuda().float()
            list_logP = list_logP.cuda().float().view(-1, 1)
            
            outputs = net(list_feature, list_adj)
            acc = r2_score(list_logP.detach().cpu().numpy(), outputs.detach().cpu().numpy())

            loss = criterion(outputs, list_logP)
            val_loss += loss.item()

        val_loss = val_loss / len(valloader)
    return val_loss, acc

def test(net, partition, args):
    testloader = torch.utils.data.DataLoader(partition['test'], 
                                             batch_size=args.batch_size, 
                                             shuffle=False)
    net.eval()
    with torch.no_grad():
        logP_total = list()
        pred_logP_total = list()
        for data in testloader:
            list_feature, list_adj, list_logP = data
            list_feature = list_feature.cuda().float()
            list_adj = list_adj.cuda().float()
            list_logP = list_logP.cuda().float()
            logP_total += list_logP.tolist()
            list_logP = list_logP.view(-1, 1)
            
            outputs = net(list_feature, list_adj)
            pred_logP_total += outputs.view(-1).tolist()

        mae = mean_absolute_error(logP_total, pred_logP_total)
        std = np.std(np.array(logP_total)-np.array(pred_logP_total))
    
    return mae, std, logP_total, pred_logP_total

In [302]:
def experiment(partition, args):
  
    net = GCNNet(args)
    net.cuda()

    # criterion = nn.MSELoss()
    criterion = nn.L1Loss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.l2)
    elif args.optim == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.l2)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.l2)
    else:
        raise ValueError('In-valid optimizer choice')
    
    train_losses = []
    val_losses = []
        
    for epoch in range(args.epoch):  # loop over the dataset multiple times
        ts = time.time()
        net, train_loss, train_acc = train(net, partition, optimizer, criterion, args)
        val_loss, val_acc = validate(net, partition, criterion, args)
        te = time.time()
        

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        print('Epoch {}, Acc(train/val): {:2.2f}/{:2.2f}, Loss(train/val) {:2.2f}/{:2.2f}. Took {:2.2f} sec'.format(epoch, train_acc, val_acc, train_loss, val_loss, te-ts))
        
    mae, std, logP_total, pred_logP_total = test(net, partition, args)    
    
    result = {}
    result['train_losses'] = train_losses
    result['val_losses'] = val_losses
    result['mae'] = mae
    result['std'] = std
    result['logP_total'] = logP_total
    result['pred_logP_total'] = pred_logP_total
    return vars(args), result

In [305]:
"""
# 위에서 args 선언 했음
args.batch_size = 32
args.lr = 0.001
args.l2 = 0
args.optim = 'Adam'
args.epoch = 5000

args.n_block = 3
args.n_layer = 1

args.n_atom = 58
args.in_dim = 58
args.hidden_dim = 256
args.pred_dim1 = 192
args.pred_dim2 = 192
args.pred_dim3 = 192
args.out_dim = 1

args.bn = False
args.sc = "no"
args.atn = False

# args.step_size = 10
# args.gamma = 0.1 
"""

# 위에서 args 선언 했음
args.batch_size = 32
args.lr = 0.001
args.l2 = 0
args.optim = 'Adam'
args.epoch = 5000

args.n_block = 3
args.n_layer = 1

args.n_atom = 58
args.in_dim = 58
args.hidden_dim = 256
args.pred_dim1 = 192
args.pred_dim2 = 192
args.pred_dim3 = 192
args.out_dim = 1

args.bn = False
args.sc = "no"
args.atn = False

# args.step_size = 10
# args.gamma = 0.1 

In [280]:

experiment(dict_partition, args)

Epoch 0, Acc(train/val): 0.72/0.62, Loss(train/val) 0.87/0.64. Took 5.43 sec
Epoch 1, Acc(train/val): 0.85/0.88, Loss(train/val) 0.48/0.33. Took 5.38 sec
Epoch 2, Acc(train/val): 0.90/0.91, Loss(train/val) 0.36/0.31. Took 5.38 sec
Epoch 3, Acc(train/val): 0.85/0.93, Loss(train/val) 0.31/0.27. Took 5.44 sec
Epoch 4, Acc(train/val): 0.93/0.93, Loss(train/val) 0.27/0.27. Took 5.37 sec
Epoch 5, Acc(train/val): 0.94/0.95, Loss(train/val) 0.28/0.23. Took 5.38 sec
Epoch 6, Acc(train/val): 0.90/0.95, Loss(train/val) 0.24/0.23. Took 5.42 sec
Epoch 7, Acc(train/val): 0.98/0.97, Loss(train/val) 0.22/0.17. Took 5.35 sec
Epoch 8, Acc(train/val): 0.97/0.96, Loss(train/val) 0.18/0.20. Took 5.32 sec
Epoch 9, Acc(train/val): 0.97/0.98, Loss(train/val) 0.16/0.15. Took 5.28 sec
Epoch 10, Acc(train/val): 0.98/0.97, Loss(train/val) 0.15/0.16. Took 5.25 sec
Epoch 11, Acc(train/val): 0.96/0.97, Loss(train/val) 0.17/0.19. Took 5.28 sec
Epoch 12, Acc(train/val): 0.98/0.99, Loss(train/val) 0.14/0.12. Took 5.26 

KeyboardInterrupt: ignored

In [306]:

experiment(dict_partition, args)

Epoch 0, Acc(train/val): -11.37/-48.89, Loss(train/val) 59.02/62.06. Took 0.42 sec
Epoch 1, Acc(train/val): -7.29/-38.62, Loss(train/val) 49.75/54.90. Took 0.43 sec
Epoch 2, Acc(train/val): -5.46/-29.58, Loss(train/val) 42.56/47.74. Took 0.42 sec
Epoch 3, Acc(train/val): -4.95/-21.76, Loss(train/val) 35.45/40.59. Took 0.44 sec
Epoch 4, Acc(train/val): -2.02/-15.15, Loss(train/val) 28.31/33.46. Took 0.43 sec
Epoch 5, Acc(train/val): -0.38/-9.87, Loss(train/val) 21.42/26.53. Took 0.42 sec
Epoch 6, Acc(train/val): -0.29/-6.53, Loss(train/val) 16.10/21.66. Took 0.45 sec
Epoch 7, Acc(train/val): -0.03/-5.09, Loss(train/val) 14.02/19.80. Took 0.42 sec
Epoch 8, Acc(train/val): -0.05/-4.42, Loss(train/val) 13.64/18.84. Took 0.42 sec
Epoch 9, Acc(train/val): -0.28/-4.07, Loss(train/val) 13.57/18.32. Took 0.42 sec
Epoch 10, Acc(train/val): -0.00/-3.77, Loss(train/val) 13.41/17.90. Took 0.42 sec
Epoch 11, Acc(train/val): -0.13/-3.68, Loss(train/val) 13.47/17.77. Took 0.42 sec
Epoch 12, Acc(train/

KeyboardInterrupt: ignored


# 4. Visualization



In [None]:
import seaborn as sns

In [None]:
def plot_performance(df_result, var1, var2):
  fig, ax = plt.subplots(1,2)
  fig.set_size_inches(10, 5)

  df_mae = df_result.pivot(var1, var2, 'mae')
  df_std = df_result.pivot(var1, var2, 'std')
  df_mae = df_mae[df_mae.columns].astype(float)
  df_std = df_std[df_std.columns].astype(float)

  hm_mae = sns.heatmap(df_mae, ax=ax[0], annot=True, fmt='f', linewidths=0.5, cmap="Y1GnBu")
  hm_std = sns.heatmap(df_std, ax=ax[0], annot=True, fmt='f', linewidths=0.5, cmap="Y1GnBu")


In [None]:
def plot_performance_bar(df_result, var1, var2):
  fig, ax = plt.subplots(1,2)
  fig.set_size_inches(10, 5)

  sns.set_style("darkgrid", {"axes.facecolor": ".9"})
  
  bar_mae = sns.barplot(x=var1, y="mae", hue=var2, data=df_result, ax=ax[0])
  bar_std = sns.barplot(x=var1, y="std", hue=var2, data=df_result, ax=ax[1])

In [None]:
def plot_loss(df_result, var1, var2, ylim):
  def plot(x, ylim=1.0, **kwargs):
    plt.plot(x[0], **kwargs)
    plt.ylim(0.0, ylim)

  sns.set_style("darkgrid", {"axes.facecolor": ".9"})
  g = sns.FacetGrid(df_result, row=var1, col= var2, margin_titles=True)
  g.map(plot, 'list_train_loss', ylim=ylim, label="Train Loss")
  g.map(plot, "list_val_loss", ylim=ylim, color='r', label="Validation Loss")
  plt.legend()
  plt.show()  