In [1]:
#@title Install Dependencies
!find . -name "*.pyc" -delete
!find . -name "__pycache__" -delete
# !pip install --no-deps git+https://github.com/GFNOrg/torchgfn.git
# Replace with your GitHub username and personal access token
username = "sdawzy"
token = "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"

# Replace with your private repository URL
repo_url = "https://github.com/Erostrate9/GFNEval.git"

!pip install --no-deps git+https://{username}:{token}@{repo_url.split('https://')[1]}#subdirectory=torchgfn

Collecting git+https://sdawzy:****@github.com/Erostrate9/GFNEval.git#subdirectory=torchgfn
  Cloning https://sdawzy:****@github.com/Erostrate9/GFNEval.git to /tmp/pip-req-build-ptl67x5b
  Running command git clone --filter=blob:none --quiet 'https://sdawzy:****@github.com/Erostrate9/GFNEval.git' /tmp/pip-req-build-ptl67x5b
  Resolved https://sdawzy:****@github.com/Erostrate9/GFNEval.git to commit cb6fcd4370790301808f630059ecd501bb4fe459
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torchgfn
  Building wheel for torchgfn (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torchgfn: filename=torchgfn-1.1.1-py3-none-any.whl size=79472 sha256=96bf1cc91638920397d2d0313c766b807e810b690563c897e1f4d15bc7007340
  Stored in directory: /tmp/pip-ephem-wheel-cache-ch25jz76/wheels/97/74/20/5c5130c3639d55c9ed0b3f7f003fa1a07cb97f41d8

In [2]:
#@title Import Necessary Packages
import torch
torch.set_default_dtype(torch.float)
from torch import Tensor
from torch import nn
from torch import optim

from gfn.gflownet import GFlowNet, TBGFlowNet, SubTBGFlowNet, FMGFlowNet, DBGFlowNet
from gfn.samplers import Sampler
from gfn.env import Env
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.modules import MLP  # is a simple multi-layer perceptron (MLP)
from gfn.containers import Trajectories
from gfn.states import States

from gfn.eval_kl import PhiFunction, calc_KL_using_model, compute_KL
from gfn.gym.hypergrid2 import HyperGrid2, get_final_states

from tqdm import tqdm
import matplotlib.pyplot as plt

import pickle
import os
from google.colab import drive

In [9]:
#@title Experiment Setup, Traing, and Testing
def experiment_setup(env : Env,  algo: GFlowNet):
    gfn = None
    sampler = None
    optimizer = None

    if algo is TBGFlowNet:
        # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
        module_PF = MLP(
            input_dim=env.preprocessor.output_dim,
            output_dim=env.n_actions
        ).to(env.device)  # Neural network for the forward policy, with as many outputs as there are actions
        module_PB = MLP(
            input_dim=env.preprocessor.output_dim,
            output_dim=env.n_actions - 1,
            trunk=module_PF.trunk  # We share all the parameters of P_F and P_B, except for the last layer
        ).to(env.device)

        pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor).to(env.device)
        pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor).to(env.device)

        gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator).to(env.device)

        sampler = Sampler(estimator=pf_estimator)

        optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
        optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

    if algo is SubTBGFlowNet:
        # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
        module_PF = MLP(
            input_dim=env.preprocessor.output_dim,
            output_dim=env.n_actions
        ).to(env.device)  # Neural network for the forward policy, with as many outputs as there are actions

        module_PB = MLP(
            input_dim=env.preprocessor.output_dim,
            output_dim=env.n_actions - 1,
            trunk=module_PF.trunk  # We share all the parameters of P_F and P_B, except for the last layer
        ).to(env.device)
        module_logF = MLP(
            input_dim=env.preprocessor.output_dim,
            output_dim=1,  # Important for ScalarEstimators!
        ).to(env.device)

        # 3 - We define the estimators.
        pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor).to(env.device)
        pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor).to(env.device)
        logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor).to(env.device)

        # 4 - We define the GFlowNet.
        gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator, lamda=0.9).to(env.device)

        # 5 - We define the sampler and the optimizer.
        sampler = Sampler(estimator=pf_estimator)  # We use an on-policy sampler, based on the forward policy

        # Different policy parameters can have their own LR.
        # Log F gets dedicated learning rate (typically higher).
        optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
        optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

    # TODO: initialize parameterizations of FMGFlowNet and DBGFlowNet

    return gfn, sampler, optimizer

