In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
from torch_geometric.datasets import QM9
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import numpy as np
from rdkit import Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

## Dataset

In [5]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

mps


In [6]:
dataset = QM9(root='./QM9')

In [9]:
from torch.utils.data import random_split
g = torch.Generator().manual_seed(67)
sample_size = int(0.1 * len(dataset))
dataset_67, _ = random_split(dataset, [sample_size, len(dataset) - sample_size], generator=g)

num_train = int(0.8 * len(dataset_67))
num_val = int(0.1 * len(dataset_67))
train_set, val_set, test_set = random_split(
    dataset_67, [num_train, num_val, len(dataset_67) - num_train - num_val], generator=g
)

In [10]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=64)
test_loader  = DataLoader(test_set,  batch_size=64)
print("Train set size:", len(train_set))
print("Validation set size:", len(val_set))
print("Test set size:", len(test_set))


Train set size: 10466
Validation set size: 1308
Test set size: 1309


## Encoder Architecture
The **Encoder** consists of:
- Message-passing layers (e.g., GCNConv, GINEConv) to compute node embeddings.
- Global readout (e.g., `global_add_pool`) to obtain graph representation $h$
- Two separate MLPs mapping $h$ to latent mean $\mu$ and log-variance log $\sigma^2$

**Reparameterization**    
Sample $z = \mu + \exp(\frac{1}{2} \sigma^2) \odot \varepsilon$ where $\varepsilon \sim N(0,1)$

In [11]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim, edge_dim):
        super(Encoder, self).__init__()
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(in_channels, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=edge_dim
        )
        self.conv2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim=edge_dim
        )
        self.readout = global_add_pool
        self.lin_mu  = nn.Linear(hidden_dim, latent_dim)
        self.lin_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)
        h = self.readout(x, batch)
        mu = self.lin_mu(h)
        logvar = self.lin_logvar(h)
        return mu, logvar


## Decoder Architecture
Condition on $z$ to reconstruct adjacency and node features. For example:
- Edge-wise MLP: For each potential node pair, predict bond existence/type.
- Autoregressive (e.g., graph sequential decoding).
- Graph deconvolution layers.

In [15]:
class Decoder(nn.Module):
    def __init__(
        self,
        latent_dim,
        hidden_dim,
        num_node_types,
        num_edge_types,
    ):
        super(Decoder, self).__init__()
        self.node_mlp = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_node_types),
        )
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_edge_types),
        )

    def forward(self, z, edge_index, batch):
        # z: [num_graphs_in_batch, latent_dim]
        # batch: [num_nodes] mapping each node to its graph index
        # Broadcast z to node-level using batch
        z_nodes = z[batch]               # -> [num_nodes, latent_dim]
        node_logits = self.node_mlp(z_nodes)

        src, dst   = edge_index
        edge_inputs = torch.cat([z_nodes[src], z_nodes[dst]], dim=-1)
        edge_logits = self.edge_mlp(edge_inputs)

        return node_logits, edge_logits

## Graph VAE Setup
**Loss**: Reconstruction Loss: Cross-entropy for discrete features (atom types, bond types)    
**KL Divergence**: $\text{KL} = -\frac{1}{2}\sum_{i=1}^d(1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2)$



In [12]:
from torch_geometric.nn.models import VGAE

In [13]:
E = Encoder(
    in_channels=dataset.num_node_features,
    hidden_dim=128,
    latent_dim=64,
    edge_dim=dataset.num_edge_features
)
E.to(device)

Encoder(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=11, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  ))
  (conv2): GINEConv(nn=Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  ))
  (lin_mu): Linear(in_features=128, out_features=64, bias=True)
  (lin_logvar): Linear(in_features=128, out_features=64, bias=True)
)

In [16]:
D = Decoder(
    latent_dim=64,
    hidden_dim=128,
    num_node_types=5, # C, N, O, F, H
    num_edge_types=4 # single, double, triple, aromatic
)
D.to(device)


Decoder(
  (node_mlp): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=5, bias=True)
  )
  (edge_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4, bias=True)
  )
)

In [17]:
model = VGAE(encoder=E, decoder=D)
model.to(device)

VGAE(
  (encoder): Encoder(
    (conv1): GINEConv(nn=Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    ))
    (conv2): GINEConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    ))
    (lin_mu): Linear(in_features=128, out_features=64, bias=True)
    (lin_logvar): Linear(in_features=128, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (node_mlp): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=5, bias=True)
    )
    (edge_mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=4, bias=True)
    )
  )
)

In [18]:
optimizer =  Adam(model.parameters(), lr=0.001)

In [19]:
recon_node_loss = nn.CrossEntropyLoss()
recon_edge_loss = nn.CrossEntropyLoss()

In [20]:
def loss_fn(data, node_logits, edge_logits, mu, logvar):
    recon_node_loss = nn.CrossEntropyLoss()(node_logits, data.x.argmax(dim=1))
    recon_edge_loss = nn.CrossEntropyLoss()(edge_logits, data.edge_attr.argmax(dim=1))
    kl = model.kl_loss(mu, logvar)
    return recon_node_loss + recon_edge_loss + 1e-4 * kl # 1e-4 is a small constant to prevent KL loss from dominating


## Model Training (10 Epochs)

