In [1]:
import os.path as osp
from math import sqrt

import torch
import torch.nn.functional as F
from rdkit import Chem

from torch_geometric.data import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn.models import AttentiveFP

In [2]:
class GenFeatures(object):
    def __init__(self):
        self.symbols = [
            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
            'Te', 'I', 'At', 'other'
        ]

        self.hybridizations = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]

        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
        ]

    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            symbol[self.symbols.index(atom.GetSymbol())] = 1.
            degree = [0.] * 6
            degree[atom.GetDegree()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            hybridization[self.hybridizations.index(
                atom.GetHybridization())] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.
            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
            chirality_type = [0.] * 2
            if atom.HasProp('_CIPCode'):
                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.

            x = torch.tensor(symbol + degree + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens + [chirality] +
                             chirality_type)
            xs.append(x)

        data.x = torch.stack(xs, dim=0)

        edge_indices = []
        edge_attrs = []
        for bond in mol.GetBonds():
            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]

            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 4
            stereo[self.stereos.index(bond.GetStereo())] = 1.

            edge_attr = torch.tensor(
                [single, double, triple, aromatic, conjugation, ring] + stereo)

            edge_attrs += [edge_attr, edge_attr]

        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            data.edge_index = torch.tensor(edge_indices).t().contiguous()
            data.edge_attr = torch.stack(edge_attrs, dim=0)

        return data

In [3]:
dataset = MoleculeNet('./ESOL', name='ESOL', pre_transform=GenFeatures()).shuffle()

In [4]:
N = len(dataset) // 10
val_dataset   = dataset[:N]
test_dataset  = dataset[N:2 * N]
train_dataset = dataset[2 * N:]

In [5]:
train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=200)
test_loader  = DataLoader(test_dataset, batch_size=200)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
val_dataset[0].y

tensor([[-5.6700]])

In [9]:
val_dataset[0]

Data(edge_attr=[6, 10], edge_index=[2, 6], smiles="ClCC#N", x=[4, 39], y=[1, 1])

In [13]:
model = AttentiveFP(
    in_channels=val_dataset.num_node_features,
    hidden_channels=200,
    out_channels=val_dataset.num_classes,
    edge_dim=10,
    num_layers=2,
    num_timesteps=2,
    dropout=0.2).to(device)

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5, weight_decay=10**-5)

In [15]:
def train():
    total_loss = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
    return sqrt(total_loss / total_examples)

In [16]:
@torch.no_grad()
def test(loader):
    mse = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
    return float(torch.cat(mse, dim=0).mean().sqrt())

In [17]:
for epoch in range(1, 201):
    train_rmse = train()
    val_rmse   = test(val_loader)
    test_rmse  = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')

Epoch: 001, Loss: 3.2639 Val: 2.4595 Test: 2.5414
Epoch: 002, Loss: 2.4051 Val: 1.7341 Test: 1.7510
Epoch: 003, Loss: 1.8597 Val: 1.7916 Test: 1.7675
Epoch: 004, Loss: 1.8131 Val: 1.6843 Test: 1.7599
Epoch: 005, Loss: 1.7144 Val: 1.6425 Test: 1.6321
Epoch: 006, Loss: 1.6561 Val: 1.5783 Test: 1.5231
Epoch: 007, Loss: 1.5744 Val: 1.4925 Test: 1.5245
Epoch: 008, Loss: 1.4497 Val: 1.2012 Test: 1.2388
Epoch: 009, Loss: 1.2594 Val: 1.0138 Test: 1.1322
Epoch: 010, Loss: 1.1547 Val: 1.0295 Test: 0.9961
Epoch: 011, Loss: 1.1126 Val: 1.0019 Test: 0.9567
Epoch: 012, Loss: 1.1113 Val: 1.0398 Test: 1.0505
Epoch: 013, Loss: 1.0992 Val: 0.9127 Test: 0.9240
Epoch: 014, Loss: 1.0456 Val: 1.0245 Test: 1.0030
Epoch: 015, Loss: 1.0229 Val: 0.9596 Test: 0.9990
Epoch: 016, Loss: 0.9633 Val: 0.9566 Test: 0.9076
Epoch: 017, Loss: 0.9547 Val: 0.8932 Test: 0.8839
Epoch: 018, Loss: 0.9290 Val: 0.8758 Test: 0.8364
Epoch: 019, Loss: 0.9149 Val: 0.9197 Test: 0.9234
Epoch: 020, Loss: 0.8841 Val: 0.7916 Test: 0.9062
