In [1]:
import GPy
import numpy as np
import matplotlib.pyplot as plt
import utils
from functools import partial
from utils import KernelFunction, KernelEnvironment, log_likelihood_reward
from utils import plot_kernel_function, compare_kernels
from gflownet import GFlowNet 
import torch.nn.functional as F
from torch.distributions import Categorical
import torch
from utils import ForwardPolicy, BackwardPolicy
import random
from utils import train

from evaluation import create_random_kernel
import itertools
from functools import partial

from evaluation import calculate_l1_distance

import pandas as pd

from copy import deepcopy

In [2]:
!pip install pandas




In [3]:


def create_env(batch_size=64):
    return KernelEnvironment(
    batch_size=batch_size,
    max_trajectory_length=MAX_LEN,
    log_reward=log_reward_fn
)

In [4]:
ll = -1#
while ll < 0:
    true_kernel = create_random_kernel()
    X, Y, true_kernel_str = utils.generate_gp_data(true_kernel, input_dim=1, n_points=30, noise_var=1e-4)
    X_test, Y_test, _ = utils.generate_gp_data(true_kernel, input_dim=1, n_points=30, noise_var=1e-4)
    ll = utils.evaluate_likelihood(true_kernel, X, Y, runtime=False)

print("True Kernel:", true_kernel_str, "Log Marginal Likelihood:", utils.evaluate_likelihood(true_kernel, X, Y, runtime=False))
print("True Kernel:", true_kernel_str, "Log Marginal Likelihood:", utils.evaluate_likelihood(true_kernel, X_test, Y_test, runtime=False))

#lt.scatter(X, Y, color='black', s=10, label='Data Points')

  -> Randomizing 'Periodic' params...
  -> Randomizing 'RBF' params...
  -> Randomizing 'RBF' params...


True Kernel: ((Periodic({'period': 1.396, 'variance': 0.799, 'lengthscale': 1.186}) + RBF({'lengthscale': 0.907, 'variance': 1.189})) + RBF({'lengthscale': 0.536, 'variance': 0.619})) Log Marginal Likelihood: 23.196005504503297
True Kernel: ((Periodic({'period': 1.396, 'variance': 0.799, 'lengthscale': 1.186}) + RBF({'lengthscale': 0.907, 'variance': 1.189})) + RBF({'lengthscale': 0.536, 'variance': 0.619})) Log Marginal Likelihood: 19.06329367331628


In [5]:
log_reward_fn = partial(log_likelihood_reward, X, Y)

In [6]:
# --- 1. Define the Hyperparameter Grid ---
param_grid = {
    'lr': [1e-4, 1e-3, 1e-2],
    'BATCH_SIZE': [16, 64, 256],
    'criterion': ['db', 'tb', 'subtb', 'cb'],
    'epsilon': [0.5],  # Initial epsilon for the forward policy
    'min_eps': [1e-2],   # Minimum epsilon for the scheduler
    'clamp_g': [1.0] # Clamping value for the gradient
}

# --- 2. Prepare for Iteration ---
# Generate a list of all hyperparameter combinations.
keys, values = zip(*param_grid.items())
param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

# Initialize a dictionary to store the best result for each criterion
best_results_per_criterion = {
    crit: {'best_l1': float('inf'), 'best_params': None}
    for crit in param_grid['criterion']
}

# --- 3. Run the Grid Search ---
print(f"Starting grid search with {len(param_combinations)} combinations...")

for i, params in enumerate(param_combinations):
    print(f"\n--- Combination {i+1}/{len(param_combinations)} ---")
    print(f"Parameters: {params}")

    # Your fixed parameters
    epochs = 100
    MAX_LEN = 4

    # Unpack current combination of parameters
    lr = params['lr']
    BATCH_SIZE = params['BATCH_SIZE']
    criterion = params['criterion']
    initial_epsilon = params['epsilon']
    min_eps = params['min_eps']
    clamp_g = params['clamp_g']

    # --- Model Initialization ---
    # log_reward_fn = partial(utils.log_likelihood_reward, X, Y)
    env = create_env()

    forward_model = ForwardPolicy(
        input_dim=MAX_LEN,
        output_dim=env.action_space_size,
        epsilon=initial_epsilon
    )
    backward_model = BackwardPolicy()

    gflownet = GFlowNet(
        forward_flow=forward_model,
        backward_flow=backward_model,
        criterion=criterion
    )
    
    # Add attributes to model for mock calculation
    gflownet.lr = lr
    
    # --- Training ---
    trained_gflownet, losses = train(
        gflownet=gflownet,
        create_env=create_env,
        epochs=epochs,
        batch_size=BATCH_SIZE,
        lr=lr,
        min_eps=min_eps,
        clamp_g=clamp_g,
        use_scheduler=True
    )

    # --- Evaluation and Logging ---
    l1 = calculate_l1_distance(gflownet.forward_flow , KernelEnvironment , MAX_LEN, X, Y)
    print(f"    => Result: L1 distance = {l1:.4f}")

    # Check if this is the new best result *for this specific criterion*
    if l1 < best_results_per_criterion[criterion]['best_l1']:
        print(f"    ✨ New best L1 for criterion '{criterion}': {l1:.4f} ✨")
        best_results_per_criterion[criterion]['best_l1'] = l1
        best_results_per_criterion[criterion]['best_params'] = params
        best_results_per_criterion[criterion]['model'] = deepcopy(trained_gflownet)


# --- 4. Final Results Table ---
print("\n--- Grid Search Complete ---")
print("Best results per criterion:")

# Prepare data for the pandas DataFrame
table_data = []
for criterion, results in best_results_per_criterion.items():
    if results['best_params']:  # Check if any result was found
        row = {
            'Criterion': criterion,
            'Best L1 Distance': f"{results['best_l1']:.4f}",
            'Learning Rate': results['best_params']['lr'],
            'Batch Size': results['best_params']['BATCH_SIZE'],
            'Clamp Value': results['best_params']['clamp_g']
        }
        table_data.append(row)

if table_data:
    # Create and display the DataFrame
    df = pd.DataFrame(table_data)
    print(df.to_string(index=False))
else:
    print("No results were recorded.")

Starting grid search with 36 combinations...

--- Combination 1/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 16, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:52<00:00,  1.90it/s, loss=8.13]


    => Result: L1 distance = 0.7846
    ✨ New best L1 for criterion 'db': 0.7846 ✨

--- Combination 2/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 16, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:01<00:00, 68.86it/s, loss=8.85]


    => Result: L1 distance = 0.9419
    ✨ New best L1 for criterion 'tb': 0.9419 ✨

--- Combination 3/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 16, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 33.25it/s, loss=43.5]


    => Result: L1 distance = 1.0167
    ✨ New best L1 for criterion 'subtb': 1.0167 ✨

--- Combination 4/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 16, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:01<00:00, 59.56it/s, loss=4.34]


    => Result: L1 distance = 0.9243
    ✨ New best L1 for criterion 'cb': 0.9243 ✨

--- Combination 5/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 64, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:02<00:00, 38.86it/s, loss=8.49]


    => Result: L1 distance = 0.8948

--- Combination 6/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 64, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 29.77it/s, loss=8.28]


    => Result: L1 distance = 1.0369

--- Combination 7/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 64, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 24.35it/s, loss=44.3]


    => Result: L1 distance = 0.8129
    ✨ New best L1 for criterion 'subtb': 0.8129 ✨

--- Combination 8/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 64, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 25.03it/s, loss=5.64]


    => Result: L1 distance = 0.6953
    ✨ New best L1 for criterion 'cb': 0.6953 ✨

--- Combination 9/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 256, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:08<00:00, 11.52it/s, loss=8.6]


    => Result: L1 distance = 0.8893

--- Combination 10/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 256, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:08<00:00, 11.12it/s, loss=6.11]


    => Result: L1 distance = 0.9220
    ✨ New best L1 for criterion 'tb': 0.9220 ✨

--- Combination 11/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 256, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:09<00:00, 10.70it/s, loss=50.9]


    => Result: L1 distance = 0.9366

--- Combination 12/36 ---
Parameters: {'lr': 0.0001, 'BATCH_SIZE': 256, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:09<00:00, 10.67it/s, loss=10.8]


    => Result: L1 distance = 0.6863
    ✨ New best L1 for criterion 'cb': 0.6863 ✨

--- Combination 13/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 32.70it/s, loss=7.36]


    => Result: L1 distance = 0.6132
    ✨ New best L1 for criterion 'db': 0.6132 ✨

--- Combination 14/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 33.23it/s, loss=1.38]


    => Result: L1 distance = 0.6174
    ✨ New best L1 for criterion 'tb': 0.6174 ✨

--- Combination 15/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:02<00:00, 47.62it/s, loss=40] 


    => Result: L1 distance = 0.8515

--- Combination 16/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 31.00it/s, loss=7.35]


    => Result: L1 distance = 0.6942

--- Combination 17/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 64, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 21.10it/s, loss=8.33]


    => Result: L1 distance = 0.7304

--- Combination 18/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 64, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 20.41it/s, loss=4.03]


    => Result: L1 distance = 0.5963
    ✨ New best L1 for criterion 'tb': 0.5963 ✨

--- Combination 19/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 64, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:05<00:00, 16.93it/s, loss=41.6]


    => Result: L1 distance = 0.9710

--- Combination 20/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 64, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:05<00:00, 18.65it/s, loss=6.07]


    => Result: L1 distance = 0.6598
    ✨ New best L1 for criterion 'cb': 0.6598 ✨

--- Combination 21/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 256, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:11<00:00,  8.81it/s, loss=8.16]


    => Result: L1 distance = 0.7104

--- Combination 22/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 256, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:10<00:00,  9.64it/s, loss=2.51]


    => Result: L1 distance = 0.5364
    ✨ New best L1 for criterion 'tb': 0.5364 ✨

--- Combination 23/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 256, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:14<00:00,  7.06it/s, loss=48.5]


    => Result: L1 distance = 0.8562

--- Combination 24/36 ---
Parameters: {'lr': 0.001, 'BATCH_SIZE': 256, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:14<00:00,  6.97it/s, loss=6.98]


    => Result: L1 distance = 0.5611
    ✨ New best L1 for criterion 'cb': 0.5611 ✨

--- Combination 25/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 22.11it/s, loss=0.949]


    => Result: L1 distance = 0.7305

--- Combination 26/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 21.77it/s, loss=0.838]


    => Result: L1 distance = 0.5095
    ✨ New best L1 for criterion 'tb': 0.5095 ✨

--- Combination 27/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:01<00:00, 52.39it/s, loss=4.48]


    => Result: L1 distance = 0.7717
    ✨ New best L1 for criterion 'subtb': 0.7717 ✨

--- Combination 28/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:01<00:00, 55.32it/s, loss=0.859]


    => Result: L1 distance = 0.5411
    ✨ New best L1 for criterion 'cb': 0.5411 ✨

--- Combination 29/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 28.38it/s, loss=0.618]


    => Result: L1 distance = 0.6795

--- Combination 30/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 23.55it/s, loss=1.71]


    => Result: L1 distance = 0.5112

--- Combination 31/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:03<00:00, 26.53it/s, loss=0.757]


    => Result: L1 distance = 0.5212
    ✨ New best L1 for criterion 'subtb': 0.5212 ✨

--- Combination 32/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:04<00:00, 21.99it/s, loss=2.92]


    => Result: L1 distance = 0.5116
    ✨ New best L1 for criterion 'cb': 0.5116 ✨

--- Combination 33/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:08<00:00, 12.13it/s, loss=0.539]


    => Result: L1 distance = 0.7544

--- Combination 34/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:12<00:00,  8.24it/s, loss=1.09]


    => Result: L1 distance = 0.5131

--- Combination 35/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:12<00:00,  8.00it/s, loss=1.08]


    => Result: L1 distance = 0.5435

--- Combination 36/36 ---
Parameters: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}


100%|██████████| 100/100 [00:13<00:00,  7.16it/s, loss=1.99]


    => Result: L1 distance = 0.4756
    ✨ New best L1 for criterion 'cb': 0.4756 ✨

--- Grid Search Complete ---
Best results per criterion:
Criterion Best L1 Distance  Learning Rate  Batch Size  Clamp Value
       db           0.6132          0.001          16          1.0
       tb           0.5095          0.010          16          1.0
    subtb           0.5212          0.010          64          1.0
       cb           0.4756          0.010         256          1.0


