In [1]:
import torch
import numpy as np
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn.pool import global_mean_pool
from torch_geometric.nn import MessagePassing
from egnn_pytorch import EGNN
from torch_geometric.utils import softmax
from loguru import logger

# Specify where to store the dataset
dataset = QM9(root='../data/external/qm9/')

# Access a sample (first molecule)
data = dataset[0]

# Print the information of the molecule
print(data)

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])


In [2]:
layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

In [3]:
# Split the dataset into training and testing sets
train_size = int(0.8 * len(dataset))  # 80% for training
test_size = len(dataset) - train_size  # Remaining 20% for testing
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=200, shuffle=False)
# train_loader = DataLoader(dataset, batch_size=200, shuffle=True)

In [4]:
for batch in train_loader:
    print(batch.batch)
    break

tensor([  0,   0,   0,  ..., 199, 199, 199])


In [5]:
batch.size()

(3542, 3542)

In [6]:
def convert_to_tensor_format(batch):
    # Get the number of graphs (molecules) in the batch
    num_graphs = batch.batch.max().item() + 1

    # Initialize lists to store tensors for each molecule
    feature_tensors = []
    pos_tensors = []

    for i in range(num_graphs):
        # Extract the nodes corresponding to the i-th molecule
        node_indices = (batch.batch == i).nonzero(as_tuple=False).squeeze()

        # Get the features and positions for the i-th molecule
        features = batch.x[node_indices]  # Shape [num_nodes_i, dimension]
        positions = batch.pos[node_indices]  # Shape [num_nodes_i, 3]

        # Add batch dimension (1, num_nodes, dim)
        feature_tensors.append(features.unsqueeze(0))  # Shape [1, num_nodes_i, dimension]
        pos_tensors.append(positions.unsqueeze(0))     # Shape [1, num_nodes_i, 3]

    # Concatenate all molecules along the batch dimension

    return feature_tensors, pos_tensors

# Define the MLP class
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)  # First fully connected layer
        self.relu = nn.ReLU()                         # ReLU activation function
        self.fc2 = nn.Linear(hidden_size, output_size) # Second fully connected layer

    def forward(self, x):
        x = self.fc1(x)   # Pass input through the first layer
        x = self.relu(x)  # Apply ReLU activation
        x = self.fc2(x)   # Pass through second layer
        return x

In [7]:
def build_adjacency_tensor(edge_index, edge_attr, num_nodes):
    # Initialize empty adjacency tensor
    # Shape: [1, num_nodes, num_nodes, edge_feature_dim]
    adj_tensor = torch.zeros(1, num_nodes, num_nodes, edge_attr.size(-1))
    
    # Get source and target nodes
    src, dst = edge_index
    
    # Populate the adjacency tensor
    # Note: we're adding batch dimension hence the 0 index
    adj_tensor[0, src, dst] = edge_attr
    
    return adj_tensor

In [8]:
class QMPredictor(nn.Module):
    def __init__(
            self, 
            node_dim = 11,
            latent_size=128, 
            hidden_size=128, 
            output_size=19,
            num_layers=3
        ):
        super(QMPredictor, self).__init__()
        self.layers = torch.nn.ModuleList([
            # EGNN(dim = node_dim)
            EGNN(dim = 512, edge_dim = 4)
            for _ in range(num_layers)
        ])
        self.pool_fn = global_mean_pool
        self.mlp = MLP(node_dim, hidden_size, output_size)

    def forward(self, feats, coors, edges):
        for layer in self.layers:
            feats, coors = layer(feats, coors, edges)
        batch=torch.zeros([feats.size(1)], dtype=torch.int64)
        pooled_tensor = self.pool_fn(x=feats[0], batch=batch)
        out = self.mlp(pooled_tensor)

        return out

