In [20]:
import json
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

import argparse

from env import Box, get_last_states
from model import CirclePF, CirclePB, NeuralNet
from sampling import (
    sample_trajectories,
    evaluate_backward_logprobs,
)

from utils import (
    fit_kde,
    plot_reward,
    sample_from_reward,
    plot_samples,
    estimate_jsd,
    plot_trajectories,
)

import config

In [21]:
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default=config.DEVICE)
parser.add_argument("--dim", type=int, default=config.DIM)
parser.add_argument("--delta", type=float, default=config.DELTA)
parser.add_argument(
    "--n_components",
    type=int,
    default=config.N_COMPONENTS,
    help="Number of components in Mixture Of Betas",
)

parser.add_argument("--reward_debug", action="store_true", default=config.REWARD_DEBUG)
parser.add_argument(
    "--reward_type",
    type=str,
    choices=["baseline", "ring", "angular_ring", "multi_ring", "curve", "gaussian_mixture", "corner_squares", "two_corners", "edge_boxes", "edge_boxes_corner_squares"],
    default=config.REWARD_TYPE,
    help="Type of reward function to use. To modify reward-specific parameters (radius, sigma, etc.), edit rewards.py"
)
parser.add_argument("--R0", type=float, default=config.R0, help="Baseline reward value")
parser.add_argument("--R1", type=float, default=config.R1, help="Medium reward value (e.g., outer square)")
parser.add_argument("--R2", type=float, default=config.R2, help="High reward value (e.g., inner square)")
parser.add_argument(
    "--n_components_s0",
    type=int,
    default=config.N_COMPONENTS_S0,
    help="Number of components in Mixture Of Betas",
)
parser.add_argument(
    "--beta_min",
    type=float,
    default=config.BETA_MIN,
    help="Minimum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--beta_max",
    type=float,
    default=config.BETA_MAX,
    help="Maximum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--PB",
    type=str,
    choices=["learnable", "tied", "uniform"],
    default=config.PB,
)
parser.add_argument("--gamma_scheduler", type=float, default=config.GAMMA_SCHEDULER)
parser.add_argument("--scheduler_milestone", type=int, default=config.SCHEDULER_MILESTONE)
parser.add_argument("--seed", type=int, default=config.SEED)
parser.add_argument("--lr", type=float, default=config.LR)
parser.add_argument("--lr_Z", type=float, default=config.LR_Z)
parser.add_argument("--lr_F", type=float, default=config.LR_F)
parser.add_argument("--tie_F", action="store_true", default=config.TIE_F)
parser.add_argument("--BS", type=int, default=config.BS)
parser.add_argument("--n_iterations", type=int, default=config.N_ITERATIONS)
parser.add_argument("--hidden_dim", type=int, default=config.HIDDEN_DIM)
parser.add_argument("--n_hidden", type=int, default=config.N_HIDDEN)
parser.add_argument("--n_evaluation_trajectories", type=int, default=config.N_EVALUATION_TRAJECTORIES)
parser.add_argument("--no_plot", action="store_true", default=config.NO_PLOT)
parser.add_argument("--no_wandb", action="store_true", default=config.NO_WANDB)
parser.add_argument("--wandb_project", type=str, default=config.WANDB_PROJECT)

# Use parse_args([]) in Jupyter to avoid conflicts with Jupyter's kernel arguments
args = parser.parse_args([])


In [22]:
device = args.device
dim = args.dim
delta = args.delta
seed = args.seed
lr = args.lr
lr_Z = args.lr_Z
lr_F = args.lr_F
n_iterations = args.n_iterations
BS = args.BS
n_components = args.n_components
n_components_s0 = args.n_components_s0

torch.manual_seed(seed)
np.random.seed(seed)

print(f"Using device: {device}")

env = Box(
    dim=dim,
    delta=delta,
    device_str=device,
    reward_type=args.reward_type,
    reward_debug=args.reward_debug,
    R0=args.R0,
    R1=args.R1,
    R2=args.R2,
)

# Get the true KDE
samples = sample_from_reward(env, n_samples=10000)
true_kde, fig1 = fit_kde(samples, plot=True)


model = CirclePF(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    n_components=n_components,
    n_components_s0=n_components_s0,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)

