In [1]:
import jax.numpy as jnp
import numpy as np
import optax
import networkx as nx
import pickle
import jax

from tqdm import trange
from numpy.random import default_rng

from dag_gflownet.env import GFlowNetDAGEnv
from dag_gflownet.gflownet import DAGGFlowNet
from dag_gflownet.utils.replay_buffer import ReplayBuffer
from dag_gflownet.utils.factories import get_scorer
from dag_gflownet.utils.gflownet import posterior_estimate
from dag_gflownet.utils.metrics import expected_shd, expected_edges, threshold_metrics
from dag_gflownet.utils import io

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json

folder_path = 'output_ell'
sample_iteration = 100000

In [3]:
# load arguments
with open(folder_path+'/arguments.json', 'r') as f:
    args = json.load(f)

    
# construct a class args with dictionary args
class args_class:
    def __init__(self, args):
        self.__dict__.update(args)

args = args_class(args)

In [4]:
rng = default_rng(args.seed)
key = jax.random.PRNGKey(args.seed)
key, subkey = jax.random.split(key)

# Create the environment
scorer, data, graph = get_scorer(args, rng=rng)
env = GFlowNetDAGEnv(
    num_envs=args.num_envs,
    scorer=scorer,
    num_workers=args.num_workers,
    context=args.mp_context
)

# Create the replay buffer
replay = ReplayBuffer(
    args.replay_capacity,
    num_variables=env.num_variables
)

# Create the GFlowNet & initialize parameters
gflownet = DAGGFlowNet(
    delta=args.delta,
    update_target_every=args.update_target_every
)


optimizer = optax.adam(args.lr)
params, state = gflownet.init(
    subkey,
    optimizer,
    replay.dummy['adjacency'],
    replay.dummy['mask']
)
exploration_schedule = jax.jit(optax.linear_schedule(
    init_value=jnp.array(0.),
    end_value=jnp.array(1. - args.min_exploration),
    transition_steps=args.num_iterations // 2,
    transition_begin=args.prefill,
))

replay.load(folder_path + '/replay_buffer.npz')

# Evaluate the posterior estimate
posterior, _ = posterior_estimate(
    gflownet,
    io.load(folder_path + '/model_'+str(sample_iteration)+'.npz'),
    env,
    key,
    num_samples=args.num_samples_posterior,
    desc='Sampling from posterior'
)

np.save(args.output_folder / 'posterior'+str(sample_iteration)+'.npy', posterior)

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Sampling from posterior:   0%|          | 0/100000 [00:00<?, ?it/s]


ValueError: Unable to retrieve parameter 'embeddings' for module 'embed' All parameters must be created as part of `init`.