def training(gfn: GFlowNet, sample: Sampler, optimizer, num_epochs: int = 1000) -> Sampler:
    for i in (pbar := tqdm(range(num_epochs))):
        trajectories = sampler.sample_trajectories(env=env, n=16)
        optimizer.zero_grad()
        loss = gfn.loss(env, trajectories)
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            pbar.set_postfix({"loss": loss.item()})
    return sampler

def testing(env: Env, gfn: GFlowNet, num_samples: int = 10000, num_epochs: int = 250, show_progress: bool = False) -> None:
    # Sample from proxy distribution
    # i.e. from the learned sampler
    samples_proxy_distribution = gfn.sample_terminating_states(env=env, n=num_samples)
    samples_proxy_tensor = samples_proxy_distribution.tensor.double().to(env.device)

    # Sample from the true distribution
    samples_true_distribution = env.sample_states_from_distribution(num_samples)
    samples_true_tensor = samples_true_distribution.tensor.double().to(env.device)

    kl, phi = compute_KL(samples_proxy_tensor, samples_true_tensor,
                         num_epochs=num_epochs, show_progress=show_progress)
    return kl, phi



In [4]:
#@title Hyper-parameters
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

ndims =     [2, 4, 8, 16]
heights =   [8, 16, 32, 64, 128, 256]
ncenters =  [2, 4, 8, 16, 32]
algos =     [TBGFlowNet, SubTBGFlowNet, FMGFlowNet, DBGFlowNet]

In [5]:
#@title Load Progress
drive.mount('/content/drive')
drive_path = "/content/drive/My Drive/KLEval_TB_results.pkl"


# Load saved results
try:
    with open(drive_path, "rb") as f:
        results = pickle.load(f)
    print("Loaded saved results from Google Drive.")
except FileNotFoundError:
    results = {}  # Start fresh if no file exists

Mounted at /content/drive
Loaded saved results from Google Drive.


In [None]:
#@title Some test
env = HyperGrid2(ndim=4, height=256, ncenters=4,
                             seed=torch.randint(0, 10000, (1,)).item(),
                             device_str='cpu')
sampler = results[(4,256,4)]['sampler']
sampler.sample_trajectories(env=env, n=10)

In [8]:
#@title (Optional) Reset results
results = {}

Start experiments from here:

In [None]:
device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Specify the algorithm
algo = TBGFlowNet
for ndim in ndims:
    for height in heights:
        if height ** ndim > 1e8:
            continue
        for ncenter in ncenters:
            if (ndim, height, ncenter) in results and 'seed' in results[(ndim, height, ncenter)]:
                print(f"Skipping already processed: ndim={ndim}, height={height}, ncenter={ncenter}")
                continue

            print(f"ndim={ndim}, height={height}, ncenter={ncenter}, algo=TB")
            seed = torch.randint(0, 10000, (1,)).item()
            env = HyperGrid2(ndim=ndim, height=height, ncenters=ncenter,
                             seed=seed,
                             device_str=device_str)
            gfn, sampler, optimizer = experiment_setup(env, algo)

            sampler = training(gfn, sampler, optimizer)
            # Save partial results
            results[(ndim, height, ncenter)] = {
                'sampler': sampler,  # Save the sampler object
                'gfn': gfn,
                'optimizer': optimizer,
                'seed': seed,
                'device_str': device_str,
                # 'env': env,
                'ndim': ndim,
                'height': height,
                'ncenter': ncenter,
            }
            with open(drive_path, "wb") as f:
                pickle.dump(results, f)
            print(f"Partial results saved to {drive_path}")

            # Calculate KL and phi
            kl, phi = testing(env, gfn)
            print(f"KL={kl}")

            # Save results
            results[(ndim, height, ncenter)].update({
                'kl': kl,
                'phi': phi,
            })
            with open(drive_path, "wb") as f:
                pickle.dump(results, f)
            print(f"Results saved to {drive_path}")



ndim=2, height=8, ncenter=2, algo=TB


100%|██████████| 1000/1000 [01:20<00:00, 12.50it/s, loss=0.000562]


Partial results saved to /content/drive/My Drive/KLEval_TB_results.pkl
KL=0.08946420041107372
Results saved to /content/drive/My Drive/KLEval_TB_results.pkl
ndim=2, height=8, ncenter=4, algo=TB


 77%|███████▋  | 773/1000 [00:48<00:15, 14.67it/s, loss=0.0023]