In [11]:
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

## QM9 Dataset Enhanced with BBBP

In [1]:
from dataset import QM9GraphDataset, create_data_loaders

In [2]:
dataset = QM9GraphDataset(csv_path='qm9_bbbp.csv')

Processing...


Processing 2142 molecules...
Successfully processed 2142 molecules, failed: 0


Done!


In [3]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=4)

Dataset splits — Train: 1713, Val: 214, Test: 215


In [8]:
for batch in train_loader:
    print("Batch x shape:", batch.x.shape)
    print("Batch edge_index shape:", batch.edge_index.shape)
    print("Batch edge_attr shape:", batch.edge_attr.shape)
    print("Batch y shape:", batch.y.shape)
    break

Batch x shape: torch.Size([72, 29])
Batch edge_index shape: torch.Size([2, 150])
Batch edge_attr shape: torch.Size([150, 6])
Batch y shape: torch.Size([4])


## GraphEncoder
Graph Neural Network Encoder for molecular graphs following current literature best practices
    
References:
- Uses multi-head GAT for attention-based aggregation (Veličković et al., 2017)
- Incorporates residual connections and layer normalization (He et al., 2016)
- Multiple pooling strategies for graph-level representation (Xu et al., 2018)
- Dropout and batch normalization for regularization

In [4]:
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_add_pool
from torch_geometric.nn import BatchNorm

In [5]:
from models.graphencoder import GraphEncoder

In [9]:
node_dim = 29
edge_dim = 6
hidden_dim = 64
latent_dim = 32
num_layers = 4
heads = 4

In [12]:
GE = GraphEncoder(node_dim, edge_dim, hidden_dim, latent_dim, num_layers, heads).to(device)

In [13]:
one_batch = next(iter(train_loader)).to(device)

In [15]:
mu, log_var = GE(
    one_batch.x,
    one_batch.edge_index,
    one_batch.edge_attr,
    one_batch.batch
)
print(mu.shape, log_var.shape)

torch.Size([4, 32]) torch.Size([4, 32])


In [17]:
from models.propertypredictor import PropertyPredictor

In [18]:
def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

In [19]:
PP = PropertyPredictor(latent_dim, num_properties=1).to(device)

In [20]:
PP(reparameterize(mu, log_var))

tensor([[1.6938e+08],
        [9.5779e+04],
        [8.7629e+05],
        [5.0267e+05]], device='mps:0', grad_fn=<LinearBackward0>)