# Variational Flow Matching for Graph Generation

### by Floor Eijkelboom et al. (2024)

Necessary imports

In [41]:
import os
import time
import math
import glob
import yaml
import torch
import random
import einops
import numpy as np
from torch import nn
from tqdm import tqdm
import torch.optim as optim
from torchdiffeq import odeint
from torch.nn import functional as F
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops

In [42]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

### We directly follow Eijkelboom et al. by adopting the implementation by Vignac et al. (2023), from **DiGress: Discrete Denoising diffusion for graph generation** by Vignac et al. (2023)

In [44]:
from src.transformer_model import GraphTransformer

In [45]:
class PlaceHolder:
    def __init__(self, X, E, y):
        self.X = X
        self.E = E
        self.y = y

    def type_as(self, x: torch.Tensor):
        """ Changes the device and dtype of X, E, y. """
        self.X = self.X.type_as(x)
        self.E = self.E.type_as(x)
        self.y = self.y.type_as(x)
        return self

    def mask(self, node_mask, collapse=False):
        x_mask = node_mask.unsqueeze(-1)          # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1

        if collapse:
            self.X = torch.argmax(self.X, dim=-1)
            self.E = torch.argmax(self.E, dim=-1)

            self.X[node_mask == 0] = - 1
            self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
        else:
            self.X = self.X * x_mask
            self.E = self.E * e_mask1 * e_mask2
            assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
        return self

In [58]:
class QM9GraphDataset(Dataset):
    def __init__(self, data, nodes_classes=5, edge_classes=4):
        self.data = data
        self.nodes_classes = nodes_classes
        self.edge_classes = edge_classes
        
    def __len__(self):
        return len(self.data['node_idx_array'])  # Number of graphs

    def __getitem__(self, idx):
        # Get node and edge ranges for the idx-th molecule
        node_start, node_end = self.data['node_idx_array'][idx]
        edge_start, edge_end = self.data['edge_idx_array'][idx]

        # Extract node and edge information
        atom_types = self.data['atom_types'][node_start:node_end]  # Shape: [num_nodes, nodes_classes]
        bond_idxs = self.data['bond_idxs'][edge_start:edge_end]    # Shape: [num_edges, 2]
        bond_types = self.data['bond_types'][edge_start:edge_end].to(torch.long)  # Shape: [num_edges]

        num_nodes = atom_types.shape[0]

        # Initialize the node feature tensor (one-hot encoded node types)
        node_labels = torch.argmax(atom_types.to(torch.float), dim=1)  # Get the node class labels
        node_features = F.one_hot(node_labels, num_classes=self.nodes_classes)  # Shape: [num_nodes, nodes_classes]

        # Initialize edge feature tensor
        edge_class_matrix = torch.full((num_nodes, num_nodes), self.edge_classes - 1, dtype=torch.long)
        for (src, dst), bond_type in zip(bond_idxs, bond_types):
            edge_class_matrix[src, dst] = bond_type
            edge_class_matrix[dst, src] = bond_type  # Ensure symmetry by assigning both directions

        # Check for symmetry
        assert torch.equal(edge_class_matrix, edge_class_matrix.T), "edge_class_matrix is not symmetric"
        
        edge_features = F.one_hot(edge_class_matrix, num_classes=self.edge_classes)  # Shape: [num_nodes, num_nodes, edge_classes]

        # Create the PyTorch Geometric Data object (graph)
        edge_index = torch.tensor(bond_idxs).T  # Shape: [2, num_edges]
        
        # Flatten edge feature matrix for each edge: [num_edges, edge_classes]
        edge_attr = edge_features[edge_index[0], edge_index[1]]  # Shape: [num_edges, edge_classes]
        
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
        return data
    
def create_dataloader(data, batch_size=32, nodes_classes=5, edge_classes=4):
    # Create the dataset
    dataset = QM9GraphDataset(data, nodes_classes, edge_classes)

    # Create the DataLoader (PyTorch Geometric handles variable graph sizes)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader


# Load your processed data
train_data = torch.load("qm9/train_data_processed.pt")
val_data = torch.load("qm9/val_data_processed.pt")

