In [1]:
%%bash
cd ../torchgfn
pip install .

Processing /Users/erostrate9/Desktop/CSI5340 DL/Project/code/GFNEval/torchgfn
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: torchgfn
  Building wheel for torchgfn (pyproject.toml): started
  Building wheel for torchgfn (pyproject.toml): finished with status 'done'
  Created wheel for torchgfn: filename=torchgfn-1.1.1-py3-none-any.whl size=82819 sha256=0f5154dc9daaf72191b9a400de50b33bc2fb28062c98aa813fca48a70274e129
  Stored in directory: /private/var/folders/c_/9pzrss116732p7dxch3kn_bc0000gn/T/pip-ephem-wheel-cache-s6ns00hc/wheels/56/de/11/edbaf478c4bdb3bf4d2dadfda48c78d0790413f2f66eee7a21
Successfully built torchgfn
Installing collected packages: torchgfn
  Attemptin

In [1]:
import torch
import numpy as np
from scipy.stats import spearmanr
from tqdm import tqdm
from gfn.env import DiscreteEnv
from gfn.gflownet import GFlowNet, TBGFlowNet, SubTBGFlowNet, FMGFlowNet, DBGFlowNet
from gfn.gym import HyperGrid2, HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils.modules import MLP
from gfn.states import States, DiscreteStates
from gfn.utils.evaluation import get_random_test_set, get_sampled_test_set, evaluate_GFNEvalS, evaluate_GFNEvalS_with_monte_carlo

# Demo

In [2]:
# 0 - Find Available GPU resource
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# 1 - Define the environment
# env = HyperGrid(ndim=4, height=8, R0=0.01)
env = HyperGrid2(ndim=4, height=8, ncenters=4,
                             seed=torch.randint(0, 10000, (1,)).item(),
                             device_str='cpu')

# 2 - Define the neural network modules
module_PF = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions)
module_PB = MLP(input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1, trunk=module_PF.trunk)

# 3 - Define the estimators
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - Define the GFlowNet
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator)

# 5 - Define the sampler and optimizer
sampler = Sampler(estimator=pf_estimator)
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - Train the GFlowNet
for i in (pbar := tqdm(range(1000))):
    trajectories = sampler.sample_trajectories(env=env, n=16)
    optimizer.zero_grad()
    loss = gfn.loss(env, trajectories).to(device)
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        pbar.set_postfix({"loss": loss.item()})

100%|██████████| 1000/1000 [00:12<00:00, 82.00it/s, loss=0.0298]


In [3]:
n_tests = 100
test_states_sample, test_rewards_sample =  get_sampled_test_set(gfn, env, n=n_tests)
test_states_random, test_rewards_random =  get_random_test_set(env, n=n_tests)

## Verify numerical correctness using MC

