<a href="https://colab.research.google.com/github/WhiteLabGx/home/blob/master/bioml_seminar_unsolved.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Geometric Deep Learning for Molecules
#### Sergei Grudinin, Ilia Igashov, Margot Selosse

This tutorial will start with the introduction to the PyTorch Geometric library. It will then present a basic description of graph-learning architectures, including convolution and attention operations. The first examples will include binary classification of 3D protein structures. After, we will apply the presented architectures to the regression task for the properties prediction of small molecules in the QM9 dataset. In the end, we will introduce more advanced architectures, specifically constructed to be rotation and translation equivariant, for the property predictions of 3D molecular graphs.

# Contents

- Prerequisites
- PyTorch Geometric and NetworkX basics
- Graph notations
- Message Passing
- Graph Convolutional Network (GCN)
- Graph Attention Network (GAT)
- Graph Classification on PROTEINS dataset
- Graph Regression with QM9 dataset
- SchNet and Equivariance
- Further reading

# Prerequisites

We will be using PyTorch Geometric and NetworkX – popular frameworks for working with graphs. 



In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
!pip install torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
!pip install ase

!mkdir data

Looking in links: https://data.pyg.org/whl/torch-1.9.0+cpu.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.9.0%2Bcpu/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (290 kB)
[K     |████████████████████████████████| 290 kB 2.7 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Looking in links: https://data.pyg.org/whl/torch-1.9.0+cpu.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.9.0%2Bcpu/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl (640 kB)
[K     |████████████████████████████████| 640 kB 2.6 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.12
Looking in links: https://data.pyg.org/whl/torch-1.9.0+cpu.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.9.0%2Bcpu/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl (352 kB)
[K     |████████████████████████████████| 352 kB 2.5 MB/s 
[?25hInstalling collected pack

In case you are going to run the notebook locally, you may need to do some additional installations (if you do not have this packages installed yet):

* `numpy`
* `matplotlib`
* `torch`
* `tqdm`
* `networkx`

# PyTorch Geometric and NetworkX basics

We will use [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) (PyG) framework as the main tool for working with graphs. It provides a convenient functionality for operating on:
- [graph structures](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html)
- [graph-learning models](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html)
- [common graph datasets](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html)

Besides, [NetworkX](https://networkx.org/) (NX) library is useful for graph visualization.

Let's create a simple graph with four nodes and five edges, where each node $v_i$ will have an associated feature value, its index $i$:

<img align="middle" src="https://www.researchgate.net/profile/Panayiota-Poirazi/publication/293945308/figure/fig1/AS:669375778529289@1536603030400/A-simple-graph-consisting-of-4-nodes-and-4-edges-The-degree-of-each-node-is.ppm" width="200"/>

In [None]:
import torch
import networkx as nx

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

In [None]:
# node features – indices:
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)

# edges of the graph:
edge_index = torch.tensor(
    [
        [0, 0, 0, 1, 1, 2, 3, 3],
        [1, 2, 3, 0, 3, 0, 0, 1],
    ], 
    dtype=torch.long
)

# PyG graph:
pg_graph = Data(x=x, edge_index=edge_index)

Let's look at the created PyG graph object:

In [None]:
pg_graph

Useful graph attributes:

In [None]:
print('Number of nodes in the graph:', pg_graph.num_nodes)
print('Number of edges in the graph:', pg_graph.num_edges)
print('Number of node features:', pg_graph.num_node_features)

Note that `edge_index`, i.e. the tensor defining the source and target nodes of all edges, **is not a list of index tuples**. If you want to write your indices this way, you should transpose and call `contiguous` on it before passing them to the data constructor:

In [None]:
edge_index = torch.tensor(
    [
        [0, 1],
        [0, 2],
        [0, 3],
        [1, 0],
        [1, 3], 
        [2, 0],
        [3, 0],
        [3, 1],
    ], 
    dtype=torch.long
)

edge_index.t().contiguous()

Let's transform PyG graph to the NX format and draw the resulting graph:

In [None]:
# transform:
nx_graph = to_networkx(pg_graph)

# draw:
nx.draw(
    nx_graph, 
    font_size=10,
    width=0.5, 
    with_labels=True,
    labels={i: f'v{pg_graph.x[i][0]}' for i in range(pg_graph.num_nodes)},
)

Now let's create the second simple graph with two nodes and one edge:

In [None]:
second_pg_graph = Data(
    x=torch.tensor([[10], [11]], dtype=torch.long),
    edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long),
)

second_nx_graph = to_networkx(second_pg_graph)

nx.draw(
    second_nx_graph, 
    font_size=10,
    width=0.5, 
    with_labels=True,
    labels={i: f'v{second_pg_graph.x[i][0]}' for i in range(second_pg_graph.num_nodes)},
)

Let's put both graphs into DataLoader with `batch_size=2`:

In [None]:
loader = DataLoader([pg_graph, second_pg_graph], batch_size=2)

Let's check how many batches the loader is going to generate:

In [None]:
len(loader)

Let's see what the loader yields:

In [None]:
for batch in loader:
    print(batch)

The batch object is very similar to the `Data` object, but it has additonal attribute `ptr` that defines ranges of nodes' indices that belong to different graphs in the batch:

In [None]:
batch.ptr

That means that nodes with indices $0\leq{i}<4$ belong to the first graph in the batch, and nodes with indices $4\leq{i}<6$ belong to the second graph in a batch. For checking that, let's take a look at the attribute `x` of the batch:

In [None]:
batch.x

So the batch of size $B$ can be thought as a new big graph that contains $B$ disjoint components corresponding to the graphs included in this batch:

In [None]:
nx_batch = to_networkx(batch)

nx.draw(
    nx_batch, 
    font_size=10,
    width=0.5, 
    with_labels=True,
    labels={i: f'v{batch.x[i][0]}' for i in range(batch.num_nodes)},
)

# Graph notations

Let's consider an *undirected graph* $G=(V, E)$, where $V$ is the set of nodes and $E$ is the set of edges.
For a graph node $u\in V$, we define its neighborhood as 

$$
N(u)=\{v\in{V}\ |\ (u,v)\in E\}.
$$

The adjacency matrix $\boldsymbol{A}$ of graph $G$ is a square $|V|\times|V|$ symmetric matrix where each entry relates to an edge between the corresponding nodes. In case of the *weigted* graph, when each edge $(v_i,v_j)\in E$ has weight $w_{ij}\in\mathbb{R}$, the corresponding entry of the adjacency matrix $a_{ij}$ equals to this weight:

$$
\forall {i,j}\in\{1,\dots,|V|\}\ \ \ \ a_{ij}=
\begin{cases}
w_{ij},\ &\text{if}\ (v_i,v_j)\in{E},\\
0,\ &\text{if}\ (v_i,v_j)\notin{E}.
\end{cases}
$$


In case of the *unweighted* graph, the adjacency matrix is binary:

$$
\forall {i,j}\in\{1,\dots,|V|\}\ \ \ \ a_{ij}=
\begin{cases}
1,\ &\text{if}\ (v_i,v_j)\in{E},\\
0,\ &\text{if}\ (v_i,v_j)\notin{E}.
\end{cases}
$$

# Message Passing

Graph Neural Networks (GNNs) rely on a more generic framework referred to as "Message Passing" that proceeds as follows.

Assume that each node $v_i$ of the input graph $G$ has an associated vector of features $\boldsymbol{z}_i^{0}$ of size $d^0$. Consider $K$ message-passing layers. 

For $k \in \{1,\ldots,K\}$:

* For all nodes $v_i$ and for all its neighbours, we build a message $\color{red}{\boldsymbol{m}_{ij}^{k}}$ with some differentiable function $\phi$:

$$
\color{red}{\boldsymbol{m}_{ij}^{k} \leftarrow \phi(z_i^{k-1},z_j^{k-1})},\tag{1}
$$


* We aggregate the messages in $\color{green}{\boldsymbol{h}_i^{k}}$, a vector of size $d^{k-1}$, using some differentiable and permutation invariant function AGGR:

$$
\color{green}{\boldsymbol{h}_i^{k} \leftarrow \text{AGGR}(\color{red}{\boldsymbol{m}_{ij}^{k}}, \forall v_j \in N(v_i) \cup \{v_i\})},\tag{2}
$$

* We build a new embedding $\boldsymbol{z}^{k}_i$ for each node $v_i$ with $\boldsymbol{W}^{k}$ a matrix of size $d^{k}\times d^{k-1}$:

$$
\boldsymbol{z}^{k}_i \leftarrow \sigma(\boldsymbol{W}^{k} . \color{green}{\boldsymbol{h}_i^{k}}).\tag{3}
$$
    

Finally, we set the embedding of node $i$ as $\boldsymbol{z}_i=\boldsymbol{z}^{K}_i$.

<div>
<img src="https://miro.medium.com/max/1400/1*fPzRm3Flq3dQErn7LEG_Ig.png" width="900"/>
</div>

Note that the user has to choose:

* the number of layers $K$,
* The $\phi$ function,
* the AGGR function, which is an aggregation function (e.g: max, sum, mean),
* the dimensions $d^{k}$ for $k \geq 1$
* $\sigma$, which is a non-linear function (e.g ReLU).

## Graph Convolutional Network

* [original paper](https://arxiv.org/abs/1609.02907)
* [original code](https://github.com/tkipf/gcn)

The Graph Convolutional Network (GCN) is a graph neural network that implements the Message Passing framework such that:

$$
\boldsymbol{z}'_i = \sigma\left[
\color{green}{\sum_{v_j\in{N}(v_i)\cup\{v_i\}}}
\color{red}{
\frac{1}{\sqrt{\text{deg}(v_i)}\sqrt{\text{deg}(v_j)}} 
}
\boldsymbol{\Theta}\boldsymbol{z}_j\right]
,\tag{4}
$$

where $\boldsymbol{\Theta}$ is a weight matrix, and $\text{deg}(v)$ is a degree of the node $v$:

$$
\text{deg}(v) = \sum_{u\in{V}}\mathbb{I}\{(u,v)\in{E}\}.\tag{5}
$$

**Note:** for simplicity, we dropped indices $k$ corresponding to the layer's number. Instead, we use prime `'` as an indication of the updated embeddings. We assume that each node $v_i$ has an embedding vector $\boldsymbol{z}_i\in\mathbb{R}^d$ before applying the convolution layer, and gets an updated embedding vector $\boldsymbol{z}'_i\in\mathbb{R}^{d'}$ after applying the convolution layer.

## Graph Attention Network

* [original paper](https://arxiv.org/abs/1710.10903)
* [original code](https://github.com/PetarV-/GAT)


The Graph Attention Network (GAT) is a graph neural network that implements the Message Passing framework such that:

$$
\boldsymbol{z}'_i = \sigma\left[
\color{green}{\sum_{v_j\in{N}(v_i)\cup\{v_i\}}}
\color{red}{
\alpha_{ij} 
}
\boldsymbol{\Theta}\boldsymbol{z}_j\right]
,\tag{6}
$$

where attention coefficients $\alpha_{ij}$ are computed as follows,

$$
\alpha_{ij}=\frac{
\exp\big(\text{LeakyReLU}\big(\boldsymbol{a}^{\text{T}}[\boldsymbol{\Theta}\boldsymbol{z}_i||\boldsymbol{\Theta}\boldsymbol{z}_j]\big)\big)
}{
\sum_{v_m\in{N(v_i)}}\exp\big(\text{LeakyReLU}\big(\boldsymbol{a}^{\text{T}}[\boldsymbol{\Theta}\boldsymbol{z}_i||\boldsymbol{\Theta}\boldsymbol{z}_m]\big)\big)
}.\tag{7}
$$

Here, $||$ represents concatenation, and $\boldsymbol{a}\in\mathbb{R}^{2d'}$ is a vector of learnable parameters.

**Note:** for simplicity, we dropped indices $k$ corresponding to the layer's number. Instead, we use prime `'` as an indication of the updated embeddings. We assume that each node $v_i$ has an embedding vector $\boldsymbol{z}_i\in\mathbb{R}^d$ before applying the convolution layer, and gets an updated embedding vector $\boldsymbol{z}'_i\in\mathbb{R}^{d'}$ after applying the convolution layer.

<div>
<img src="https://miro.medium.com/max/1036/1*3D844_twutCaunYMPuo-Sw.png" width="400"/>
</div>

To stabilize the learning process of self-attention, we use *multi-head attention*. To do this we use $L$ independent attention mechanisms, or “heads” compute output features. Then, we aggregate these output feature representations:

$$
\boldsymbol{z}'_i = \sigma\left[
\sum_{l=1}^{L}\sum_{v_j\in{N}(v_i)\cup\{v_i\}}
\alpha_{ij}^{(l)} 
\boldsymbol{\Theta}^{(l)}\boldsymbol{z}_j\right].\tag{8}
$$

# Graph classification: PROTEINS dataset

PROTEINS is a dataset of proteins that are classified as enzymes or non-enzymes. Nodes represent the amino acids and two nodes are connected by an edge if they are less than 6Å apart.

This dataset can be obtained from PyTorch Geometric: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html

In [None]:
from torch_geometric.datasets import TUDataset

In [None]:
dataset = TUDataset(root='data/PROTEINS', name='PROTEINS').shuffle()

In [None]:
print(f'Number of graphs: {len(dataset)}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of node features: {dataset.num_features}')

Let's vizualise several graphs using NetworkX:

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

from torch_geometric.utils import to_networkx

COLORS = [
    '#C3EFFC',
    '#FCC4C3',
    '#FCF9C3',
    '#E1FCC3',
    '#C1EFBF',
    '#BFC9EF',
    '#CCBFEF',
    '#EBBFEF',
    '#CCB4C4',
    '#EEEEEE',
]

def draw_colored_graph(nx_graph, colors, labels, ax=None):
    if ax is None:
        plt.figure(figsize=(20, 12))
    nx.draw(
        nx_graph, 
        node_color=colors,  
        font_size=10, 
        width=0.2, 
        with_labels=True,
        labels=labels,
        ax=ax
    )

In [None]:
n_examples = 5

fix, ax = plt.subplots(nrows=1, ncols=n_examples, figsize=(5*n_examples, 5))

for i, rand_ix in enumerate(np.random.choice(dataset.indices(), 5)):
    curr_ax = ax[i]
    curr_ax.set_title(f'Graph with id={rand_ix}')
    
    pg_graph = dataset[rand_ix]
    nx_graph = to_networkx(pg_graph, to_undirected=True)
    colors = [COLORS[np.argmax(features)] for features in pg_graph.x]
    
    draw_colored_graph(nx_graph, colors, labels={}, ax=curr_ax)

## 1. Graph classification with GCN

At first, let's create data loaders for training, validation and testing. For that, we will use [PyG DataLoader](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader) that combines input graphs into batches. One batch is represented as a single graph with multiple disconnected components.

In [None]:
from torch_geometric.loader import DataLoader

batch_size = 4
loader = DataLoader(dataset, batch_size=batch_size)

In [None]:
# Let's pick the first batch and visualize it.
# We will see that it is a simple graph with multiple disconnected components.
# Each component corresponds to a graph in the initial dataset.
# Here for illustration we label and color nodes according to these components

batch = loader.__iter__().next()
nx_graph = to_networkx(batch, to_undirected=True)

mask = np.concatenate([
    np.ones(batch.ptr[i+1] - batch.ptr[i], dtype=int) * i
    for i in range(batch_size) 
])
colors = [COLORS[graph_idx] for graph_idx in mask]
labels = dict(zip(range(len(mask)), mask))


draw_colored_graph(nx_graph, colors=colors, labels=labels)

In [None]:
batch_size = 32
data_size = len(dataset)

train_loader = DataLoader(dataset[:int(data_size * 0.8)], batch_size=batch_size)
val_loader = DataLoader(dataset[int(data_size * 0.8):int(data_size * 0.9)], batch_size=batch_size)
test_loader = DataLoader(dataset[int(data_size * 0.9):], batch_size=batch_size)

### Task 1.1 – Implement GCN Layer

Pytorch Geometric provides the [Message Passing interface](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html) that contains three main functions:
* `message()`, which defines how the message $\color{green}{\boldsymbol{m}_{ij}^{k}}$ is built,
* `aggregate()`, which defines how the messages are aggregated into $\color{red}{\boldsymbol{h}_{i}^{k}},$ 
* `propagate()`, which calls the `message()`, `aggregate()` functions.  

We will implement GCN layer according to the formula (4) on top this interface.

In [None]:
import torch

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [None]:
class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__(aggr='add')
         
        # Create a learnable linear parameter Theta used in formula (4)
        # YOUR CODE HERE
        self.lin = ...

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Linearly transform node feature matrix.
        # YOUR CODE HERE:
        x = ...

        # Compute normalization.
        # Hint: function `degree` from torch_geometric.utils can be useful here
        # YOUR CODE HERE (~5 lines):
        norm = ...

        # Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Normalize node features.
        return norm.view(-1, 1) * x_j

In [None]:
# Check:
out_channels = 10
gcn_layer = GCN(in_channels=dataset.num_node_features, out_channels=out_channels)

rand_idx = np.random.randint(0, len(dataset))
graph = dataset[rand_idx]
output = gcn_layer(graph.x, graph.edge_index)

if output.shape == torch.Size([graph.num_nodes, out_channels]): 
    print('Good job!')
else:
    print('Error: layer should output a 2-dimensional tensor of shape (N, out_channels)')

### Task 1.2 – Implement GNN with GCN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn


class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0):
        super(GNN, self).__init__()
        
        self.dropout = dropout

        # Create a sequence of GCN layers with non-linearities (ReLU)
        # Hint: consider pyg_nn.Sequential, an extension of the torch.nn.Sequential
        # YOUR CODE HERE:
        self.convs = ...

        # Create post-message-passing linear transformations and aggregations
        # Hint: try a couple of linear layers with non-linearities (ReLU) and dropout
        # YOUR CODE HERE:
        self.post_mp = ...

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply GCN layers with non-linearities (ReLU)
        # YOUR CODE HERE:
        x = ...
        
        # Aggregate node embeddings tensor of shape (N, hidden_dim) 
        # to get the graph embedding tensor of shape (hidden_dim).
        # Hint 1: pyg_nn.global_max_pool can be useful
        # Hint 2: keep in mind that the data is batched
        # YOUR CODE HERE:
        x = ...
        
        # Apply post-message-passing transformations
        # YOUR CODE HERE:
        x = ...

        return x

In [None]:
# Check:
hidden_dim = 10
gnn = GNN(
    input_dim=dataset.num_node_features, 
    hidden_dim=hidden_dim, 
    output_dim=dataset.num_classes, 
)

batch = loader.__iter__().next()
output = gnn(batch)

if output.shape == torch.Size([loader.batch_size, dataset.num_classes]): 
    print('Good job!')
else:
    print('Error: GNN should output a 2-dimensional tensor of shape (batch_size, num_classes)')

### Task 1.3 – Train GNN with GCN to predict classes of protein graphs

In [None]:
from tqdm import tqdm

def cross_entropy_loss(x, labels):
    return F.cross_entropy(x, labels)


def train(model, optimizer, train_loader, val_loader, epochs):
    train_loss = []
    val_accuracy = []

    for epoch in tqdm(range(epochs)):
        batch_train_loss = []
        batch_val_accuracy = []

        model.train()
        for batch in train_loader:
            # Get logits from the model
            # YOUR CODE HERE:
            logits = ...
            
            # Get ground-truth labels
            # YOUR CODE HERE:
            labels = ...

            # Calculate loss
            # YOUR CODE HERE:
            loss = ...
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_train_loss.append(float(loss.data.numpy()))

        train_loss.append(np.mean(batch_train_loss))

        model.eval()
        for batch in val_loader:
            # Get predictions from the model
            # YOUR CODE HERE:
            pred = ...

            # Get ground-truth labels
            # YOUR CODE HERE:
            labels = ...
            
            batch_val_accuracy.append(np.mean((labels == pred).numpy()))

        val_accuracy.append(np.mean(batch_val_accuracy))
        
    return model, train_loss, val_accuracy


def plot_progress(train_loss, val_accuracy):
    fig, (train_ax, val_ax) = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

    train_ax.plot(train_loss)
    train_ax.set_title('Train loss')
    train_ax.set_xlabel('Epoch')

    val_ax.plot(val_accuracy)
    val_ax.set_title('Val accuracy')
    val_ax.set_xlabel('Epoch')

    plt.show()
    
def evaluate(model, loader):
    model.eval()

    predictions = np.array([])
    labels = np.array([])

    for batch in loader:

        # Get predicted labels
        # YOUR CODE HERE:
        pred = ...

        # Get ground-truth labels
        # YOUR CODE HERE:
        true = ...

        predictions = np.append(predictions, pred)
        labels = np.append(labels, true)

    return np.mean(predictions == labels)

In [None]:
import torch.optim as optim

# Create GNN
# YOUR CODE HERE:
model = ...
optimizer = ...

Training:

In [None]:
epochs = 50
model, train_loss, val_accuracy = train(model, optimizer, train_loader, val_loader, epochs)
plot_progress(train_loss, val_accuracy)

Evaluation:

In [None]:
accuracy = evaluate(model, test_loader)

print('Accuracy:', accuracy)
if accuracy >= 0.7:
    print('Good job!')
else:
    print('Try better!')

## 2. Graph classification with GAT

Let's now implement Graph Attention Layer and perform the same graph classification procedure on PROTEINS dataset with a new GNN that contains GAT layers.

### Task 2.1 – Implement GAT Layer

We will implement multi-head GAT layer according to the formula (8) on top PyG Message-Passing interface.

In [None]:
import torch_geometric.utils as pyg_utils

class GAT(MessagePassing):
    def __init__(self, in_channels, out_channels, num_heads):
        super(GAT, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        
        # Create a learnable linear parameter Theta used in formula (6)
        # Hint: keep in mind that we have multiple independent heads
        # YOUR CODE HERE:
        self.lin = ...
        
        # Create a learnable attention vector that is used in formula (7)
        # YOUR CODE HERE:
        self.att = ...
        
        # Initialization of the attention vector
        nn.init.xavier_uniform_(self.att)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        # Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # Apply linear transformation to the node feature matrix.
        # YOUR CODE HERE:
        x = ...

        # Start propagating messages.
        return self.propagate(edge_index=edge_index, x=x)

    def message(self, x_i, x_j, index):
        # Constructs messages to node i for each edge (j, i).
        # x_i – feature vectors of target nodes corresponding to i-th index, i.e. x[edge_index[1]]
        # x_j – feature vectors of source nodes corresponding to j-th index, i.e. x[edge_index[0]]
        # index – target (i-th) nodes indices, i.e. edge_index[1]

        # Compute the attention coefficients alpha as described in equation (7).
        # Remember to be careful of the number of heads with dimension!
        # Hint: function pyg_utils.softmax can be useful here
        # YOUR CODE HERE (~6 lines):
        alpha = ...
        
        return (alpha * x_j).mean(dim=1)

In [None]:
# Check:
out_channels = 10
num_heads = 3
gat_layer = GAT(in_channels=dataset.num_node_features, out_channels=out_channels, num_heads=num_heads)

rand_idx = np.random.randint(0, len(dataset))
graph = dataset[rand_idx]
output = gat_layer(graph.x, graph.edge_index)

if output.shape == torch.Size([graph.num_nodes, out_channels]): 
    print('Good job!')
else:
    print('Error: layer should output a 2-dimensional tensor of shape (N, out_channels)')

### Task 2.2 – Implement GNN with GAT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn


class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, num_layers, dropout=0.0):
        super(GNN, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout

        # Create a sequence of GAT layers
        # YOUR CODE HERE:
        self.convs = ...

        # Create post-message-passing linear transformations and aggregations
        # Hint: try a couple of linear layers with non-linearities (ReLU) and dropout
        # YOUR CODE HERE:
        self.post_mp = ...

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply GAT layers with non-linearities (ReLU)
        # YOUR CODE HERE:
        x = ...
        
        # Aggregate node embeddings tensor of shape (N, hidden_dim) 
        # to get the graph embedding tensor of shape (hidden_dim).
        # Hint 1: pyg_nn.global_max_pool can be useful
        # Hint 2: keep in mind that the data is batched
        # YOUR CODE HERE:
        x = ...
        
        # Apply post-message-passing transformations
        # YOUR CODE HERE:
        x = ...

        return x

In [None]:
# Check:
hidden_dim = 10
num_layers = 3
num_heads = 3

gnn = GNN(
    input_dim=dataset.num_node_features, 
    hidden_dim=hidden_dim, 
    output_dim=dataset.num_classes, 
    num_heads=num_heads, 
    num_layers=num_layers
)

batch = loader.__iter__().next()
output = gnn(batch)

if output.shape == torch.Size([loader.batch_size, dataset.num_classes]): 
    print('Good job!')
else:
    print('Error: GNN should output a 2-dimensional tensor of shape (batch_size, num_classes)')

### Task 2.3 – Train GNN with GAT to predict classes of protein graphs

In [None]:
import torch.optim as optim

# Create GNN
# YOUR CODE HERE:
model = ...
optimizer = ...

Training:

In [None]:
epochs = 50
model, train_loss, val_accuracy = train(model, optimizer, train_loader, val_loader, epochs)
plot_progress(train_loss, val_accuracy)

Evaluation:

In [None]:
accuracy = evaluate(model, test_loader)

print('Accuracy:', accuracy)
if accuracy >= 0.7:
    print('Good job!')
else:
    print('Try better!')

# Graph regression: QM9 dataset

QM9 is a [molecular dataset](https://www.nature.com/articles/sdata201422) standardized in machine learning as a chemical property prediction benchmark. It consists of small molecules (up to 29 atoms per molecule). Atoms contain positional coordinates embedded in a 3D space, a one-hot encoding vector that defines the type of molecule (H, C, N, O, F) and an integer value with the atom charge. For each molecule, authors of the dataset provide computed geometries minimal in energy, corresponding harmonic frequencies, dipole moments, polarizabilities, along with energies, enthalpies, and free energies of atomization. Any of these values can be considered as targets in the graph regression problem. In this seminar, we will predict one of them, the energy of the highest occupied molecular orbital $\epsilon_{\text{HOMO}}$.

There are examples of some molecules from QM9 constructed using [PyMOL](https://pymol.org/2/):

<img src="https://i.ibb.co/qDv7Xy0/qm9-examples-v2.jpg" alt="qm9-examples-v2" border="0">

In [None]:
from torch_geometric.datasets import QM9

In [None]:
# In this seminar we will consider only a small part of this dataset
# (in total it contains ~134k molecules)
dataset = QM9(root='data/QM9')[:10000]

In [None]:
print(f'Number of graphs: {len(dataset)}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of node features: {dataset.num_features}')

In [None]:
n_examples = 5
atom_types = ['H', 'C', 'N', 'O', 'F']

fix, ax = plt.subplots(nrows=1, ncols=n_examples, figsize=(5*n_examples, 5))

for i, rand_ix in enumerate(np.random.choice(dataset.indices(), 5)):
    curr_ax = ax[i]
    curr_ax.set_title(f'Graph with id={rand_ix}')
    
    pg_graph = dataset[rand_ix]
    nx_graph = to_networkx(pg_graph, to_undirected=True)
    colors = [COLORS[np.argmax(features[:4])] for features in pg_graph.x]
    labels = {
        i: atom_types[np.argmax(features[:4])]
        for i, features in enumerate(pg_graph.x)
    }
    
    draw_colored_graph(nx_graph, colors, labels=labels, ax=curr_ax)

In [None]:
from torch_geometric.loader import DataLoader

batch_size = 32
data_size = len(dataset)

train_loader = DataLoader(dataset[:int(data_size * 0.8)], batch_size=batch_size)
val_loader = DataLoader(dataset[int(data_size * 0.8):int(data_size * 0.9)], batch_size=batch_size)
test_loader = DataLoader(dataset[int(data_size * 0.9):], batch_size=batch_size)

### Task 3.1 – Implement GNN for graph regresion

Architecture is up to you. Feel free to use GAT or GCN as well as [any other](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers) layers.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn

class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_heads=1, dropout=0.0):
        # YOUR CODE HERE

    def forward(self, data):
        # YOUR CODE HERE

In [None]:
# Check:
hidden_dim = 10
num_layers = 3

gnn = GNN(
    input_dim=dataset.num_node_features, 
    hidden_dim=hidden_dim, 
    output_dim=dataset.num_classes, 
    num_layers=num_layers
)

rand_idx = np.random.randint(0, len(train_loader))
graph = train_loader.__iter__().next()

output = gnn(graph)

if output.shape == torch.Size([train_loader.batch_size, dataset.num_classes]): 
    print('Good job!')
else:
    print('Error: GNN should output a 2-dimensional tensor of shape (batch_size, num_classes)')

### Task 3.2 – Train GNN

Keep in mind that now we'are going to predict a real value instead of class so we need to change loss function and the way we evaluate the model.

In [None]:
import torch.nn.functional as F

from tqdm import tqdm

def mse_loss(predictions, targets):
    # YOUR CODE HERE:
    return ...


# Wrapper for different architectures
def model_forward(model, batch):
    if model.__class__.__name__ == 'GNN':
        return model(batch)
    if model.__class__.__name__ == 'SchNet':
        return model(batch.z, batch.pos, batch.batch)
    raise Exception('Unknown model')


def train_graph_regression(model, optimizer, train_loader, val_loader, target_ix, epochs):
    train_mse = []
    val_mse = []

    for epoch in range(epochs):
        batch_train_mse = []
        batch_val_mse = []

        model.train()
        for batch in tqdm(train_loader, desc=f'Epoch {epoch} train'):
            # Get model predictions
            # YOUR CODE HERE:
            predictions = ...
            
            # Get ground-truth values
            # YOUR CODE HERE:
            targets = ...
            
            # Calculate loss
            # YOUR CODE HERE:
            loss = ...
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_train_mse.append(float(loss.data.numpy()))

        train_mse.append(np.mean(batch_train_mse))

        model.eval()
        for batch in tqdm(val_loader, desc=f'Epoch {epoch} valid'):
            predictions = model_forward(model, batch).squeeze()
            targets = batch.y[:, target_ix].squeeze()
            batch_val_mse.append(mse_loss(predictions, targets).detach().numpy())

        val_mse.append(np.mean(batch_val_mse))
        
    return model, train_mse, val_mse


def plot_progress(train_mse, val_mse):
    fig, (train_ax, val_ax) = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

    train_ax.plot(train_mse)
    train_ax.set_title('Train MSE')
    train_ax.set_xlabel('Epoch')

    val_ax.plot(val_mse)
    val_ax.set_title('Val MSE')
    val_ax.set_xlabel('Epoch')

    plt.show()
    
def evaluate_graph_regression(model, loader, target_ix):
    loss = []
    model.eval()
    for batch in tqdm(loader):

        # Get predicted values
        # YOUR CODE HERE:
        pred = ...

        # Get ground-truth values
        # YOUR CODE HERE:
        true = ...

        loss.append(mse_loss(pred, true).detach().numpy())

    return np.mean(loss)

In [None]:
import torch.optim as optim

# Create GNN
# YOUR CODE HERE:
model = ...
optimizer = ...

In [None]:
# In this seminar we will predict the HOMO energy.
# This value goes the 3rd in the list of all targets in the PyG QM9 dataset
# For details, see https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.QM9
target_homo_ix = 2

epochs = 10
model, train_mse, val_mse = train_graph_regression(
    model, 
    optimizer, 
    train_loader, 
    val_loader, 
    target_homo_ix, 
    epochs
)
plot_progress(train_mse, val_mse)

In [None]:
loss = evaluate_graph_regression(model, test_loader, target_homo_ix)

print('MSE:', loss)
if loss <= 0.2:
    print('Good job!')
else:
    print('Try better!')

# SchNet and Equivariance

* [original paper](https://arxiv.org/abs/1706.08566)
* [original code](https://github.com/atomistic-machine-learning/SchNet)

SchNet is one of the first GCN models where authors attempted to take into accout geometry of the underlying data. 

In case of molecular graphs, each node $v_i$ of graph $G$ is an atom that is characterized by its vector of features $\boldsymbol{z}_i\in\mathbb{R}^d$ and by its position $\boldsymbol{r}_i\in\mathbb{R}^3$. In general, the graph itself may not include the information about spatial relations between atoms. To take it into account, rotation-invariant _continuous-filter convolutions_ were proposed:

$$
z'_i=\sum_{j=1}^Nz_j\circ\boldsymbol{W}(\boldsymbol{r_i}-\boldsymbol{r}_j),\tag{9}
$$

where $\boldsymbol{W}(\boldsymbol{r_i}-\boldsymbol{r}_j)$ is a trainable and relative-distance-dependent filter, and $\circ$ denotes element-wise multiplication.

One of the most important aspects of learning on 3D objects is the fact that in most cases we do not have the fixed and preferred global orientation. When constructing geometric-learning models, one should take it into account and impose additional symmetry-related constraints on the model. This constraint is called $equivariance$ can be required with respect to rotations and translations in case of 3D space. More formally, having some group $G$ (e.g. group of rotations in 3D) and a function $f:X\to Y$, this function is called $G$-$equivariant$ if for any $x\in{X}$ and for any $g\in{G}$

$$
f(\rho_X(g)(x))=\rho_Y(g)(f(x)),\tag{10}
$$

where $\rho_X:G\to{GL}(X)$ and $\rho_Y:G\to{GL}(Y)$ are representations of group $G$ on spaces $X$ and $Y$ respectively. In case if $\rho_Y=1$, function $f$ is called $G$-$invariant$.

In case of (9), filters $\boldsymbol{W}(\boldsymbol{r_i}-\boldsymbol{r}_j)$ are constructed to be rotation- and translation-invariant. To do it, we will make them depend only on distances between molecules:

$$
\boldsymbol{W}(\boldsymbol{r_i}-\boldsymbol{r}_j)=\boldsymbol{W}(\|\boldsymbol{r_i}-\boldsymbol{r}_j\|).\tag{11}
$$

In SchNet, these filters are constructed as MLPs that operate on $K$-dimensional radial basis finctions $\{\boldsymbol{e}_{ij}\}$:

$$
\boldsymbol{W}(\boldsymbol{r_i}-\boldsymbol{r}_j)=
\sigma(\boldsymbol{W}_2\sigma(\boldsymbol{W}_1\boldsymbol{e}_{ij})),\tag{12}
$$

where 

$$
\boldsymbol{e}_{ij}=\big(e_{ij}^{(1)},\dots,e_{ij}^{(K)}\big)^{\text{T}}, \tag{13}
$$

and

$$
e_{ij}^{(k)}=exp\big(-\gamma\big[\|\boldsymbol{r}_i-\boldsymbol{r}_j\|-\mu_k\big]^2\big).\tag{14}
$$

Here, $\sigma$ is non-linearity, $K$, $\gamma$ and $\{\mu_k\}$ are adjustable hyperparameters and $\boldsymbol{W}_1$ and $\boldsymbol{W}_2$ are trainable matrices.

In the original paper, $\sigma(x)=\ln(0.5e^x+0.5)$, $K=300$, $0Å\leq\mu_k\leq30Å$ every $0.1$Å, and $\gamma=10$Å.

<div>
<img src="https://user-images.githubusercontent.com/7134790/65094886-d8dda800-d9f9-11e9-92e8-2737c4913ab7.png" width="800"/>
</div>

Here we will not implement SchNet by ourselves, but we can [import it from PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.SchNet) and train on QM9.

In [None]:
from torch_geometric.nn import SchNet

# YOUR CODE HERE
model = ...
optimizer = ...

In [None]:
epochs = 10
target_homo_ix = 2
model, train_mse, val_mse = train_graph_regression(
    model, 
    optimizer, 
    train_loader, 
    val_loader, 
    target_homo_ix, 
    epochs
)
plot_progress(train_mse, val_mse)

In [None]:
loss = evaluate(model, test_loader, target_homo_ix)

print('MSE:', loss)
if loss <= 0.2:
    print('Good job!')
else:
    print('Try better!')

# Further reading

In order to read more about geometric learning, message passing and equivariance, we refer to the following sources:
* [Geometric Learning](https://arxiv.org/abs/2104.13478)
* [Directional Message Passing](https://arxiv.org/abs/2003.03123)
* [E(n)-equivariant graph neural networks](https://arxiv.org/abs/2102.09844)
* [SE(3)-Transformers](https://arxiv.org/abs/2006.10503)