# Create the DataLoader
batch_size = 1
train_dataloader = create_dataloader(train_data, batch_size=batch_size)
val_dataloader = create_dataloader(val_data, batch_size=batch_size)

  train_data = torch.load("qm9/train_data_processed.pt")
  val_data = torch.load("qm9/val_data_processed.pt")


In [59]:
def encode_no_edge(E):
    assert len(E.shape) == 4, "Expected shape [batch, nodes, nodes, edge_classes]"
    if E.shape[-1] == 0:
        return E
    
    # Find locations where no edge is present in any class
    no_edge = torch.sum(E, dim=3) == 0  # Shape: [batch, nodes, nodes]

    # Set absence indicator (channel 0) at no-edge locations symmetrically
    E[:, :, :, 0][no_edge] = 1
    E[:, :, :, 0] = torch.max(E[:, :, :, 0], E[:, :, :, 0].transpose(1, 2))  # Make channel 0 symmetric
    
    # Copy all channels symmetrically for each [i, j] and [j, i] pair
    for k in range(E.shape[-1]):
        E[:, :, :, k] = torch.max(E[:, :, :, k], E[:, :, :, k].transpose(1, 2))

    # Set diagonal elements to zero for all channels
    diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    E[diag] = 0
    
    # Ensure final symmetry in all channels
    assert torch.allclose(E, E.transpose(1, 2)), "encode_no_edge produced a non-symmetric tensor"
    return E


def to_dense(x, edge_index, edge_attr, batch):
    X, node_mask = to_dense_batch(x=x, batch=batch)
    node_mask = node_mask.float()
    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
    max_num_nodes = X.size(1)
    E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
    E = encode_no_edge(E)

    return PlaceHolder(X=X, E=E, y=None), node_mask

In [60]:
# Taken from Davis et al. (2024) Fisher Flow Matching, https://github.com/olsdavis/fisher-flow

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=half, dtype=torch.float32)
        / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

In [61]:
def sample_gaussian(size):
    x = torch.randn(size)
    return x

def sample_feature_noise(X_size, E_size, y_size, node_mask):
    """Standard normal noise for all features.
        Output size: X.size(), E.size(), y.size() """
    # TODO: How to change this for the multi-gpu case?
    epsX = sample_gaussian(X_size)
    epsE = sample_gaussian(E_size)
    epsy = sample_gaussian(y_size)

    float_mask = node_mask.float()
    epsX = epsX.type_as(float_mask)
    epsE = epsE.type_as(float_mask)
    epsy = epsy.type_as(float_mask)

    # Get upper triangular part of edge noise, without main diagonal
    upper_triangular_mask = torch.zeros_like(epsE)
    indices = torch.triu_indices(row=epsE.size(1), col=epsE.size(2), offset=1)
    upper_triangular_mask[:, indices[0], indices[1], :] = 1

    epsE = epsE * upper_triangular_mask
    epsE = (epsE + torch.transpose(epsE, 1, 2))

    assert (epsE == torch.transpose(epsE, 1, 2)).all()

    return PlaceHolder(X=epsX, E=epsE, y=epsy).mask(node_mask)


def sample_normal(mu_X, mu_E, mu_y, sigma, node_mask):
    """Samples from a Normal distribution."""
    # TODO: change for multi-gpu case
    eps = sample_feature_noise(mu_X.size(), mu_E.size(), mu_y.size(), node_mask).type_as(mu_X)
    X = mu_X + sigma * eps.X
    E = mu_E + sigma.unsqueeze(1) * eps.E
    y = mu_y + sigma.squeeze(1) * eps.y
    return PlaceHolder(X=X, E=E, y=y)

### CatFlow Model