bw_model = CirclePB(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    torso=model.torso if args.PB == "tied" else None,
    uniform=args.PB == "uniform",
    n_components=n_components,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)


logZ = torch.zeros(1, requires_grad=True, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if args.PB != "uniform":
    optimizer.add_param_group(
        {
            "params": bw_model.output_layer.parameters()
            if args.PB == "tied"
            else bw_model.parameters(),
            "lr": lr,
        }
    )
optimizer.add_param_group({"params": [logZ], "lr": lr_Z})


scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[i * args.scheduler_milestone for i in range(1, 10)],
    gamma=args.gamma_scheduler,
)

jsd = float("inf")

Using device: cuda:0


In [33]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )


    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    print(trajectories[1])
    print(actionss[1])
    print(all_bw_logprobs[1])
    break
    # TB (Trajectory Balance) loss
    loss = torch.mean((logZ + logprobs - bw_logprobs - logrewards) ** 2)

    if torch.isinf(loss):
        raise ValueError("Infinite loss")
    loss.backward()
    # clip the gradients for bw_model
    for p in bw_model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    for p in model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    optimizer.step()
    scheduler.step()

    if any(
        [
            torch.isnan(list(model.parameters())[i]).any()
            for i in range(len(list(model.parameters())))
        ]
    ):
        raise ValueError("NaN in model parameters")

    if i % 100 == 0:
        log_dict = {
            "loss": loss.item(),
            "sqrt(logZdiff**2)": np.sqrt((np.log(env.Z) - logZ.item())**2),
            "states_visited": (i + 1) * BS,
        }

        # Evaluate JSD every 500 iterations and add to the same log
        if i % 500 == 0:
            trajectories, _, _, _ = sample_trajectories(
                env, model, args.n_evaluation_trajectories
            )
            last_states = get_last_states(env, trajectories)
            kde, fig4 = fit_kde(last_states, plot=True)
            jsd = estimate_jsd(kde, true_kde)

            log_dict["JSD"] = jsd

            if not NO_PLOT:
                colors = plt.cm.rainbow(np.linspace(0, 1, 10))
                fig1 = plot_samples(last_states[:2000].detach().cpu().numpy())
                fig2 = plot_trajectories(trajectories.detach().cpu().numpy()[:20])

                log_dict["last_states"] = wandb.Image(fig1)
                log_dict["trajectories"] = wandb.Image(fig2)
                log_dict["kde"] = wandb.Image(fig4)

        if USE_WANDB:
            wandb.log(log_dict, step=i)

        tqdm.write(
            # Loss with 3 digits of precision, logZ with 2 digits of precision, true logZ with 2 digits of precision
            # Last computed JSD with 4 digits of precision
            f"States: {(i + 1) * BS}, Loss: {loss.item():.3f}, logZ: {logZ.item():.2f}, true logZ: {np.log(env.Z):.2f}, JSD: {jsd:.4f}"
        )


# if USE_WANDB:
#     wandb.finish()

# # Save model and arguments as JSON
# save_path = os.path.join("saved_models", run_name)
# if not os.path.exists(save_path):
#     os.makedirs(save_path)
#     torch.save(model.state_dict(), os.path.join(save_path, "model.pt"))
#     torch.save(bw_model.state_dict(), os.path.join(save_path, "bw_model.pt"))
#     torch.save(logZ, os.path.join(save_path, "logZ.pt"))
#     with open(os.path.join(save_path, "args.json"), "w") as f:
#         json.dump(vars(args), f)


  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

tensor([[0.0000, 0.0000],
        [0.0547, 0.0772],
        [0.2881, 0.1668],
        [0.4994, 0.3005],
        [0.6544, 0.4965],
        [0.7699, 0.7183],
        [  -inf,   -inf],
        [  -inf,   -inf]], device='cuda:0')
tensor([[0.0547, 0.0772],
        [0.2334, 0.0896],
        [0.2112, 0.1337],
        [0.1551, 0.1961],
        [0.1154, 0.2218],
        [  -inf,   -inf],
        [  -inf,   -inf]], device='cuda:0')
tensor([0.0000, 2.2548, 1.3788, 1.4409, 1.1008,   -inf], device='cuda:0',
       grad_fn=<SelectBackward0>)





