# Variational Flow Matching for Graph Generation

### by Floor Eijkelboom et al. (2024)

Necessary imports

In [1]:
import os
import time
import math
import random
import glob
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
import dgl.function as fn
import torch.optim as optim
from torchdiffeq import odeint
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

### Backbone Graph Transformer

#### by Vijay Prakash Dwivedi, Xavier Bresson (2021)

The code is taken from the official implementation [here](https://github.com/graphdeeplearning/graphtransformer?tab=readme-ov-file). We use the part for the task of *SBMs_node_classification*, as this exactly reflects our classification objective.

Utils from *layers/graph_transformer_layer.py* and *train/metrics.py*:

In [3]:
# util functions


def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {
            out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(
                -1, keepdim=True
            )
        }

    return func


def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


def accuracy_SBM(scores, targets):
    S = targets.cpu().numpy()
    C = np.argmax(torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy(), axis=1)
    CM = confusion_matrix(S, C).astype(np.float32)
    nb_classes = CM.shape[0]
    targets = targets.cpu().detach().numpy()
    nb_non_empty_classes = 0
    pr_classes = np.zeros(nb_classes)
    for r in range(nb_classes):
        cluster = np.where(targets == r)[0]
        if cluster.shape[0] != 0:
            pr_classes[r] = CM[r, r] / float(cluster.shape[0])
            if CM[r, r] > 0:
                nb_non_empty_classes += 1
        else:
            pr_classes[r] = 0.0
    acc = 100.0 * np.sum(pr_classes) / float(nb_classes)
    return acc

In [17]:
# Single attention head


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, use_bias):
        super().__init__()

        self.out_dim = out_dim
        self.num_heads = num_heads

        if use_bias:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
        else:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
            self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)

    def propagate_attention(self, g):
        # Compute attention score
        g.apply_edges(src_dot_dst("K_h", "Q_h", "score"))  # , edges)
        g.apply_edges(scaled_exp("score", np.sqrt(self.out_dim)))

        # Send weighted values to target nodes
        eids = g.edges()
        g.send_and_recv(
            eids, fn.src_mul_edge("V_h", "score", "V_h"), fn.sum("V_h", "wV")
        )
        g.send_and_recv(eids, fn.copy_edge("score", "score"), fn.sum("score", "z"))

    def forward(self, g, h):

        Q_h = self.Q(h)
        K_h = self.K(h)
        V_h = self.V(h)

        # Reshaping into [num_nodes, num_heads, feat_dim] to
        # get projections for multi-head attention
        g.ndata["Q_h"] = Q_h.view(-1, self.num_heads, self.out_dim)
        g.ndata["K_h"] = K_h.view(-1, self.num_heads, self.out_dim)
        g.ndata["V_h"] = V_h.view(-1, self.num_heads, self.out_dim)

        self.propagate_attention(g)

        head_out = g.ndata["wV"] / g.ndata["z"]

        return head_out

In [18]:
class GraphTransformerLayer(nn.Module):
    """
    Param:
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        num_heads,
        dropout=0.0,
        layer_norm=False,
        batch_norm=True,
        residual=True,
        use_bias=False,
    ):
        super().__init__()

        self.in_channels = in_dim
        self.out_channels = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.residual = residual
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm

        self.attention = MultiHeadAttentionLayer(
            in_dim, out_dim // num_heads, num_heads, use_bias
        )

        self.O = nn.Linear(out_dim, out_dim)

        if self.layer_norm:
            self.layer_norm1 = nn.LayerNorm(out_dim)

        if self.batch_norm:
            self.batch_norm1 = nn.BatchNorm1d(out_dim)

        # FFN
        self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2)
        self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim)

        if self.layer_norm:
            self.layer_norm2 = nn.LayerNorm(out_dim)

        if self.batch_norm:
            self.batch_norm2 = nn.BatchNorm1d(out_dim)

    def forward(self, g, h):
        h_in1 = h  # for first residual connection

        # multi-head attention out
        attn_out = self.attention(g, h)
        h = attn_out.view(-1, self.out_channels)

        h = F.dropout(h, self.dropout, training=self.training)

        h = self.O(h)

        if self.residual:
            h = h_in1 + h  # residual connection

        if self.layer_norm:
            h = self.layer_norm1(h)

        if self.batch_norm:
            h = self.batch_norm1(h)

        h_in2 = h  # for second residual connection

        # FFN
        h = self.FFN_layer1(h)
        h = F.relu(h)
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.FFN_layer2(h)

        if self.residual:
            h = h_in2 + h  # residual connection

        if self.layer_norm:
            h = self.layer_norm2(h)

        if self.batch_norm:
            h = self.batch_norm2(h)

        return h

    def __repr__(self):
        return "{}(in_channels={}, out_channels={}, heads={}, residual={})".format(
            self.__class__.__name__,
            self.in_channels,
            self.out_channels,
            self.num_heads,
            self.residual,
        )

Contents of the file *layers/mlp_readout_layer.py*:

In [19]:
"""
    MLP Layer used after graph vector representation
