In [2]:
# Load Hydra config in notebooks
# https://github.com/facebookresearch/hydra/blob/main/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb
import os
from hydra import initialize_config_dir, compose
import hydra
from omegaconf import OmegaConf
abs_config_dir = os.path.abspath("config/")

with initialize_config_dir(version_base=None, config_dir=abs_config_dir):
    config = compose(config_name="test.yaml", overrides=[])
    print(OmegaConf.to_yaml(config))
    print(config)

config.sampler.conf.logger.do.online = False



dataset:
  _target_: dataset.grid.BraninDatasetHandler
  grid_size: 100
  normalize_scores: true
  train_fraction: 1.0
  batch_size: 16
  shuffle: true
  train_path: ~/activelearning/my_package/storage/branin/data_100_train.csv
  test_path: null
oracle:
  _target_: oracle.oracle.Branin
  fidelity: 1
  do_domain_map: true
filter:
  _target_: filter.filter.Filter
sampler:
  conf:
    state_flow: null
    policy:
      forward:
        _target_: gflownet.policy.base.Policy
        config:
          type: mlp
          n_hid: 128
          n_layers: 2
          checkpoint: null
          reload_ckpt: false
          is_model: false
      backward:
        _target_: gflownet.policy.base.Policy
        config: null
      shared: null
    env:
      _target_: gflownet.envs.grid.Grid
      id: grid
      func: corners
      n_dim: 2
      length: 100
      max_increment: 1
      max_dim_per_action: 1
      cell_min: 0
      cell_max: 0.99
      buffer:
        train: null
        test: null
  

In [3]:
import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = config.device
n_iterations = config.budget  # TODO: replace with budget
grid_size = config.dataset.grid_size
n_samples = config.n_samples
maximize = config.maximize

from gflownet.utils.common import set_float_precision
float_prec = set_float_precision(config.float_precision)

import matplotlib.colors as cm
import matplotlib.pyplot as plt
# colors = ["red", "blue", "green", "orange", "brown", "pink"]
colors = plt.get_cmap("Reds")

In [4]:
# from utils.logger import WandBLogger
# logger = WandBLogger(project_name="test_hartmann", run_name="GFlowNetSampler 100x100 power")

from utils.plotter import PlotHelper
# plotter = PlotHelper(logger)
plotter = PlotHelper()

In [5]:
from dataset.grid import HartmannDatasetHandler
from surrogate.surrogate import SingleTaskGPRegressor
from sampler.sampler import GreedySampler, RandomSampler
from filter.filter import Filter, OracleFilter
from oracle.oracle import HartmannOracle


# Dataset
dataset_handler = HartmannDatasetHandler(
    grid_size=grid_size,
    train_path="./storage/hartmann/data_train.csv",
    train_fraction=1.0,
    float_precision=float_prec,
)

In [6]:
candidates, xi, yi = dataset_handler.get_candidate_set()
candidates.shape

: 

In [20]:
import numpy as np

# define candidate set
xi = np.arange(0, 10)
yi = np.arange(0, 10)
grid = np.array(np.meshgrid(*[xi, yi] * 3))
grid_flat = torch.tensor(grid.T, dtype=torch.float64).reshape(-1, 6)
grid_flat.shape, grid.shape

(torch.Size([1000000, 6]), (6, 10, 10, 10, 10, 10, 10))

In [None]:
from dataset.grid import HartmannDatasetHandler
from surrogate.surrogate import SingleTaskGPRegressor
from sampler.sampler import GreedySampler, RandomSampler
from filter.filter import Filter, OracleFilter
from oracle.oracle import HartmannOracle


# Dataset
dataset_handler = HartmannDatasetHandler(
    grid_size=grid_size,
    train_path="./storage/hartmann/data_train.csv" % grid_size,
    train_fraction=1.0,
    float_precision=float_prec,
)

candidate_set, xi, yi = dataset_handler.get_candidate_set()

# Oracle
oracle = HartmannOracle(
    fidelity=1, device=device, float_precision=float_prec
)
# Filter
filter = Filter()
# filter = OracleFilter(oracle)

if plotter is not None:
    fig_oracle, ax_oracle = plotter.plot_function(oracle, candidate_set.clone().to(device), xi=xi, yi=yi)


best_scores = []

for i in range(n_iterations):

    train_data, test_data = dataset_handler.get_dataloader()
    # print("iteration", i)
    # Surrogate (e.g., Bayesian Optimization)
    # starts with a clean slate each iteration
    surrogate = SingleTaskGPRegressor(
        float_precision=float_prec, device=device, maximize=maximize
    )
    surrogate.fit(train_data)

    # Sampler (e.g., GFlowNet, or Random Sampler)
    # also starts with a clean slate; TODO: experiment with NOT training from scratch
    # sampler = RandomSampler(surrogate)
    sampler = GreedySampler(surrogate)
    # sampler = hydra.utils.instantiate(
    #     config.sampler,
    #     surrogate=surrogate,
    #     device=device,
    #     float_precision=float_prec,
    #     _recursive_=False,
    # )

    sampler.fit()  # only necessary for samplers that train a model

    samples = sampler.get_samples(
        n_samples * 3, candidate_set=candidate_set.clone().to(device)
    )
    filtered_samples = filter(
        n_samples=n_samples, candidate_set=samples.clone(), maximize=maximize
    )
    
    if plotter is not None:
        fig_acq, ax_acq = plotter.plot_function(surrogate, candidate_set.clone().to(device), xi=xi, yi=yi)
        fig_acq, ax_acq = plotter.plot_samples(filtered_samples, ax_acq, fig_acq)
        ax_acq.set_title("acquisition fn + selected samples of iteration %i"%i)
        plotter.log_figure(fig_acq, "acq")

    if ax_oracle is not None:
        ax_oracle.scatter(
            x=filtered_samples[:, 0].cpu(),
            y=filtered_samples[:, 1].cpu(),
            c=cm.to_hex(colors(i / n_iterations)),
            marker="x",
            label="it %i" % i,
        )

    scores = oracle(filtered_samples.clone())
    dataset_handler.update_dataset(filtered_samples.cpu(), scores.cpu())
    best_scores.append(scores.min().cpu())

if ax_oracle is not None:
    fig_oracle.legend()
    ax_oracle.set_title("oracle fn + samples")
    plotter.log_figure(fig_oracle, key="oracle")


fig = plt.figure()
plt.plot(best_scores)
plt.xlabel("iterations")
plt.ylabel("scores")
plt.title("Best Score in each iteration")
if plotter is not None:
    plotter.log_figure(fig, key="best_scores")