In [62]:
class CatFlow(nn.Module):
    def __init__(
        self,
        config: dict,
        device: torch.device,
        eps: float = 1e-6,
    ) -> None:
        """
        Constructor of the CatFlow model.

        Args:
            backbone_model (Backbone): Backbone model to extract features. In the case of our experiments, we use a graph transformer network.
            batch_size (int): Batch size. Default value is 32.
            num_nodes (int): Number of nodes. Default value is 10.
            num_classes (int): Number of classes. Default value is 10.
            eps (float): Epsilon value to avoid numerical instability. Default value is 1e-6.
        """
        super(CatFlow, self).__init__()
        self.backbone_model = GraphTransformer(
            n_layers=config['n_layers'],
            input_dims=config['input_dims'],
            hidden_mlp_dims=config['hidden_mlp_dims'],
            hidden_dims=config['hidden_dims'],
            output_dims=config['output_dims'],
        ).to(device)
        self.batch_size = config['batch_size']
        self.num_classes_nodes = config['input_dims']['X']
        self.num_classes_edges = config['input_dims']['E']
        self.eps = eps

    def sample_time(self, lambd: torch.tensor = torch.tensor([1.0])) -> torch.tensor:
        """
        Function to sample the time step for the CatFlow model.

        Args:
            lambd (torch.tensor): Rate parameter of the exponential distribution. Default value is 1.0.

        Returns:
            torch.tensor: Time step. Shape: (batch_size,).
        """
        # As in Dirichlet Flow Matching, we sample the time step from Exp(1)
        return torch.distributions.exponential.Exponential(lambd).sample().expand(self.batch_size)

    def sample_noise(self, kind, num_nodes) -> torch.tensor:
        """
        Function to sample the noise for the CatFlow model.
        
        Args:
            kind (str): Type of noise to sample. Options: 'node' or 'edge'.

        Returns:
            torch.tensor: Noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).
        """
        # Judging by the page 7 of the paper: the noise is not constrained to the simplex; so we can sample from a normal distribution
        # TODO: after figuring out the dimensions, use the proper forward pass
        # return torch.randn(self.batch_size, self.num_nodes + (self.num_nodes - 1)**2, self.num_classes + 1)
        if kind == 'node':
            return torch.randn(self.batch_size, num_nodes, self.num_classes_nodes)
        elif kind == 'edge':
            return torch.randn(self.batch_size, num_nodes, num_nodes, self.num_classes_edges)
        else:
            raise ValueError(f"Invalid noise type: {kind}")
    
    
    def forward(self, t: torch.tensor, x: torch.tensor, e: torch.tensor, node_mask: torch.tensor) -> torch.tensor:
        """
        Forward pass of the backbone model.

        Args:
            t (torch.tensor): Time step. Shape: (batch_size,).
            x (torch.tensor): Node noise. Shape: (batch_size, num_nodes, num_classes_nodes).
            e (torch.tensor): Edge noise. Shape: (batch_size, num_nodes, num_nodes, num_classes_edges).
            

        Returns:
            torch.tensor: Parameters of the variational distribution. Shape: (batch_size, num_nodes + (num_nodes - 1)^2, num_classes + 1).
        """
        # embed the timestep using the sinusoidal positional encoding
        t_embedded_nodes = timestep_embedding(t, dim=x.shape[-1]) # Shape: (batch_size, num_classes_nodes)
        t_embedded_edges = timestep_embedding(t, dim=e.shape[-1]) # Shape: (batch_size, num_classes_edges)
        # add time embedding to the input across the class feature dimension
        x += einops.rearrange(t_embedded_nodes, 'b c -> b 1 c')
        e += einops.rearrange(t_embedded_edges, 'b c -> b 1 1 c')

        return self.backbone_model(
            X=x,
            E=e,
            # dummy y
            y=torch.ones(self.batch_size, 1).to(device),
            node_mask=node_mask,
        )


    def vector_field(self, t: torch.tensor, x: torch.tensor) -> torch.tensor:
        """
        Function that returns the vector field of the CatFlow model for a given timestamp.

        Args:
            t (torch.tensor): Time step. Shape: (batch_size, 1).
            x (torch.tensor): Input noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).

        Returns:
            torch.tensor: Vector field. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).
        """

        return (self.backbone_model(t, x) - x) / (1 - t)

    def sampling(self, x: torch.tensor) -> torch.tensor:
        """
        Function to sample a new instance following the learned vector field.

        Args:
            x (torch.tensor): Input noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).

        Returns:
            torch.tensor: Sampled result.
        """
        # Define the time points over which to solve the ODE
        time_points = torch.linspace(0, 0.95, steps=20)  # Adjust steps as needed

        # Run the ODE solver with fixed time points
        return odeint(self.vector_field, x, time_points)

### Training