"""


class MLPReadout(nn.Module):

    def __init__(self, input_dim, output_dim, L=2):  # L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [
            nn.Linear(input_dim // 2**l, input_dim // 2 ** (l + 1), bias=True)
            for l in range(L)
        ]
        list_FC_layers.append(nn.Linear(input_dim // 2**L, output_dim, bias=True))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L

    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y

Contents of the file *nets/SBMs_node_classification/graph_transformer_net.py*:

In [5]:
class GraphTransformerNet(nn.Module):

    def __init__(self, net_params):
        super().__init__()

        in_dim_node = net_params["in_dim"]  # node_dim (feat is an integer)
        hidden_dim = net_params["hidden_dim"]
        out_dim = net_params["out_dim"]
        n_classes = net_params["n_classes"]
        num_heads = net_params["n_heads"]
        in_feat_dropout = net_params["in_feat_dropout"]
        dropout = net_params["dropout"]
        n_layers = net_params["L"]

        self.readout = net_params["readout"]
        self.layer_norm = net_params["layer_norm"]
        self.batch_norm = net_params["batch_norm"]
        self.residual = net_params["residual"]
        self.dropout = dropout
        self.n_classes = n_classes
        self.device = net_params["device"]
        self.lap_pos_enc = net_params["lap_pos_enc"]
        self.wl_pos_enc = net_params["wl_pos_enc"]
        max_wl_role_index = 100

        if self.lap_pos_enc:
            pos_enc_dim = net_params["pos_enc_dim"]
            self.embedding_lap_pos_enc = nn.Linear(pos_enc_dim, hidden_dim)
        if self.wl_pos_enc:
            self.embedding_wl_pos_enc = nn.Embedding(max_wl_role_index, hidden_dim)

        self.embedding_h = nn.Embedding(
            in_dim_node, hidden_dim
        )  # node feat is an integer

        self.in_feat_dropout = nn.Dropout(in_feat_dropout)

        self.layers = nn.ModuleList(
            [
                GraphTransformerLayer(
                    hidden_dim,
                    hidden_dim,
                    num_heads,
                    dropout,
                    self.layer_norm,
                    self.batch_norm,
                    self.residual,
                )
                for _ in range(n_layers - 1)
            ]
        )
        self.layers.append(
            GraphTransformerLayer(
                hidden_dim,
                out_dim,
                num_heads,
                dropout,
                self.layer_norm,
                self.batch_norm,
                self.residual,
            )
        )
        self.MLP_layer = MLPReadout(out_dim, n_classes)

    def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None):

        # input embedding
        h = self.embedding_h(h)
        if self.lap_pos_enc:
            h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float())
            h = h + h_lap_pos_enc
        if self.wl_pos_enc:
            h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc)
            h = h + h_wl_pos_enc
        h = self.in_feat_dropout(h)

        # GraphTransformer Layers
        for conv in self.layers:
            h = conv(g, h)

        # output
        h_out = self.MLP_layer(h)

        return h_out

    def loss(self, pred, label):

        # calculating label weights for weighted loss computation
        V = label.size(0)
        label_count = torch.bincount(label)
        label_count = label_count[label_count.nonzero()].squeeze()
        cluster_sizes = torch.zeros(self.n_classes).long().to(self.device)
        cluster_sizes[torch.unique(label)] = label_count
        weight = (V - cluster_sizes).float() / V
        weight *= (cluster_sizes > 0).float()

        # weighted cross-entropy for unbalanced classes
        criterion = nn.CrossEntropyLoss(weight=weight)
        loss = criterion(pred, label)

        return loss

Contents of the file *train/train_SBMs_node_classification.py*:

In [10]:
"""
    Utility functions for training one epoch
    and evaluating one epoch