In [7]:
if table_data:
    # Create and display the DataFrame
    df = pd.DataFrame(table_data)
    print(df.to_string(index=False))
else:
    print("No results were recorded.")

Criterion Best L1 Distance  Learning Rate  Batch Size  Clamp Value
       db           0.6132          0.001          16          1.0
       tb           0.5095          0.010          16          1.0
    subtb           0.5212          0.010          64          1.0
       cb           0.4756          0.010         256          1.0


In [11]:

# --- 5. Robust Evaluation: Re-train and Evaluate Multiple Times ---
print("\n" + "="*50)
print("--- Robust Evaluation from Best Hyperparameters ---")
print("="*50 + "\n")

N_ROBUSTNESS_RUNS = 3
final_stats_results = []

for criterion, results in best_results_per_criterion.items():
    best_params = results.get('best_params')
    if not best_params:
        print(f"Skipping robust evaluation for '{criterion}' as no best parameters were found.\n")
        continue

    print(f"Starting robust evaluation for criterion '{criterion}' with params: {best_params}")
    
    # Lists to store metrics from each of the N runs
    run_mean_likelihoods = []
    run_max_likelihoods = []

    for run in range(N_ROBUSTNESS_RUNS):
        print(f"  -> Run {run + 1}/{N_ROBUSTNESS_RUNS}...")
        
        # Unpack the best parameters for this criterion
        lr = best_params['lr']
        BATCH_SIZE = best_params['BATCH_SIZE']
        
        # Initialize and train a new model from scratch
        env = create_env()
        forward_model = ForwardPolicy(input_dim=MAX_LEN, output_dim=env.action_space_size, epsilon=best_params['epsilon'])
        backward_model = BackwardPolicy()
        gflownet = GFlowNet(forward_flow=forward_model, backward_flow=backward_model, criterion=criterion)
        
        trained_gflownet, _ = train(gflownet=gflownet, create_env=create_env, epochs=100, batch_size=BATCH_SIZE, lr=lr, min_eps=best_params['min_eps'], clamp_g=best_params['clamp_g'], use_scheduler=True)
        
        # Sample and evaluate likelihood
        eval_env = KernelEnvironment(batch_size=100, max_trajectory_length=MAX_LEN, log_reward=log_reward_fn)
        trained_gflownet.eval()
        final_batch = trained_gflownet.sample(eval_env)
        
        likelihoods = [utils.evaluate_likelihood(k, X_test, Y_test) for k in final_batch.state]
        
        run_mean_likelihoods.append(np.mean(likelihoods))
        run_max_likelihoods.append(np.max(likelihoods))

    # Calculate statistics over the N runs
    mean_of_means = np.mean(run_mean_likelihoods)
    std_of_means = np.std(run_mean_likelihoods)
    mean_of_maxs = np.mean(run_max_likelihoods)
    std_of_maxs = np.std(run_max_likelihoods)
    
    print(f"  => Final Stats for '{criterion}':")
    print(f"     - Mean of Mean Likelihoods: {mean_of_means:.4f} (Std: {std_of_means:.4f})")
    print(f"     - Mean of Max Likelihoods:  {mean_of_maxs:.4f} (Std: {std_of_maxs:.4f})\n")

    final_stats_results.append({
        'Criterion': criterion,
        'Mean of Means (LL)': f"{mean_of_means:.4f}",
        'Std of Means (LL)': f"{std_of_means:.4f}",
        'Mean of Maxs (LL)': f"{mean_of_maxs:.4f}",
        'Std of Maxs (LL)': f"{std_of_maxs:.4f}"
    })

# --- 6. Final Statistics Summary Table ---
if final_stats_results:
    print("\n" + "="*50)
    print("--- Final Likelihood Statistics Summary ---")
    print("="*50)
    df_stats = pd.DataFrame(final_stats_results)
    print(df_stats.to_string(index=False))



--- Robust Evaluation from Best Hyperparameters ---

Starting robust evaluation for criterion 'db' with params: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:02<00:00, 40.79it/s, loss=8.03]


  -> Run 2/3...