In [7]:
n_samples = 1 * env.n_states
print('------------Random test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_random, test_rewards_random)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_random, test_rewards_random, n_samples=n_samples)
print('------------Sampled test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_sample, test_rewards_sample)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_sample, test_rewards_sample, n_samples=n_samples)


------------Random test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 112.31it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.9585478547854784. Runtime: 0.8987269401550293 seconds.
Function 'evaluate_GFNEvalS' executed in 0.8991 seconds


Processing terminal_states: 100%|██████████| 4096/4096 [00:00<00:00, 295506.63it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 40768.90it/s]


Spearman's Rank Correlation (Monte Carlo): 0.6734990475698972. MC sample number: 4096. Runtime: 0.4656798839569092 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 0.4657 seconds
------------Sampled test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 170.12it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8144907707448213. Runtime: 0.5948591232299805 seconds.
Function 'evaluate_GFNEvalS' executed in 0.5953 seconds


Processing terminal_states: 100%|██████████| 4096/4096 [00:00<00:00, 305164.91it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 33759.69it/s]

Spearman's Rank Correlation (Monte Carlo): 0.8152987444729859. MC sample number: 4096. Runtime: 0.5411820411682129 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 0.5412 seconds





In [6]:
n_samples = 5 * env.n_states
print('------------Random test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_random, test_rewards_random)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_random, test_rewards_random, n_samples=n_samples)
print('------------Sampled test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_sample, test_rewards_sample)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_sample, test_rewards_sample, n_samples=n_samples)


------------Random test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 114.78it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.9585478547854784. Runtime: 0.881892204284668 seconds.
Function 'evaluate_GFNEvalS' executed in 0.8824 seconds


Processing terminal_states: 100%|██████████| 20480/20480 [00:00<00:00, 147699.20it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 34402.10it/s]


Spearman's Rank Correlation (Monte Carlo): 0.8808481087198067. MC sample number: 20480. Runtime: 2.643099308013916 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 2.6431 seconds
------------Sampled test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 159.10it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8144907707448213. Runtime: 0.6345369815826416 seconds.
Function 'evaluate_GFNEvalS' executed in 0.6350 seconds


Processing terminal_states: 100%|██████████| 20480/20480 [00:00<00:00, 159636.91it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 37536.28it/s]

Spearman's Rank Correlation (Monte Carlo): 0.9585927172335587. MC sample number: 20480. Runtime: 2.6412699222564697 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 2.6413 seconds





In [8]:
n_samples = 10 * env.n_states
print('------------Random test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_random, test_rewards_random)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_random, test_rewards_random, n_samples=n_samples)
print('------------Sampled test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_sample, test_rewards_sample)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_sample, test_rewards_sample, n_samples=n_samples)


------------Random test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 120.09it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.9585478547854784. Runtime: 0.8431618213653564 seconds.
Function 'evaluate_GFNEvalS' executed in 0.8436 seconds


Processing terminal_states: 100%|██████████| 40960/40960 [00:00<00:00, 282864.65it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 38969.66it/s]


Spearman's Rank Correlation (Monte Carlo): 0.8999019098914222. MC sample number: 40960. Runtime: 5.286223888397217 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 5.2862 seconds
------------Sampled test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 160.49it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8144907707448213. Runtime: 0.6285898685455322 seconds.
Function 'evaluate_GFNEvalS' executed in 0.6291 seconds


Processing terminal_states: 100%|██████████| 40960/40960 [00:00<00:00, 315467.75it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 34433.17it/s]

Spearman's Rank Correlation (Monte Carlo): 0.9602812436398143. MC sample number: 40960. Runtime: 5.014269113540649 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 5.0143 seconds





In [10]:
n_samples = 20 * env.n_states
print('------------Random test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_random, test_rewards_random)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_random, test_rewards_random, n_samples=n_samples)
print('------------Sampled test set------------')
_, _, _ = evaluate_GFNEvalS(gfn, env, test_states_sample, test_rewards_sample)
_, _, _ = evaluate_GFNEvalS_with_monte_carlo(gfn, env, test_states_sample, test_rewards_sample, n_samples=n_samples)


------------Random test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 119.52it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.9585478547854784. Runtime: 0.8468070030212402 seconds.
Function 'evaluate_GFNEvalS' executed in 0.8473 seconds


Processing terminal_states: 100%|██████████| 81920/81920 [00:00<00:00, 317863.18it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 37593.47it/s]


Spearman's Rank Correlation (Monte Carlo): 0.9584972891457396. MC sample number: 81920. Runtime: 10.354176044464111 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 10.3542 seconds
------------Sampled test set------------


Evaluating test set...: 100%|██████████| 100/100 [00:00<00:00, 174.00it/s]


Spearman's Rank Correlation (Modified GFNEvalS, including termination actions): 0.8144907707448213. Runtime: 0.5798079967498779 seconds.
Function 'evaluate_GFNEvalS' executed in 0.5803 seconds


Processing terminal_states: 100%|██████████| 81920/81920 [00:00<00:00, 322136.64it/s]
Evaluating GFNEvalS with monte carlo: 100%|██████████| 100/100 [00:00<00:00, 40532.51it/s]

Spearman's Rank Correlation (Monte Carlo): 0.9704449994404256. MC sample number: 81920. Runtime: 10.310059070587158 seconds
Function 'evaluate_GFNEvalS_with_monte_carlo' executed in 10.3101 seconds



