# Install dependency in Colab

In [329]:
%%bash
cd ../torchgfn
pip install .

Processing /Users/erostrate9/Desktop/CSI5340 DL/Project/code/GFNEval/torchgfn
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: torchgfn
  Building wheel for torchgfn (pyproject.toml): started
  Building wheel for torchgfn (pyproject.toml): finished with status 'done'
  Created wheel for torchgfn: filename=torchgfn-1.1.1-py3-none-any.whl size=82659 sha256=0aaa942db23e923d172421a5e00a64411452a8d8bcbe2a4071cd59d367f371d4
  Stored in directory: /private/var/folders/c_/9pzrss116732p7dxch3kn_bc0000gn/T/pip-ephem-wheel-cache-8gu9mafz/wheels/56/de/11/edbaf478c4bdb3bf4d2dadfda48c78d0790413f2f66eee7a21
Successfully built torchgfn
Installing collected packages: torchgfn
  Attemptin

# GFNEvalS Demo

In [None]:
import torch
import numpy as np
from scipy.stats import spearmanr
from tqdm import tqdm

from gfn.gflownet import GFlowNet
from gfn.gym import HyperGrid, HyperGrid2
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils.modules import MLP
from gfn.states import States, DiscreteStates

ImportError: cannot import name 'HyperGrid2' from 'gfn.gym' (/Users/erostrate9/miniconda3/envs/gfn/lib/python3.10/site-packages/gfn/gym/__init__.py)

## Train GFlowNet

In [2]:
# 0 - Find Available GPU resource
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 1 - Define the environment
env = HyperGrid(ndim=4, height=8, R0=0.01)

# 2 - Define the neural network modules
module_PF = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions)
module_PB = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk)

# 3 - Define the estimators
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - Define the GFlowNet
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator)

# 5 - Define the sampler and optimizer
sampler = Sampler(estimator=pf_estimator)
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - Train the GFlowNet
for i in (pbar := tqdm(range(1000))):
    trajectories = sampler.sample_trajectories(env=env, n=16)
    optimizer.zero_grad()
    loss = gfn.loss(env, trajectories).to(device)
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        pbar.set_postfix({"loss": loss.item()})

Using device: mps


100%|██████████| 1000/1000 [00:14<00:00, 67.16it/s, loss=0.215]


In [18]:
# to get all possible states from the environment
# height**ndim
env.all_states
assert len(env.all_states)==env.height**env.ndim

In [19]:
# n_actions = ndim + 1
# actions are represented by a number in {0, …, n_actions - 1}, the last one being the exit action.
env.n_actions

5

## Compute Sampling Probability

In [20]:
import torch
from collections import defaultdict

class TensorDict:
    def __init__(self, default_factory=None):
        # Use a defaultdict with an optional default factory
        self.data = defaultdict(default_factory)
        self.default_factory = default_factory

    def _tensor_to_hashable(self, tensor):
        # Recursively convert a tensor to a hashable structure (nested tuples)
        if isinstance(tensor, torch.Tensor):
            return self._tensor_to_hashable(tensor.tolist())
        elif isinstance(tensor, list):
            return tuple(self._tensor_to_hashable(item) for item in tensor)
        else:
            return tensor  # Base case: numbers are already hashable

    def __setitem__(self, tensor, value):
        # Convert tensor to a hashable structure for storage
        key = self._tensor_to_hashable(tensor)
        self.data[key] = value

    def __getitem__(self, tensor):
        # Retrieve value based on hashable structure
        key = self._tensor_to_hashable(tensor)
        return self.data[key]

    def __contains__(self, tensor):
        # Check existence based on hashable structure
        key = self._tensor_to_hashable(tensor)
        return key in self.data

    def __str__(self):
        # Pretty-print the dictionary content as tuples and values
        pretty_dict = {
            str(key): value for key, value in self.data.items()
        }
        return str(pretty_dict)

    def __repr__(self):
        # Provide a developer-friendly representation
        return f"TensorDict({self.__str__()})"

# Example usage
dic = TensorDict(default_factory=lambda: False)

x1 = torch.tensor([[6, 6], [2, 0]])
x2 = torch.tensor([[1, 2], [3, 4]])
x3 = torch.tensor([0, 0, 0, 1])

