In [2]:
import os
import sys

ROOT_DIR = os.path.abspath("__file__" + "/../../")
sys.path.insert(0, f"{ROOT_DIR}")

In [3]:
import numpy as np

spin_j = 3

env_name = f"single vertex spinfoam/j={float(spin_j)}"
batch_size = 16
n_iterations = int(1e5)

vertex = np.load(f"{ROOT_DIR}/data/EPRL_vertices/Python/Dl_20/vertex_j_{float(spin_j)}.npz")
sq_ampl = vertex**2
grid_rewards = sq_ampl / np.sum(sq_ampl)

In [3]:
from src.MCMC.batched_mcmc import MCMCRunner

mcmc = MCMCRunner(grid_rewards=grid_rewards)
mcmc_chains, _ = mcmc.run_mcmc_chains(
    batch_size=batch_size, n_iterations=n_iterations, generated_data_dir=f"{ROOT_DIR}/data/MCMC/{env_name}"
)

100%|██████████| 100000/100000 [00:46<00:00, 2162.86it/s]


In [4]:
from src.grid_environments.base import BaseGrid
from src.trainers.trainer import base_train_gfn

terminal_states, _ = base_train_gfn(
    env=BaseGrid(grid_rewards=grid_rewards),
    generated_data_dir=f"{ROOT_DIR}/data/GFN/{env_name}",
    batch_size=batch_size,
    n_iterations=n_iterations,
    hidden_dim=256,
    n_hidden_layers=2,
    activation_fn="relu",
    exploration_rate=0.0,
    learning_rate=0.001,
)

100%|██████████| 100000/100000 [4:53:27<00:00,  5.68it/s] 


In [4]:
from src.MCMC.MCMC import grid_rewards_2d

grid_len = 64
dimensions = 2
env_name = f"GFN Paper Grid Peaks/grid_length={grid_len}, grid_dim={dimensions}"

batch_size = 16
n_iterations = 200000

gfn_paper_rewards = grid_rewards_2d(grid_len)

In [5]:
from src.MCMC.batched_mcmc import MCMCRunner

mcmc = MCMCRunner(grid_rewards=gfn_paper_rewards)
mcmc_chains, _ = mcmc.run_mcmc_chains(
    batch_size=batch_size, n_iterations=n_iterations, generated_data_dir=f"{ROOT_DIR}/data/MCMC/{env_name}"
)

  0%|          | 0/200000 [00:00<?, ?it/s]

100%|██████████| 200000/200000 [02:52<00:00, 1160.94it/s]


In [6]:
from src.grid_environments.base import BaseGrid
from src.trainers.trainer import base_train_gfn

terminal_states, _ = base_train_gfn(
    env=BaseGrid(grid_rewards=gfn_paper_rewards),
    generated_data_dir=f"{ROOT_DIR}/data/GFN/{env_name}",
    batch_size=batch_size,
    n_iterations=n_iterations,
    hidden_dim=256,
    n_hidden_layers=2,
    activation_fn="relu",
    exploration_rate=0.0,
    learning_rate=0.0001,
)

 16%|█▋        | 32676/200000 [3:29:04<61:29:43,  1.32s/it]