In [1]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
from torch_geometric.datasets import QM9
from torch_geometric.data import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import numpy as np
import rdkit
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

## Dataset

In [6]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

mps


In [14]:
dataset = QM9(root='./QM9')

In [15]:
from torch.utils.data import random_split
g = torch.Generator().manual_seed(67)
sample_size = int(0.67 * len(dataset))
dataset_67, _ = random_split(dataset, [sample_size, len(dataset) - sample_size], generator=g)

num_train = int(0.8 * len(dataset_67))
num_val = int(0.1 * len(dataset_67))
train_set, val_set, test_set = random_split(
    dataset_67, [num_train, num_val, len(dataset_67) - num_train - num_val], generator=g
)

In [16]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=64)
test_loader  = DataLoader(test_set,  batch_size=64)
print("Train set size:", len(train_set))
print("Validation set size:", len(val_set))
print("Test set size:", len(test_set))


Train set size: 70124
Validation set size: 8765
Test set size: 8767


## Encoder Architecture
The **Encoder** consists of:
- Message-passing layers (e.g., GCNConv, GINEConv) to compute node embeddings.
- Global readout (e.g., `global_add_pool`) to obtain graph representation $h$
- Two separate MLPs mapping $h$ to latent mean $\mu$ and log-variance log $\sigma^2$

**Reparameterization**    
Sample $z = \mu + \exp(\frac{1}{2} \sigma^2) \odot \varepsilon$ where $\varepsilon \sim N(0,1)$

In [17]:
from torch_geometric.nn import global_add_pool

In [18]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim, edge_dim):
        super(Encoder, self).__init__()
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(in_channels, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=edge_dim
        )
        self.conv2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=edge_dim
        )
        self.readout = global_add_pool
        self.lin_mu  = nn.Linear(hidden_dim, latent_dim)
        self.lin_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)
        h = self.readout(x, batch)
        mu = self.lin_mu(h)
        logvar = self.lin_logvar(h)
        return mu, logvar


In [19]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

## Decoder Architecture
Condition on $z$ to reconstruct adjacency and node features. For example:
- Edge-wise MLP: For each potential node pair, predict bond existence/type.
- Autoregressive (e.g., graph sequential decoding).
- Graph deconvolution layers.

In [21]:
class Decoder(nn.Module):
    def __init__(
        self,
        latent_dim,
        hidden_dim,
        num_node_types,
        num_edge_types,
    ):
        super(Decoder, self).__init__()
        self.node_mlp = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_node_types),
        )
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_edge_types),
        )

    def forward(self, z, batch, edge_index=None):
        # z: [num_graphs_in_batch, latent_dim]
        # batch: [num_nodes] mapping each node to its graph index
        # Broadcast z to node-level using batch
        z_nodes = z[batch]  # [num_nodes, latent_dim]
        node_logits = self.node_mlp(z_nodes)

        edge_logits = None
        if edge_index is not None:
            src, dst = edge_index
            edge_inputs = torch.cat([z_nodes[src], z_nodes[dst]], dim=-1)
            edge_logits = self.edge_mlp(edge_inputs)

        return node_logits, edge_logits

## VAE Setup and Loss
**Loss**: Reconstruction Loss: Cross-entropy for discrete features (atom types, bond types)    
**KL Divergence**: $\text{KL} = -\frac{1}{2}\sum_{i=1}^d(1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2)$



In [22]:
def vae_loss(node_logits, true_node_labels, edge_logits, true_edge_labels, mu, logvar):
    recon_node = nn.CrossEntropyLoss()(node_logits, true_node_labels)
    recon_edge = nn.CrossEntropyLoss()(edge_logits, true_edge_labels)
    kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_node + recon_edge + 1e-4 * kl

In [23]:
E = Encoder(
    in_channels=dataset.num_node_features,
    hidden_dim=128,
    latent_dim=64,
    edge_dim=dataset.num_edge_features
)
E.to(device)
print(E)

Encoder(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=11, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  ))
  (conv2): GINEConv(nn=Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  ))
  (lin_mu): Linear(in_features=128, out_features=64, bias=True)
  (lin_logvar): Linear(in_features=128, out_features=64, bias=True)
)


In [24]:
D = Decoder(
    latent_dim=64,
    hidden_dim=128,
    num_node_types=5, # C, N, O, F, H
    num_edge_types=4 # single, double, triple, aromatic
)
D.to(device)
print(D)


Decoder(
  (node_mlp): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=5, bias=True)
  )
  (edge_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4, bias=True)
  )
)


In [25]:
optimizer = Adam(list(E.parameters()) + list(D.parameters()), lr=0.001)

## Model Training (10 Epochs)

In [26]:
for epoch in range(1, 11):
    E.train()
    D.train()
    total_loss = 0
    for data in train_loader:
        # Move batch to the selected device to avoid CPU/MPS mismatch
        data = data.to(device)
        optimizer.zero_grad()
        mu, logvar = E(data.x, data.edge_index, data.edge_attr, data.batch)
        z = reparameterize(mu, logvar)
        
        # Decoder now expects per-node broadcast via batch and optional edge_index
        node_logits, edge_logits = D(z, data.batch, data.edge_index)

        # Targets aligned to per-node and per-edge predictions
        node_targets = data.x.argmax(dim=1)
        edge_targets = data.edge_attr.argmax(dim=1)

        loss = vae_loss(
            node_logits, node_targets,
            edge_logits, edge_targets,
            mu, logvar
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch:03d}, Loss: {total_loss/len(train_loader):.4f}")

Epoch 001, Loss: 0.2327
Epoch 002, Loss: 0.2135
Epoch 003, Loss: 0.2133
Epoch 004, Loss: 0.2131


KeyboardInterrupt: 

Ok this uses too much memory and training is about done

In [28]:
ckpt = {
    "epoch": epoch,
    "encoder": E.state_dict(),
    "decoder": D.state_dict(),
    "optimizer": optimizer.state_dict(),
    "train_args": {"lr": 1e-3, "batch_size": 64},
}
torch.save(ckpt, "checkpoints/qvae_epoch_{:03d}.pt".format(epoch))

In [29]:
torch.save(ckpt, "checkpoints/qvae_best.pt")

## Resume Training

In [None]:
ckpt = torch.load("checkpoints/qvae_best.pt", map_location=device)
E.load_state_dict(ckpt["encoder"])
D.load_state_dict(ckpt["decoder"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt.get("epoch", 0) + 1