# Optimising a GCN using Ax

<a target="_blank" href="https://colab.research.google.com/github/chaitjo/geometric-gnn-dojo/blob/main/geometric_gnn_101.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [10]:
import os
import time
import random
import numpy as np

import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential

from torch_geometric.datasets import QM9

from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse

import pprint
from os.path import join
from pathlib import Path
import pandas as pd

import plotly.graph_objects as go

import gc
import torch

from ax.storage.json_store.save import save_experiment
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import KFold

from rdkit import Chem  # To extract information of the molecules
from rdkit.Chem import Draw  # To draw the molecules
import random
from functools import partial

In [4]:
# Seed everything

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)

In [5]:
# Create a directory for the QM9 dataset
qm9_dir = "qm9_data"
Path(qm9_dir).mkdir(parents=True, exist_ok=True)

In [11]:
# Write a helper function for data prep

def select_target(data, target_index=0):
    data.y = data.y[:, target_index]
    return data
SetTargetTransform = partial(select_target, target_index=0)
#class SetTargetTransformWrapper(object):
    
class SetTarget(object):
    """
    This transform sets the target attribute of the dataset to the specified index.

    Default index is 0 which is the dipole moment.
    """
    def __init__(self, target_index=0):
        self.target_index = target_index
    def __call__(self, data):
        data.y = data.y[:, self.target_index]
        return data
    
class MakeComplete(object):
    """
    This transform makes the graph complete by adding edges between all pairs of nodes.
    """
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

## Dataset

For this example, we will be using the QM9 dataset, which is a dataset of small organic molecules. The dataset contains 134k molecules, each with 19 features.

| Target | Description                                |
|--------|--------------------------------------------|
| 0      | Dipole moment                              |
| 1      | Isotropic polarizability                   |
| 2      | Highest occupied molecular orbital energy  |
| 3      | Lowest unoccupied molecular orbital energy |
| 4      | Gap between HOMO and LUMO                  |
| 5      | Electronic spatial extent                  |
| 6      | Zero point vibrational energy              |
| 7      | Internal energy at 0K                      |
| 8      | Internal energy at 298.15K                 |
| 9      | Enthalpy at 298.15K                        |
| 10     | Free energy at 298.15K                     |
| 11     | Heat capavity at 298.15K                   |
| 12     | Atomization energy at 0K                   |
| 13     | Atomization energy at 298.15K              |
| 14     | Atomization enthalpy at 298.15K            |
| 15     | Atomization free energy at 298.15K         |
| 16     | Rotational constant                        |
| 17     | Rotational constant                        |
| 18     | Rotational constant                        |

In [12]:
# Load the QM9 dataset

# Target 0 - dipole moment
target = 0
# Transforms applied are to set the target and to fully connect the graphs
transform = T.Compose([SetTarget(target_index=target), MakeComplete()])

# Load the dataset
qm9_dataset =QM9(qm9_dir,transform=transform)

# Normalize the targets
mean = qm9_dataset.data.y.mean(dim=0, keepdim=True)
std = qm9_dataset.data.y.std(dim=0, keepdim=True)
qm9_dataset.data.y = (qm9_dataset.data.y - mean) / std
mean, std = mean[:, target].item(), std[:, target].item()

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting qm9_data\raw\qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|██████████| 133885/133885 [03:02<00:00, 733.94it/s] 
Done!


In [15]:
qm9_dataset[0]

Data(x=[5, 11], edge_index=[2, 20], edge_attr=[20, 4], y=[1], pos=[5, 3], z=[5], smiles='[H]C([H])([H])[H]', name='gdb_1', idx=[1])

Let's visualise some of the dataset

In [18]:
#
num_viz = 5
smiles = [qm9_dataset[i].smiles for i in range(num_viz)]
mols = [Chem.MolFromSmiles(smile) for smile in smiles]
dipoles = [qm9_dataset.y for i in range(num_viz)]
legend = [(smile, dipole) for smile, dipole in zip(smiles, dipoles)]
Draw.MolsToGridImage(mols, molsPerRow=5, useSVG=True,legends=legend)

TypeError: No registered converter was able to produce a C++ rvalue of type class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> > from this Python object of type tuple

In [None]:
# Let's create  the train val test splits. We'll use 2000 samples for this demo and split them into 80% train, 10% val and 10% test.

num_samples = 2000
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

