In [1]:
import os
import sys

ROOT_DIR = os.path.abspath("__file__" + "/../../")
sys.path.insert(0, f"{ROOT_DIR}")

In [3]:
from src.spinfoam.spinfoams import SingleVertexSpinFoam
from src.spinfoam.sf_env import SpinFoamEnvironment

spin_j = 3.0
env = SpinFoamEnvironment(
    spinfoam_model=SingleVertexSpinFoam(spin_j=spin_j)
)

In [4]:
import torch

from gfn import LogitPBEstimator, LogitPFEstimator, LogZEstimator
from gfn.losses import TBParametrization, TrajectoryBalance
from gfn.samplers import DiscreteActionsSampler, TrajectoriesSampler

logit_PF = LogitPFEstimator(env=env, module_name="NeuralNet")
logit_PB = LogitPBEstimator(
    env=env,
    module_name="NeuralNet",
    torso=logit_PF.module.torso,  # To share parameters between PF and PB
)
logZ = LogZEstimator(torch.tensor(0.0))


training_sampler = TrajectoriesSampler(
    env=env,
    actions_sampler=DiscreteActionsSampler(
        estimator=logit_PF,
        epsilon=0.5
    )
)

eval_sampler = TrajectoriesSampler(
    env=env, actions_sampler=DiscreteActionsSampler(estimator=logit_PF)
)

parametrization = TBParametrization(logit_PF, logit_PB, logZ)
loss_fn = TrajectoryBalance(
    parametrization=parametrization,
    log_reward_clip_min=-500.0
)

params = [
    {
        "params": [
            val for key, val in parametrization.parameters.items() if "logZ" not in key
        ],
        "lr": 0.001,
    },
#     {"params": [val for key, val in parametrization.parameters.items() if "logZ" in key], "lr": 0.1},
]
optimizer = torch.optim.Adam(params)

  super().__init__(params, defaults)


In [5]:
from tqdm import tqdm

losses = []
terminal_states = []

for i in (pbar := tqdm(range(int(1e3)))):
    trajectories = training_sampler.sample(
        n_trajectories=int(1e3)
    )
    optimizer.zero_grad()
    loss = loss_fn(trajectories)
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
        pbar.set_postfix({"loss": loss.item()})
        eval_trajectories = eval_sampler.sample(
            n_trajectories=int(1e3)
        )
        terminal_states.append(
            eval_trajectories.last_states.states_tensor
        )
       
    losses.append(loss.item())

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:27<00:00,  6.78it/s, loss=0.974]
