# Experiments


In [17]:
import itertools

In [18]:
# === CONFIG SPACE ===
SEEDS = [0, 42, 100]
FRAMEWORKS = ['legacy', 'jax']

In [19]:
# === BASE PARAMETERS ===
BASE_ITERATIONS = 5000
BASE_PROXY = 'box/corners'
BASE_ENV = 'grid'
BASE_LENGTH = 16
BASE_DIM = 2

BASE_RANDOM_ACTION_PROB = 0.0
BASE_BATCH_SIZE = 100
BASE_LR = 0.0001
BASE_Z_DIM = 16
BASE_LR_MULT = 100

EVAL_PERIOD = 500
EVAL_N = 1000
EVAL_CHECKPOINT = 500

# Neural Net Architecture

In [15]:
BASE_N_HID = 128
BASE_N_LAYERS = 2
# === CONFIG SPACE ===
N_HID = [64, 128, 256]
N_LAYERS = [2, 4, 8]
EXP_TYPE = "mlp"

In [None]:
# N_HID EXPERIMENT
for seed, framework, hid in itertools.product(SEEDS, FRAMEWORKS, N_HID):
  
  if framework == 'legacy':
    continue
  else:
    framework_tag = 'jax-full-eval1'
    BASE_ENV = 'grid_jax'
  
  run_name = f"run-{EXP_TYPE}-seed{seed}-{framework_tag}-hid{hid}-layers{BASE_N_LAYERS}"
  print(f"\n🚀 Launching: {run_name}\n")

  tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_' + framework_tag), str('nhid_' + str(hid))]
  tags_str = f"\"['{'\',\''.join(tags)}']\""

  cmd = f"""
  python train.py \
      trainer.mode={framework} \
      buffer.test.type=all \
      env={BASE_ENV} \
      env.n_dim={BASE_DIM} \
      env.length={BASE_LENGTH} \
      proxy={BASE_PROXY} \
      gflownet.random_action_prob={BASE_RANDOM_ACTION_PROB} \
      gflownet.optimizer.batch_size.forward={BASE_BATCH_SIZE} \
      gflownet.optimizer.lr={BASE_LR} \
      gflownet.optimizer.z_dim={BASE_Z_DIM} \
      gflownet.optimizer.lr_z_mult={BASE_LR_MULT} \
      gflownet.optimizer.n_train_steps={BASE_ITERATIONS} \
      logger.do.online=true \
      logger.project_name=GFlowNet-Experiments \
      logger.run_name={run_name} \
      logger.tags={tags_str} \
      seed={seed} \
      policy.forward.n_hid={hid} \
      policy.forward.n_layers={BASE_N_LAYERS} \
      evaluator.first_it=true \
      evaluator.period={EVAL_PERIOD} \
      evaluator.n={EVAL_N} \
      evaluator.checkpoints_period={EVAL_CHECKPOINT}
  """

  cmd = " ".join(cmd.split())
  !{cmd}

In [None]:
# N_LAYERS EXPERIMENT
for seed, framework, layers in itertools.product(SEEDS, FRAMEWORKS, N_LAYERS):
    
    if framework == 'legacy':
        continue
    else:
        framework_tag = 'jax-full-eval1'
        BASE_ENV = 'grid_jax'

    run_name = f"run-seed{seed}-{framework_tag}-hid{BASE_N_HID}-layers{layers}"
    print(f"\n🚀 Launching: {run_name}\n")
    
    tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_' + framework_tag), str('nlayers_' + str(layers))]
    tags_str = f"\"['{'\',\''.join(tags)}']\""

    cmd = f"""
    python train.py \
        trainer.mode={framework} \
        buffer.test.type=all \
        env={BASE_ENV} \
        env.n_dim={BASE_DIM} \
        env.length={BASE_LENGTH} \
        proxy={BASE_PROXY} \
        gflownet.random_action_prob={BASE_RANDOM_ACTION_PROB} \
        gflownet.optimizer.batch_size.forward={BASE_BATCH_SIZE} \
        gflownet.optimizer.lr={BASE_LR} \
        gflownet.optimizer.z_dim={BASE_Z_DIM} \
        gflownet.optimizer.lr_z_mult={BASE_LR_MULT} \
        gflownet.optimizer.n_train_steps={BASE_ITERATIONS} \
        logger.do.online=true \
        logger.project_name=GFlowNet-Experiments \
        logger.run_name={run_name} \
        logger.tags={tags_str} \
        seed={seed} \
        policy.forward.n_hid={BASE_N_HID} \
        policy.forward.n_layers={layers} \
        evaluator.first_it=true \
        evaluator.period={EVAL_PERIOD} \
        evaluator.n={EVAL_N} \
        evaluator.checkpoints_period={EVAL_CHECKPOINT}
    """

    cmd = " ".join(cmd.split())
    !{cmd}

