<div class="alert alert-block alert-info">

<b>Note:</b> this template was adpated from TeachOpenCADD's [GitHub repository](https://github.com/volkamerlab/teachopencadd).

</div>

# Topic 7 - DiGress: Discrete Denoising Diffusion for Graph Generation

**Note:** This talktorial was created as a part of the 2024 [Hands-on Graph Neural Networks seminar](https://cms.sic.saarland/hognn2024/) at Saarland University.

Author:
- Jilixin Tang, 2024, Saarland University

## Aim of this talktorial

In this tutorial, I will walk you through the DiGress (Discrete Graph Denoising Diffusion) model, a diffusion model for graph generation using discrete noise. In particular, I will first explain the intuition, high level understanding behind the diffusion model; the differences between adding Gaussion noise and discrete noise; how the training and generation are done; as well as the methods to augment the generation process; and then I will demonstrate the training and generation algorithms in Pytorch framework. Finally, I will discuss how DiGress performs across different tasks.

## Contents in *theory*


* Why experts take interests in AI generated new molecules?
* Why diffusion models?
* What is a diffusion model?
* Why discrete diffusion?
* Mathematical framework of DiGress

## Contents in *practice*


* Python dependencies
* Training algorithm
* Sampling algorithm
* DiGress's Performance

## References

- Papers:

  [1] DiGress: Discrete Denoising diffusion for graph generation: [Vignac et al., *ICLR* (2023)](https://arxiv.org/abs/2209.14734).   
  [2] Denoising Diffusion Probabilistic Models: [Ho et al., *CoRR* (2020)](https://arxiv.org/abs/2006.11239).   
  [3] Fast Graph Generation via Spectral Diffusion: [Luo et al., *IEEE TPAMI* (2023), **46**](https://arxiv.org/abs/2211.08892).   
  [4] Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations: [Jo et al., *CoRR* (2022)](https://arxiv.org/abs/2202.02514).  
  [5] Can graph neural networks count substructures?: [Chen et al., *CoRR* (2020)](https://arxiv.org/abs/2002.04025)

- Github:  
  [6] [DiGress: Discrete Denoising diffusion models for graph generation](https://github.com/cvignac/DiGress)

- Webpages:    
  [7] [Open Catalyst Project](https://opencatalystproject.org/)




## Theory

###Why experts take interests in AI generated new molecules? (i.e. why not generate molecules using established rules and guidelines?)

*   Rules might constrain innovation. Rules and guidelines like Lipinski’s Rule of Five or Quantitative Structure-Activity Relationships (QSAR) are based on historical data and may not fully capture the potential of unexplored chemistries.

*   The theoretical number of possible molecules is estimated to be around $10^6$, AI, particularly generative models, can explore this enormous chemical space efficiently, finding novel and unexpected candidates that traditional methods might overlook.

*   Graph has nodes and edges which resemble that of atoms and bonds, we can easily model molecules in computer science. Given a set of existing molecules, generative models combined with graph neural networks are good at learning latent representations of molecules and approximating their underlying distribution, potentially accelerating the discovery of molecules with desired properties.

*   In drug discovery and material science, scientists try to disover new drugs to treat diseases, as well as new materials, such as a catalyst to enable clean energy. For example, in the [Open Catalyst Project](https://opencatalystproject.org/) (by Meta and Carnegie Mellon University): as we increase our reliance on renewable energy sources such as wind and solar, storage is needed to transfer power from times of peak generation to peak demand. This may require the storage of power for hours, days, or months.

  One solution that offers the potential of scaling to nation-sized grids (see below for figure 1 centre) is the conversion of renewable energy to other fuels, such as hydrogen. Hydrogen is used as a medium to store energy generated from wind and solar as an alternative to batteries. An open challenge is finding low-cost catalysts to drive these reactions at high rates. However, the cost of testing new catalysts is also high, the use of AI or machine learning may provide a method to efficiently approximate these calculations, leading to new approaches in finding effective catalysts [7].



<img src="https://www.dropbox.com/scl/fi/6todn0jjbv2dpczmnu477/Energy-Storage.jpg?rlkey=4t986uszn4a9xnjvyljybd7dt&st=j6zu9w0r&raw=1" style="width: 100px;"/>



**Figure 1**: high-voltage power lines(left); hydrogen storage tanks labeled with $H_2$ (centre); localized energy storage solutions(right).


### Why diffusion models?


*   Diffusion models, like auto-regressive models, generative adversial networks(GAN),variational auto-encoders(VAE), form a powerful class of generation models.

*   They excel in image and video generation and often outperfom other generative methods.

*   Molecule generation has ample room for improvement, and successes in image and video generation naturally raise hope for graph generation.



### What is a diffusion model?

For images, as shown below in Figure 2 on the right, we have an image of a person. Diffusion models learn the underlying distribution of the image by first adding **noise** to the image using $q$, gradually blurring it as we move to the left; then the reverse diffusion iterations $p_𝜃$ learns to **denoise** this process, recovering the clean image step by step.

The information is removed and recovered on a global level, with the entired image processed at once. This enables diffusion models to capture broader patterns compared to, for example, **autoregressive models**, which learn pixel-by-pixel dependencies in a fixed sequence, primarily focusing on local information.

<img src="https://www.dropbox.com/scl/fi/rv7qtu44s1tuqhv7iqbcv/Gaussian-Pipeline.png?rlkey=2694pw07nakj8nsj59zq8726e&st=lsjjj28i&dl=1" style="width: 800px;"/>


**Figure 2**: The Markov chain of forward (reverse) diffusion process of generating a sample by slowly adding (removing) noise, image taken from [2].







For graphs, one way to diffuse a graph is by adding noise to its adjacency matrix. As shown in Figure 3 on the left, a clean adjacency matrix $A_0$, derived from the graph $G_0$, undergoes a diffusion process where Gaussian noise is gradually added from left to right, corrupting the structure. A reverse process is then modeled by a denoising network, which learns to reconstruct the clean adjacency matrix $A_0$, ultimately recovering the original graph $G_0$.

<img src="https://www.dropbox.com/scl/fi/nulsg1u1l817ts5rdfgoq/Gaussian-noise-diffusion_edited.jpg?rlkey=urodkcpmk1bdu5tqbxam7bd1f&st=ce344wa0&dl=1" style="width: 800px;"/>

**Figure 3**: Graph Diffusion via SDE systems (GDSS), first proposed in [4], with Gaussian noise added to node features and adjacency matrix. This image, taken from [3], only displays Gaussian noise insertion to adjacency matrix. **Note**: the original image also contains the formulas of Stochastic Differential Equation (SDE), which describe how noise is added and removed. However, these formulas differ from the diffusion models discussed in this notebook, which can be thought of as a specific application of SDEs. Since the focus here is on *Gaussian noise* rather than the precise formulation of how it is added, the SDE formulas have been removed to avoid confusion.



### Why discrete and not Gaussian diffusion?

It has been argued that, unlike densely distributed image data, Gaussian noise is unsuitable as a noise model for graphs because it destroys graph sparsity—the property of graphs having relatively few edges compared to the maximum possible number of edges. Additionally, graphs have a discrete edge structure, where an edge either exists or does not, but Gaussian noise introduces continuous perturbations that blur this binary nature. As shown in Figure 3, the originally sparse adjacency matrix turns into a fully densed matrix. This makes adjacency matrices less interpretable and more challenging to reconstruct during denoising.

A fundamental problem with Gaussian noise lies in its isotropic nature: it treats all dimensions equally, adding noise independently to every entry of the adjacency matrix regardless of its initial value. This uniform treatment disregards the graph's inherent structure, blindly encouraging message passing in sparsely connected parts and severely distorting the global message-passing pattern [1], [3].

Therefore, it has been proposed that discrete noise should be used instead.

### Overview of DiGress

What is ***discrete*** noise? As opposed to adding random continuous values drawn from a Gaussian distribution , it means the noise we can insert is **categorical**. We have pre-defined categories of node types and edge types, in other words, the *state space* is finite, adding noise is to gradually altering the types of the original nodes and edges. As shown in Figure 4, the noisy graph $G^{t-1}$ is formed by altering a maroon edge of the clean graph $G$ to no-edge, and an orange edge to green edge.   

As we have discussed above, (discrete) diffusion models learn the underlying distribution of images and graphs by first gradually adding noise, and then trying to reverse the diffusion process step by step.

\\

**In this notebook we focus on DiGress**:

Adding noise gradually to the clean graph requires a **noise model**, which is typically a Markov process consisting of successive graph edits (edge addition or deletion, node or edge category edit). This is called the **forward diffusion** process, as shown in Figure 4 from graph $G$ to $G^T$.

Reversing the diffusion process step by step need: 1) a **gragh neural network**, in DiGress a graph transformer that predicts the clean state of a noisy graph, this is usually referred to as the **denoising network**, as shown in Figure 4, the denoising network $\phi_\theta$ takes in the noisy graph $G^t$ and make predictions on its clean state; 2) the reverse of the noise model used in the forward process, this distribution is conditioned on the noisy graph and its original clean state, there is no prediction involved. It is often called the **ideal reverse**, as shown in Figure 4, $q(G^{t-1}|G^{t}, G)$ directly reverse the forward diffusion.

The denoising network is trained by optimising the **cross-entropy** between the network predictions of the clean state and the true labels.

Combining 1) with 2), we can compute the **network-predicted reverse diffusion**: $p_{θ}(G^{t-1}|G^{t})$. This is a **learned** distribution of the step by step denoising process, without knowing the the true graph $G$.

Since $p_{θ}(G^{t-1}|G^{t})$ is the predictions made by the **denoising network** combined with the **ideal reverse**, the less noisy graphs it produces at each time step may not match exactly the ones that the ideal reverse produces.

Therefore, after each training epoch, the denoising network is validated on the validation dataset using several metrics, one of the metrics used is **KL-Divergence** between one less noisy graph $q(G^{t-1}|G^{t}, G)$ produced by the ideal reverse and the one produced by the network-predicted reverse diffusion $p_{θ}(G^{t-1}|G^{t})$. If the loss improves, it will be saved as the model checkpoint. After all epochs, the model with the lowest validation loss will be selected. This step is used to evaluate the model's ability to **approximate** the process of the ideal reverse producing a nice clean graph.

After training is completed and the network has learned the distribution of the provided dataset, it can then take a randomly sampled noisy graph $G^t$, denoise the graph step by step $p_{θ}(G^{t-1}|G^{t})$ and eventually compute $p_{θ}(G^{0}|G^{1})$. Since we start from a randomly sampled noisy graph without a label, $G^{0}$ is **generated** as a new graph.

<img src="https://www.dropbox.com/scl/fi/pd9apx6id1qacikzonao1/Discrete-Pipeline.png?rlkey=9rc79ahc25eu8lgx0qjz0mfk4&st=mi6i0kyh&dl=1" style="width: 800px;"/>

**Figure 4**: Overview of DiGress, image taken from [1].

### Mathematical framework of DiGress

**Notation**:  

$G = (\mathbf{X}, \mathbf{E})$: a clean graph used for training.  \\
$G^t = (\mathbf{X}^t, \mathbf{E}^t)$: a noisy graph at time step $t$.  \\
$G^T = (\mathbf{X}^T, \mathbf{E}^T)$: a fully noised graph, $T$ is pre-defined number for the total time steps.  \\
$\mathbf{x}_i \in \mathbb{R}^a$: the one-hot encoding for node $i$, $a$ stands for the cardinality of node types. \\
$\mathbf{X}\in \mathbb{R}^{n\times a}$: a node matrix that organises all node encodings, where $n$ is the number of nodes in a graph.  \\
$\mathbf{e}_{ij} \in \mathbb{R}^b$: the one-hot encoding for the attribute of the edge between nodes $i$ and $j$, $b$ stands for the cardinality of edge types. \\
$\mathbf{E} \in \mathbb{R}^{n\times n\times b}$: an edge tensor encompasses the edge type between $n\times n$ all possible pair of nodes.   \\
$\mathbf{Q}^{'}$: the matrix transpose of $\mathbf{Q}$. \\
$G^0$: a new graph generated during sampling \\



In [None]:
# One hot encodings for node, edge types
import torch
import torch.nn.functional as F


atom_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
bond_types = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}

atom_one_hot_encodings = torch.tensor([
                        [1, 0, 0, 0, 0],  # Hydrogen
                        [0, 1, 0, 0, 0],  # Carbon
                        [0, 0, 1, 0, 0],  # Nitrogen
                        [0, 0, 0, 1, 0],  # Oxygen
                        [0, 0, 0, 0, 1]   # Fluorine
                        ])

bond_one_hot_encodings = torch.tensor([
                        [1, 0, 0, 0],  # SINGLE
                        [0, 1, 0, 0],  # DOUBLE
                        [0, 0, 1, 0],  # TRIPLE
                        [0, 0, 0, 1]   # AROMATIC
                        ])

**Forward noise model**: $$q(G^t|G^{t-1}) = (\mathbf{X}^{t-1}\mathbf{Q}^t_X, \mathbf{E}^{t-1}\mathbf{Q}^t_E) \quad (1)$$

The noise model is defined by Markov transition matrices $\mathbf{Q}^t$
, whose cumulative product is $\mathbf{\bar{Q}}^t$. Therefore, $q(G^t|G) = (\mathbf{X}\mathbf{\bar{Q}}^t_X, \mathbf{E}\mathbf{\bar{Q}}^t_E) \quad (2)$

$[\mathbf{Q}^t_X]_{ij} \in \mathbb{R}^{a\times a}$ represents the probabilities of a node transitioning from type $i$ at time $t-1$, to type $j$ at time $t$. Note:  node diffusion (resp. edge) is done separately on each node, but each node is applied to the same transition probability matrix $Q^t$ at step t.

In [None]:
# Markov transition matrix (a made-up example for illustration)

total_timesteps = 3
num_nodes = 3
torch.manual_seed(234)

# Transition probability matrix at step 1
Q1_X = torch.tensor([
                    [0.4, 0.2, 0.2, 0.1, 0.1],        # Type 1 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.1, 0.6, 0.1, 0.1, 0.1],        # Type 2 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.2, 0.2, 0.5, 0.05, 0.05],      # Type 3 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.1, 0.1, 0.1, 0.6, 0.1],        # Type 4 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.05, 0.05, 0.05, 0.15, 0.7]     # Type 5 node transition probabilities to type 1, 2, 3, 4, 5
                    ])

# Transition probability matrix at step 2
Q2_X = torch.tensor([
                    [0.5, 0.3, 0.1, 0.05, 0.05],
                    [0.2, 0.5, 0.2, 0.05, 0.05],
                    [0.1, 0.1, 0.6, 0.1, 0.1],
                    [0.05, 0.2, 0.1, 0.6, 0.05],
                    [0.1, 0.05, 0.05, 0.1, 0.7]
                    ])

# Transition probability matrix at step 3
Q3_X = torch.tensor([
                    [0.6, 0.2, 0.1, 0.05, 0.05],
                    [0.1, 0.7, 0.1, 0.05, 0.05],
                    [0.05, 0.1, 0.7, 0.1, 0.05],
                    [0.05, 0.2, 0.2, 0.5, 0.05],
                    [0.1, 0.05, 0.05, 0.1, 0.7]
                    ])

# Cumulative product
Q_bar3_X = torch.matmul(torch.matmul(Q1_X, Q2_X), Q3_X)
print("The cumulative product of transition probabilities matrices:\n", Q_bar3_X)


# Clean node features
x_1 = torch.tensor([1, 0, 0, 0, 0])
x_2 = torch.tensor([0, 1, 0, 0, 0])
x_3 = torch.tensor([0, 1, 0, 0, 0])

X = torch.stack([x_1, x_2, x_3])
print("The node matrix:\n", X)

# Applying the cumulative transition probability matrix to the clean node matrix
X = X.float()
XQbar3_X = torch.matmul(X, Q_bar3_X)
print("The noisy node features after three steps of diffusion:\n", XQbar3_X)  # noisy node features X^3


'''
Since XQbar_X returns a 3 x 5 matrix of probabilities, each row representing a node's probabilities of being node type 1, 2, 3, 4, 5.
We need to make it discrete by sampling a discrete node type from each row's probability distribution.
'''

# Ensure probabilities sum to 1 using softmax
XQbar3_X_normalized = F.softmax(XQbar3_X, dim=-1)
print("Normalized noisy node features (probabilities):\n", XQbar3_X_normalized)

# Sample discrete node types
discrete_nodes = torch.multinomial(XQbar3_X_normalized, num_samples=1)

print("Discrete node types after sampling:\n", discrete_nodes.squeeze())

The cumulative product of transition probabilities matrices:
 tensor([[0.2207, 0.2922, 0.2348, 0.1210, 0.1312],
        [0.1815, 0.3440, 0.2285, 0.1180, 0.1280],
        [0.1751, 0.2584, 0.3256, 0.1210, 0.1199],
        [0.1328, 0.2740, 0.2260, 0.2393, 0.1280],
        [0.1455, 0.1694, 0.1486, 0.1566, 0.3799]])
The node matrix:
 tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0]])
The noisy node features after three steps of diffusion:
 tensor([[0.2207, 0.2922, 0.2348, 0.1210, 0.1312],
        [0.1815, 0.3440, 0.2285, 0.1180, 0.1280],
        [0.1815, 0.3440, 0.2285, 0.1180, 0.1280]])
Normalized noisy node features (probabilities):
 tensor([[0.2038, 0.2189, 0.2066, 0.1844, 0.1863],
        [0.1957, 0.2302, 0.2051, 0.1836, 0.1855],
        [0.1957, 0.2302, 0.2051, 0.1836, 0.1855]])
Discrete node types after sampling:
 tensor([3, 0, 4])



**Denoising neural network $\phi_\theta$**:  
a graph transformer parametrised by $\theta$, that takes a noisy graph $\mathbf{G}^t = (\mathbf{X}^t, \mathbf{E}^t)$ as input and make predictions for its clean state $\mathbf{G}$. The model $\phi_\theta$ is trained by optimising the cross entropy between the network predictions and the true graph:

\\
$$
\text{cross-entropy}\left(x_i, \widehat{p}_i^X\right) = - \sum_{k} x_i^{(k)} \log \widehat{p}_i^{X(k)}
$$

$$
l({\widehat{p}}^G, G)=\sum_{1 \leq i \leq n} \text{cross-entropy}\left(x_i, {\widehat{p}}^X_i\right) + \lambda \sum_{1 \leq i, j \leq n} \text{cross-entropy}\left(e_{ij}, {\widehat{p}}^E_{ij}\right) \quad (3)
$$  

In [None]:
# Recall we have clean node features
x_1 = torch.tensor([1, 0, 0, 0, 0])       # Type 0
x_2 = torch.tensor([0, 1, 0, 0, 0])       # Type 1
x_3 = torch.tensor([0, 1, 0, 0, 0])       # Type 1



# Recall that Discrete noisy node types after sampling: tensor([3, 0, 4])
X_3 = torch.tensor([[0, 0, 0, 1, 0],    # Type 3
                    [1, 0, 0, 0, 0],    # Type 0
                    [0, 0, 0, 0, 1]     # Type 4
                    ])
'''
After getting a noisy discrete node matrix X at step T, the denoising network take
X_3 as the input and output a n x a (a is the cardinality of node types), which in
this case is a 3 x 5 matrix, with each row representing the probability of a node
being node type 1, 2, 3, 4, 5. Since this process is taking a noisy input and assigning
probabilities of it being each type, it's essentially a classification problem.
Without sampling from this probability distribution and making it discrete, we directly
compare it with the clean state (one-hot encoded) of the node.
'''

# A made-up network prediction (logits at first)
p_hat_X = torch.tensor([
                        [1.2, -0.5, -1.2, -3.0, -0.5],  # logits for node 1
                        [-0.8, 0.4, -1.5, 0.2, -1.0],   # logits for node 2
                        [0.3, -1.2, -3.5, 1.0, -0.8]    # logits for node 3
                        ])

# Cross-entroy between the predictions and the clean state
X_labels = torch.argmax(X, dim=-1) # returns the indices of the maximum values at the last dimension
print("Labels of clean state:\n",X_labels)

loss = F.cross_entropy(p_hat_X, X_labels) # Softmax is applied internally
print("Cross-entropy loss:", loss.item())

Labels of clean state:
 tensor([0, 1, 1])
Cross-entropy loss: 1.3624825477600098


The above are the **two main components** of the diffusion model. You may ask, shouldn't that be enough to **generate new graphs**? Given a noisy graph, the denoising network make predictions of its clean state, which is a probability distribution for each node, and we can sample from this distribution to get *new* graphs.

That being said, there are only two bits of randomness involved: the first bit is when sampling from a distribution to get a noisy graph; and the second bit is when sampling from the predicted probability distribution to get a discrete state. This is not enough to put together a generation model, when one of the probabilities is dominating, it is very likely to see the same output over and over again.

Diffusion models generate graphs by denoising noisy graphs **step by step**, which is like auto-regressive models generate graphs **node/edge by node/edge**. Each reversing step adds more **stochasticity** to the process, since at each time step $p_{θ}(x_i^{t-1}|x_i^{t})$ returns a vector of probabilities of node $x_i$ being each type, and then a discrete state is *sampled* randomly from that multinomial distribution.

We will explain $p_{θ}(x_i^{t-1}|x_i^{t})$ later. First, instead of using the noisy version of a clean graph, we need a distribution from which to directly sample noisy graphs.

\\

**Prior - the limit distribution of the noise model**:

We talked about the noise adding process using the trasition
probability matrix $\mathbf{Q}^t$, but we have not discussed the mathematical definition of the noise model, that is *how* do we come up with the transition probability $[\mathbf{Q}^t]_{ij}$ of jumping from type $i$ at time $t-1$ to type $j$ at time $t$.

The limit/prior distribution of the noise model is the distribution that the forward noise model $q(z^t|x) = \mathbf{x}\mathbf{\bar{Q}}^t$converges to, when $T \to \infty$. It is also the distribution we can directly sample noisy graphs from.

\\

It has been argued in DiGress that the **optimal noise model** is:

$$\mathbf{Q}^t_X = α^t𝑰 + (1-α)^t\mathbf{1}_a\mathbf{m}_X^{'}  \quad\mathbf{Q}^t_E = α^t𝑰 + (1-α)^t\mathbf{1}_b\mathbf{m}_E^{'} \quad (4)$$

the limit distribution for all node $i, {lim}_{T\to ∞} \mathbf{\bar{Q}}^T_X\mathbb{1}_i = \mathbf{m}_X$, is the **marginal distribution** of node types (and similarly for edge types), which makes training easier since it is close to the true data distribution, compared to a **uniform distribution** over
categories.

The **noise schedule** is governed by $α^t$, noise schedule *plans* each step of the diffusion process. $α^t$ is a parameter that transitions from 1 to 0 as $t$ increases; 𝑰 is the identity matrix representing no added-noise. When $α^t$ is close to 1 (early steps), $\mathbf{Q}^t$ is closer to the identity matrix; when $α^t$ is close to 0 (later steps), $\mathbf{Q}^t$ becomes more heavily weighted by the marginal distribution of node/edge types $\mathbf{m}^{'}$.

Note: you can also calculate $\mathbf{\bar{Q}}^t_X$ by calculating $\bar{α^t}$, the details are omitted here to avoid confusion, it is spelt out later in the *Practice* section.

In [None]:
# Made-up marginal distribution of node types

limit_dist_marginal_X = torch.tensor([0.65, 0.15, 0.1, 0.05, 0.05])
limit_dist_uniform_X  = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])


# the marginal distribution matrix: 1am'x
ones_a = torch.ones(len(limit_dist_marginal_X), 1)
marginal_distribution = limit_dist_marginal_X.unsqueeze(0)  # Shape (1 x a)
matrix_1_m_prime = ones_a @ marginal_distribution

print("The resulting matrix 1_m_prime:\n", matrix_1_m_prime)

print('''
This matrix represents a system where the probability of:
any type transitions to type 0 is 0.65;
any type transitions to type 1 is 0.15;
any type transitions to type 2 is 0.10;
any type transitions to type 3 is 0.05;
any type transitions to type 4 is 0.05;

When t approaches to infinity and Qt become more and more heavily weight by the
marginal distribution matrix, every node i will eventually have an encoding of
[0.6500, 0.1500, 0.1000, 0.0500, 0.0500].
''')


The resulting matrix 1_m_prime:
 tensor([[0.6500, 0.1500, 0.1000, 0.0500, 0.0500],
        [0.6500, 0.1500, 0.1000, 0.0500, 0.0500],
        [0.6500, 0.1500, 0.1000, 0.0500, 0.0500],
        [0.6500, 0.1500, 0.1000, 0.0500, 0.0500],
        [0.6500, 0.1500, 0.1000, 0.0500, 0.0500]])

This matrix represents a system where the probability of:
any type transitions to type 0 is 0.65;
any type transitions to type 1 is 0.15;
any type transitions to type 2 is 0.10;
any type transitions to type 3 is 0.05;
any type transitions to type 4 is 0.05;

When t approaches to infinity and Qt become more and more heavily weight by the 
marginal distribution matrix, every node i will eventually have an encoding of 
[0.6500, 0.1500, 0.1000, 0.0500, 0.0500].



**Posterior - ideal reverse diffusion**:  
Using Bayes rule and Markovian properties, we have the ideal reverse of the forward noise model in closed form.

\\
$$q(x^{t-1}_i|x, x^t_i) \propto {\mathbf{x}^t{(\mathbf{Q}^t)}^{'} \odot \mathbf{x}\mathbf{\bar{Q}}^{t-1}} \quad (5)$$  

The reason we need to **condition** on both $x^i_t$ and the clean state $x$ is that we need to fix $x^{t-1}_i$ by knowing both directions: forward and reverse, for multiple forward noise steps could lead to the same $x^t$ state, depending on the transition probabilities, knowing only $x^{t-1}$ leaves ambiguity about which $x^t$ state it came from.

In [None]:
# An illustration for the differences between with/without the clean state
torch.manual_seed(234)

Q_bar2_X = torch.matmul(Q1_X, Q2_X)
XQbar2_X = torch.matmul(X, Q_bar2_X)
print('Forward diffusion at step 2:\n',XQbar2_X)


def compute_posterior(X_t, Qt_X, Qsb_X, Qtb_X, X=None, use_clean_state=True):
    """
    This function is taken from the DiGress official code with some alterations.

    Compute the posterior distribution for reverse diffusion with and without clean state.

    Args:
        X_t (Tensor): Noisy node state at time t. Shape: (N, d).
        Qt_X (Tensor): Transition matrix for t to t-1. Shape: (d, d).
        Qsb_X (Tensor): Cumulative transition matrix from 0 to t-1. Shape: (d, d).
        Qtb_X (Tensor): Cumulative transition matrix from 0 to t. Shape: (d, d).
        X (Tensor): Clean state (optional). Shape: (N, d).
        use_clean_state (bool): Whether to include clean state in the computation.

    Returns:
        Tensor: Posterior probabilities. Shape: (N, d).
    """

    Qt_X_T = Qt_X.T

    # Left term: X_t @ Qt.T
    left_term = X_t @ Qt_X_T

    if use_clean_state and X is not None:
        # Right term: X @ Qsb
        right_term = X @ Qsb_X
        numerator = left_term * right_term
    else:
        numerator = left_term


    # Denominator: (X @ Qtb) * X_t
    if use_clean_state and X is not None:
        denom = (X @ Qtb_X) * X_t
        denom = denom.sum(dim=-1, keepdim=True)
    else:
        denom = (X_t @ Qtb_X).sum(dim=-1, keepdim=True)

    #denom[denom == 0] = 1e-8

    prob = numerator / denom

    return prob


posterior_with_clean = compute_posterior(XQbar3_X.type(torch.float32), Q3_X, Q_bar2_X, Q_bar3_X, X, use_clean_state=True)
print("Reverse to step 2 with clean state:\n", posterior_with_clean)

posterior_without_clean = compute_posterior(XQbar3_X.type(torch.float32), Q3_X, Q_bar2_X, Q_bar3_X, use_clean_state=False)
print("Reverse to step 2 without clean state:\n", posterior_without_clean)

print("\nWith or without conditioning on the clean state clearly produce different results.\n"
      "With the clean state, we can see some similarities, but the posterior doesn't exactly\n"
      "match the forward diffusion at step 2. This can be due to the fact that this is only a\n"
      "toy example, the forward diffusion has only 3 steps, whereas the official DiGress code\n"
      "set the time steps to be 500. The theory may need to be tested on a longer diffusion process."
)


Forward diffusion at step 2:
 tensor([[0.2750, 0.2650, 0.2150, 0.1200, 0.1250],
        [0.1950, 0.3650, 0.2050, 0.1150, 0.1200],
        [0.1950, 0.3650, 0.2050, 0.1150, 0.1200]])
Reverse to step 2 with clean state:
 tensor([[0.2823, 0.3149, 0.2171, 0.0996, 0.0862],
        [0.1775, 0.4591, 0.1943, 0.0930, 0.0761],
        [0.1775, 0.4591, 0.1943, 0.0930, 0.0761]])
Reverse to step 2 without clean state:
 tensor([[0.2270, 0.2627, 0.2233, 0.1835, 0.1524],
        [0.2129, 0.2941, 0.2216, 0.1890, 0.1482],
        [0.2129, 0.2941, 0.2216, 0.1890, 0.1482]])

With or without conditioning on the clean state clearly produce different results.
With the clean state, we can see some similarities, but the posterior doesn't exactly
match the forward diffusion at step 2. This can be due to the fact that this is only a
toy example, the forward diffusion has only 3 steps, whereas the official DiGress code
set the time steps to be 500. The theory may need to be tested on a longer diffusion process.



Finally, with everything above we can define the **Generation model**:

The generation model is the **network-predicted reverse diffusion** we discussed earlier. This is an iterative process from $x_i^T, x_i^{T-1}, x_i^{T-2}\text{,... to } x_i^0$. We first sample a noisy graph from the limit distribution we talked about above, then we begin the reverse iterations: at each reverse step, $p_{\theta}(x^{t-1}_i | x^t_i)$ is calculated, which returns a vector of probabilities for node $x_i$ being each type at step $t-1$, each step we need to sample a discrete type from this multinomial probability which adds stochasticity to the generation process.

Also as we discussed before, $p_{\theta}(x^{t-1}_i | x^t_i)$ is calculated combining the **denoising network prediction** and the **ideal reverse**. It marginalises over all five types (for the QM9 dataset) of $x_i$:


\\
\begin{align*}
p_{\theta}(x^{t-1}_i | x^t_i) &=
\sum_{x \in \mathscr{X}}
\begin{cases}
q(x^{t-1}_i | x_i = x, x^t_i) \cdot \widehat{p}^X_i(x) & \quad \text{if } q(x^t_i | x_i = x) > 0, \\
0 & \quad \text{otherwise.}
\end{cases}
\end{align*}

\\

The ideal reverse is computed on supposing the the clean state of $x_i$ is $x$ (one of the five node types). It is then combined with the network prediction $\widehat{p}^X_i(x)$,  it is the predicted probability for the clean state of $x_i$ being $x$, which is a scaler. Together, the term can be interpreted as a **network-weighted ideal reverse**, and then we sum over the results for all five types to get the final probability vector for $x_i^{t-1}$.

This is how DiGress get from prediction to generation: with the help of the **ideal reverse** to ***formulate*** a reverse diffusion process; using the trained **denoising network** to ***guide*** this process; as well as the stochasticity added at each step, to generate new graphs.

Note: we focus on the generation of nodes, for edges it was done similarily; the ideal reverse is $0$ if the the transition from $x_i = x$ to $x_i^t$ is deemed to be theoretically impossible.

\\

Beyond the basics of a DiGress model, there are also ways to augment its sampling process. In drug design, it is often advantageous to generate molecules tailored to specific biological or chemical properties to meet therapeutic objectives, it is called **conditional generation**: besides a denoising network predicting the clean state of a noisy graph, we have a **regressor model** $g_ŋ$ to predict target properties $\mathbf{y}_G$ (label) of its clean state: $g_ŋ(G^t) = \mathbf{\bar{y}}$ (prediction); and the ideal reverse is not only **conditioned** on the supposed clean states but also the **target properties**. Therefore, the sampling process is not only ***guided/weighted*** by the denoising network but also the regressor model. Another way to improve DiGress is by adding **structural features** as extra inputs to the denoising network, such as **cycle counts**, since it is difficult for the standard message passing networks(MPNNs) to detect cycles in a graph [5], feeding this extra information to the denoising network experimentally improves performance [1].   



In [None]:
'''
In practice, rather than calculating the ideal reverse conditioning on 5 node types
seperately, it is much easier to calculate if we leave out the clean state term in the
calculation since it's one-hot encoded, please refer to the explanation below.
'''
torch.manual_seed(234)

# To calculate q(xi_s|xi_t, xi = x), we need xi_tQt_transpose * xQbars
# fomula spelt out in earlier section, also xi_s means node xi at t-1 step, suppose:

xi_tQt_transpose = torch.tensor([0.3,
                                 0.11,
                                 0.42,
                                 0.15,
                                 0.02])
Qbars = torch.tensor([
                      [0.4, 0.3, 0.2, 0.05, 0.05], # Probabilities of type 0 transitions to type 1,2,3,4
                      [0.1, 0.5, 0.2, 0.1, 0.1],
                      [0.25, 0.25, 0.3, 0.1, 0.1],
                      [0.2, 0.3, 0.1, 0.3, 0.1],
                      [0.1, 0.1, 0.1, 0.2, 0.5]
                      ])

print('Compute xQbars and suppose the clean state of xi is type 0 node first:\n')
x = torch.tensor([1, 0, 0, 0, 0]) # Type 0 node
xQbars = torch.matmul(x.type(torch.float32), Qbars)
print(xQbars, 'results in the first row of *Qbars*')



print('\nDoing this for all five node types is just multiplying each row of Qbars itself with\n',
    'xi_tQt_transpose element-wise. It is equal to doing matrix multiplication between xi_tQt and\n',
    'the transpose of Qbars, or more simply, using broadcasting in Pytorch.\n'
     )


ideal_reverse = Qbars * xi_tQt_transpose.unsqueeze(0)
# I potentially left out denominator since we have already spelt it out earlier, and
# it's only a normaliser plus the focus here is to show how to make calculations easier
print('The ideal reverse is:\n', ideal_reverse)

# A made-up denosing network prediction vector of node i
p_hat_i = torch.tensor([0.13, 0.29, 0.08, 0.3, 0.2])

weighted_ideal_reverse = ideal_reverse * p_hat_i.unsqueeze(1)

xi_s = weighted_ideal_reverse.sum(dim=0)
print('\nThe weighted ideal reverse is:\n', weighted_ideal_reverse)
print('\nThe final probability vector for xi_s is:\n', xi_s)


# Sample from the distribution
xi_s_probs_normalized = xi_s / xi_s.sum()
print("Normalized probabilities:\n", xi_s_probs_normalized)

sample = torch.multinomial(xi_s_probs_normalized, num_samples=1)
print("The discrete xi at t-1 step (node type index):", sample.item())


Compute xQbars and suppose the clean state of xi is type 0 node first:

tensor([0.4000, 0.3000, 0.2000, 0.0500, 0.0500]) results in the first row of *Qbars*

Doing this for all five node types is just multiplying each row of Qbars itself with
 xi_tQt_transpose element-wise. It is equal to doing matrix multiplication between xi_tQt and
 the transpose of Qbars, or more simply, using broadcasting in Pytorch.

The ideal reverse is:
 tensor([[0.1200, 0.0330, 0.0840, 0.0075, 0.0010],
        [0.0300, 0.0550, 0.0840, 0.0150, 0.0020],
        [0.0750, 0.0275, 0.1260, 0.0150, 0.0020],
        [0.0600, 0.0330, 0.0420, 0.0450, 0.0020],
        [0.0300, 0.0110, 0.0420, 0.0300, 0.0100]])

The weighted ideal reverse is:
 tensor([[0.0156, 0.0043, 0.0109, 0.0010, 0.0001],
        [0.0087, 0.0159, 0.0244, 0.0044, 0.0006],
        [0.0060, 0.0022, 0.0101, 0.0012, 0.0002],
        [0.0180, 0.0099, 0.0126, 0.0135, 0.0006],
        [0.0060, 0.0022, 0.0084, 0.0060, 0.0020]])

The final probability vector fo

## Practice

The below section is a much simplified version of the original DiGress code that I implemented myself. I have followed their math and their main implementing logic, but some features due to the scope of this talktorial have been left out, if you are interested, please check out their Github repo at [DiGress](https://github.com/cvignac/DiGress) [6].


### Python Dependencies

In [None]:
# In Runtime -> Change runtime type, choose an available GPU runtime

!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
!pip install torch-geometric

'''
For reviewer: for a strange reason, without the help of CUDA, the process was
forever running on the line - Building wheel for torch-scatter (setup.py)
using a requirements.txt didn't work for me either, I got: "ERROR:
torch_scatter-2.1.1+pt2.5.1cu121-cp310-cp310-linux_x86_64.whl is not a supported wheel on this platform."
'''

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_scatter-2.1.2%2Bpt25cu121-cp311-cp311-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt25cu121
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_sparse-0.6.18%2Bpt25cu121-cp311-cp311-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m86.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt25cu121
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-cluster
  Downloading https://data.p

'\nFor reviewer: for a strange reason, without the help of CUDA, the process was\nforever running on the line - Building wheel for torch-scatter (setup.py)\nusing a requirements.txt didn\'t work for me either, I got: "ERROR:\ntorch_scatter-2.1.1+pt2.5.1cu121-cp310-cp310-linux_x86_64.whl is not a supported wheel on this platform."\n'

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import TransformerConv
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import to_dense_adj

In [None]:
torch.cuda.manual_seed(234)

### Training Algorithm

Step 1: Input the graph.   
Step 2: Define noise schedule.   
Step 3: Define the graph transformer which takes a noisy graph as input and make predictions of its clean state.   
Step 4: Train the graph transformer by optimising the cross entropy between the true graph and the predictions.   


Note: as we discussed in the Theory section, model selection should be done in the validation phase using several metrics such as KL-Divergence as per official DiGress code, however due to the scope of this talktorial, this part is not implemented.

#### Step 1: Load the Dataset

We are using [QM9](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.QM9), a dataset widely used as a benchmark for molecule generation. It consists of about 130,000 molecules, we are directly using its processed version from Pytorch. For runtime consideration, only 1,000 molecules are sampled for training phase.

In [None]:
# Define a transform to slice the node feature matrix
def slice_node_features(data):
    data.x = data.x[:, :5]
    # the first 5 columns use one hot encoding for atom types
    # the last 6 are [atomic_number, aromatic, sp, sp2, sp3, num_hs] which will not be used in our model
    return data

In [None]:
dataset = QM9(root="./data/qm9", pre_transform=slice_node_features, force_reload=True)[:1000]
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)

Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting data/qm9/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


In [None]:
# An example of the QM9 dataset
data = dataset[988]
print('The dataset has already been processed:\n',data)
print('\nThis is a node matrix where each row represents the encoding of a node:\n',data.x)

The dataset has already been processed:
 Data(x=[10, 5], edge_index=[2, 20], edge_attr=[20, 4], y=[1, 19], pos=[10, 3], idx=[1], name='gdb_1015', z=[10])

This is a node matrix where each row represents the encoding of a node:
 tensor([[0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]])


In [None]:
# Predefined node and edge types (based on the official implementation)
atoms = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
bonds = {0: 0, 1: 1, 2: 2, 3: 3}    # They represent single, double, triple, aromatic bonds.

nodes_one_hot_encoding = torch.tensor([
                        [1, 0, 0, 0, 0],  # Hydrogen
                        [0, 1, 0, 0, 0],  # Carbon
                        [0, 0, 1, 0, 0],  # Nitrogen
                        [0, 0, 0, 1, 0],  # Oxygen
                        [0, 0, 0, 0, 1]   # Fluorine
                        ])



#### Step 2: Define the Noise Model



In [None]:
# Step 2.1: Using calculated marginal distributions from DiGress code

node_count_distribution = torch.tensor([0, 0, 0, 1.5287e-05, 3.0574e-05, 3.8217e-05,
                        9.1721e-05, 1.5287e-04, 4.9682e-04, 1.3147e-03, 3.6918e-03, 8.0486e-03,
                        1.6732e-02, 3.0780e-02, 5.1654e-02, 7.8085e-02, 1.0566e-01, 1.2970e-01,
                        1.3332e-01, 1.3870e-01, 9.4802e-02, 1.0063e-01, 3.3845e-02, 4.8628e-02,
                        5.4421e-03, 1.4698e-02, 4.5096e-04, 2.7211e-03, 0.0000e+00, 2.6752e-04])


node_distribution = torch.tensor([0.5122, 0.3526, 0.0562, 0.0777, 0.0013])

edge_distribution = torch.tensor([0.88162,  0.11062,  5.9875e-03,  1.7758e-03])

# Define the number of node types and edge types directly
a = len(atoms)  # Number of node types (5 types: 'H', 'C', 'N', 'O', 'F')
b = len(bonds)
n = len(node_count_distribution) # 30

In [None]:
# Step 2.2: Define the Markovian Noise Model

def cosine_beta_schedule_discrete(timesteps, s=0.008): #  taken from DiGress code
    """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
    steps = timesteps + 2
    x = np.linspace(0, steps, steps)

    alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas = 1 - alphas
    return betas.squeeze()

class NoiseModel:
    def __init__(self, T, s, node_distribution, edge_distribution):
        self.T = T
        self.s = s
        self.betas = torch.tensor(cosine_beta_schedule_discrete(self.T))
        self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999)
        log_alpha = torch.log(self.alphas)
        log_alpha_bar = torch.cumsum(log_alpha, dim=0)
        self.alphas_bar = torch.exp(log_alpha_bar)


        self.node_distribution = node_distribution
        self.edge_distribution = edge_distribution

        self.cardinality_x = len(node_distribution)
        self.cardinality_e = len(edge_distribution)

        self.one_a_x = torch.ones((self.cardinality_x, self.cardinality_x)) / self.cardinality_x
        self.one_a_e = torch.ones((self.cardinality_e, self.cardinality_e)) / self.cardinality_e


    def get_alpha_bar(self, t_normalized=None, t_int=None):
        assert int(t_normalized is None) + int(t_int is None) == 1
        if t_int is None:
            t_int = torch.round(t_normalized * self.timesteps)
        if not isinstance(t_int, torch.Tensor):
            t_int = torch.tensor(t_int, device=self.alphas_bar.device)
        return self.alphas_bar.to(t_int.device)[t_int.long()]

    def qt(self, t):
        '''
        Define the Markovian noise model for node type and edge type
        Qt approaches the marginal distribution M as t approaches infinity
        '''
        if t <= 0 or t > self.T:
            raise ValueError(f"Timestep t={t} is out of range (1 to {self.T}).")
        alpha_t = self.alphas[t]

        # Create the transition matrix, which approaches marginal distribution as t approaches infinity
        Qt_x = alpha_t * torch.eye(self.cardinality_x) + (1 - alpha_t) * (self.one_a_x * self.node_distribution.unsqueeze(1).T) # M_transpose_x
        Qt_e = alpha_t * torch.eye(self.cardinality_e) + (1 - alpha_t) * (self.one_a_e * self.edge_distribution.unsqueeze(1).T) # M_transpose_e
        return Qt_x, Qt_e

    def get_Qt_bar(self, alpha_bar_t, device):
        # Returns t-step transition matrices for nodes and edges from step 0 to step t using alpha_bar_t

        alpha_bar_t = alpha_bar_t.unsqueeze(0).to(device).float()
        node_distribution = self.node_distribution.to(device).float()
        edge_distribution = self.edge_distribution.to(device).float()

        q_x = alpha_bar_t * torch.eye(self.cardinality_x, device=device).unsqueeze(0) + (1 - alpha_bar_t) * node_distribution.unsqueeze(1)
        q_e = alpha_bar_t * torch.eye(self.cardinality_e, device=device).unsqueeze(0) + (1 - alpha_bar_t) * edge_distribution.unsqueeze(1)
        return q_x, q_e

#### Step 3: Define the Denoising Network

In [None]:
class GraphTransformer(nn.Module):
    def __init__(self, num_node_types, num_edge_types):
        super(GraphTransformer, self).__init__()

        # Node transformer layers
        self.node_conv1 = TransformerConv(in_channels=num_node_types, out_channels=64, edge_dim=num_edge_types)
        self.node_conv2 = TransformerConv(in_channels=64, out_channels=64, edge_dim=num_edge_types)
        self.node_fc = nn.Linear(64, num_node_types)  # Predict node classes

        # Edge transformer layers
        self.edge_fc1 = nn.Linear(num_edge_types, 64)  # Initial edge embedding
        self.edge_fc2 = nn.Linear(64, 64)             # Edge processing
        self.edge_fc3 = nn.Linear(64, num_edge_types) # Predict edge classes

    def forward(self, x, edge_index, edge_attr):
        """
        x: Node features [num_nodes, num_node_features]
        edge_index: Edge indices [2, num_edges]
        edge_attr: Edge attributes [num_edges, num_edge_features]
        """
        # ---- Node Predictions ----
        node_x = self.node_conv1(x, edge_index, edge_attr)  # [num_nodes, 64]
        node_x = F.relu(node_x)
        node_x = self.node_conv2(node_x, edge_index, edge_attr)  # [num_nodes, 64]
        node_x = F.relu(node_x)
        node_predictions = self.node_fc(node_x)  # [num_nodes, num_node_types]

        # ---- Edge Predictions ----
        edge_features = F.relu(self.edge_fc1(edge_attr))  # [num_edges, 64]
        edge_features = F.relu(self.edge_fc2(edge_features))  # [num_edges, 64]
        edge_predictions = self.edge_fc3(edge_features)  # [num_edges, num_edge_types]

        return node_predictions, edge_predictions

#### Step 4: Training loop

In [None]:
model = GraphTransformer(num_node_types=a, num_edge_types=b)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.CrossEntropyLoss()


num_epochs = 20
noise_model = NoiseModel(T=10, s=0.008, node_distribution=node_distribution, edge_distribution=edge_distribution) # updated to directly use dists
# DiGress used T = 500

# Initialize variables to track the total node-edge ratio and the number of graphs
total_ratio = 0
total_graphs = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(num_epochs):
    total_loss = 0
    for data in data_loader:
        x, edge_index, edge_attr, batch = data.x.to(device), data.edge_index.to(device), data.edge_attr.to(device), data.batch.to(device)
        '''
        'batch' information is not utilised in this *simplified* DiGress model,
         since I'm only doing node or edge-level predictions, as we discussesd
         in the theory section, each node/edge is applied with the same transition
         probability matrix at each step, and the noise is added separately on each
         node and edge feature, without interfering with each other; the denosing
         network does not make predictions on the graph-level features; also no structural
         features of the graphs are given to the denosing network as input in this model.
        '''

        # Convert node and edge attributes to one-hot encoding
        x = F.one_hot(x.argmax(dim=-1), num_classes = a).float()  #  num_nodes, a
        edge_attr = F.one_hot(edge_attr.argmax(dim=-1), num_classes = b).float()  # num_edges, b
        '''
        The official DiGress code first processes the dataset using the sparse representation,
        and then converts it to a dense adjacency matrix. In contrast, for runtime efficiency,
        I will only be using the sparse representation to avoid memory overhead.
        '''

        # Sample a timestep t
        t_int = torch.randint(0, noise_model.T, size=(1,), device=x.device).item()
        t_normalized = t_int / noise_model.T

        # Compute alpha_bar for timestep t
        alpha_bar_t = noise_model.get_alpha_bar(t_int=t_int).clone().detach().to(device)


        # Compute Qt_bar using alpha_bar_t directly
        Qt_node_bar = noise_model.get_Qt_bar(alpha_bar_t, device=device)[0]
        Qt_edge_bar = noise_model.get_Qt_bar(alpha_bar_t, device=device)[1]

        # Apply noise to nodes and edges using the cumulative transition matrices directly
        # Use matrix multiplication to apply the Markovian transition for nodes and edges
        x_noisy = torch.matmul(x, Qt_node_bar)  # Apply cumulative noise
        edge_attr_noisy = torch.matmul(edge_attr, Qt_edge_bar)


        # Ensure probabilities sum to 1
        x_noisy_prob = F.softmax(x_noisy, dim=-1)  # Shape: [1, num_nodes, d0]
        edge_attr_noisy_prob = F.softmax(edge_attr_noisy, dim=-1)  # Shape: [1, num_edges, d0]


        x_noisy_prob = x_noisy_prob.squeeze(0)  # Shape: [num_nodes, 5]
        edge_attr_noisy_prob = edge_attr_noisy_prob.squeeze(0)  # Shape: [num_edges, 4]

        # Sample from the computed multinomial distribution
        sampled_x = torch.multinomial(x_noisy_prob, 1).squeeze(-1)  # Shape: [n_nodes]
        sampled_edge_attr = torch.multinomial(edge_attr_noisy_prob, 1).squeeze(-1)  # Shape: [num_edges]

        # Convert sampled indices into one-hot encoded tensors
        x_noisy_discrete = F.one_hot(sampled_x, num_classes=Qt_node_bar.shape[-1]).float()  # Shape: [num_nodes, d0]
        edge_attr_noisy_discrete = F.one_hot(sampled_edge_attr, num_classes=Qt_edge_bar.shape[-1]).float()  # Shape: [num_edges, d0]


        # Forward pass
        node_predictions, edge_predictions = model(x_noisy_discrete, edge_index, edge_attr_noisy_discrete)


        # Cross Entropy Loss
        node_loss = loss_function(node_predictions.view(-1, a), x.argmax(dim=-1).view(-1))
        edge_loss = loss_function(edge_predictions.view(-1, b), edge_attr.argmax(dim=-1).view(-1))
        # logits are converted into probabilities internally via softmax,


        # For simplicity, we chose not to use the hyperparameter lambda
        loss = node_loss + edge_loss
        total_loss += loss.item()

        # Compute the node-edge ratio for the current batch
        num_nodes = x.size(0)
        num_edges = edge_index.size(1)
        if num_nodes > 0:
            batch_ratio = num_edges / num_nodes
            total_ratio += batch_ratio
            total_graphs += 1

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")
print("The final result is high because we are cumulating the loss")


# Compute the final average node-edge ratio
average_node_edge_ratio = total_ratio / total_graphs if total_graphs > 0 else 0
average_node_edge_ratio = torch.tensor(average_node_edge_ratio, device=device)
print(f"\n\nAverage Node-Edge Ratio: {average_node_edge_ratio.item():.4f}")

Epoch 1/20, Loss: 118.1463
Epoch 2/20, Loss: 76.6991
Epoch 3/20, Loss: 62.2255
Epoch 4/20, Loss: 57.2331
Epoch 5/20, Loss: 55.6253
Epoch 6/20, Loss: 55.2134
Epoch 7/20, Loss: 54.4743
Epoch 8/20, Loss: 53.7410
Epoch 9/20, Loss: 53.7376
Epoch 10/20, Loss: 53.3483
Epoch 11/20, Loss: 53.0723
Epoch 12/20, Loss: 52.9757
Epoch 13/20, Loss: 53.0750
Epoch 14/20, Loss: 52.6078
Epoch 15/20, Loss: 52.8955
Epoch 16/20, Loss: 52.3903
Epoch 17/20, Loss: 52.8890
Epoch 18/20, Loss: 52.7958
Epoch 19/20, Loss: 52.6954
Epoch 20/20, Loss: 52.4746
The final result is high because we are cumulating the loss


Average Node-Edge Ratio: 1.9639


In [None]:
torch.save(model.state_dict(), "graph_transformer_model.pth")


In [None]:
model.load_state_dict(torch.load("graph_transformer_model.pth"))


  model.load_state_dict(torch.load("graph_transformer_model.pth"))


<All keys matched successfully>

### Sampling Algorithm

Step 0: sample a noisy graph from the limit distribution.  
Step 1: graph transformer takes the noisy graph $G^t$ as input and make predictions of its clean state $\widehat{p}^X, \widehat{p}^E$.    
Step 2: for $t = T \text{ to } 1$, using the posterior (ideal reverse weighted by the network predictions)to get $G^{t-1}, G^{t-2}...G^{0}$.

In [None]:
# Step 5: Generation Step with network-predicted reverse diffusion computation

def generate_graph(noise_model, model, T):

    device = next(model.parameters()).device

    node_count_dist = node_count_distribution
    node_count_dist[node_count_dist == 0] = 1e-10    # avoids zero entries

    # Normalize to ensure it sums to 1
    node_count_dist /= node_count_dist.sum()
    node_count_dist = node_count_dist.to(device)

    # Start from the limit distributions (using marginal distributions as limit distributions)
    limit_dist_node = noise_model.node_distribution.to(device)
    limit_dist_edge = noise_model.edge_distribution.to(device)


    # Sample number of nodes based on node count distribution
    num_nodes = torch.multinomial(node_count_dist, 1).item()

    # Impose limit on the number of edges using pre-calculated Average Node-Edge Ratio
    num_edges = int(average_node_edge_ratio * num_nodes)


    # Generate the node and edge types from the limit distribution
    X = torch.multinomial(limit_dist_node, num_nodes, replacement=True).to(device)
    E = torch.multinomial(limit_dist_edge, num_edges, replacement=True).to(device)

    # Generate all possible edges
    total_possible_edges = torch.combinations(torch.arange(num_nodes), r=2).to(device)

    # Create edge_index for sparse representation of edges
    edge_indices = torch.multinomial(
    torch.ones(total_possible_edges.size(0), device=device),  # Uniform distribution over edge indices
    num_edges,
    replacement=False)
    edge_index = total_possible_edges[edge_indices].T  # Shape: [2, num_edges]


    # Because we still need to make it DISCRETE (aligns with DiGress code)
    X = F.one_hot(X, num_classes=a).float()  #  n, a
    E = F.one_hot(E, num_classes=b).float()  # num_edges, b

    edge_attr = E
    print('The shapes of the noisy X and E are :',X.shape, E.shape)


    for t in range(T, 0, -1):  # Start from T and stop at t=1 (inclusive)

        # Transformer predictions
        node_predictions, edge_predictions = model(X, edge_index, edge_attr)
        print(node_predictions.shape, edge_predictions.shape)

        alpha_bar_t = noise_model.get_alpha_bar(t_int=t_int).clone().detach().to(device)
        if t > 0:
          alpha_bar_s = noise_model.get_alpha_bar(t_int=t-1).clone().detach().to(device)
        else:
          raise ValueError("alpha_bar_s should be defined for t > 0")


        # transition matrices at t and t-1
        Qt_node = noise_model.qt(t)[0].to(device)
        Qt_edge = noise_model.qt(t)[1].to(device)
        Qtb_node = noise_model.get_Qt_bar(alpha_bar_t, device=device)[0]
        Qtb_edge = noise_model.get_Qt_bar(alpha_bar_t, device=device)[1]
        Qsb_node = noise_model.get_Qt_bar(alpha_bar_s, device=device)[0]
        Qsb_edge = noise_model.get_Qt_bar(alpha_bar_s, device=device)[1]

        # Normalize predictions
        pred_X = F.softmax(node_predictions, dim=-1)  #  n, d0
        pred_E = F.softmax(edge_predictions, dim=-1)  #  n, n, d0



        # ----- Node Posterior Calculation -----

        # Compute ideal reverse transition for nodes: xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
        Qt_node_T = Qt_node.T
        # b in Qtb means bar

        # Qtb_node before transformation: [1, 5, 5]
        Qtb_node = Qtb_node.squeeze()

        # Element wise mutiplication
        numerator_X = torch.matmul(X, Qt_node_T).unsqueeze(1) * Qsb_node # Qsb encode all possible x0 states [n,1,5] * [1,5,5]
        # X here is a nosiy node matrix


        denominator_X = torch.matmul(Qtb_node.squeeze(), X.T).unsqueeze(-1)
        denominator_X = denominator_X.permute(1, 0, 2)  # [5, n, 1] -> [n, 5, 1]
        ideal_reverse_X = numerator_X / denominator_X


        # Multiply the ideal reverse by the predictions (this step is directly from DiGress.sample_batch)
        p_s_and_t_given_0_X = ideal_reverse_X * pred_X.unsqueeze(-1)  # n, d0, d_{t-1}

        # Sum over all possible node types to get the unnormalized posterior
        unnormalized_prob_X = p_s_and_t_given_0_X.sum(dim=2)  #  n, d_{t-1}
        unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5

        # Normalize to get the posterior distribution
        prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  #  n, d_{t-1}



        # ----- Edge Posterior Calculation -----

        # Compute ideal reverse transition for edges
        Qt_edge_T = Qt_edge.T
        Qtb_edge = Qtb_edge.squeeze()


        # Element wise mutiplication
        numerator_E = torch.matmul(E, Qt_edge_T).unsqueeze(1) * Qsb_edge # Qsb encode all possible x0 states

        denominator_E = torch.matmul(Qtb_edge.squeeze(), E.T).unsqueeze(-1)
        denominator_E = denominator_E.permute(1, 0, 2)

        ideal_reverse_E = numerator_E / denominator_E


        # Multiply the ideal reverse by the predictions
        p_s_and_t_given_0_E = ideal_reverse_E * pred_E.unsqueeze(-1)

        # Sum over all possible edge types to get the unnormalized posterior
        unnormalized_prob_E = p_s_and_t_given_0_E.sum(dim=-2)  # n, n, d_{t-1}
        unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5

        # Normalize to get the posterior distribution
        prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)



        # ----- Sampling from Posterior -----


        # Sample a discrete class from the computed posterior distribution
        sampled_s_X = torch.multinomial(prob_X, 1).squeeze(-1).to(device)  # Shape: [n]
        sampled_s_E = torch.multinomial(prob_E, 1).squeeze(-1).to(device)  # Shape: [num_edges]


        # Because we need to stay DISCRETE (aligns with DiGress code)
        X = F.one_hot(sampled_s_X, num_classes=a).float().to(device)
        edge_attr = F.one_hot(sampled_s_E, num_classes=b).float().to(device)


        # Assertions to verify correctness
        assert (X.shape[-1] == a), "Node shape mismatch detected"
        assert (edge_attr.shape[-1] == b), "Edge attribute shape mismatch detected"

    return X, edge_index, edge_attr


In [None]:
# Genaration
torch.cuda.manual_seed(234)
noise_model = NoiseModel(T=10, s=0.008, node_distribution=node_distribution, edge_distribution=edge_distribution)
generated_X, generated_edge_index, generated_edge_attr = generate_graph(
                                                                        noise_model,
                                                                        model,
                                                                        T=10,
                                                                        )

# Decode
decoded_node_types = torch.argmax(generated_X, dim=-1).tolist()
decoded_edge_types = torch.argmax(generated_edge_attr, dim=-1).tolist()

print("Generated Node Types (decoded):", decoded_node_types)
print("Generated Edge Index:", generated_edge_index)
print("Generated Edge Types (decoded):", decoded_edge_types)


The shapes of the noisy X and E are : torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
torch.Size([16, 5]) torch.Size([31, 4])
Generated Node Types (decoded): [0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 1, 1, 0, 1, 1]
Generated Edge Index: tensor([[ 2,  0,  8,  5, 12,  9,  1,  4,  0, 12,  6,  1,  1,  5,  7,  2,  7, 10,
          4,  9,  3,  8,  8,  5,  2,  5, 12,  7,  2,  6,  3],
        [ 4,  7, 12,  9, 14, 15,  4, 12, 12, 13,  7,  9, 14,  8, 11, 12, 10, 11,
          9, 11,  6, 15, 14, 11,  8, 14, 15, 14, 15, 11,  7]], device='cuda:0')
Generated Edge Types (decoded): [0, 2, 0, 1, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2

### DiGress's Performance

DiGress performs exceptionally well on the QM9 small molecule dataset, achieving near-perfect validity (99.0%), uniqueness (96.2%), and competitive performance (1.0 hour) in terms of training efficiency compared to continuous models (7.2 hours). Beyond QM9, DiGress scales effectively to larger and more complex datasets like MOSES and GuacaMol, demonstrating its ability to handle diverse molecular graphs while achieving competitive performance without molecule-specific representations. It also excels in non-molecular graph generation tasks, outperforming other models in generating planar graphs and stochastic block models, showing its versatility across graph types [1].

## Discussion

In this talktorial, we have gone over the theory and pratice of the model DiGress. We began with the motivations for AI molecule generation, the intuition behind diffusion models. In particular, we have discussed in depth why a discrete diffusion might be more appropriate in the case of graph generation, its mathematical framework, the training and sampling algorithms. We only briefly explained how to improve DiGress with conditional generation and structural features. The DiGress paper has a thorough explanation on this matter, and those who are interested should refer to section 4 in the paper for further reading.

To gain a deeper understanding of diffusion models, I strongly recommend the paper "Deep Unsupervised Learning using Nonequilibrium Thermodynamics" by [Sohl-Dickstein et al., 2015](https://arxiv.org/abs/1503.03585) and Lilian Weng's blog post ['What are diffusion models'](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/).

## Quiz

1. Do nodes and edges have the same noise model?
2. What does it mean that the author of DiGress say that the diffusion process under discussion is defined independently for *each* node and edge?
3. Why the denoising network is essentially solving a classificatin problem in the training phase?
4. Why the graph transformer (i.e. the denoising network) alone is not enough for sampling?
5. Does using the sparse representation of graphs not the dense adjacency matrix limit the predictive power of my simplified DiGress model? (Answer: yes, it does. Since we are only using edge_index and edge_attr, and 'no_edge' is not encoded as an edge type, the graph transformer (i.e. the denoising network) is not trained to predict the 'no_edge' type. Note that in our simplified model, the num_edges is calculated based on the average node-edge ratio of the dataset, and the edge indices are randomly sampled when we sample the noisy graphs, before the reverse diffusion begins. However, in official DiGress code, a dense adjacency matrix is used, and during sampling, the num_edges(actual edges) is set as the theoretical upper bound for undirected graphs (i.e. the upper triangular part of the adjacency matrix) at first, and since the graph transformer has been trained on clean graphs and can predict no_edge type, the actual number of edges may decrease, and the edge_index is dynamic too in the reverse diffusion process.)



In [None]:
import torch

# Transition probability matrix at step 1
Q1_X = torch.tensor([
                    [0.4, 0.2, 0.2, 0.1, 0.1],        # Type 1 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.1, 0.6, 0.1, 0.1, 0.1],        # Type 2 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.2, 0.2, 0.5, 0.05, 0.05],      # Type 3 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.1, 0.1, 0.1, 0.6, 0.1],        # Type 4 node transition probabilities to type 1, 2, 3, 4, 5
                    [0.05, 0.05, 0.05, 0.15, 0.7]     # Type 5 node transition probabilities to type 1, 2, 3, 4, 5
                    ])

# Transition probability matrix at step 2
Q2_X = torch.tensor([
                    [0.5, 0.3, 0.1, 0.05, 0.05],
                    [0.2, 0.5, 0.2, 0.05, 0.05],
                    [0.1, 0.1, 0.6, 0.1, 0.1],
                    [0.05, 0.2, 0.1, 0.6, 0.05],
                    [0.1, 0.05, 0.05, 0.1, 0.7]
                    ])

# Cumulative product
Q_bar2_X = torch.matmul(Q1_X, Q2_X)
print("The cumulative product of transition probabilities matrices:\n", Q_bar2_X)


The cumulative product of transition probabilities matrices:
 tensor([[0.2750, 0.2650, 0.2150, 0.1200, 0.1250],
        [0.1950, 0.3650, 0.2050, 0.1150, 0.1200],
        [0.1975, 0.2225, 0.3675, 0.1050, 0.1075],
        [0.1200, 0.2150, 0.1550, 0.3900, 0.1200],
        [0.1175, 0.1100, 0.0950, 0.1700, 0.5075]])