dic[x1] = True
dic[x2] = False
dic[x3] = True

print(dic)

{'((6, 6), (2, 0))': True, '((1, 2), (3, 4))': False, '(0, 0, 0, 1)': True}


In [289]:
def get_all_transition_log_probs(env, pf_estimator):
    """
        Args:
    Returns transition_log_probs, a Tensor list with length of env.n_actions. 
        transition_log_probs[i][j] indicates the log probability of taking action i at a State env.all_states[j], i in [0, n_actions-1]
    """
    all_states = env.all_states
    estimator_output = pf_estimator(all_states)
    dist = pf_estimator.to_probability_distribution(all_states, estimator_output)
    transition_log_probs = [None] * env.n_actions
    
    for i in range(env.n_actions):
        action = torch.Tensor([i])
        transition_log_probs[i] = dist.log_prob(action)
    return transition_log_probs

$$\log \pi_\theta(s) = \log \left( \sum_{s{\prime} \in \text{Parent}(s)} \exp \left( \log P_{F_\theta}(s | s{\prime}) + \log \pi_\theta(s{\prime}) \right) \right)$$

where $ P_{F_\theta}(s | s{\prime})  $ is the forward transition probability, and s is a state in the trajectory.

In [None]:
import torch
from gfn.states import stack_states
def compute_log_probability(env: HyperGrid, gfn, state: DiscreteStates, memo, transition_log_probs):
    """
    Recursively computes the log of the sampling probability π_θ(s) for a given terminal state `state`
    in a GFlowNet `gfn` using torchgfn library.

    Args:
        gfn (GFlowNet): The GFlowNet model instance.
        state (States): The terminal state for which we want to compute log π_θ(s).
        memo (dict): A dictionary for memoization to store previously computed log probabilities.

    Returns:
        torch.Tensor: The log probability π_θ(s).
    """
    if len(state.tensor.shape)==1:
        state = stack_states([state])
    # Check if the result is already computed and stored in memo
    if state.tensor in memo:
        return memo[state.tensor]

    # Base case: if the state is the initial state, log π_θ(s_initial) = 0
    if state.is_initial_state.all():
        log_prob = torch.tensor([0.0], requires_grad=False)
        memo[state.tensor] = log_prob
        return log_prob
    
    # Recursive case: compute log π_θ(s) from all parent states
    # Collect log-probabilities for each parent transition
    log_probs = []
    # to iterate each parent state and the corresponding action
    for i in range(env.n_actions-1):
        action = env.actions_from_tensor(torch.Tensor([[i]]).to(torch.int64))
        env.update_masks(state)
        if env.is_action_valid(state, action, backward=True):
            # s'
            parent_state_tensor = env.backward_step(state, action)
            parent_state = env.states_from_tensor(parent_state_tensor)
            # parent_state = stack_states([parent_state])
            parent_state_idx = env.get_states_indices(parent_state)
            # logPF(s|s'): Forward transition probability in log form
            log_forward_prob = transition_log_probs[i][parent_state_idx]
            # log π_θ(s'): Recursively compute log π_θ(parent_state)
            log_parent_prob=compute_log_probability(env, gfn, parent_state, memo, transition_log_probs)
            # Compute the sum inside the exponent for this parent
            log_probs.append(log_forward_prob + log_parent_prob)
    # Sum of exponentiated log-probabilities (log-sum-exp trick for numerical stability)
    log_prob = torch.logsumexp(torch.stack(log_probs), dim=0)
    # Memoize and return
    memo[state.tensor] = log_prob
    return log_prob

In [302]:
# 8 - Generate a test set and compute probabilities
n_test = 100  # Number of test trajectories
test_trajectories = sampler.sample_trajectories(env=env, n=n_test)

## Compute GFNEvalS

In [303]:
def compute_log_prob_termination(env: HyperGrid, terminal_state: DiscreteStates, memo, transition_log_probs):
    if len(terminal_state.tensor.shape)==1:
        terminal_state = stack_states([terminal_state])
    terminal_state_tensor = terminal_state.tensor
    termination_action = env.actions_from_tensor(torch.Tensor([[env.n_actions-1]]).to(torch.int64))
    env.update_masks(terminal_state)
    assert env.is_action_valid(terminal_state, termination_action, backward=False), f"Error: Termination at given state {terminal_state.tensor} is invalid!"
    terminal_state_idx = env.get_states_indices(terminal_state)
    # log π_θ(s_terminal) + log termination
    return memo[terminal_state_tensor] + transition_log_probs[-1][terminal_state_idx]