# Environment

In [21]:
# === CONFIG SPACE ===
LENGTHS = [16, 32, 64]
DIMS = [2, 3, 4]
EXP_TYPE = "env"

In [22]:
# DIMS EXPERIMENT
for seed, framework, dim in itertools.product(SEEDS, FRAMEWORKS, DIMS):

    if framework == 'legacy':
        continue
    else:
        framework_tag = 'jax-full-eval1'
        BASE_ENV = 'grid_jax'
    
    run_name = f"run-{EXP_TYPE}-seed{seed}-{framework_tag}-dim{dim}-len{BASE_LENGTH}"
    print(f"\n🚀 Launching: {run_name}\n")
    
    tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_' + framework_tag), str('dim_' + str(dim))]
    tags_str = f"\"['{'\',\''.join(tags)}']\""

    cmd = f"""
    python train.py \
        trainer.mode={framework} \
        buffer.test.type=all \
        env={BASE_ENV} \
        env.n_dim={dim} \
        env.length={BASE_LENGTH} \
        proxy={BASE_PROXY} \
        gflownet.random_action_prob={BASE_RANDOM_ACTION_PROB} \
        gflownet.optimizer.batch_size.forward={BASE_BATCH_SIZE} \
        gflownet.optimizer.lr={BASE_LR} \
        gflownet.optimizer.z_dim={BASE_Z_DIM} \
        gflownet.optimizer.lr_z_mult={BASE_LR_MULT} \
        gflownet.optimizer.n_train_steps={BASE_ITERATIONS} \
        logger.do.online=true \
        logger.project_name=GFlowNet-Experiments \
        logger.run_name={run_name} \
        logger.tags={tags_str} \
        seed={seed} \
        evaluator.first_it=true \
        evaluator.period={EVAL_PERIOD} \
        evaluator.n={EVAL_N} \
        evaluator.checkpoints_period={EVAL_CHECKPOINT}
    """

    cmd = " ".join(cmd.split())
    !{cmd}


🚀 Launching: run-env-seed0-jax-full-eval1-dim2-len16


Working directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet
Logging directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/logs/local/2025-12-10_11-51-24_549503

