### In this file, we will define our model and parameters

Instructions: 

Finally, you should have at least one notebook where you define and train your notebook.

In [7]:
import pandas as pd
import numpy as np
import torch
from dgl.data import DGLDataset
from dgl.data.utils import save_graphs, load_graphs

### Model

In [8]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import NNConv

In [32]:
class Electron_MPNN(nn.Module):
    """MPNN.

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    This class performs message passing in MPNN and returns the updated node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_out_feats : int
        Size for the output node representations. Default to 64.
    edge_in_feats : int
        Size for the input edge features. Default to 128.
    edge_hidden_feats : int
        Size for the hidden edge representations.
    num_step_message_passing : int
        Number of message passing steps. Default to 6.
    """
    def __init__(self, node_in_feats, edge_in_feats, out_dim=1,
                 edge_hidden_feats=128, num_step_message_passing=6):
        super(Electron_MPNN, self).__init__()

        self.fc1 = nn.Sequential(
            nn.Linear(node_in_feats, out_dim),
            nn.ReLU()
        )
        self.num_step_message_passing = num_step_message_passing
        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats),
            nn.ReLU(),
            nn.Linear(edge_hidden_feats, out_dim * out_dim)
        )
        self.gnn_layer = NNConv(
            in_feats=out_dim,
            out_feats=out_dim,
            edge_func=edge_network,
            aggregator_type='sum'
        )
        self.gru = nn.GRU(out_dim, out_dim)
        
        self.fc2 = nn.Linear(out_dim, out_dim)


    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        node_feats : float32 tensor of shape (V, node_out_feats)
            Output node representations.
        """
        node_feats = self.fc1(node_feats) # (V, node_out_feats)
        
        hidden_feats = node_feats.unsqueeze(0)           # (1, V, node_out_feats)

        for _ in range(self.num_step_message_passing):
            node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
            node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
            node_feats = node_feats.squeeze(0)
            
        node_feats = self.fc2(node_feats)
        
        return node_feats.mean()

#### Load the graphs to get dimensions for our model

In [6]:
graphs, _ = load_graphs("./DataGraphs/data_F_graph.bin")

#### Construct the model

In [33]:
# All graphs in the list have the same scheme size, so pull the dimensions from the first
node_dim = graphs[0].ndata['atom_feats'].shape[1]
edge_dim = graphs[0].edata['bond_feats'].shape[1]
print("Dimensions:", node_dim, "(node),", edge_dim, "(edge)")

Dimensions: 11 (node), 5 (edge)


In [29]:
model = Electron_MPNN(node_dim, edge_dim, out_dim=1)

In [30]:
sample_data = graphs[1]
edges = sample_data.edata['bond_feats'].float()
nodes = sample_data.ndata['atom_feats'].float()

In [31]:
model(sample_data, nodes, edges)

tensor(-0.1162, grad_fn=<MeanBackward0>)