In [9]:
# class EGNNLayer(MessagePassing):
#     def __init__(self, node_feat_dim, edge_feat_dim, coord_dim=3, aggr="mean"):
#         super(EGNNLayer, self).__init__(aggr=aggr)
#         self.node_mlp = torch.nn.Sequential(
#             torch.nn.Linear(2*node_feat_dim + edge_feat_dim + 1, 64),
#             torch.nn.ReLU(),
#             torch.nn.Linear(64, node_feat_dim)
#         )
#         self.coord_mlp = torch.nn.Sequential(
#             torch.nn.Linear(node_feat_dim, 64),
#             torch.nn.ReLU(),
#             torch.nn.Linear(64, 1)
#         )

#     def forward(self, x, edge_index, edge_attr, coord):
#         # batch argument helps to know which nodes belong to which graph in the batch
#         return self.propagate(edge_index, x=x, coord=coord, edge_attr=edge_attr)

#     def message(self, x_i, x_j, coord_i, coord_j, edge_attr):
#         d_ij = coord_j - coord_i
#         distance_ij = torch.norm(d_ij, dim=-1, keepdim=True)

#         m_ij = torch.cat([x_i, x_j, edge_attr, distance_ij], dim=-1)

#         node_output = self.node_mlp(m_ij)
#         # coord_output = self.coord_mlp(x_j)
#         # print("coord: ", coord_output.size())
#         # print("node_output size: ", node_output.size())

#         return node_output

#     def update(self, aggr_out, coord):
#         # print("coord: ", coord.size())
#         # print("aggr_out[1] size: ", aggr_out[1].size())
#         node_update = aggr_out[0]
#         # coord_update = coord + aggr_out[1]

#         return node_update, coord

In [10]:
# # Compatible with batches
# class QMPredictorBatch(nn.Module):
#     def __init__(
#             self, 
#             node_feat_dim = 11,
#             edge_feat_dim = 4,
#             coord_dim = 3,
#             latent_size=128, 
#             hidden_size=128, 
#             output_size=19,
#             num_layers=3
#         ):
#         super(QMPredictorBatch, self).__init__()
#         self.layers = torch.nn.ModuleList([
#             EGNNLayer(node_feat_dim, edge_feat_dim, coord_dim)
#             for _ in range(num_layers)
#         ])
#         self.pool_fn = global_mean_pool
#         self.mlp = MLP(node_feat_dim, hidden_size, output_size)

#     def forward(self, x, edge_index, edge_attr, coord, batch=None):
#         for layer in self.layers:
#             print(x.size())
#             print(edge_index.size())
#             print(edge_attr.size())
#             print(coord.size())
#             print(batch.size())
#             x, _ = layer(x, edge_index, edge_attr, coord)#, batch=batch)
#         print(x.size())
#         print(batch.size())
#         pooled_tensor = self.pool_fn(x=x, batch=batch)
#         out = self.mlp(pooled_tensor)
#         return out

In [11]:
# model = QMPredictor()

# criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# # Initialize lists to store losses
# train_losses = []
# val_losses = []

# num_epochs = 1

# for epoch in range(num_epochs):
#     model.train()
#     epoch_train_loss = 0

#     for graphs in tqdm(train_loader, desc="Training", unit="batch"):
#         feats, coors = convert_to_tensor_format(graphs)
#         labels = graphs.y
#         outputs = []

#         for i in range(len(feats)):
#             optimizer.zero_grad()
#             output = model(feats[i], coors[i])
#             outputs.append(output)

#         outputs = torch.cat(outputs, dim=0)
#         # Check if the shapes of outputs and labels match
#         if outputs.shape != labels.shape:
#             print(f"Shape mismatch: outputs {outputs.shape}, labels {labels.shape}")
#             continue  # Skip this iteration if shapes do not match

#         loss = criterion(outputs, labels) if labels is not None else None  # Handle loss calculation
#         if loss is not None:
#             loss.backward()
#             optimizer.step()
#             epoch_train_loss += loss.item()

#      # Average training loss for the epoch
#     train_losses.append(epoch_train_loss / len(train_loader))