In [210]:

def trajectories_to_transitions(trajectories, actionss, all_bw_logprobs, logrewards, env):
    """
    Convert trajectories to transitions for replay buffer.

    Args:
        trajectories: tensor of shape (batch_size, trajectory_length, dim)
        actionss: tensor of shape (batch_size, trajectory_length, dim)
        all_bw_logprobs: tensor of shape (batch_size, trajectory_length)
        last_states: tensor of shape (batch_size, dim)
        logrewards: tensor of shape (batch_size,)
        env: environment object

    Returns:
        Tuple of (states, actions, rewards, next_states, dones) as tensors
    """

    # Extract states and next_states for intermediate transitions
    # Match the length to all_bw_logprobs

    states = trajectories[:, :-1, :]  
    next_states = trajectories[:, 1:, :] 
    is_not_sink = torch.all(states != env.sink_state, dim=-1)
    is_next_sink = torch.all(next_states == env.sink_state, dim=-1)
    last_state = is_not_sink & is_next_sink
    dones = torch.zeros_like(last_state, dtype=torch.float32)  # (batch_size, bw_length)
    dones[last_state] = 1.0
    dones = dones[:, 1:]
    rewards = all_bw_logprobs
    rewards = torch.where(last_state[:,1:], rewards + logrewards.unsqueeze(1), rewards)
    states = states[:, :-1, :]  
    next_states = next_states[:, :-1, :] 
    actions = actionss[:, :-1, :] 
    # Check which rewards are valid (not inf/nan)
    is_valid = torch.isfinite(rewards)  # (batch_size, bw_length)

    # Flatten batch and time dimensions for transitions
    states_flat = states[is_valid]
    actions_flat = actions[is_valid]
    rewards_flat = rewards[is_valid]
    next_states_flat = next_states[is_valid]
    dones_flat = dones[is_valid]

    return states_flat, actions_flat, rewards_flat, next_states_flat, dones_flat



In [211]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )

    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    all_states, all_actions, all_rewards, all_next_states, all_dones = trajectories_to_transitions(
        trajectories, actionss, all_bw_logprobs, logrewards, env
    )


    for i in range(15):
        print(f"==={i}===")
        print(all_states[i])
        print(all_actions[i])
        print(torch.exp(all_rewards[i]))
        print(all_next_states[i])
        print(all_dones[i])
    break

  0%|          | 0/20000 [00:00<?, ?it/s]

===0===
tensor([0., 0.], device='cuda:0')
tensor([0.0515, 0.0231], device='cuda:0')
tensor(1., device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.0515, 0.0231], device='cuda:0')
tensor(0., device='cuda:0')
===1===
tensor([0.0515, 0.0231], device='cuda:0')
tensor([0.1983, 0.1522], device='cuda:0')
tensor(3.4034, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.2498, 0.1753], device='cuda:0')
tensor(0., device='cuda:0')
===2===
tensor([0.2498, 0.1753], device='cuda:0')
tensor([0.1348, 0.2105], device='cuda:0')
tensor(3.8133, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.3846, 0.3859], device='cuda:0')
tensor(0., device='cuda:0')
===3===
tensor([0.3846, 0.3859], device='cuda:0')
tensor([0.1413, 0.2063], device='cuda:0')
tensor(3.9550, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.5259, 0.5921], device='cuda:0')
tensor(0., device='cuda:0')
===4===
tensor([0.5259, 0.5921], device='cuda:0')
tensor([0.2142, 0.1288], device='cuda:0')
tensor(3.8797, device='cuda:0', grad_fn=<ExpBack




In [216]:
from torch.distributions import Beta


class CirclePF_Uniform():
    def __init__(self):
        pass
    
    def to_dist(self, x):
        if torch.all(x[0] == 0.0):
            assert torch.all(
                x == 0.0
            )  # If one of the states is s0, all of them must be
            alpha = torch.ones(x.shape[0], device=x.device)
            beta = torch.ones(x.shape[0], device=x.device)
            
            return Beta(alpha, beta), Beta(alpha, beta)
        
        alpha = torch.ones(x.shape[0], device=x.device)
        beta = torch.ones(x.shape[0], device=x.device)
        return Beta(alpha, beta)

tensor([0.6858])