In [2]:
# Install PyTorch
!pip install torch torchvision

# Install PyTorch Geometric and its dependencies
# Check the PyTorch Geometric website for the correct installation command based on your PyTorch version
!pip install torch-scatter -f https://data.pyg.org/whl/torch-<torch_version>+<cpu_or_cuda>.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-<torch_version>+<cpu_or_cuda>.html
!pip install torch-geometric

# RDKit
!conda install -c conda-forge rdkit

# OpenMM
!conda install -c conda-forge openmm

# Additional libraries
!pip install numpy pandas tqdm scikit-learn

!pip install matplotlib

# Data Collection and Preparation
# Downloading the QM9 Dataset

from torch_geometric.datasets import QM9

# Saved it in the argument dataset
dataset = QM9(root='data/QM9')


/bin/bash: line 1: torch_version: No such file or directory
/bin/bash: line 1: torch_version: No such file or directory
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
/bin/bash: line 1: conda: command not found
/bin/bash: line 1: conda: command not found


Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting data/QM9/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


In [3]:
# Let's explore the Dataset
# We know that each molecule in the dataset is represented as a graph. Let's examine one example.

from torch_geometric.data import Data

# Get a single data point
data = dataset[0]

print(data)

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])


In [19]:
# Data Preprocessing
#Our goal is to predict total energy (E0) and forces on each atom.
#However, QM9 provides total energies but not forces.
# We'll need to approximate forces using gradients of the energy with respect to positions, but since we don't have potential energy surfaces, we'll focus on energy prediction for now.

from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split

# Indices for splitting
indices = list(range(len(dataset)))
train_idx, test_idx = train_test_split(indices, test_size=0.1, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.1, random_state=42)

# Create subsets
train_dataset = dataset[train_idx]
val_dataset = dataset[val_idx]
test_dataset = dataset[test_idx]

# Data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [20]:
# Let's define the GNN Model Architecture

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.nn import radius_graph
from torch_geometric.utils import remove_self_loops

# Let's define the MPNN Model

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class MPNN(MessagePassing):
    def __init__(self, node_features, edge_features, hidden_dim):
        super(MPNN, self).__init__(aggr='add')  # Aggregation method ('add', 'mean', 'max')
        self.node_features = node_features
        self.edge_features = edge_features
        self.hidden_dim = hidden_dim

        # Initial node embedding
        self.node_embedding = torch.nn.Linear(node_features, hidden_dim)

        # Edge embedding
        self.edge_embedding = torch.nn.Linear(edge_features, hidden_dim)

        # Message function
        self.message_mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
        )

        # Update function
        self.update_mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
        )

        # Output MLP
        self.output_mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
        )

    def forward(self, x, edge_index, edge_attr, batch):
        # Initial node embeddings
        x = self.node_embedding(x)
        edge_attr = self.edge_embedding(edge_attr)

        # Message passing
        for _ in range(3):  # Number of message passing steps
            x = self.propagate(edge_index, x=x, edge_attr=edge_attr)

        # Readout
        out = global_mean_pool(x, batch)
        out = self.output_mlp(out)
        return out.squeeze()

    def message(self, x_i, x_j, edge_attr):
        # x_i: Nodes receiving the message
        # x_j: Nodes sending the message
        # edge_attr: Edge features
        m = torch.cat([x_i, x_j + edge_attr], dim=-1)
        return self.message_mlp(m)

    def update(self, aggr_out, x):
        # aggr_out: Aggregated messages
        m = torch.cat([x, aggr_out], dim=-1)
        return self.update_mlp(m)


In [28]:
# Training the GNN Model

# Prepare Input Features

# Node Features:

from torch.nn.functional import one_hot

# Maximum atomic number in QM9 is 9 (F)
max_atomic_number = 11

# max_atomic_number = dataset.data.z.max().item()
print(f"Updated max_atomic_number: {max_atomic_number}")


def preprocess_data(data):
    # One-hot encode atomic numbers
    z = data.z  # Atomic numbers
    x = one_hot(z - 1, num_classes=max_atomic_number).to(torch.float)
    data.x = x

    # Edge attributes: Compute distances between connected nodes
    row, col = data.edge_index
    pos_diff = data.pos[row] - data.pos[col]
    dist = torch.norm(pos_diff, p=2, dim=-1).unsqueeze(-1)
    data.edge_attr = dist

    return data

# Apply preprocessing
dataset.transform = preprocess_data

Updated max_atomic_number: 11


In [29]:
# Let's instantiate the Model


# node_features = max_atomic_number  # One-hot encoding size
# edge_features = 1  # Distance as edge attribute
# hidden_dim = 128

node_features = 11  # Number of node features
edge_features = 1
hidden_dim = 128