100%|██████████| 100/100 [00:03<00:00, 30.96it/s, loss=6.99]


  -> Run 3/3...


100%|██████████| 100/100 [00:01<00:00, 52.00it/s, loss=6.21]


  => Final Stats for 'db':
     - Mean of Mean Likelihoods: -3.3880 (Std: 0.3925)
     - Mean of Max Likelihoods:  22.3995 (Std: 0.0000)

Starting robust evaluation for criterion 'tb' with params: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:01<00:00, 50.58it/s, loss=4.28]


  -> Run 2/3...


100%|██████████| 100/100 [00:02<00:00, 35.38it/s, loss=0.833]


  -> Run 3/3...


100%|██████████| 100/100 [00:01<00:00, 66.42it/s, loss=0.518]


  => Final Stats for 'tb':
     - Mean of Mean Likelihoods: 13.7562 (Std: 3.5199)
     - Mean of Max Likelihoods:  22.3726 (Std: 0.0190)

Starting robust evaluation for criterion 'subtb' with params: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:04<00:00, 23.44it/s, loss=8.19]


  -> Run 2/3...


100%|██████████| 100/100 [00:03<00:00, 27.82it/s, loss=12] 


  -> Run 3/3...


100%|██████████| 100/100 [00:04<00:00, 20.73it/s, loss=1.13]


  => Final Stats for 'subtb':
     - Mean of Mean Likelihoods: 6.2352 (Std: 2.5363)
     - Mean of Max Likelihoods:  22.3857 (Std: 0.0196)

Starting robust evaluation for criterion 'cb' with params: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:19<00:00,  5.04it/s, loss=1.9]


  -> Run 2/3...


100%|██████████| 100/100 [00:15<00:00,  6.37it/s, loss=2.14]


  -> Run 3/3...


100%|██████████| 100/100 [00:10<00:00,  9.79it/s, loss=2.7]

  => Final Stats for 'cb':
     - Mean of Mean Likelihoods: 13.4221 (Std: 1.3614)
     - Mean of Max Likelihoods:  22.3857 (Std: 0.0196)


--- Final Likelihood Statistics Summary ---
Criterion Mean of Means (LL) Std of Means (LL) Mean of Maxs (LL) Std of Maxs (LL)
       db            -3.3880            0.3925           22.3995           0.0000
       tb            13.7562            3.5199           22.3726           0.0190
    subtb             6.2352            2.5363           22.3857           0.0196
       cb            13.4221            1.3614           22.3857           0.0196





In [15]:
# --- 5. Robust Evaluation: Re-train and Evaluate Multiple Times ---
print("\n" + "="*50)
print("--- Robust Evaluation from Best Hyperparameters ---")
print("="*50 + "\n")

N_ROBUSTNESS_RUNS = 3
final_stats_results = []

