# GFlowNet Ring Graph Generation Tutorial

This tutorial demonstrates how to train a GFlowNet to generate ring graphs - where each vertex has exactly two neighbors and the edges form a single cycle containing all vertices. We'll explore both directed and undirected ring generation.

## Introduction to GFlowNets

GFlowNets (Generative Flow Networks) are a class of generative models designed for sequential decision-making problems. They learn to sample compositional objects by constructing them step by step. In this tutorial, we'll apply GFlowNets to the task of generating ring graphs.

## Setup

Let's first import all the necessary libraries:

In [1]:
import math
import time
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from matplotlib import patches

from gfn.containers import ReplayBuffer
from gfn.gflownet.trajectory_balance import TBGFlowNet
from gfn.gym.graph_building import GraphBuildingOnEdges
from gfn.modules import DiscreteGraphPolicyEstimator
from gfn.states import GraphStates
from gfn.utils.modules import GraphEdgeActionGNN, GraphEdgeActionMLP

  return torch._C._show_config()


## Defining the Reward Function

A crucial component of GFlowNet training is the reward function that evaluates whether a generated graph forms a valid ring structure. We define different validation logic for directed and undirected rings:


class RingReward(object):
    """
    This function evaluates if a graph forms a valid ring (directed or
        undirected cycle).

    Args:
        directed: Whether the graph is directed.
        reward_val: The reward for valid directed rings.
        eps_val: The reward for invalid structures.

    Returns:
        A tensor of rewards with the same batch shape as states
    """

    def __init__(
        self,
        directed: bool,
        reward_val: float = 100.0,
        eps_val: float = 1e-6,
        device: torch.device = torch.device("cpu"),
    ):
        self.directed = directed
        self.reward_val = reward_val
        self.eps_val = eps_val
        self.device = device

    def __call__(self, states: GraphStates) -> torch.Tensor:
        if self.directed:
            return self.directed_reward(states)
        else:
            return self.undirected_reward(states)

    def directed_reward(self, states: GraphStates) -> torch.Tensor:
        """Compute reward for directed ring graphs.

        This function evaluates if a graph forms a valid directed ring (cycle).
        A valid directed ring must satisfy these conditions:
        1. Each node must have exactly one outgoing edge (row sum = 1 in
            adjacency matrix).
        2. Each node must have exactly one incoming edge (column sum = 1 in
            adjacency matrix).
        3. Following the edges must form a single cycle that includes all nodes.

        Args:
            states: A batch of graph states to evaluate.

        Returns:
            A tensor of rewards with the same batch shape as states.
        """
        if states.tensor.edge_index.numel() == 0:
            return torch.full(states.batch_shape, self.eps_val, device=self.device)

        out = torch.full(
            (len(states),), self.eps_val, device=self.device
        )  # Default reward.

        for i in range(len(states)):
            graph = states[i]
            adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes)
            adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1

            # Check if each node has exactly one outgoing edge (row sum = 1)
            if not torch.all(adj_matrix.sum(dim=1) == 1):
                continue

            # Check that each node has exactly one incoming edge (column sum = 1)
            if not torch.all(adj_matrix.sum(dim=0) == 1):
                continue

            # Starting from node 0, follow edges and see if we visit all nodes
            # and return to the start
            visited, current = [], 0  # Start from node 0.

            while current not in visited:
                visited.append(current)

                # Get the outgoing neighbor
                current = torch.where(adj_matrix[int(current)] == 1)[0].item()

                # If we've visited all nodes and returned to 0, it's a valid ring
                if len(visited) == graph.tensor.num_nodes and current == 0:
                    out[i] = self.reward_val
                    break

        return out.view(*states.batch_shape)

    def undirected_reward(self, states: GraphStates) -> torch.Tensor:
        """Compute reward for undirected ring graphs.

        This function evaluates if a graph forms a valid undirected ring (cycle).
        A valid undirected ring must satisfy these conditions:
        1. Each node must have exactly two neighbors (degree = 2)
        2. The graph must form a single connected cycle including all nodes.

        The algorithm:
        1. Checks that all nodes have degree 2
        2. Performs a traversal starting from node 0, following edges
        3. Checks if the traversal visits all nodes and returns to start

        Args:
            states: A batch of graph states to evaluate

        Returns:
            A tensor of rewards with the same batch shape as states
        """
        if states.tensor.edge_index.numel() == 0:
            return torch.full(states.batch_shape, self.eps_val, device=self.device)

        out = torch.full(
            (len(states),), self.eps_val, device=self.device
        )  # Default reward.

        for i in range(len(states)):
            graph = states[i]
            if graph.tensor.num_nodes == 0:
                continue
            adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes)
            adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1
            adj_matrix[graph.tensor.edge_index[1], graph.tensor.edge_index[0]] = 1

            # In an undirected ring, every vertex should have degree 2.
            if not torch.all(adj_matrix.sum(dim=1) == 2):
                continue

            # Traverse the cycle starting from vertex 0.
            start_vertex = 0
            visited = [start_vertex]
            neighbors = torch.where(adj_matrix[start_vertex] == 1)[0]
            if neighbors.numel() == 0:
                continue
            # Arbitrarily choose one neighbor to begin the traversal.
            current = neighbors[0].item()
            prev = start_vertex

            while True:
                if current == start_vertex:
                    break
                visited.append(int(current))
                current_neighbors = torch.where(adj_matrix[int(current)] == 1)[0]
                # Exclude the neighbor we just came from.
                current_neighbors_list = [n.item() for n in current_neighbors]
                possible = [n for n in current_neighbors_list if n != prev]
                if len(possible) != 1:
                    break
                next_node = possible[0]
                prev, current = current, next_node

            if current == start_vertex and len(visited) == graph.tensor.num_nodes:
                out[i] = self.reward_val

        return out.view(*states.batch_shape)