In [305]:
import time
start_time = time.time()

# Initialize lists to hold the probabilities and rewards
# transition_log_probs = get_all_transition_log_probs(env, pf_estimator=pf_estimator)
log_probs = []
log_probs_termination = []
log_rewards = []
memo = TensorDict(default_factory=lambda: torch.tensor(['-inf'], requires_grad=False))
# Calculate the log probability and log reward for each terminal state
# for traj in test_trajectories:
for traj in tqdm(test_trajectories, desc="Processing trajectories"):
    terminal_state = traj.states[-2]
    reward = env.reward(terminal_state)
    log_reward = torch.log(reward)
    log_prob=compute_log_probability(env, gfn, terminal_state, memo, transition_log_probs)
    log_prob_termination = compute_log_prob_termination(env, terminal_state, memo, transition_log_probs) 
    log_probs.append(log_prob.detach().numpy())
    log_probs_termination.append(log_prob_termination.detach().numpy())
    log_rewards.append(log_reward.detach().numpy())

# 9 - Compute Spearman's Rank Correlation
spearman_corr_termination, _ = spearmanr(log_probs_termination, log_rewards)
print(f"Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): {spearman_corr_termination}. Runtime: {time.time()-start_time} seconds.")

Processing trajectories: 100%|██████████| 100/100 [00:00<00:00, 102.21it/s]

Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8283671542265011. Runtime: 0.9845030307769775 seconds.





In [307]:
terminal_state = test_trajectories[0].states[-2]
terminal_state_tensor = terminal_state.tensor
print(f's_terminal: {terminal_state_tensor}')
print(f'log π_θ(s_terminal): {memo[terminal_state_tensor]}')
log_prob_termination = compute_log_prob_termination(env, terminal_state, memo, transition_log_probs)
print(f'log_prob when termination at s_terminal: {log_prob_termination}')

s_terminal: tensor([[6, 1, 6, 0]])
log π_θ(s_terminal): tensor([-4.3418], grad_fn=<LogsumexpBackward0>)
log_prob when termination at s_terminal: tensor([-5.9106], grad_fn=<AddBackward0>)


In [308]:
# 9 - Compute Spearman's Rank Correlation (Original GFNEvalS, excluding termination actions)
spearman_corr, _ = spearmanr(log_probs, log_rewards)
print(f"Spearman's Rank Correlation (Original GFNEvalS, excluding termination actions): {spearman_corr}")

Spearman's Rank Correlation (Original GFNEvalS, excluding termination actions): 0.567878302378827


In [309]:
# 10 - Compute Spearman's Rank Correlation (Modified GFNEvalS, including termination actions)
spearman_corr_termination, _ = spearmanr(log_probs_termination, log_rewards)
print(f"Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): {spearman_corr_termination}")

Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8283671542265011


## an approximation of sampling probability with monte carlo 

In [310]:
import torch
from gfn.samplers import Sampler
from collections import Counter
from gfn.states import States

# to compute the sampling probability wit monte_carlo
def count_occurrences_with_monte_carlo(env, sampler, n_samples=10000):
    """
    Computes the sampling probability of a given terminal state using Monte Carlo.

    Args:
        env: The environment instance.
        sampler: An initialized Sampler using the forward policy estimator.
        terminal_state: The terminal state whose probability we want to compute (as a tensor).
        n_samples: The number of trajectories to sample.

    Returns:
        occurrences: occurrences dict of each state.tensor
    """
    # Sample trajectories
    trajectories = sampler.sample_trajectories(env=env, n=n_samples)
    # Extract terminal states
    terminal_states = [traj.states[-2] for traj in trajectories]
    occurrences = TensorDict(int)
    for state in tqdm(terminal_states, desc="Processing terminal_states"):
        occurrences[state.tensor]+=1 
    return occurrences
def compute_log_prob_with_monte_carlo(occurrences, terminal_state, n_samples: int) -> float:
    # Calculate the probability
    if isinstance(terminal_state, States):
        if len(terminal_state.tensor.shape)==1:
            terminal_state = stack_states([terminal_state])
        terminal_state = terminal_state.tensor
    return torch.log(torch.tensor(occurrences[terminal_state] / n_samples, requires_grad=False))

In [311]:
import time
start_time = time.time()
# Define the terminal state (replace with the actual state representation)
terminal_state = torch.tensor([[6, 1, 1, 6]])
# Compute the sampling probability
n_samples = 20 * env.n_states
occurrences = count_occurrences_with_monte_carlo(env, sampler, n_samples=n_samples)

# 
terminal_state = torch.tensor([[6, 1, 1, 6]])
log_prob = compute_log_prob_with_monte_carlo(occurrences, terminal_state, n_samples)
print(f"Log_prob of the terminal state via Monte Carlo {terminal_state.tolist()}: {log_prob}")
print(f"Log_prob of the terminal state via GFNEvalS {terminal_state.tolist()}: {memo[terminal_state]}")
# 
log_probs_monte_carlo = []
log_rewards_monte_carlo = []
for traj in tqdm(test_trajectories, desc="Processing trajectories"):
    terminal_state = traj.states[-2]
    reward = env.reward(terminal_state)
    log_reward = torch.log(reward)
    log_prob=compute_log_prob_with_monte_carlo(occurrences, terminal_state, n_samples)
    log_probs_monte_carlo.append(log_prob.detach().numpy())
    log_rewards_monte_carlo.append(log_reward.detach().numpy())
# Compute Spearman's Rank Correlation
spearman_corr_monte_carlo, _ = spearmanr(log_probs_monte_carlo, log_rewards_monte_carlo)
print(f"Spearman's Rank Correlation (Monte Carlo): {spearman_corr_monte_carlo}. MC sample number: {n_samples}. Runtime: {time.time()-start_time} seconds")

Processing terminal_states: 100%|██████████| 81920/81920 [00:00<00:00, 316073.69it/s]


Log_prob of the terminal state via Monte Carlo [[6, 1, 1, 6]]: -4.631389617919922
Log_prob of the terminal state via GFNEvalS [[6, 1, 1, 6]]: tensor([-4.1811], grad_fn=<LogsumexpBackward0>)


Processing trajectories: 100%|██████████| 100/100 [00:00<00:00, 11005.78it/s]

Spearman's Rank Correlation (Monte Carlo): 0.8284914593953175. MC sample number: 81920. Runtime: 7.849664926528931 seconds





# Environments

In [312]:
import time

def timer(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Start the timer
        result = func(*args, **kwargs)
        end_time = time.time()  # End the timer
        print(f"Function '{func.__name__}' executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper

In [313]:
def get_random_test_set(env: HyperGrid, n=100):
    random_indices = torch.randperm(len(env.all_states))[:n]
    terminal_states = env.all_states[random_indices]
    log_rewards = torch.log(env.reward(terminal_states))
    return terminal_states, log_rewards

In [314]:
def get_sampled_test_set(gfn, env, n=100):
    sampler = Sampler(estimator=gfn.pf)
    test_trajectories = sampler.sample_trajectories(env=env, n=n)
    terminal_states = test_trajectories.last_states
    log_rewards = torch.log(env.reward(terminal_states))
    return terminal_states, log_rewards

In [315]:
test_terminal_states_random, test_log_rewards_random = get_random_test_set(env, 100)

In [316]:
test_terminal_states_sample, test_log_rewards_sample = get_sampled_test_set(gfn, env, 100)

In [None]:
from gfn.gflownet import GFlowNet

@timer
def evaluate_GFNEvalS(gfn: GFlowNet, env: HyperGrid, terminal_states, log_rewards):
    """
    Computes the sampling probability of a given terminal state using Backtracking with memoization.

    Args:
        gfn: An initialized Sampler using the forward policy estimator.
        env: The HyperGrid environment instance.
        test_trajectories: trajectories used to test the GFlowNet, which contain terminal states and corresponding true rewards

    Returns:
        spearman_corr_termination: Spearman's Rank Correlation (Modified GFNEvalS, including termination actions)
        memo: TensorDict, memo[s] indicates the probability from init_state to s, without counting the probability of termanating at state s.
        transition_log_probs: a Tensor list with length of env.n_actions. 
            transition_log_probs[i][j] indicates the log probability of taking action i at a State env.all_states[j], i in [0, n_actions-1]
    """
    start_time = time.time()
    memo = TensorDict(default_factory=lambda: torch.tensor(['-inf'], requires_grad=False))
    transition_log_probs = get_all_transition_log_probs(env, gfn.pf)
    log_probs = []
    log_probs_termination = []
    # Calculate the log probability and log reward for each terminal state
    # for traj in test_trajectories:
    for terminal_state in tqdm(terminal_states, desc="Evaluating test set..."):
        log_prob=compute_log_probability(env, gfn, terminal_state, memo, transition_log_probs)
        log_probs.append(log_prob.detach().numpy())
        log_prob_termination = compute_log_prob_termination(env, terminal_state, memo, transition_log_probs) 
        log_probs_termination.append(log_prob_termination.detach().numpy())
    # 9 - Compute Spearman's Rank Correlation
    spearman_corr_termination, _ = spearmanr(log_probs_termination, log_rewards.detach())
    print(f"Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): {spearman_corr_termination}. Runtime: {time.time()-start_time} seconds.")
    return spearman_corr_termination, memo, transition_log_probs

In [321]:
spearman_corr_termination, memo, transition_log_probs = evaluate_GFNEvalS(gfn, env, test_terminal_states_random, test_log_rewards_random)

Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 114.36it/s]

Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.4113598445305147. Runtime: 0.8838462829589844 seconds.
Function 'evaluate_GFNEvalS' executed in 0.8842 seconds





In [323]:
_, _, _ = evaluate_GFNEvalS(gfn, env, test_terminal_states_sample, test_log_rewards_sample)

Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 105.79it/s]

Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.870711363609113. Runtime: 0.9630138874053955 seconds.
Function 'evaluate_GFNEvalS' executed in 0.9636 seconds





In [None]:
import time
@timer
def evaluate_GFNEvalS_with_monte_carlo(gfn: GFlowNet, env: HyperGrid, terminal_states, log_rewards, sample_multiples=20):
    start_time = time.time()
    sampler = Sampler(estimator=gfn.pf)
    # Generate a large number of samples as monte carlo experiment to count occurrences of appeared terminal states
    n_samples = sample_multiples * env.n_states
    occurrences = count_occurrences_with_monte_carlo(env, sampler, n_samples=n_samples)
    #
    log_probs_monte_carlo = []
    for terminal_state in tqdm(terminal_states, desc="Evaluating GFNEvalS with monte carlo"):
        log_prob = compute_log_prob_with_monte_carlo(occurrences, terminal_state, n_samples)
        log_probs_monte_carlo.append(log_prob.detach().numpy())
    # Compute Spearman's Rank Correlation
    spearman_corr_monte_carlo, _ = spearmanr(log_probs_monte_carlo, log_rewards)
    print(f"Spearman's Rank Correlation (Monte Carlo): {spearman_corr_monte_carlo}. MC sample number: {n_samples}. Runtime: {time.time()-start_time} seconds")
    return spearman_corr_monte_carlo, occurrences, log_probs_monte_carlo

In [327]:
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_terminal_states_random, test_log_rewards_random, sample_multiples=20)

Processing terminal_states: 100%|██████████| 81920/81920 [00:00<00:00, 223622.69it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 35499.82it/s]

Spearman's Rank Correlation (Monte Carlo): 0.41592602767382925. MC sample number: 81920. Runtime: 8.38442587852478 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 8.3844 seconds





In [325]:
spearman_corr_monte_carlo, occurrences, log_probs_monte_carlo = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_terminal_states_sample, test_log_rewards_sample, sample_multiples=20)

Processing terminal_states: 100%|██████████| 81920/81920 [00:00<00:00, 312232.11it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 38657.18it/s]

Spearman's Rank Correlation (Monte Carlo): 0.8707819195578946. MC sample number: 81920. Runtime: 8.04563570022583 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 8.0457 seconds