"""


def train_epoch(model, optimizer, device, data_loader, epoch):

    model.train()
    epoch_loss = 0
    epoch_train_acc = 0
    for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
        batch_graphs = batch_graphs.to(device)
        batch_x = batch_graphs.ndata["feat"].to(device)  # num x feat
        batch_e = batch_graphs.edata["feat"].to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()
        try:
            batch_lap_pos_enc = batch_graphs.ndata["lap_pos_enc"].to(device)
            sign_flip = torch.rand(batch_lap_pos_enc.size(1)).to(device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            batch_lap_pos_enc = batch_lap_pos_enc * sign_flip.unsqueeze(0)
        except:
            batch_lap_pos_enc = None

        try:
            batch_wl_pos_enc = batch_graphs.ndata["wl_pos_enc"].to(device)
        except:
            batch_wl_pos_enc = None

        batch_scores = model.forward(
            batch_graphs, batch_x, batch_e, batch_lap_pos_enc, batch_wl_pos_enc
        )

        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        epoch_train_acc += accuracy_SBM(batch_scores, batch_labels)
    epoch_loss /= iter + 1
    epoch_train_acc /= iter + 1

    return epoch_loss, epoch_train_acc, optimizer


def evaluate_network(model, device, data_loader, epoch):

    model.eval()
    epoch_test_loss = 0
    epoch_test_acc = 0
    with torch.no_grad():
        for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
            batch_graphs = batch_graphs.to(device)
            batch_x = batch_graphs.ndata["feat"].to(device)
            batch_e = batch_graphs.edata["feat"].to(device)
            batch_labels = batch_labels.to(device)
            try:
                batch_lap_pos_enc = batch_graphs.ndata["lap_pos_enc"].to(device)
            except:
                batch_lap_pos_enc = None

            try:
                batch_wl_pos_enc = batch_graphs.ndata["wl_pos_enc"].to(device)
            except:
                batch_wl_pos_enc = None

            batch_scores = model.forward(
                batch_graphs, batch_x, batch_e, batch_lap_pos_enc, batch_wl_pos_enc
            )
            loss = model.loss(batch_scores, batch_labels)
            epoch_test_loss += loss.detach().item()
            epoch_test_acc += accuracy_SBM(batch_scores, batch_labels)
        epoch_test_loss /= iter + 1
        epoch_test_acc /= iter + 1

    return epoch_test_loss, epoch_test_acc

Contents of the file *main_SBMs_node_classification.py*:

In [24]:
"""
    TRAINING CODE
"""


def train_val_pipeline(dataset, params, net_params, dirs):

    start0 = time.time()
    per_epoch_time = []

    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params["device"]

    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)

    # setting seeds
    random.seed(params["seed"])
    np.random.seed(params["seed"])
    torch.manual_seed(params["seed"])
    if device.type == "cuda":
        torch.cuda.manual_seed(params["seed"])

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    print("Number of Classes: ", net_params["n_classes"])

    model = GraphTransformerNet(net_params)
    model = model.to(device)

    optimizer = optim.Adam(
        model.parameters(), lr=params["init_lr"], weight_decay=params["weight_decay"]
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=params["lr_reduce_factor"],
        patience=params["lr_schedule_patience"],
        verbose=True,
    )

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs = [], []

    train_loader = DataLoader(
        trainset,
        batch_size=params["batch_size"],
        shuffle=True,
        collate_fn=dataset.collate,
    )
    val_loader = DataLoader(
        valset,
        batch_size=params["batch_size"],
        shuffle=False,
        collate_fn=dataset.collate,
    )
    test_loader = DataLoader(
        testset,
        batch_size=params["batch_size"],
        shuffle=False,
        collate_fn=dataset.collate,
    )

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params["epochs"])) as t:
            for epoch in t:
                t.set_description("Epoch %d" % epoch)
                start = time.time()

                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                    model, optimizer, device, train_loader, epoch
                )

                epoch_val_loss, epoch_val_acc = evaluate_network(
                    model, device, val_loader, epoch
                )
                _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)

                writer.add_scalar("train/_loss", epoch_train_loss, epoch)
                writer.add_scalar("val/_loss", epoch_val_loss, epoch)
                writer.add_scalar("train/_acc", epoch_train_acc, epoch)
                writer.add_scalar("val/_acc", epoch_val_acc, epoch)
                writer.add_scalar("test/_acc", epoch_test_acc, epoch)
                writer.add_scalar(
                    "learning_rate", optimizer.param_groups[0]["lr"], epoch
                )

                t.set_postfix(
                    time=time.time() - start,
                    lr=optimizer.param_groups[0]["lr"],
                    train_loss=epoch_train_loss,
                    val_loss=epoch_val_loss,
                    train_acc=epoch_train_acc,
                    val_acc=epoch_val_acc,
                    test_acc=epoch_test_acc,
                )

                per_epoch_time.append(time.time() - start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(
                    model.state_dict(),
                    "{}.pkl".format(ckpt_dir + "/epoch_" + str(epoch)),
                )

                files = glob.glob(ckpt_dir + "/*.pkl")
                for file in files:
                    epoch_nb = file.split("_")[-1]
                    epoch_nb = int(epoch_nb.split(".")[0])
                    if epoch_nb < epoch - 1:
                        os.remove(file)

                scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]["lr"] < params["min_lr"]:
                    print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    break

                # Stop training after params['max_time'] hours
                if time.time() - start0 > params["max_time"] * 3600:
                    print("-" * 89)
                    print(
                        "Max_time for training elapsed {:.2f} hours, so stopping".format(
                            params["max_time"]
                        )
                    )
                    break

    except KeyboardInterrupt:
        print("-" * 89)
        print("Exiting from training early because of KeyboardInterrupt")

    _, test_acc = evaluate_network(model, device, test_loader, epoch)
    _, train_acc = evaluate_network(model, device, train_loader, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()

    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + ".txt", "w") as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""".format(
                "Test",
                "Test",
                params,
                net_params,
                model,
                net_params["total_param"],
                test_acc,
                train_acc,
                epoch,
                (time.time() - start0) / 3600,
                np.mean(per_epoch_time),
            )
        )

