In [134]:
from ipywidgets import interact, interact_manual

# Molecules classification

## The dataset

We study the [`ogbg-molhiv` dataset][1].
Each graph represents a molecule, where nodes are atoms, and edges are chemical bonds. Input node features are 9-dimensional, containing atomic number and chirality, as well as other additional atom features such as formal charge and whether the atom is in the ring or not.

## The task

We want to predict whether a molecule inhibits HIV virus replication or not, as accurately as possible,

[1]: https://ogb.stanford.edu/docs/graphprop/

## Data Loading

In [61]:
# IMPORTS
from ogb.graphproppred import PygGraphPropPredDataset # Dataset package
from torch_geometric.data import DataLoader # Utility to load data

import pandas as pd

In [73]:
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygGraphPropPredDataset(name = "ogbg-molhiv", root = '/io/ogbg/')

elements = pd.read_csv("/io/data/external/elementlist.csv", index_col=0)

In [70]:
print(el0.x[1,0].item())

5


In [74]:
el0 = dataset[0]
print(f"""
First element : {el0}

Attributes : 
    - Node features : {el0.num_node_features}
    - Edge features : {el0.num_edge_features}

elements in el0.x[0] : {[elements.iloc[el0.x[i,0].item()].symbol for i in range(el0.num_nodes)]}
""")


First element : Data(edge_attr=[40, 3], edge_index=[2, 40], x=[19, 9], y=[1, 1])

Attributes : 
    - Node features : 9
    - Edge features : 3

elements in el0.x[0] : ['C', 'C', 'C', 'O', 'Cu', 'O', 'C', 'C', 'C', 'C', 'O', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'O']



### Features

From the [doc](https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py).

**Atom** (node) features:
- atomic_num : cat (118 vals)
- chirality : cat (4 vals)
- degree : int (0 to 10)
- formal_charge : int (-5 to 5)
- numH : int (0 to 8)
- number_radical_e : int (0 to 4)
- possible hybridization : cat (5 vals)
- is_aromatic : bool
- is_in_ring : bool

**Bond** (edge) features:
- bond_type : cat (4 vals)
- bond_stereo : cat (6 vals)
- is_conjugated : bool

## Visualization

We will use a graph visualization library to represent the molecules.

In [115]:
# IMPORTS
import networkx as nx  # Graph manipulation library
from torch_geometric.utils import to_networkx # Conversion function

import matplotlib.pyplot as plt
import torch

In [138]:
def plot_mol(torch_graph):
    fig, ax = plt.subplots(dpi=120)

    G = to_networkx(
        torch_graph,
        to_undirected=True
    )
    
    pos = nx.kamada_kawai_layout(G)


    atoms = torch_graph.x[:,0]
    single_bonds = torch.where(torch_graph.edge_attr[:,0] == 0)[0]
    double_bonds = torch.where(torch_graph.edge_attr[:,0] == 1)[0]
    triple_bonds = torch.where(torch_graph.edge_attr[:,0] == 2)[0]
    aromatic = torch.where(torch_graph.edge_attr[:,0] == 3)[0]

    nx.draw_networkx(
        G, pos,
        node_color = atoms,  # color coded by atomic number
        cmap="prism",
        node_size=10 * atoms,    # size give by atomic number
        with_labels=False,
        width = torch_graph.edge_attr[:,0] + 1 
    )

    offset = 0.05
    nx.draw_networkx_labels(
        G,
        {
            node : (p[0] - offset, p[1] + 1.5 * offset)
            for node, p in pos.items()
        },
        labels={
            i: elements.iloc[torch_graph.x[i,0].item()].symbol
            for i in range(torch_graph.num_nodes)
        },
    )

    fig.tight_layout()
    plt.show()

In [141]:
interact(lambda x: plot_mol(dataset[x]), x=range(len(dataset)))

interactive(children=(Dropdown(description='x', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,…

<function __main__.<lambda>(x)>