for criterion, results in best_results_per_criterion.items():
    best_params = results.get('best_params')
    if not best_params:
        print(f"Skipping robust evaluation for '{criterion}' as no best parameters were found.\n")
        continue

    print(f"Starting robust evaluation for criterion '{criterion}' with params: {best_params}")
    
    # Lists to store metrics from each of the N runs
    run_mean_likelihoods = []
    run_max_likelihoods = []
    run_l1_distances = []

    for run in range(N_ROBUSTNESS_RUNS):
        print(f"  -> Run {run + 1}/{N_ROBUSTNESS_RUNS}...")
        
        # Unpack the best parameters for this criterion
        lr = best_params['lr']
        BATCH_SIZE = best_params['BATCH_SIZE']
        
        # Initialize and train a new model from scratch
        env = create_env()
        forward_model = ForwardPolicy(input_dim=MAX_LEN, output_dim=env.action_space_size, epsilon=best_params['epsilon'])
        backward_model = BackwardPolicy()
        gflownet = GFlowNet(forward_flow=forward_model, backward_flow=backward_model, criterion=criterion)
        gflownet.lr = lr # For mock calculation
        
        trained_gflownet, _ = train(gflownet=gflownet, create_env=create_env, epochs=100, batch_size=BATCH_SIZE, lr=lr, min_eps=best_params['min_eps'], clamp_g=best_params['clamp_g'], use_scheduler=True)
        
        # Calculate L1 distance for this run
        l1 = calculate_l1_distance(trained_gflownet.forward_flow, KernelEnvironment, MAX_LEN, X, Y)
        run_l1_distances.append(l1)

        # Sample and evaluate likelihood
        eval_env = KernelEnvironment(batch_size=100, max_trajectory_length=MAX_LEN, log_reward=log_reward_fn)
        trained_gflownet.eval()
        final_batch = trained_gflownet.sample(eval_env)
        
        likelihoods = [utils.evaluate_likelihood(k, X_test, Y_test) for k in final_batch.state]
        
        run_mean_likelihoods.append(np.mean(likelihoods))
        run_max_likelihoods.append(np.max(likelihoods))

    # Calculate statistics over the N runs
    mean_of_l1s = np.mean(run_l1_distances)
    std_of_l1s = np.std(run_l1_distances)
    mean_of_means = np.mean(run_mean_likelihoods)
    std_of_means = np.std(run_mean_likelihoods)
    mean_of_maxs = np.mean(run_max_likelihoods)
    std_of_maxs = np.std(run_max_likelihoods)
    
    print(f"  => Final Stats for '{criterion}':")
    print(f"     - Mean L1 Distance:         {mean_of_l1s:.4f} (Std: {std_of_l1s:.4f})")
    print(f"     - Mean of Mean Likelihoods: {mean_of_means:.4f} (Std: {std_of_means:.4f})")
    print(f"     - Mean of Max Likelihoods:  {mean_of_maxs:.4f} (Std: {std_of_maxs:.4f})\n")

    final_stats_results.append({
        'Criterion': criterion,
        'Mean L1': f"{mean_of_l1s:.4f}",
        'Std L1': f"{std_of_l1s:.4f}",
        'Mean of Means (LL)': f"{mean_of_means:.4f}",
        'Std of Means (LL)': f"{std_of_means:.4f}",
        'Mean of Maxs (LL)': f"{mean_of_maxs:.4f}",
        'Std of Maxs (LL)': f"{std_of_maxs:.4f}",
        'Mean L1': f"{mean_of_l1s:.4f}",
        'Std L1': f"{std_of_l1s:.4f}"
    })

# --- 6. Final Statistics Summary Table ---
if final_stats_results:
    print("\n" + "="*50)
    print("--- Final Likelihood & L1 Statistics Summary ---")
    print("="*50)
    df_stats = pd.DataFrame(final_stats_results)
    print(df_stats.to_string(index=False))



--- Robust Evaluation from Best Hyperparameters ---