### Instead of taking the original implementation, we directly follow Eijkelboom et al. by adopting the implementation by Vignac et al. (2023), from **DiGress: Discrete Denoising diffusion for graph generation**

In [13]:
# add softmax to the output layer
from transformer_model import GraphTransformer

In [14]:
# Taken from Davis et al. (2024) Fisher Flow Matching, https://github.com/olsdavis/fisher-flow
def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=half, dtype=torch.float32)
        / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

### Our definition of the Backbone for the CatFlow model

In [24]:
# TODO: add +1 to the num_classes that denotes the absence of an edge


class Backbone(nn.Module):
    def __init__(self):
        super(Backbone, self).__init__()
        # TODO: Figure out the proper model dictionary dimensions and add to the signature
        self.batch_size = 10
        self.num_nodes = 10
        self.num_classes = 10

        n_layers = 2
        input_dims = {"X": 10, "E": 10, "y": 10}
        hidden_mlp_dims = {"X": 10, "E": 10, "y": 10}
        hidden_dims = {
            "dx": 10,
            "de": 10,
            "dy": 10,
            "n_head": 2,
            "dim_ffX": 10,
            "dim_ffE": 10,
        }
        output_dims = {"X": 10, "E": 10, "y": 10}

        self.graph_transformer = GraphTransformer(
            n_layers=n_layers,
            input_dims=input_dims,
            hidden_mlp_dims=hidden_mlp_dims,
            hidden_dims=hidden_dims,
            output_dims=output_dims,
        )

    def forward(self, t: torch.tensor, x: torch.tensor):
        """
        Forward pass of the backbone model.

        Args:
            t (torch.tensor): Time step. Shape: (batch_size, num_nodes, 1).
            x (torch.tensor): Input noise. Shape: (batch_size, num_nodes + (num_nodes - 1)^2, num_classes + 1). The second dimension corresponds to the nodes and edges of the graph, respectively.

        Returns:
            torch.tensor: Parameters of the variational distribution. Shape: (batch_size, num_nodes + (num_nodes - 1)^2, num_classes + 1).
        """
        # embed the timestep using the sinusoidal positional encoding
        print(t.shape)
        t_embedded = timestep_embedding(t, dim=self.num_classes)
        # add time embedding to the input
        print(x.shape)
        print(t_embedded.shape)
        print(t_embedded[:, None, :].shape)
        x += t_embedded[:, None, :]
        # TODO: after figuring out the dimensions, use the proper forward pass
        # return self.graph_transformer(X=x[:, :self.num_nodes, :], E=x[:, self.num_nodes:, :], y=torch.zeros(self.batch_size, self.num_classes + 1), node_mask=torch.ones(self.batch_size, self.num_nodes))
        return self.graph_transformer(
            X=x,
            E=x,
            y=torch.zeros(self.batch_size, self.num_classes),
            node_mask=torch.ones(self.batch_size, self.num_nodes),
        )

### CatFlow Model

In [34]:
class CatFlow:

    def __init__(
        self,
        backbone_model: Backbone,
        batch_size: int = 10,
        num_nodes: int = 10,
        num_classes: int = 10,
        eps: float = 1e-6,
    ) -> None:
        """
        Constructor of the CatFlow model.

        Args:
            backbone_model (Backbone): Backbone model to extract features. In the case of our experiments, we use a graph transformer network.
            batch_size (int): Batch size. Default value is 32.
            num_nodes (int): Number of nodes. Default value is 10.
            num_classes (int): Number of classes. Default value is 10.
            eps (float): Epsilon value to avoid numerical instability. Default value is 1e-6.
        """
        self.backbone_model = backbone_model
        self.batch_size = batch_size
        self.num_nodes = num_nodes
        self.num_classes = num_classes
        self.eps = eps

    def sample_time(self, lambd: torch.tensor = torch.tensor([1.0])) -> torch.tensor:
        """
        Function to sample the time step for the CatFlow model.

        Args:
            lambd (torch.tensor): Rate parameter of the exponential distribution. Default value is 1.0.

        Returns:
            torch.tensor: Time step. Shape: (1,).
        """
        # As in Dirichlet Flow Matching, we sample the time step from Exp(1)
        return torch.distributions.exponential.Exponential(lambd).sample()

    def sample_noise(self) -> torch.tensor:
        """
        Function to sample the noise for the CatFlow model.

        Returns:
            torch.tensor: Noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).
        """
        # Judging by the page 7 of the paper: the noise is not constrained to the simplex; so we can sample from a normal distribution
        # TODO: after figuring out the dimensions, use the proper forward pass
        # return torch.randn(self.batch_size, self.num_nodes + (self.num_nodes - 1)**2, self.num_classes + 1)
        return torch.randn(self.batch_size, self.num_nodes, self.num_classes)

    def loss(self, theta_true: torch.tensor) -> torch.tensor:
        """
        Function to calculate the cross entropy loss of the CatFlow model.
        Theta corresponds to the parameters of the variational distribution for the categorical case.

        Args:
            theta_true (torch.tensor): One-hot encoded true classes for the batch of samples. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).

        Returns:
            torch.tensor: Cross entropy loss.
        """
        # Sample time steps
        t = self.sample_time()
        # Sample noise
        x = self.sample_noise()
        # Forward pass of the backbone model
        theta_pred = self.backbone_model(t, x)
        # TODO: remove, just for testing
        theta_pred = F.relu(theta_pred.X)
        # Calculate cross entropy loss
        return -torch.sum(theta_true * torch.log(theta_pred + self.eps))

    def vector_field(self, t: torch.tensor, x: torch.tensor) -> torch.tensor:
        """
        Function that returns the vector field of the CatFlow model for a given timestamp.

        Args:
            t (torch.tensor): Time step. Shape: (batch_size, 1).
            x (torch.tensor): Input noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).

        Returns:
            torch.tensor: Vector field. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).
        """

        return (self.backbone_model(t, x) - x) / (1 - t)

    def sampling(self, x: torch.tensor) -> torch.tensor:
        """
        Function to sample a new instance following the learned vector field.

        Args:
            x (torch.tensor): Input noise. Shape: (batch_size, num_nodes + (num_nodes - 1)**2, num_classes + 1).

        Returns:
            torch.tensor: Sampled result.
        """
        # Define the time points over which to solve the ODE
        time_points = torch.linspace(0, 0.95, steps=20)  # Adjust steps as needed

        # Run the ODE solver with fixed time points
        return odeint(self.vector_field, x, time_points)

### Sanity check: a single forward pass

In [35]:
# Define the batch

# Generate random indices for the one-hot encoding
indices = torch.randint(0, 10, (10, 10))
# Create a one-hot encoded tensor using scatter
x = torch.zeros(10, 10, 10).scatter_(2, indices.unsqueeze(-1), 1)

graph_transformer = Backbone()
catflow = CatFlow(backbone_model=graph_transformer)

In [36]:
catflow.loss(x)

torch.Size([1])
torch.Size([10, 10, 10])
torch.Size([1, 10])
torch.Size([1, 1, 10])


tensor(372.9319, grad_fn=<NegBackward0>)

### Training

In [38]:
# TBD: implement the training loop for the CatFlow model