Initializing JAX components from config...
[34m[1mwandb[0m: Currently logged in as: [33mthom-mousseau[0m ([33mfatty_data[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: [38;5;178m⢿[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣻[0m Waiting for wandb.init()...
[34m[1mwandb[0m: Tracking run with wandb version 0.22.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/wandb/run-20251210_115126-rd7n82c3[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mrun-env-seed0-jax-full-eval1-dim2-len16 10/12-11:51:24[0m
[34m

In [23]:
#LENGTHS EXPERIMENT
for seed, framework, length in itertools.product(SEEDS, FRAMEWORKS, LENGTHS):
    
    if framework == 'legacy':
        continue
    else:
        framework_tag = 'jax-full-eval1'
        BASE_ENV = 'grid_jax'
        
    run_name = f"run-{EXP_TYPE}-seed{seed}-{framework_tag}-len{length}-dim{BASE_DIM}"
    print(f"\n🚀 Launching: {run_name}\n")
    
    tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_' + framework_tag), str('length_' + str(length))]
    tags_str = f"\"['{'\',\''.join(tags)}']\""

    cmd = f"""
    python train.py \
        trainer.mode={framework} \
        buffer.test.type=all \
        env={BASE_ENV} \
        env.n_dim={BASE_DIM} \
        env.length={length} \
        proxy={BASE_PROXY} \
        gflownet.random_action_prob={BASE_RANDOM_ACTION_PROB} \
        gflownet.optimizer.batch_size.forward={BASE_BATCH_SIZE} \
        gflownet.optimizer.lr={BASE_LR} \
        gflownet.optimizer.z_dim={BASE_Z_DIM} \
        gflownet.optimizer.lr_z_mult={BASE_LR_MULT} \
        gflownet.optimizer.n_train_steps={BASE_ITERATIONS} \
        logger.do.online=true \
        logger.project_name=GFlowNet-Experiments \
        logger.run_name={run_name} \
        logger.tags={tags_str} \
        seed={seed} \
        evaluator.first_it=true \
        evaluator.period={EVAL_PERIOD} \
        evaluator.n={EVAL_N} \
        evaluator.checkpoints_period={EVAL_CHECKPOINT}
    """

    cmd = " ".join(cmd.split())
    !{cmd}


🚀 Launching: run-env-seed0-jax-full-eval1-len16-dim2


Working directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet
Logging directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/logs/local/2025-12-10_12-06-51_513750

Initializing JAX components from config...
[34m[1mwandb[0m: Currently logged in as: [33mthom-mousseau[0m ([33mfatty_data[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: [38;5;178m⢿[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣻[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣽[0m setting up run 81vdkmqx (0.3s)
[34m[1mwandb[0m: [38;5;178m⣾[0m setting up run 81vdkmqx (0.3s)
[34m[1mwandb[0m: [38;5;178m⣷[0m setting up run 81vdkmqx (0.3s)
[34m[1mwandb[0m: Tracking run with wandb version 0.22.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/wandb

In [9]:
#Scrabble and Tetris
envs = {
    # 'scrabble': 'scrabble/jay',
    'tetris': 'tetris/simple'
}

for seed, framework, env_key in itertools.product(SEEDS, FRAMEWORKS, envs.keys()):
    if framework == 'legacy' and seed == 0:
        continue
    run_name = f"run-{EXP_TYPE}-seed{seed}-{framework}-env{env_key}"
    print(f"\n🚀 Launching: {run_name}\n")
    
    tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_' + framework), str('env_' + str(env_key))]
    tags_str = f"\"['{'\',\''.join(tags)}']\""

    cmd = f"""
    python train.py \
        trainer.mode={framework} \
        +experiments={envs[env_key]} \
        logger.do.online=true \
        logger.project_name=GFlowNet-Experiments \
        logger.run_name={run_name} \
        logger.tags={tags_str} \
        seed={seed} \
        evaluator.first_it=true \
        evaluator.period={EVAL_PERIOD} \
        evaluator.n={EVAL_N} \
        evaluator.checkpoints_period={EVAL_CHECKPOINT}
    """

    cmd = " ".join(cmd.split())
    !{cmd}


🚀 Launching: run-env-seed0-jax-envtetris


Working directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet
Logging directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/logs/local/2025-11-24_13-46-15_407852

[34m[1mwandb[0m: Currently logged in as: [33mthom-mousseau[0m ([33mfatty_data[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: [38;5;178m⢿[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣻[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣽[0m Waiting for wandb.init()...
[34m[1mwandb[0m: Tracking run with wandb version 0.22.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/wandb/run-20251124_134617-0vix9kmt[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mrun-env-seed0-jax-envtetris 24/11-13:46:15[0m
[34m[

# Reduced padding experiments

In [18]:

EXP_TYPE = "env"
SEEDS = [42] # [ [100] [0]
# length = {16: 20}
# length = {32: 40}
length = {64: 80}

for seed in SEEDS:
    run_name = f"run-{EXP_TYPE}-seed{seed}-jax-padding-{str(list(length.keys())[0])}"
    print(f"\n🚀 Launching: {run_name}\n")

    tags = [str('exp_' + EXP_TYPE), str('seed_' + str(seed)), str('framework_jax'), str('padding_' + str(length[list(length.keys())[0]]))]
    tags_str = f"\"['{'\',\''.join(tags)}']\""

    cmd = f"""
    python train.py \
        trainer.mode=jax \
        buffer.test.type=all \
        env={BASE_ENV} \
        env.n_dim={BASE_DIM} \
        env.length={list(length.keys())[0]} \
        proxy={BASE_PROXY} \
        gflownet.random_action_prob={BASE_RANDOM_ACTION_PROB} \
        gflownet.optimizer.batch_size.forward={BASE_BATCH_SIZE} \
        gflownet.optimizer.lr={BASE_LR} \
        gflownet.optimizer.z_dim={BASE_Z_DIM} \
        gflownet.optimizer.lr_z_mult={BASE_LR_MULT} \
        gflownet.optimizer.n_train_steps={BASE_ITERATIONS} \
        logger.do.online=true \
        logger.project_name=GFlowNet-Experiments \
        logger.run_name={run_name} \
        logger.tags={tags_str} \
        seed={seed} \
        evaluator.first_it=true \
        evaluator.period={EVAL_PERIOD} \
        evaluator.n={EVAL_N} \
        evaluator.checkpoints_period={EVAL_CHECKPOINT}
    """

    cmd = " ".join(cmd.split())
    !{cmd}
    


🚀 Launching: run-env-seed42-jax-padding-64



133052.13s - pydevd: Sending message related to process being replaced timed-out after 5 seconds



Working directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet
Logging directory of this run: /Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/logs/local/2025-11-25_23-07-43_384004

[34m[1mwandb[0m: Currently logged in as: [33mthom-mousseau[0m ([33mfatty_data[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: [38;5;178m⢿[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣻[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣽[0m Waiting for wandb.init()...
[34m[1mwandb[0m: [38;5;178m⣾[0m setting up run l04sa6rm (0.3s)
[34m[1mwandb[0m: [38;5;178m⣷[0m setting up run l04sa6rm (0.3s)
[34m[1mwandb[0m: Tracking run with wandb version 0.22.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/thomasmousseau/School/Maitrise/GFlowNet/gflownet/wandb/run-20251125_230745-l04sa6rm[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing