In [1]:
import os
import os.path as osp
import math
import scipy.io

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Parameter

from torch_sparse import coalesce
from torch_scatter import scatter_add


import torch_geometric.transforms as T
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
from torch_geometric.data import DataLoader
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import scatter_

import rdkit
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem.rdchem import HybridizationType
from rdkit import RDConfig
from rdkit.Chem import ChemicalFeatures
from rdkit.Chem.rdchem import BondType as BT



In [2]:
# model training을 위한 hyperparameter 설정
target = 0
dim = 64
batch_size = 128
lr = 0.01
epochs = 300
weight_decay = 5e-4

In [3]:
# QM9 Dataset
class QM9(InMemoryDataset):
    # Dataset 처리를 위한 variable 설정
    url = 'https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/qm9.zip'
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bond_types = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
    
    # 초기화 함수
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super(QM9, self).__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['gdb9.sdf', 'gdb9.sdf.csv']

    @property
    def processed_file_names(self):
        return 'data.pt'
    
    # url을 통해 qm9 dataset을 다운
    def download(self):
        url = self.processed_url if rdkit is None else self.url
        file_path = download_url(url, self.raw_dir)
        extract_zip(file_path, self.raw_dir)
        os.unlink(file_path)
    
    # node feature matrix => x 구성
    @staticmethod
    def get_atom_features(self, atoms, mol_feats):
        type_idx = []
        atomic_number = []
        acceptor = []
        donor = []
        aromatic = []
        sp, sp2, sp3 = [], [], []
        num_hs = []
        for atom in atoms:
            type_idx.append(self.types[atom.GetSymbol()])
            atomic_number.append(atom.GetAtomicNum())
            donor.append(0)
            acceptor.append(0)
            aromatic.append(1 if atom.GetIsAromatic() else 0)
            hybridization = atom.GetHybridization()
            sp.append(1 if hybridization == HybridizationType.SP else 0)
            sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
            sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
            num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))
        
        for j in range(len(mol_feats)):
            if mol_feats[j].GetFamily() == 'Donor':
                node_list = mol_feats[j].GetAtomIds()
                for k in node_list:
                    donor[k] = 1
            elif mol_feats[j].GetFamily() == 'Acceptor':
                node_list = mol_feats[j].GetAtomIds()
                for k in node_list:
                    acceptor[k] = 1

        x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.types))
        x2 = torch.tensor([atomic_number, acceptor, donor, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
        x = torch.cat([x1.to(torch.float), x2], dim=-1)
        return x
    
    # adjacency matrix와 edge feature matrix 구성 (sparse)
    @staticmethod
    def get_bind_pair(self, bonds):
        row, col, bond_idx = [], [], []
        for bond in bonds:
            # 양방향 연결임을 고려
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            bond_idx += 2 * [self.bond_types[bond.GetBondType()]]

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=len(self.bond_types)).to(torch.float)

        return edge_index, edge_attr
    
    # molecular data를 graph data로 변환
    @staticmethod
    def mol2vec(self, mol):
        atoms = mol.GetAtoms()
        bonds = mol.GetBonds()
        edge_index, edge_attr = self.get_bind_pair(self, bonds)

        x = self.get_atom_features(self, atoms, self.factory.GetFeaturesForMol(mol))

        return edge_index, edge_attr, x
    
    # dataset 구성
    def process(self):
        with open(self.raw_paths[1], 'r') as f:
            target = f.read().split('\n')[1:-1]
            target = [[float(x) for x in line.split(',')[4:20]]
                      for line in target]
            target = torch.tensor(target, dtype=torch.float)
        
        data_list = []
        suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False)
        for i, mol in enumerate(suppl):
            if mol is None: continue
            
            text = suppl.GetItemText(i)
            N = mol.GetNumAtoms()
            
            # get pos
            pos = text.split('\n')[4:4 + N]
            pos = [[float(x) for x in line.split()[:3]] for line in pos]
            pos = torch.tensor(pos, dtype=torch.float)
            
            # get graph data => edge_index, edge_attr, x
            edge_index, edge_attr, x = self.mol2vec(self, mol)
            edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
            
            # get target data => y
            y = target[i].unsqueeze(0)
            
            # get name
            name = mol.GetProp('_Name')
            
            # data 구성
            data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y, name=name)

            if self.pre_filter is not None and not self.pre_filter(data): continue
            if self.pre_transform is not None: data = self.pre_transform(data)

            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[0])

In [4]:
# uniform xaiver weight initializer
def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)

def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