#     # Validation step
#     model.eval()  # Set the model to evaluation mode
#     epoch_val_loss = 0
#     with torch.no_grad():
#         for graphs in test_loader:
#             feats, coors = convert_to_tensor_format(graphs)
#             labels = graphs.y
#             outputs = []
#             for i in range(len(feats)):
#                 output = model(feats[i], coors[i])
#                 outputs.append(output)

#             outputs = torch.cat(outputs, dim=0)
#             loss = criterion(outputs, labels) if labels is not None else None
#             if loss is not None:
#                 epoch_val_loss += loss.item()

#     # Average validation loss for the epoch
#     val_losses.append(epoch_val_loss / len(test_loader))


In [12]:
dataset[10]

Data(x=[7, 11], edge_index=[2, 12], edge_attr=[12, 4], y=[1, 19], pos=[7, 3], idx=[1], name='gdb_11', z=[7])

In [13]:
for graphs in tqdm(train_loader, desc="Training", unit="batch"):
    graphs
    print(graphs)
    print(graphs.batch)
    break

Training:   0%|          | 0/524 [00:00<?, ?batch/s]

DataBatch(x=[3673, 11], edge_index=[2, 7610], edge_attr=[7610, 4], y=[200, 19], pos=[3673, 3], idx=[200], name=[200], z=[3673], batch=[3673], ptr=[201])
tensor([  0,   0,   0,  ..., 199, 199, 199])





In [14]:
model = QMPredictor()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Initialize lists to store losses
train_losses = []
val_losses = []

num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0

    for graphs in tqdm(train_loader, desc="Training", unit="batch"):
        optimizer.zero_grad()

        node_features = graphs.x
        labels = graphs.y
        edge_index = graphs.edge_index
        edge_attr = graphs.edge_attr
        coord = graphs.pos  # Random initial coordinates
        batch_index = graphs.batch  # Index of each node in the batch

        node_features = node_features.unsqueeze(0)
        coord = coord.unsqueeze(0)

        # Get number of nodes in the graph
        num_nodes = graphs.x.size(0)
        
        # Build adjacency tensor
        edge_tensor = build_adjacency_tensor(
            graphs.edge_index, 
            graphs.edge_attr, 
            num_nodes
        )

        logger.info(node_features.size())
        logger.info(coord.size())
        logger.info(edge_tensor.size())

        outputs = model(node_features, coord, edge_tensor)#, batch=batch_index)

        loss = criterion(outputs, labels) if labels is not None else None  # Handle loss calculation
        if loss is not None:
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()

     # Average training loss for the epoch
    train_losses.append(epoch_train_loss / len(train_loader))

    # Validation step
    model.eval()  # Set the model to evaluation mode
    epoch_val_loss = 0
    with torch.no_grad():
        for graphs in test_loader:
            labels = graphs.y
            node_features = graphs.x
            edge_index = graphs.edge_index
            edge_attr = graphs.edge_attr
            coord = graphs.pos  # Random initial coordinates
            batch_index = graphs.batch  # Index of each node in the batch
            outputs = model(node_features, edge_index, edge_attr, coord)#, batch=batch_index)

            loss = criterion(outputs, labels) if labels is not None else None
            if loss is not None:
                epoch_val_loss += loss.item()

    # Average validation loss for the epoch
    val_losses.append(epoch_val_loss / len(test_loader))


Training:   0%|          | 0/524 [00:00<?, ?batch/s][32m2024-11-07 19:09:54.235[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m39[0m - [1mtorch.Size([1, 3655, 11])[0m
[32m2024-11-07 19:09:54.236[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m40[0m - [1mtorch.Size([1, 3655, 3])[0m
[32m2024-11-07 19:09:54.236[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m41[0m - [1mtorch.Size([1, 3655, 3655, 4])[0m
Training:   0%|          | 0/524 [00:00<?, ?batch/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (13359025x27 and 1029x2058)

In [None]:
val_losses

In [None]:
train_losses

In [None]:
# Plotting the learning curves
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Learning Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()