## Visualization Function

To understand what the model is generating, we need a visualization function:

def render_states(states: GraphStates, state_evaluator: callable, directed: bool):
    """Visualize a batch of graph states as ring structures.

    This function creates a matplotlib visualization of graph states, rendering them
    as circular layouts with nodes positioned evenly around a circle. For directed
    graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines.

    The visualization includes:
    - Circular positioning of nodes
    - Drawing edges between connected nodes
    - Displaying the reward value for each graph

    Args:
        states: A batch of graphs to visualize
        state_evaluator: Function to compute rewards for each graph
        directed: Whether to render directed or undirected edges
    """
    rewards = state_evaluator(states)
    fig, ax = plt.subplots(2, 4, figsize=(15, 7))
    for i in range(8):
        current_ax = ax[i // 4, i % 4]
        state = states[i]
        n_circles = state.tensor.num_nodes
        radius = 5
        xs, ys = [], []
        for j in range(n_circles):
            angle = 2 * math.pi * j / n_circles
            x = radius * math.cos(angle)
            y = radius * math.sin(angle)
            xs.append(x)
            ys.append(y)
            current_ax.add_patch(
                patches.Circle((x, y), 0.5, facecolor="none", edgecolor="black")
            )

        edge_index = states[i].tensor.edge_index

        for edge in edge_index.T:
            start_x, start_y = xs[edge[0]], ys[edge[0]]
            end_x, end_y = xs[edge[1]], ys[edge[1]]
            dx = end_x - start_x
            dy = end_y - start_y
            length = math.sqrt(dx**2 + dy**2)
            dx, dy = dx / length, dy / length

            circle_radius = 0.5
            head_thickness = 0.2

            start_x += dx * circle_radius
            start_y += dy * circle_radius
            if directed:
                end_x -= dx * circle_radius
                end_y -= dy * circle_radius
                current_ax.arrow(
                    start_x,
                    start_y,
                    end_x - start_x,
                    end_y - start_y,
                    head_width=head_thickness,
                    head_length=head_thickness,
                    fc="black",
                    ec="black",
                )

            else:
                end_x -= dx * (circle_radius + head_thickness)
                end_y -= dy * (circle_radius + head_thickness)
                current_ax.plot([start_x, end_x], [start_y, end_y], color="black")

        current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$")
        current_ax.set_xlim(-(radius + 1), radius + 1)
        current_ax.set_ylim(-(radius + 1), radius + 1)
        current_ax.set_aspect("equal")
        current_ax.set_xticks([])
        current_ax.set_yticks([])

    plt.show()

## Training Function

Now we'll define the main training function that will put everything together:



In [2]:
def train_ring_gflownet(
    n_nodes=4,
    n_iterations=200,
    lr=0.001,
    batch_size=128,
    directed=True,
    use_buffer=False,
    use_gnn=True,
    num_conv_layers=1,
    device="cpu",
    plot=True
):
    """
    Train a GFlowNet to generate ring graphs.

    Args:
        n_nodes: Number of nodes in the graph
        n_iterations: Number of training iterations
        lr: Learning rate for optimizer
        batch_size: Batch size for training
        directed: Whether to generate directed rings
        use_buffer: Whether to use a replay buffer
        use_gnn: Whether to use GNN-based policy (True) or MLP-based policy (False)
        num_conv_layers: Number of convolutional layers (only used if use_gnn=True)
        device: Device to run on ("cpu" or "cuda")
        plot: Whether to plot generated graphs after training
        
    Returns:
        tuple: (trained GFlowNet, environment, state evaluator, training losses)
    """
    device = torch.device(device)

    state_evaluator = RingReward(
        directed=directed,
        reward_val=100.0,
        eps_val=1e-6,
        device=device,
    )
    torch.random.manual_seed(7)

    env = GraphBuildingOnEdges(
        n_nodes=n_nodes,
        state_evaluator=state_evaluator,
        directed=directed,
        device=device,
    )

    # Choose model type based on use_gnn flag
    if use_gnn:
        module_pf = GraphEdgeActionGNN(
            env.n_nodes,
            directed,
            num_conv_layers=num_conv_layers,
            num_edge_classes=env.num_edge_classes,
        )
        module_pb = GraphEdgeActionGNN(
            env.n_nodes,
            directed,
            is_backward=True,
            num_conv_layers=num_conv_layers,
            num_edge_classes=env.num_edge_classes,
        )
    else:
        module_pf = GraphEdgeActionMLP(
            env.n_nodes,
            directed,
            num_edge_classes=env.num_edge_classes,
        )
        module_pb = GraphEdgeActionMLP(
            env.n_nodes,
            directed,
            is_backward=True,
            num_edge_classes=env.num_edge_classes,
        )

    pf = DiscreteGraphPolicyEstimator(
        module=module_pf,
    )
    pb = DiscreteGraphPolicyEstimator(
        module=module_pb,
        is_backward=True,
    )
    gflownet = TBGFlowNet(pf, pb).to(device)
    optimizer = torch.optim.Adam(gflownet.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    replay_buffer = ReplayBuffer(
        env,
        capacity=batch_size,
        prioritized=True,
    )

    losses = []
    ring_percentages = []

    print(f"Starting training on {device} device...")
    print(f"Training parameters: n_nodes={n_nodes}, directed={directed}, use_gnn={use_gnn}")
    
    t1 = time.time()
    epsilon_dict = defaultdict(float)
    for iteration in range(n_iterations):
        epsilon_dict["action_type"] = 0.0  # 0.2 * (1 - iteration / n_iterations)

        trajectories = gflownet.sample_trajectories(
            env,
            n=batch_size,
            save_logprobs=True,
            epsilon=epsilon_dict,
        )
        training_samples = gflownet.to_training_samples(trajectories)

        # Collect rewards for reporting.
        terminating_states = training_samples.terminating_states
        assert isinstance(terminating_states, GraphStates)
        rewards = state_evaluator(terminating_states)

        if use_buffer:
            with torch.no_grad():
                replay_buffer.add(training_samples)
                if iteration > 20:
                    training_samples = training_samples[:batch_size // 2]
                    buffer_samples = replay_buffer.sample(
                        n_trajectories=batch_size // 2
                    )
                    training_samples.extend(buffer_samples)  # type: ignore

        optimizer.zero_grad()
        loss = gflownet.loss(env, training_samples, recalculate_all_logprobs=True)
        pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100
        ring_percentages.append(pct_rings.item())
        
        if iteration % 10 == 0:
            print(
                "Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format(
                    iteration, loss.item(), pct_rings
                )
            )
            
        loss.backward()
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())

    t2 = time.time()
    print(f"Training completed in {t2 - t1:.2f} seconds")
    print(f"Final percentage of valid rings: {ring_percentages[-1]:.0f}%")

    # Plot training curves
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses)
    plt.title('Training Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(ring_percentages)
    plt.title('Percentage of Valid Rings')
    plt.xlabel('Iteration')
    plt.ylabel('Valid Rings (%)')
    plt.ylim(0, 100)
    
    plt.tight_layout()
    plt.show()

    # Visualize generated graphs
    if plot:
        print("\nGenerating sample graphs...")
        samples_to_render = trajectories.terminating_states[:8]
        assert isinstance(samples_to_render, GraphStates)
        render_states(samples_to_render, state_evaluator, directed)
        
    return gflownet, env, state_evaluator, losses

## Training a Directed Ring Generator

Let's train a GFlowNet to generate directed ring graphs:








## Comparing GNN vs MLP Policies

Let's compare the performance of GNN-based policies versus MLP-based policies:


```python
# Train models with different policy networks and compare
def compare_policy_networks():
    results = {}
    
    # Train with GNN
    _, _, _, losses_gnn = train_ring_gflownet(
        n_nodes=5,
        n_iterations=100,
        use_gnn=True,
        plot=False
    )
    results['GNN'] = losses_gnn
    
    # Train with MLP
    _, _, _, losses_mlp = train_ring_gflownet(
        n_nodes=5,
        n_iterations=100,
        use_gnn=False,
        plot=False
    )
    results['MLP'] = losses_mlp
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    plt.plot(results['GNN'], label='GNN Policy')
    plt.plot(results['MLP'], label='MLP Policy')
    plt.title('GNN vs MLP Policy Performance')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
    return results

# Run the comparison
policy_comparison = compare_policy_networks()
```

## Scaling to Larger Graphs

Let's explore how our model performs as we increase the number of nodes:


```python
def node_scaling_experiment():
    node_counts = [4, 6, 8]
    results = {}
    
    for n_nodes in node_counts:
        print(f"\nTraining with {n_nodes} nodes...")
        _, _, _, losses = train_ring_gflownet(
            n_nodes=n_nodes,
            n_iterations=100,
            plot=False
        )
        results[n_nodes] = losses
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    for n_nodes, losses in results.items():
        plt.plot(losses, label=f'{n_nodes} nodes')
    
    plt.title('Training Loss for Different Graph Sizes')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
    return results

# Run the scaling experiment
scaling_results = node_scaling_experiment()
```

## Conclusion

In this tutorial, we've learned how to train a GFlowNet to generate both directed and undirected ring graphs. We've explored different policy network architectures (GNN vs MLP) and investigated how the model scales with graph size.

Key takeaways:
1. GFlowNets offer a powerful approach for generating structured graphs with specific constraints
2. Graph Neural Networks (GNNs) are well-suited for learning graph generation policies
3. The model can be trained to generate both directed and undirected rings
4. Training performance depends on graph size, with larger graphs requiring more training iterations

GFlowNets can be extended to generate many other types of graph structures, such as trees, grids, or more complex molecular structures.

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(42)

# Train a directed ring generator
gflownet_directed, env_directed, evaluator_directed, losses_directed = train_ring_gflownet(
    n_nodes=5,
    n_iterations=150,
    directed=True,
    use_gnn=True,
    batch_size=64,
    plot=True
)

## Training an Undirected Ring Generator

Now let's train a GFlowNet to generate undirected ring graphs:

# Set the random seed for reproducibility
torch.manual_seed(42)

# Train an undirected ring generator
gflownet_undirected, env_undirected, evaluator_undirected, losses_undirected = train_ring_gflownet(
    n_nodes=5,
    n_iterations=150,
    directed=False,
    use_gnn=True,
    batch_size=64,
    plot=True
)

## Sampling from Trained Models

We can use our trained models to generate additional ring graphs:

In [None]:
def generate_samples(gflownet, env, evaluator, n_samples=8, directed=True):
    """Generate samples from a trained GFlowNet and visualize them"""
    trajectories = gflownet.sample_trajectories(env, n=n_samples)
    terminating_states = trajectories.terminating_states
    render_states(terminating_states, evaluator, directed)
    return terminating_states


In [None]:
# Generate samples from the directed model
directed_samples = generate_samples(
    gflownet_directed, env_directed, evaluator_directed, n_samples=8, directed=True
)

In [None]:
# Generate samples from the undirected model
undirected_samples = generate_samples(
    gflownet_undirected, env_undirected, evaluator_undirected, n_samples=8, directed=False
)