In [21]:
for epoch in range(1, 11):
    model.train()
    total_loss = 0.0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        # 1. Encode: get mu and logvar
        mu, logvar = model.encoder(data.x, data.edge_index, data.edge_attr, data.batch)

        # 2. Reparameterize to obtain z
        z = model.reparametrize(mu, logvar)

        # 3. Decode: predict node and edge logits
        node_logits, edge_logits = model.decoder(z, data.edge_index, data.batch)

        # 4. Compute loss
        loss = loss_fn(data, node_logits, edge_logits, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch:02d} — Loss: {avg_loss:.4f}")

Epoch 01 — Loss: 238.3771
Epoch 02 — Loss: 0.2649
Epoch 03 — Loss: 0.2523
Epoch 04 — Loss: 0.2423
Epoch 05 — Loss: 0.2344
Epoch 06 — Loss: 0.2300
Epoch 07 — Loss: 0.2272
Epoch 08 — Loss: 0.2248
Epoch 09 — Loss: 0.2228
Epoch 10 — Loss: 0.2220


Less than $1\%$ of the dataset...

In [22]:
ckpt = {
    "epoch": epoch,
    "encoder": E.state_dict(),
    "decoder": D.state_dict(),
    "optimizer": optimizer.state_dict(),
    "train_args": {"lr": 1e-3, "batch_size": 64},
}
torch.save(ckpt, "checkpoints/qvae_epoch_{:03d}.pt".format(epoch))

In [23]:
torch.save(ckpt, "checkpoints/qvae_best.pt")

## Inference

In [24]:
ckpt = torch.load("checkpoints/qvae_best.pt", map_location=device)
E.load_state_dict(ckpt["encoder"])
D.load_state_dict(ckpt["decoder"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt.get("epoch", 0) + 1

In [25]:
vae = VGAE(encoder=E, decoder=D)
vae.to(device)
vae.eval()

VGAE(
  (encoder): Encoder(
    (conv1): GINEConv(nn=Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    ))
    (conv2): GINEConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    ))
    (lin_mu): Linear(in_features=128, out_features=64, bias=True)
    (lin_logvar): Linear(in_features=128, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (node_mlp): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=5, bias=True)
    )
    (edge_mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=4, bias=True)
    )
  )
)

In [26]:
def sample_molecules(model, num_samples, edge_index, batch, device):
    model.eval()
    with torch.no_grad():
        # 1. Sample latent vectors
        z = torch.randn(num_samples, model.encoder.lin_mu.out_features).to(device)

        # 2. Decode to get logits
        node_logits, edge_logits = model.decoder(z, edge_index, batch)

        # 3. Discretize
        atom_types = node_logits.argmax(dim=-1)
        bond_types = edge_logits.argmax(dim=-1)

    return atom_types, bond_types

In [27]:
def make_generation_graph(num_samples: int, max_nodes: int, device):
    # 1. Build batch: [0,0,…,0,1,1,…,1,…,num_samples-1,…]
    batch = torch.arange(num_samples, device=device).unsqueeze(1).repeat(1, max_nodes).view(-1)  # shape [num_samples * max_nodes]

    # 2. Build node indices for each sample
    node_offsets = torch.arange(num_samples, device=device) * max_nodes
    node_indices = (node_offsets.unsqueeze(1) + torch.arange(max_nodes, device=device)).view(-1)  # shape [num_samples * max_nodes]

    # 3. Create all possible directed edges (excluding self-loops)
    src = node_indices.unsqueeze(1).repeat(1, max_nodes).view(-1)  # repeat each node max_nodes times
    dst = node_indices.repeat(max_nodes)                          # tile the entire list max_nodes times

    # 4. Mask out self-loops
    mask = src != dst
    edge_index = torch.stack([src[mask], dst[mask]], dim=0)       # shape [2, num_edges]

    return batch, edge_index

In [28]:
gen_batch, gen_edge_index = make_generation_graph(10, 9, device)
atom_types, bond_types = sample_molecules(vae, 10, gen_edge_index, gen_batch, device)

In [29]:
element_list = ['H','C','N','O','F']
bond_type_list = [
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC
]

In [30]:
molecules = []
start = 0
for i in range(10):
    # Select nodes for this molecule
    node_mask = gen_batch == i
    n_nodes = node_mask.sum().item()

    # Build node features: one-hot from atom_types
    ats = atom_types[start:start+n_nodes]
    x = torch.nn.functional.one_hot(ats, num_classes=5).float()

    # Select edges for this molecule
    edge_mask = (gen_batch[gen_edge_index[0]] == i) & (gen_batch[gen_edge_index[1]] == i)
    ei = gen_edge_index[:, edge_mask]
    bs = bond_types[edge_mask]
    edge_attr = torch.nn.functional.one_hot(bs, num_classes=4).float()

    data = Data(x=x, edge_index=ei, edge_attr=edge_attr)
    molecules.append(data)
    start += n_nodes

In [31]:
def to_rdkit(data):
    mol = Chem.RWMol()
    # Add atoms
    for at in data.x.argmax(dim=1).tolist():
        mol.AddAtom(Chem.Atom(element_list[at]))
    # Add bonds
    src, dst = data.edge_index
    for u, v, bt in zip(src.tolist(), dst.tolist(), data.edge_attr.argmax(dim=1).tolist()):
        # Avoid duplicates in undirected graph
        if u < v:
            mol.AddBond(u, v, bond_type_list[bt])
    Chem.SanitizeMol(mol)
    return mol

In [32]:
rdkit_mols = [to_rdkit(d) for d in molecules]

[09:00:13] Explicit valence for atom # 0 C, 16, is greater than permitted


AtomValenceException: Explicit valence for atom # 0 C, 16, is greater than permitted