In [16]:
import torch
import numpy as np
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader

# Specify where to store the dataset
dataset = QM9(root='../data/external/qm9/')

# Access a sample (first molecule)
data = dataset[0]

# Print the information of the molecule
print(data)

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])


In [17]:
train_loader = DataLoader(dataset, batch_size=200, shuffle=True)

In [41]:
for batch in train_loader:
    print(batch.batch)
    break

tensor([  0,   0,   0,  ..., 199, 199, 199])


In [52]:
import torch
from egnn_pytorch import EGNN
import torch.nn as nn
from torch_geometric.nn.pool import global_mean_pool

def convert_to_tensor_format(batch):
    # Get the number of graphs (molecules) in the batch
    num_graphs = batch.batch.max().item() + 1

    # Initialize lists to store tensors for each molecule
    feature_tensors = []
    pos_tensors = []

    for i in range(num_graphs):
        # Extract the nodes corresponding to the i-th molecule
        node_indices = (batch.batch == i).nonzero(as_tuple=False).squeeze()

        # Get the features and positions for the i-th molecule
        features = batch.x[node_indices]  # Shape [num_nodes_i, dimension]
        positions = batch.pos[node_indices]  # Shape [num_nodes_i, 3]

        # Add batch dimension (1, num_nodes, dim)
        feature_tensors.append(features.unsqueeze(0))  # Shape [1, num_nodes_i, dimension]
        pos_tensors.append(positions.unsqueeze(0))     # Shape [1, num_nodes_i, 3]

    # Concatenate all molecules along the batch dimension

    return feature_tensors, pos_tensors

# Define the MLP class
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)  # First fully connected layer
        self.relu = nn.ReLU()                         # ReLU activation function
        self.fc2 = nn.Linear(hidden_size, output_size) # Second fully connected layer

    def forward(self, x):
        x = self.fc1(x)   # Pass input through the first layer
        x = self.relu(x)  # Apply ReLU activation
        x = self.fc2(x)   # Pass through second layer
        return x
    
class QMPredictor(nn.Module):
    def __init__(
            self, 
            latent_size=128, 
            hidden_size=128, 
            output_size=1
        ):
        super(QMPredictor, self).__init__()
        self.layer1 = EGNN(dim = 128)
        self.layer2 = EGNN(dim = 128)
        self.layer3 = EGNN(dim = 128)
        self.pool_fn = global_mean_pool
        self.mlp = MLP(latent_size, hidden_size, output_size)

    def forward(self, feats, coors):
        feats, coors = self.layer1(feats, coors)
        feats, coors = self.layer2(feats, coors)
        feats, coors = self.layer3(feats, coors)
        pooled_tensor = self.pool_fn(feats)
        out = self.mlp(pooled_tensor)

        return out


In [53]:
train_loader

<torch_geometric.loader.dataloader.DataLoader at 0x15ad15a80>

In [54]:
model = QMPredictor()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for graphs in train_loader:
    feats, coors = convert_to_tensor_format(graphs)
    labels = graphs.y
    for i in range(len(feats)):
        optimizer.zero_grad()
        outputs = model(feats[i], coors[i])
        loss = criterion(outputs, labels) if labels is not None else None  # Handle loss calculation
        if loss is not None:
            loss.backward()
            optimizer.step()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (289x23 and 257x514)