In [5]:
class GCNConvLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, improved=False, cached=False,
                 bias=True, normalize=True, single_param=False, **kwargs):
        super(GCNConvLayer, self).__init__(aggr='add', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.normalize = normalize
        self.single_param = single_param

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None
        self.cached_num_edges = None
    
    # Equation 7
    @staticmethod
    def single_parameter(edge_index, num_nodes, edge_weight=None, improved=False,
             dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        
        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)
        
        return edge_index, edge_weight
    
    # Equation 8
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False,
             dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        x = torch.matmul(x, self.weight)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please '
                    'disable the caching behavior of this layer by removing '
                    'the `cached=True` argument in its constructor.'.format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            if self.normalize:
                edge_index, norm = self.norm(edge_index, x.size(0),
                                             edge_weight, self.improved,
                                             x.dtype)
            elif self.single_param:
                edge_index, norm = self.single_parameter(edge_index, x.size(0),
                                             edge_weight, self.improved,
                                             x.dtype)
            else:
                norm = edge_weight
                
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [6]:
# average node features
def global_mean_pool(x, batch, size=None):
    size = batch.max().item() + 1 if size is None else size
    return scatter_('mean', x, batch, dim=0, dim_size=size)

In [7]:
# Graph Neural Network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConvLayer(dataset.num_features, dim * 2, cached=False,
                            normalize=True, single_param=False)
        self.conv2 = GCNConvLayer(dim * 2, dim, cached=False,
                            normalize=True, single_param=False)
        self.lin1 = nn.Linear(dim, 1)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        out = F.relu(self.conv1(x, edge_index))
        out = F.dropout(out, training=self.training)
        out = self.conv2(out, edge_index)
        out = global_mean_pool(out, data.batch)
        out = self.lin1(out)
        return out.view(-1)

In [8]:
class MyTransform(object):
    def __call__(self, data):
        # Specify target.
        data.y = data.y[:, target]
        return data


class Complete(object):
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

In [9]:
path = osp.join(osp.dirname(osp.realpath('')), '..', 'data', 'QM9')
transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])
dataset = QM9(path, transform=transform).shuffle()

In [10]:
# Normalize targets to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, target].item(), std[:, target].item()

# Split datasets. (cross validation)
test_dataset = dataset[:10000]
val_dataset = dataset[10000:20000]
train_dataset = dataset[20000:]

# Data Loader 구성
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)

In [12]:
# model과 optimizer 초기화
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [13]:
def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = F.mse_loss(model(data), data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)

def test(loader):
    model.eval()
    error = 0

    for data in loader:
        data = data.to(device)
        error += (model(data) * std - data.y * std).abs().sum().item()  # MAE
    return error / len(loader.dataset)

In [14]:
best_val_error = None
for epoch in range(epochs):
    loss = train(epoch)
    val_error = test(val_loader)

    if best_val_error is None or val_error <= best_val_error:
        test_error = test(test_loader)
        best_val_error = val_error

    print('Epoch: {:03d}, Loss: {:.7f}, Validation MAE: {:.7f}, Test MAE: {:.7f}'.format(epoch + 1, loss, val_error, test_error))

Epoch: 001, Loss: 0.6967202, Validation MAE: 0.8832881, Test MAE: 0.8829607
Epoch: 002, Loss: 0.6569940, Validation MAE: 0.8855286, Test MAE: 0.8829607
Epoch: 003, Loss: 0.6504361, Validation MAE: 0.8596573, Test MAE: 0.8557926
Epoch: 004, Loss: 0.6488518, Validation MAE: 0.8699367, Test MAE: 0.8557926
Epoch: 005, Loss: 0.6462197, Validation MAE: 0.8593770, Test MAE: 0.8622034
Epoch: 006, Loss: 0.6436653, Validation MAE: 0.8694602, Test MAE: 0.8622034
Epoch: 007, Loss: 0.6423671, Validation MAE: 0.8707799, Test MAE: 0.8622034
Epoch: 008, Loss: 0.6421476, Validation MAE: 0.8775520, Test MAE: 0.8622034
Epoch: 009, Loss: 0.6386914, Validation MAE: 0.8555285, Test MAE: 0.8539425
Epoch: 010, Loss: 0.6405800, Validation MAE: 0.8521609, Test MAE: 0.8515646
Epoch: 011, Loss: 0.6414883, Validation MAE: 0.8628033, Test MAE: 0.8515646
Epoch: 012, Loss: 0.6382103, Validation MAE: 0.8584975, Test MAE: 0.8515646
Epoch: 013, Loss: 0.6390986, Validation MAE: 0.9089658, Test MAE: 0.8515646
Epoch: 014, 

KeyboardInterrupt: 