Starting robust evaluation for criterion 'db' with params: {'lr': 0.001, 'BATCH_SIZE': 16, 'criterion': 'db', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


  0%|          | 0/100 [00:00<?, ?it/s, loss=8.06]

100%|██████████| 100/100 [00:02<00:00, 40.56it/s, loss=7.66]


  -> Run 2/3...


100%|██████████| 100/100 [00:02<00:00, 44.11it/s, loss=5.51]


  -> Run 3/3...


100%|██████████| 100/100 [00:03<00:00, 25.74it/s, loss=8.74]


  => Final Stats for 'db':
     - Mean L1 Distance:         0.8615 (Std: 0.0563)
     - Mean of Mean Likelihoods: -8.4253 (Std: 0.7168)
     - Mean of Max Likelihoods:  22.3995 (Std: 0.0000)

Starting robust evaluation for criterion 'tb' with params: {'lr': 0.01, 'BATCH_SIZE': 16, 'criterion': 'tb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:02<00:00, 45.04it/s, loss=4.92]


  -> Run 2/3...


100%|██████████| 100/100 [00:01<00:00, 52.96it/s, loss=0.775]


  -> Run 3/3...


100%|██████████| 100/100 [00:02<00:00, 34.02it/s, loss=0.51]


  => Final Stats for 'tb':
     - Mean L1 Distance:         0.5121 (Std: 0.0093)
     - Mean of Mean Likelihoods: 11.0030 (Std: 0.9731)
     - Mean of Max Likelihoods:  22.3995 (Std: 0.0000)

Starting robust evaluation for criterion 'subtb' with params: {'lr': 0.01, 'BATCH_SIZE': 64, 'criterion': 'subtb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:04<00:00, 22.23it/s, loss=6.12]


  -> Run 2/3...


100%|██████████| 100/100 [00:06<00:00, 15.87it/s, loss=0.704]


  -> Run 3/3...


100%|██████████| 100/100 [00:06<00:00, 16.62it/s, loss=8.16]


  => Final Stats for 'subtb':
     - Mean L1 Distance:         0.5749 (Std: 0.0371)
     - Mean of Mean Likelihoods: 7.8658 (Std: 1.7695)
     - Mean of Max Likelihoods:  22.3995 (Std: 0.0000)

Starting robust evaluation for criterion 'cb' with params: {'lr': 0.01, 'BATCH_SIZE': 256, 'criterion': 'cb', 'epsilon': 0.5, 'min_eps': 0.01, 'clamp_g': 1.0}
  -> Run 1/3...


100%|██████████| 100/100 [00:12<00:00,  7.92it/s, loss=2.5]


  -> Run 2/3...


100%|██████████| 100/100 [00:10<00:00,  9.50it/s, loss=1.72]


  -> Run 3/3...


100%|██████████| 100/100 [00:10<00:00,  9.41it/s, loss=1.66]


  => Final Stats for 'cb':
     - Mean L1 Distance:         0.5008 (Std: 0.0022)
     - Mean of Mean Likelihoods: 14.3573 (Std: 1.2281)
     - Mean of Max Likelihoods:  22.3995 (Std: 0.0000)


--- Final Likelihood & L1 Statistics Summary ---
Criterion Mean L1 Std L1 Mean of Means (LL) Std of Means (LL) Mean of Maxs (LL) Std of Maxs (LL)
       db  0.8615 0.0563            -8.4253            0.7168           22.3995           0.0000
       tb  0.5121 0.0093            11.0030            0.9731           22.3995           0.0000
    subtb  0.5749 0.0371             7.8658            1.7695           22.3995           0.0000
       cb  0.5008 0.0022            14.3573            1.2281           22.3995           0.0000


In [9]:
likelihoods

[21.754836933213255,
 21.754837276057494,
 21.754835696815867,
 -24.535347280501647,
 21.754836908564336,
 21.75483732641125,
 -26.234023441334703,
 21.75483744850833,
 22.35798314131126,
 21.754833475082002,
 22.357923829376045,
 -22.921964358772556,
 21.75477707014084,
 21.754835696815867,
 -26.04262692299885,
 22.357923829376045,
 21.90018796744321,
 22.39950774571875,
 21.707848445403283,
 21.754837448416104,
 21.710066160652595,
 21.754029861316866,
 22.111277062964167,
 -22.921964358772556,
 21.75483744850833,
 21.75477707014084,
 22.35798314131126,
 21.754837448507907,
 -17.950360561387587,
 -26.231603561657735,
 21.75483744804797,
 21.75483675370739,
 21.754777092694482,
 21.754837206718907,
 21.475047969212284,
 21.75477579643438,
 21.75477715822685,
 -26.047071516860356,
 21.754837448054086,
 21.754836909301687,
 21.75483744850833,
 21.754833475082002,
 21.90018796745239,
 21.754836908564336,
 -24.53534728047471,
 22.111277062972174,
 21.754836966146144,
 21.475047968948807,


In [10]:

epochs = 100
BATCH_SIZE = 64
MAX_LEN = 4
lr = 1e-3

log_reward_fn = partial(utils.log_likelihood_reward, X, Y)
env = create_env()

forward_model = ForwardPolicy(input_dim=MAX_LEN, output_dim=env.action_space_size, epsilon=0.5)
backward_model = BackwardPolicy()
criterion = 'db'

gflownet = GFlowNet(
    forward_flow=forward_model, 
    backward_flow=backward_model, 
    criterion=criterion 
)


gflownet, losses = train(
    gflownet=gflownet,
    create_env=create_env,
    epochs=epochs,
    batch_size=BATCH_SIZE,
    lr=lr,
    min_eps=1e-2,
    clamp_g= 10,
    use_scheduler=True
)

100%|██████████| 100/100 [00:06<00:00, 15.63it/s, loss=8.18]
