In [1]:
from ipywidgets import interact, interact_manual
import pandas as pd

elements = pd.read_csv("/io/data/external/elementlist.csv", index_col=0)

# Molecules classification

## The dataset

We study the [`ogbg-molhiv` dataset][1].
Each graph represents a molecule, where nodes are atoms, and edges are chemical bonds. Input node features are 9-dimensional, containing atomic number and chirality, as well as other additional atom features such as formal charge and whether the atom is in the ring or not.

## The task

We want to predict whether a molecule inhibits HIV virus replication or not, as accurately as possible,

[1]: https://ogb.stanford.edu/docs/graphprop/

## Data Loading

In [2]:
# IMPORTS
from ogb.graphproppred import PygGraphPropPredDataset # Dataset package
from torch_geometric.data import DataLoader # Utility to load data

In [3]:
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygGraphPropPredDataset(name = "ogbg-molhiv", root = '/io/ogbg/')

### Features

From the [doc](https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py).

**Atom** (node) features:
- atomic_num : cat (118 vals)
- chirality : cat (4 vals)
- degree : int (0 to 10)
- formal_charge : int (-5 to 5)
- numH : int (0 to 8)
- number_radical_e : int (0 to 4)
- possible hybridization : cat (5 vals)
- is_aromatic : bool
- is_in_ring : bool

**Bond** (edge) features:
- bond_type : cat (4 vals)
- bond_stereo : cat (6 vals)
- is_conjugated : bool

In [4]:
el0 = dataset[0]
print(f"""
First element : {el0}

Attributes : 
    - Node features : {el0.num_node_features}
    - Edge features : {el0.num_edge_features}

elements in el0.x[0] : {[elements.iloc[el0.x[i,0].item()].symbol for i in range(el0.num_nodes)]}
""")


First element : Data(edge_attr=[40, 3], edge_index=[2, 40], x=[19, 9], y=[1, 1])

Attributes : 
    - Node features : 9
    - Edge features : 3

elements in el0.x[0] : ['C', 'C', 'C', 'O', 'Cu', 'O', 'C', 'C', 'C', 'C', 'O', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'O']



## Visualization

We will use a graph visualization library to represent the molecules.

In [5]:
# IMPORTS
from src.visualization import plot_mol

In [6]:
interact(lambda x: plot_mol(dataset[x]), x=range(len(dataset)))

interactive(children=(Dropdown(description='x', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,…

<function __main__.<lambda>(x)>