model = MPNN(node_features, edge_features, hidden_dim)


# model = MPNN(node_features, edge_features, hidden_dim)


In [30]:
# Let's define Loss Function and Optimizer
# We will use Mean Squared Error (MSE) loss for energy prediction

import torch.optim as optim

criterion = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [32]:
# Training loop

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def train(epoch):
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        # Forward pass
        output = model(data.x, data.edge_index, data.edge_attr, data.batch)
        # Target: Total energy E0 (index 12 in QM9 targets)
        y = data.y[:, 12]
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def validate(loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data.x, data.edge_index, data.edge_attr, data.batch)
            y = data.y[:, 12]
            loss = criterion(output, y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

# Training process
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    train_loss = train(epoch)
    val_loss = validate(val_loader)
    print(f'Epoch: {epoch}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')


100%|██████████| 3312/3312 [04:03<00:00, 13.57it/s]


Epoch: 1, Train Loss: 21.349504, Val Loss: 20.809556


100%|██████████| 3312/3312 [03:58<00:00, 13.89it/s]


Epoch: 2, Train Loss: 21.265583, Val Loss: 19.559981


100%|██████████| 3312/3312 [04:01<00:00, 13.72it/s]


Epoch: 3, Train Loss: 21.316550, Val Loss: 19.568311


100%|██████████| 3312/3312 [03:59<00:00, 13.81it/s]


Epoch: 4, Train Loss: 20.955014, Val Loss: 19.770341


100%|██████████| 3312/3312 [03:59<00:00, 13.80it/s]


Epoch: 5, Train Loss: 20.937122, Val Loss: 20.596571


100%|██████████| 3312/3312 [04:03<00:00, 13.61it/s]


Epoch: 6, Train Loss: 20.808076, Val Loss: 20.708014


100%|██████████| 3312/3312 [04:01<00:00, 13.72it/s]


Epoch: 7, Train Loss: 20.574783, Val Loss: 24.843819


100%|██████████| 3312/3312 [04:00<00:00, 13.78it/s]


Epoch: 8, Train Loss: 20.129895, Val Loss: 19.024179


100%|██████████| 3312/3312 [03:57<00:00, 13.97it/s]


Epoch: 9, Train Loss: 20.047264, Val Loss: 18.965532


100%|██████████| 3312/3312 [03:57<00:00, 13.92it/s]


Epoch: 10, Train Loss: 19.223372, Val Loss: 18.218576


100%|██████████| 3312/3312 [03:56<00:00, 13.98it/s]


Epoch: 11, Train Loss: 18.924302, Val Loss: 20.274628


100%|██████████| 3312/3312 [03:56<00:00, 14.03it/s]


Epoch: 12, Train Loss: 18.655296, Val Loss: 19.004444


100%|██████████| 3312/3312 [03:55<00:00, 14.04it/s]


Epoch: 13, Train Loss: 18.278770, Val Loss: 20.244577


100%|██████████| 3312/3312 [03:56<00:00, 14.03it/s]


Epoch: 14, Train Loss: 18.002099, Val Loss: 19.472459


100%|██████████| 3312/3312 [03:57<00:00, 13.93it/s]


Epoch: 15, Train Loss: 17.763243, Val Loss: 16.733207


100%|██████████| 3312/3312 [03:57<00:00, 13.96it/s]


Epoch: 16, Train Loss: 17.274656, Val Loss: 17.607585


100%|██████████| 3312/3312 [03:56<00:00, 13.99it/s]


Epoch: 17, Train Loss: 17.096800, Val Loss: 14.993468


100%|██████████| 3312/3312 [03:56<00:00, 14.00it/s]


Epoch: 18, Train Loss: 16.797847, Val Loss: 14.971330


100%|██████████| 3312/3312 [03:56<00:00, 13.99it/s]


Epoch: 19, Train Loss: 16.634123, Val Loss: 15.536028


100%|██████████| 3312/3312 [03:56<00:00, 13.98it/s]


Epoch: 20, Train Loss: 16.208611, Val Loss: 15.640961


In [33]:
# Handling Units
# The target energies in QM9 are in Hartree units
# For practical purposes, you might want to convert energies to eV or kcal/mol
hartree_to_ev = 27.2114
hartree_to_kcalmol = 627.509

# Modify the target variable in the training and validation loops
y = data.y[:, 12] * hartree_to_ev  # Convert to eV

In [34]:
# Testing the Model
test_loss = validate(test_loader)
print(f'Test Loss: {test_loss:.6f} eV')

Test Loss: 16.767337 eV


In [35]:
# Saving the model
torch.save(model.state_dict(), 'mpnn_qm9_energy.pth')