In [63]:
# prepare config
config = yaml.safe_load(open('configs/catflow.yaml', 'r'))

In [68]:
# Define the model
catflow = CatFlow(config=config, device=device).to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(catflow.parameters(), lr=0.0002, weight_decay=1e-12)

criterion = nn.BCEWithLogitsLoss()

# define the seed
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train_epoch(model, optimizer, dataloader, device):
    model.train()
    total_loss = 0
    for data in dataloader:#tqdm(dataloader):
        data = data.to(device)
        # Get the dense representation of the graph
        dense_data, node_mask = to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
        x_true, e_true = dense_data.X.to(device), dense_data.E.to(device)
        node_mask = node_mask.to(device)
        num_nodes = x_true.size(1)#.to(device)
        # Zero the gradients
        optimizer.zero_grad()
        # CatFlow forward pass
        # Step 1: Sample t ~ Exp(1), x ~ N(0, I), e ~ N(0, I)
        t = model.sample_time().to(device)
        x = model.sample_noise(kind='node', num_nodes=num_nodes).to(device)
        e = model.sample_noise(kind='edge', num_nodes=num_nodes).to(device)
        # Step 2: Forward pass of the graph transformer
        inferred = model.forward(t=t, x=x, e=e, node_mask=node_mask)
        # Step 3: Calculate the loss
        loss = criterion(inferred.X, x_true.float()) + criterion(inferred.E, e_true.float())
        # Step 4: Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in dataloader:#tqdm(dataloader):
            data = data.to(device)
            # Get the dense representation of the graph
            dense_data, node_mask = to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
            x_true, e_true = dense_data.X.to(device), dense_data.E.to(device)
            node_mask = node_mask.to(device)
            num_nodes = x_true.size(1)
            # CatFlow forward pass
            # Step 1: Sample t ~ Exp(1), x ~ N(0, I), e ~ N(0, I)
            t = model.sample_time().to(device)
            x = model.sample_noise(kind='node', num_nodes=num_nodes).to(device)
            e = model.sample_noise(kind='edge', num_nodes=num_nodes).to(device)
            # Step 2: Forward pass of the graph transformer
            inferred = model.forward(t=t, x=x, e=e, node_mask=node_mask)
            # Step 3: Calculate the loss
            loss = criterion(inferred.X, x_true.float()) + criterion(inferred.E, e_true.float())
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [69]:
for i in train_dataloader:
    trial_batch = i
    break   

  edge_index = torch.tensor(bond_idxs).T  # Shape: [2, num_edges]


In [70]:
# Train the model
num_epochs = 1000

for epoch in range(num_epochs):
    train_loss = train_epoch(catflow, optimizer, DataLoader(trial_batch), device)
    val_loss = validate_epoch(catflow, DataLoader(trial_batch), device)
    if epoch % 25 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}: Train loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}")

Epoch 1/1000: Train loss: 1.8866, Validation loss: 1.4932
Epoch 26/1000: Train loss: 0.7687, Validation loss: 0.7307
Epoch 51/1000: Train loss: 0.6128, Validation loss: 0.6401
Epoch 76/1000: Train loss: 0.5009, Validation loss: 0.5600
Epoch 101/1000: Train loss: 0.4733, Validation loss: 0.4771
Epoch 126/1000: Train loss: 0.4648, Validation loss: 0.4714
Epoch 151/1000: Train loss: 0.4520, Validation loss: 0.4431
Epoch 176/1000: Train loss: 0.4376, Validation loss: 0.4720
Epoch 201/1000: Train loss: 0.4454, Validation loss: 0.4321
Epoch 226/1000: Train loss: 0.4379, Validation loss: 0.4187
Epoch 251/1000: Train loss: 0.4174, Validation loss: 0.4353
Epoch 276/1000: Train loss: 0.4318, Validation loss: 0.4168
Epoch 301/1000: Train loss: 0.4086, Validation loss: 0.4316
Epoch 326/1000: Train loss: 0.4385, Validation loss: 0.4126
Epoch 351/1000: Train loss: 0.4309, Validation loss: 0.4194
Epoch 376/1000: Train loss: 0.4133, Validation loss: 0.4254
Epoch 401/1000: Train loss: 0.4278, Validatio