# Install dependency in Colab

In [None]:
%%bash
cd ../torchgfn
pip install .

# Demo (Pseudocode)

In [3]:
import torch
import numpy as np
from scipy.stats import spearmanr
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils.modules import MLP
from gfn.states import DiscreteStates

In [4]:
# 0 - Find Available GPU resource
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 1 - Define the environment
env = HyperGrid(ndim=4, height=8, R0=0.01)

# 2 - Define the neural network modules
module_PF = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions)
module_PB = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk)

# 3 - Define the estimators
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - Define the GFlowNet
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator)

# 5 - Define the sampler and optimizer
sampler = Sampler(estimator=pf_estimator)
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - Train the GFlowNet
for i in (pbar := tqdm(range(1000))):
    trajectories = sampler.sample_trajectories(env=env, n=16)
    optimizer.zero_grad()
    loss = gfn.loss(env, trajectories).to(device)
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        pbar.set_postfix({"loss": loss.item()})

Using device: mps


100%|██████████| 1000/1000 [00:14<00:00, 68.39it/s, loss=0.129]


In [56]:
import torch
# TODO
def compute_log_probability(gfn, state, memo={}):
    """
    Recursively computes the log of the sampling probability π_θ(s) for a given terminal state `state`
    in a GFlowNet `gfn` using torchgfn library.

    Args:
        gfn (GFlowNet): The GFlowNet model instance.
        state (States): The terminal state for which we want to compute log π_θ(s).
        memo (dict): A dictionary for memoization to store previously computed log probabilities.

    Returns:
        torch.Tensor: The log probability π_θ(s).
    """
    # Check if the result is already computed and stored in memo
    if state in memo:
        return memo[state]

    # Base case: if the state is the initial state, log π_θ(s_initial) = 0
    if state.is_initial_state.all():
        log_prob = torch.tensor(0.0, requires_grad=False)
        memo[state] = log_prob
        return log_prob

    # Recursive case: compute log π_θ(s) from parent states
    # TODO: how to get the parents states?
    parent_states = get_parents(state)

    # Collect log-probabilities for each parent transition
    log_probs = []
    for parent_state in parent_states:
        # Forward transition probability in log form
        log_forward_prob = torch.log(gfn.get_forward_transition_probability(state, parent_state))

        # Recursively compute log π_θ(parent_state)
        log_parent_prob = compute_log_probability(gfn, parent_state, memo)

        # Compute the sum inside the exponent for this parent
        log_probs.append(log_forward_prob + log_parent_prob)

    # Sum of exponentiated log-probabilities (log-sum-exp trick for numerical stability)
    log_prob = torch.logsumexp(torch.stack(log_probs), dim=0)

    # Memoize and return
    memo[state] = log_prob
    return log_prob

In [42]:
# 8 - Generate a test set and compute probabilities
n_test = 100  # Number of test trajectories
test_trajectories = sampler.sample_trajectories(env=env, n=n_test)

In [None]:
# each trajactories instance contains n_trajectories trajactories, probably with different lengths
# in this case, it only contains one trajactory, with a length of 15.
test_trajectories[0]

Trajectories(n_trajectories=1, max_length=15, First 10 trajectories:states=
[0 0 0 0]-> [0 0 1 0]-> [0 1 1 0]-> [0 2 1 0]-> [1 2 1 0]-> [2 2 1 0]-> [2 3 1 0]-> [3 3 1 0]-> [4 3 1 0]-> [4 4 1 0]-> [5 4 1 0]-> [5 5 1 0]-> [5 6 1 0]-> [6 6 1 0]-> [6 6 2 0]-> [-1 -1 -1 -1]
when_is_done=[15])

In [54]:
# states[0] denotes the initial state
print(test_trajectories[0].states[0].tensor)
print(f'states[0] is_initial_state: {test_trajectories[0].states[0].is_initial_state}')

tensor([[0, 0, 0, 0]])
states[0] is_initial_state: tensor([True])


In [None]:
# states[-1] denotes the sink state
print(test_trajectories[0].states[-1].tensor)
print(f'states[-1] is_sink_state: {test_trajectories[0].states[-1].is_sink_state}')

tensor([[-1, -1, -1, -1]])
states[-1] is_sink_state: tensor([True])


$$\log \pi_\theta(s) = \log \left( \sum_{s{\prime} \in \text{Parent}(s)} \exp \left( \log P_{F_\theta}(s | s{\prime}) + \log \pi_\theta(s{\prime}) \right) \right)$$

where $ P_{F_\theta}(s | s{\prime})  $ is the forward transition probability, and s is a state in the trajectory.

In [60]:
import torch
from gfn.samplers import Sampler
from collections import Counter

# Define the function to compute the sampling probability
def compute_sampling_probability_with_monte_carlo(env, sampler, terminal_state, n_samples=10000):
    """
    Computes the sampling probability of a given terminal state using Monte Carlo.

    Args:
        env: The environment instance.
        sampler: An initialized Sampler using the forward policy estimator.
        terminal_state: The terminal state whose probability we want to compute (as a tensor).
        n_samples: The number of trajectories to sample.

    Returns:
        float: The estimated sampling probability of the terminal state.
    """
    # Sample trajectories
    trajectories = sampler.sample_trajectories(env=env, n=n_samples)

    # Extract terminal states
    terminal_states = trajectories[0].states[-2]

    # Convert terminal states to a hashable form (tuple) for counting
    terminal_states_tuples = [tuple(state.tolist()) for state in terminal_states.tensor]

    # Count occurrences of the terminal state
    terminal_state_tuple = tuple(terminal_state.tolist())
    occurrences = Counter(terminal_states_tuples)

    # Calculate the probability
    sampling_probability = occurrences[terminal_state_tuple] / n_samples
    return sampling_probability


# Define the terminal state (replace with the actual state representation)
terminal_state = torch.tensor([6, 6, 2, 0])

# Compute the sampling probability
sampling_probability = compute_sampling_probability_with_monte_carlo(env, sampler, terminal_state, n_samples=10000)
print(f"Sampling probability of the terminal state {terminal_state.tolist()}: {sampling_probability}")

Sampling probability of the terminal state [6, 6, 2, 0]: 0.0


In [None]:
# Initialize lists to hold the probabilities and rewards
log_probs = []
log_rewards = []
memo = {}
# Calculate the log probability and log reward for each terminal state
for traj in test_trajectories:
    terminal_states = traj[-1].states
    reward = env.reward(terminal_state)
    log_reward = np.log(reward)
    # TODO
    log_prob = compute_log_probability(gfn, terminal_states, memo)
    log_probs.append(log_prob)
    log_rewards.append(log_reward)

In [None]:
# 9 - Compute Spearman's Rank Correlation
spearman_corr, _ = spearmanr(log_probs, log_rewards)
print(f"Spearman's Rank Correlation (GFNEvalS): {spearman_corr}")