In [1]:
# For all examples in this section we use the following imports.
# Note that we are using torch_geometric's DataLoader.
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, global_add_pool
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from torch.utils.data import random_split

def train(params):
    size = int(params[0]/2)*2
    init_scale = params[1]
    seed = 0
    np.random.seed(seed)
    torch.manual_seed(seed)
    alpha = init_scale
    # let's load the QM9 small molecule dataset
    dset = QM9('.')
    #size = 1000
    epochs = int(100*1000/size)
    dset = dset[:size]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    class ExampleNet(torch.nn.Module):
        def __init__(self, num_node_features, num_edge_features):
            super().__init__()
            conv1_net = nn.Sequential(
                nn.Linear(num_edge_features, 32),
                nn.ReLU(),
                nn.Linear(32, num_node_features*32))
            conv2_net = nn.Sequential(
                nn.Linear(num_edge_features, 32),
                nn.ReLU(),
                nn.Linear(32, 32*16))
            self.conv1 = NNConv(num_node_features, 32, conv1_net)
            self.conv2 = NNConv(32,16, conv2_net)
            self.fc_1 = nn.Linear(16, 32)
            self.out = nn.Linear(32, 1)
        def forward(self, data):
            batch, x, edge_index, edge_attr = (
                data.batch, data.x, data.edge_index, data.edge_attr)
            # First graph conv layer
            x = F.relu(self.conv1(x, edge_index, edge_attr))
            # Second graph conv layer
            x = F.relu(self.conv2(x, edge_index, edge_attr))
            x = global_add_pool(x,batch)
            x = F.relu(self.fc_1(x))
            output = self.out(x)
            return output
        
        
    def L2(model):
        L2_ = 0.
        for p in model.parameters():
            L2_ += torch.sum(p**2)
        return L2_

    def rescale(model, alpha):
        for p in model.parameters():
            p.data = alpha * p.data
    
    train_set, test_set = random_split(dset,[int(size/2), int(size/2)])
    trainloader = DataLoader(train_set, batch_size=32, shuffle=True)
    testloader = DataLoader(test_set, batch_size=32, shuffle=True)
    
    
    # initialize a network
    qm9_node_feats, qm9_edge_feats = 11, 4
    net = ExampleNet(qm9_node_feats, qm9_edge_feats)
    # initialize an optimizer with some reasonable parameters
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    target_idx = 1 # index position of the polarizability label
    net.to(device)
    
    rescale(net, alpha)
    L2_ = L2(net)
    
    train_best = 1e10
    test_best = 1e10
    
    for total_epochs in range(epochs):
        epoch_loss = 0
        total_graphs = 0
        net.train()
        for batch in trainloader:
            batch.to(device)
            optimizer.zero_grad()
            output = net(batch)
            loss = F.mse_loss(
            output,batch.y[:, target_idx].unsqueeze(1))
            loss.backward()
            epoch_loss += loss.item()
            total_graphs += batch.num_graphs
            optimizer.step()
            L2_new = L2(net)
            # rescale weights such that the weight norm remains a constant in training.
            rescale(net, torch.sqrt(L2_/L2_new))
        train_avg_loss = epoch_loss / total_graphs
        if train_avg_loss < train_best:
            train_best = train_avg_loss
        test_loss = 0
        total_graphs = 0
        net.eval()
        for batch in testloader:
            batch.to(device)
            output = net(batch)
            loss = F.mse_loss(
            output,batch.y[:, target_idx].unsqueeze(1))
            test_loss += loss.item()
            total_graphs += batch.num_graphs
        test_avg_loss = test_loss / total_graphs
        if test_avg_loss < test_best:
            test_best = test_avg_loss
        print(f"Epochs: {total_epochs} | "
           f"epoch avg. loss: {train_avg_loss:.3f} | "
           f"test avg. loss: {test_avg_loss:.3f}")

    np.savetxt("./results/train_size_%d_alpha_%.4f"%(size, init_scale), np.array([train_best]))
    np.savetxt("./results/test_size_%d_alpha_%.4f"%(size, init_scale), np.array([test_best]))
     


In [None]:
import numpy as np
data_sizes = [int(item) for item in 10**np.linspace(1,4,num=22)]
alphas = 10**np.linspace(-1,1,num=21)


xx, yy = np.meshgrid(data_sizes, alphas)
params = list(np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)])))

from multiprocess import Pool

if __name__ == '__main__':
    with Pool(10) as p:
        print(p.map(train, params))