In [1]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.utils import remove_self_loops

In [2]:
target = 0
dim = 64

In [3]:
class MyTransform(object):
    def __call__(self, data):
        # Specify target.
        data.y = data.y[:, target]
        return data

In [4]:
class Complete(object):
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

In [5]:
path = '../data/QM9'
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')
# transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])
transform = T.Compose([MyTransform()])
dataset = QM9(path, transform=transform)

# dataset = QM9(path)


In [50]:
dataset.data.num_nodes

2359210

In [54]:
dataset[0]

Data(x=[5, 11], edge_index=[2, 20], edge_attr=[20, 5], y=[1], pos=[5, 3], idx=[1], name='gdb_1', z=[5])

In [51]:
dataset[0].edge_index

tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
        [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3]])

In [30]:
dataset[1]

Data(x=[4, 11], edge_index=[2, 12], edge_attr=[12, 5], y=[1], pos=[4, 3], idx=[1], name='gdb_2', z=[4])

In [52]:
dt = QM9(path)

In [55]:
dt[0]

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])

In [53]:
dt[0].edge_index

tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])

In [31]:
dt[1]

Data(x=[4, 11], edge_index=[2, 6], edge_attr=[6, 4], y=[1, 19], pos=[4, 3], idx=[1], name='gdb_2', z=[4])

In [7]:
dataset.data.edge_index[0]

tensor([ 0,  0,  0,  ..., 13, 14, 15])

In [8]:
dataset.data.edge_index[1]

tensor([1, 2, 3,  ..., 4, 7, 8])

In [9]:
idx = dataset.data.edge_index[0] * dataset.data.num_nodes + dataset.data.edge_index[1]
idx

tensor([       1,        2,        3,  ..., 30669734, 33028947, 35388158])

In [None]:
dataset.data.to_dict()

In [11]:
type(dataset.data)

torch_geometric.data.data.Data

In [12]:
dataset.data

Data(x=[2359210, 11], edge_index=[2, 4883516], edge_attr=[4883516, 4], y=[130831, 19], pos=[2359210, 3], idx=[130831], name=[130831], z=[2359210])

In [6]:
# Normalize targets to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, target].item(), std[:, target].item()

In [7]:
# Split datasets.
test_dataset = dataset[:100]
val_dataset = dataset[100:200]
train_dataset = dataset[200:500]
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [20]:
for data in train_loader:
    print(data[1])
    break

Data(x=[12, 11], edge_index=[2, 22], edge_attr=[22, 4], y=[1], pos=[12, 3], idx=[1], name='gdb_242', z=[12])


In [13]:

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)

        nn = Sequential(Linear(4, 128), ReLU(), Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nn, aggr='mean')
        self.gru = GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * dim, dim)
        self.lin2 = torch.nn.Linear(dim, 1)

    def forward(self, data):
        out = F.relu(self.lin0(data.x))
        h = out.unsqueeze(0)

        for i in range(3):
            m = F.relu(self.conv(out, data.edge_index, data.edge_attr))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)

        out = self.set2set(out, data.batch)
        out = F.relu(self.lin1(out))
        out = self.lin2(out)
        return out.view(-1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.7, patience=5,
                                                       min_lr=0.00001)


In [14]:
def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = F.mse_loss(model(data), data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)


In [15]:
def test(loader):
    model.eval()
    error = 0

    for data in loader:
        data = data.to(device)
        error += (model(data) * std - data.y * std).abs().sum().item()  # MAE
    return error / len(loader.dataset)


In [16]:
best_val_error = None
for epoch in range(1, 301):
    lr = scheduler.optimizer.param_groups[0]['lr']
    loss = train(epoch)
    val_error = test(val_loader)
    scheduler.step(val_error)

    if best_val_error is None or val_error <= best_val_error:
        test_error = test(test_loader)
        best_val_error = val_error

    print(f'Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, '
          f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')

Epoch: 001, LR: 0.001000, Loss: 1.1400501, Val MAE: 1.0571542, Test MAE: 1.0789881
Epoch: 002, LR: 0.001000, Loss: 0.8221679, Val MAE: 0.7890811, Test MAE: 0.8157076
Epoch: 003, LR: 0.001000, Loss: 0.7846252, Val MAE: 0.8291530, Test MAE: 0.8157076
Epoch: 004, LR: 0.001000, Loss: 0.7763700, Val MAE: 0.7938999, Test MAE: 0.8157076
Epoch: 005, LR: 0.001000, Loss: 0.7786045, Val MAE: 0.7970200, Test MAE: 0.8157076
Epoch: 006, LR: 0.001000, Loss: 0.7895770, Val MAE: 0.8062895, Test MAE: 0.8157076
Epoch: 007, LR: 0.001000, Loss: 0.7597871, Val MAE: 0.7783310, Test MAE: 0.7056801
Epoch: 008, LR: 0.001000, Loss: 0.7338634, Val MAE: 0.7906348, Test MAE: 0.7056801
Epoch: 009, LR: 0.001000, Loss: 0.7924408, Val MAE: 0.7396377, Test MAE: 0.7114548
Epoch: 010, LR: 0.001000, Loss: 0.7279657, Val MAE: 0.7304237, Test MAE: 0.6824323
Epoch: 011, LR: 0.001000, Loss: 0.7172971, Val MAE: 0.7110445, Test MAE: 0.6356699
Epoch: 012, LR: 0.001000, Loss: 0.7228842, Val MAE: 0.7094059, Test MAE: 0.6305641
Epoc

KeyboardInterrupt: 