In [1]:
import json
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

import argparse

from env import Box, get_last_states
from model import CirclePF, CirclePB, NeuralNet
from sampling import (
    sample_trajectories,
    evaluate_backward_logprobs,
)

from utils import (
    fit_kde,
    plot_reward,
    sample_from_reward,
    plot_samples,
    estimate_jsd,
    plot_trajectories,
)

import config

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default=config.DEVICE)
parser.add_argument("--dim", type=int, default=config.DIM)
parser.add_argument("--delta", type=float, default=config.DELTA)
parser.add_argument(
    "--n_components",
    type=int,
    default=config.N_COMPONENTS,
    help="Number of components in Mixture Of Betas",
)

parser.add_argument("--reward_debug", action="store_true", default=config.REWARD_DEBUG)
parser.add_argument(
    "--reward_type",
    type=str,
    choices=["baseline", "ring", "angular_ring", "multi_ring", "curve", "gaussian_mixture", "corner_squares", "two_corners", "edge_boxes", "edge_boxes_corner_squares"],
    default=config.REWARD_TYPE,
    help="Type of reward function to use. To modify reward-specific parameters (radius, sigma, etc.), edit rewards.py"
)
parser.add_argument("--R0", type=float, default=config.R0, help="Baseline reward value")
parser.add_argument("--R1", type=float, default=config.R1, help="Medium reward value (e.g., outer square)")
parser.add_argument("--R2", type=float, default=config.R2, help="High reward value (e.g., inner square)")
parser.add_argument(
    "--n_components_s0",
    type=int,
    default=config.N_COMPONENTS_S0,
    help="Number of components in Mixture Of Betas",
)
parser.add_argument(
    "--beta_min",
    type=float,
    default=config.BETA_MIN,
    help="Minimum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--beta_max",
    type=float,
    default=config.BETA_MAX,
    help="Maximum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--PB",
    type=str,
    choices=["learnable", "tied", "uniform"],
    default=config.PB,
)
parser.add_argument("--gamma_scheduler", type=float, default=config.GAMMA_SCHEDULER)
parser.add_argument("--scheduler_milestone", type=int, default=config.SCHEDULER_MILESTONE)
parser.add_argument("--seed", type=int, default=config.SEED)
parser.add_argument("--lr", type=float, default=config.LR)
parser.add_argument("--lr_Z", type=float, default=config.LR_Z)
parser.add_argument("--lr_F", type=float, default=config.LR_F)
parser.add_argument("--tie_F", action="store_true", default=config.TIE_F)
parser.add_argument("--BS", type=int, default=config.BS)
parser.add_argument("--n_iterations", type=int, default=config.N_ITERATIONS)
parser.add_argument("--hidden_dim", type=int, default=config.HIDDEN_DIM)
parser.add_argument("--n_hidden", type=int, default=config.N_HIDDEN)
parser.add_argument("--n_evaluation_trajectories", type=int, default=config.N_EVALUATION_TRAJECTORIES)
parser.add_argument("--no_plot", action="store_true", default=config.NO_PLOT)
parser.add_argument("--no_wandb", action="store_true", default=config.NO_WANDB)
parser.add_argument("--wandb_project", type=str, default=config.WANDB_PROJECT)

# Use parse_args([]) in Jupyter to avoid conflicts with Jupyter's kernel arguments
args = parser.parse_args([])


In [3]:
device = args.device
dim = args.dim
delta = args.delta
seed = args.seed
lr = args.lr
lr_Z = args.lr_Z
lr_F = args.lr_F
n_iterations = args.n_iterations
BS = args.BS
n_components = args.n_components
n_components_s0 = args.n_components_s0

torch.manual_seed(seed)
np.random.seed(seed)

print(f"Using device: {device}")

env = Box(
    dim=dim,
    delta=delta,
    device_str=device,
    reward_type=args.reward_type,
    reward_debug=args.reward_debug,
    R0=args.R0,
    R1=args.R1,
    R2=args.R2,
)

# Get the true KDE
samples = sample_from_reward(env, n_samples=10000)
true_kde, fig1 = fit_kde(samples, plot=True)


model = CirclePF(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    n_components=n_components,
    n_components_s0=n_components_s0,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)

bw_model = CirclePB(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    torso=model.torso if args.PB == "tied" else None,
    uniform=args.PB == "uniform",
    n_components=n_components,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)


logZ = torch.zeros(1, requires_grad=True, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if args.PB != "uniform":
    optimizer.add_param_group(
        {
            "params": bw_model.output_layer.parameters()
            if args.PB == "tied"
            else bw_model.parameters(),
            "lr": lr,
        }
    )
optimizer.add_param_group({"params": [logZ], "lr": lr_Z})


scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[i * args.scheduler_milestone for i in range(1, 10)],
    gamma=args.gamma_scheduler,
)

jsd = float("inf")

Using device: cuda:3


In [5]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )


    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    print(trajectories.shape)
    print(actionss.shape)
    print(logprobs.shape)
    print(all_logprobs.shape)
    break
    # TB (Trajectory Balance) loss
    loss = torch.mean((logZ + logprobs - bw_logprobs - logrewards) ** 2)

    if torch.isinf(loss):
        raise ValueError("Infinite loss")
    loss.backward()
    # clip the gradients for bw_model
    for p in bw_model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    for p in model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    optimizer.step()
    scheduler.step()

    if any(
        [
            torch.isnan(list(model.parameters())[i]).any()
            for i in range(len(list(model.parameters())))
        ]
    ):
        raise ValueError("NaN in model parameters")

    if i % 100 == 0:
        log_dict = {
            "loss": loss.item(),
            "sqrt(logZdiff**2)": np.sqrt((np.log(env.Z) - logZ.item())**2),
            "states_visited": (i + 1) * BS,
        }

        # Evaluate JSD every 500 iterations and add to the same log
        if i % 500 == 0:
            trajectories, _, _, _ = sample_trajectories(
                env, model, args.n_evaluation_trajectories
            )
            last_states = get_last_states(env, trajectories)
            kde, fig4 = fit_kde(last_states, plot=True)
            jsd = estimate_jsd(kde, true_kde)

            log_dict["JSD"] = jsd

            if not NO_PLOT:
                colors = plt.cm.rainbow(np.linspace(0, 1, 10))
                fig1 = plot_samples(last_states[:2000].detach().cpu().numpy())
                fig2 = plot_trajectories(trajectories.detach().cpu().numpy()[:20])

                log_dict["last_states"] = wandb.Image(fig1)
                log_dict["trajectories"] = wandb.Image(fig2)
                log_dict["kde"] = wandb.Image(fig4)

        if USE_WANDB:
            wandb.log(log_dict, step=i)

        tqdm.write(
            # Loss with 3 digits of precision, logZ with 2 digits of precision, true logZ with 2 digits of precision
            # Last computed JSD with 4 digits of precision
            f"States: {(i + 1) * BS}, Loss: {loss.item():.3f}, logZ: {logZ.item():.2f}, true logZ: {np.log(env.Z):.2f}, JSD: {jsd:.4f}"
        )


# if USE_WANDB:
#     wandb.finish()

# # Save model and arguments as JSON
# save_path = os.path.join("saved_models", run_name)
# if not os.path.exists(save_path):
#     os.makedirs(save_path)
#     torch.save(model.state_dict(), os.path.join(save_path, "model.pt"))
#     torch.save(bw_model.state_dict(), os.path.join(save_path, "bw_model.pt"))
#     torch.save(logZ, os.path.join(save_path, "logZ.pt"))
#     with open(os.path.join(save_path, "args.json"), "w") as f:
#         json.dump(vars(args), f)


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([256, 8, 2])
torch.Size([256, 7, 2])
torch.Size([256])
torch.Size([256, 7])





In [210]:

def trajectories_to_transitions(trajectories, actionss, all_bw_logprobs, logrewards, env):
    """
    Convert trajectories to transitions for replay buffer.

    Args:
        trajectories: tensor of shape (batch_size, trajectory_length, dim)
        actionss: tensor of shape (batch_size, trajectory_length, dim)
        all_bw_logprobs: tensor of shape (batch_size, trajectory_length)
        last_states: tensor of shape (batch_size, dim)
        logrewards: tensor of shape (batch_size,)
        env: environment object

    Returns:
        Tuple of (states, actions, rewards, next_states, dones) as tensors
    """

    # Extract states and next_states for intermediate transitions
    # Match the length to all_bw_logprobs

    states = trajectories[:, :-1, :]  
    next_states = trajectories[:, 1:, :] 
    is_not_sink = torch.all(states != env.sink_state, dim=-1)
    is_next_sink = torch.all(next_states == env.sink_state, dim=-1)
    last_state = is_not_sink & is_next_sink
    dones = torch.zeros_like(last_state, dtype=torch.float32)  # (batch_size, bw_length)
    dones[last_state] = 1.0
    dones = dones[:, 1:]
    rewards = all_bw_logprobs
    rewards = torch.where(last_state[:,1:], rewards + logrewards.unsqueeze(1), rewards)
    states = states[:, :-1, :]  
    next_states = next_states[:, :-1, :] 
    actions = actionss[:, :-1, :] 
    # Check which rewards are valid (not inf/nan)
    is_valid = torch.isfinite(rewards)  # (batch_size, bw_length)

    # Flatten batch and time dimensions for transitions
    states_flat = states[is_valid]
    actions_flat = actions[is_valid]
    rewards_flat = rewards[is_valid]
    next_states_flat = next_states[is_valid]
    dones_flat = dones[is_valid]

    return states_flat, actions_flat, rewards_flat, next_states_flat, dones_flat



In [211]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )

    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    all_states, all_actions, all_rewards, all_next_states, all_dones = trajectories_to_transitions(
        trajectories, actionss, all_bw_logprobs, logrewards, env
    )


    for i in range(15):
        print(f"==={i}===")
        print(all_states[i])
        print(all_actions[i])
        print(torch.exp(all_rewards[i]))
        print(all_next_states[i])
        print(all_dones[i])
    break

  0%|          | 0/20000 [00:00<?, ?it/s]

===0===
tensor([0., 0.], device='cuda:0')
tensor([0.0515, 0.0231], device='cuda:0')
tensor(1., device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.0515, 0.0231], device='cuda:0')
tensor(0., device='cuda:0')
===1===
tensor([0.0515, 0.0231], device='cuda:0')
tensor([0.1983, 0.1522], device='cuda:0')
tensor(3.4034, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.2498, 0.1753], device='cuda:0')
tensor(0., device='cuda:0')
===2===
tensor([0.2498, 0.1753], device='cuda:0')
tensor([0.1348, 0.2105], device='cuda:0')
tensor(3.8133, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.3846, 0.3859], device='cuda:0')
tensor(0., device='cuda:0')
===3===
tensor([0.3846, 0.3859], device='cuda:0')
tensor([0.1413, 0.2063], device='cuda:0')
tensor(3.9550, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.5259, 0.5921], device='cuda:0')
tensor(0., device='cuda:0')
===4===
tensor([0.5259, 0.5921], device='cuda:0')
tensor([0.2142, 0.1288], device='cuda:0')
tensor(3.8797, device='cuda:0', grad_fn=<ExpBack




In [216]:
from torch.distributions import Beta


class CirclePF_Uniform():
    def __init__(self):
        pass
    
    def to_dist(self, x):
        if torch.all(x[0] == 0.0):
            assert torch.all(
                x == 0.0
            )  # If one of the states is s0, all of them must be
            alpha = torch.ones(x.shape[0], device=x.device)
            beta = torch.ones(x.shape[0], device=x.device)
            
            return Beta(alpha, beta), Beta(alpha, beta)
        
        alpha = torch.ones(x.shape[0], device=x.device)
        beta = torch.ones(x.shape[0], device=x.device)
        return Beta(alpha, beta)

In [12]:
states_ex = torch.rand(10, 2, device=env.device)
states_ex_0 = torch.zeros(10, 2, device=env.device)

In [19]:
ations, logprobs, samples = sample_actions(env, model, states_ex)
ations_0, logprobs_0, samples_r_0, samples_theta_0 = sample_actions(env, model, states_ex_0)
samples_0 = torch.cat([samples_r_0.unsqueeze(-1), samples_theta_0.unsqueeze(-1)], dim=-1)



In [20]:
print(ations.shape)
print(samples.shape)
print(samples_r_0.shape)
print(samples_theta_0.shape)
print(samples_0.shape)

torch.Size([10, 2])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10, 2])


In [35]:
def sample_actions(env, model, states):
    # states is a tensor of shape (n, dim)
    batch_size = states.shape[0]
    out = model.to_dist(states)
    if isinstance(out, tuple):  # s0 input returns (dist_r, dist_theta)
        dist_r, dist_theta = out
        if model.uniform_ratio:
            samples_r = torch.rand(batch_size, device=env.device)
            samples_theta = torch.rand(batch_size, device=env.device)
        else:
            samples_r = dist_r.sample(torch.Size((batch_size,)))
            samples_theta = dist_theta.sample(torch.Size((batch_size,)))

        actions = (
            torch.stack(
                [
                    samples_r * torch.cos(torch.pi / 2.0 * samples_theta),
                    samples_r * torch.sin(torch.pi / 2.0 * samples_theta),
                ],
                dim=1,
            )
            * env.delta
        )
        logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
        )
        return actions, logprobs, samples_r, samples_theta
    else:
        dist = out

        # Automatic termination: check if min_i(1 - state_i) <= env.delta
        # This means at least one dimension is within delta of the boundary
        should_terminate = torch.any(states >= 1 - env.delta, dim=-1)

        A = torch.where(
            states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - states[:, 0]) / env.delta),
        )
        B = torch.where(
            states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - states[:, 1]) / env.delta),
        )
        assert torch.all(
            B[~should_terminate] >= A[~should_terminate]
        )
        if model.uniform_ratio:
            samples = torch.rand(batch_size, device=env.device)
        else:
            samples = dist.sample()

        actions = samples * (B - A) + A
        actions *= torch.pi / 2.0
        actions = (
            torch.stack([torch.cos(actions), torch.sin(actions)], dim=1) * env.delta
        )

        logprobs = (
            dist.log_prob(samples)
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )

        # Set terminal actions and zero logprobs for terminated states
        actions[should_terminate] = -float("inf")
        logprobs[should_terminate] = 0.0
        samples[should_terminate] = -float("inf")
    return actions, logprobs, samples


def sample_trajectories(env, model, n_trajectories):
    step = 0
    states = torch.zeros((n_trajectories, env.dim), device=env.device)
    actionss = []
    sampless = []
    trajectories = [states]
    trajectories_logprobs = torch.zeros((n_trajectories,), device=env.device)
    all_logprobs = []
    first = True
    while not torch.all(states == env.sink_state):
        step_logprobs = torch.full((n_trajectories,), -float("inf"), device=env.device)
        non_terminal_mask = torch.all(states != env.sink_state, dim=-1)
        actions = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        samples = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        if first:
            first = False
            non_terminal_actions, logprobs, non_terminal_samples_r, non_terminal_samples_theta = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples_r.unsqueeze(-1), non_terminal_samples_theta.unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        else:
            non_terminal_actions, logprobs, non_terminal_samples = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples.unsqueeze(-1), torch.zeros_like(non_terminal_samples).unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        
        actions[non_terminal_mask] = non_terminal_actions.reshape(-1, env.dim)
        actionss.append(actions)
        sampless.append(samples)
        states = env.step(states, actions)
        trajectories.append(states)
        trajectories_logprobs[non_terminal_mask] += logprobs
        step_logprobs[non_terminal_mask] = logprobs
        all_logprobs.append(step_logprobs)
        step += 1
    trajectories = torch.stack(trajectories, dim=1)
    actionss = torch.stack(actionss, dim=1)
    sampless = torch.stack(sampless, dim=1)
    all_logprobs = torch.stack(all_logprobs, dim=1)
    return trajectories, actionss, trajectories_logprobs, all_logprobs, sampless


In [36]:
t, a, l, al, s = sample_trajectories(env, model, 10)

In [56]:
al[:,0]

tensor([3.1494, 5.2854, 3.1336, 2.7636, 2.9630, 2.5759, 3.5160, 3.1856, 3.0356,
        5.1557], device='cuda:3', grad_fn=<SelectBackward0>)

In [60]:
t[:, 1, :]

tensor([[0.0084, 0.1035],
        [0.0111, 0.0038],
        [0.0680, 0.1033],
        [0.1169, 0.1327],
        [0.1309, 0.0603],
        [0.0670, 0.1886],
        [0.0757, 0.0160],
        [0.1070, 0.0231],
        [0.0984, 0.0957],
        [0.0084, 0.0113]], device='cuda:3')

In [54]:
s[:, 0, :]

tensor([[0.4153, 0.9487],
        [0.0468, 0.2099],
        [0.4949, 0.6293],
        [0.7075, 0.5403],
        [0.5764, 0.2750],
        [0.8008, 0.7827],
        [0.3097, 0.1327],
        [0.4378, 0.1351],
        [0.5491, 0.4911],
        [0.0563, 0.5927]], device='cuda:3')

In [194]:
def evaluate_forward_step_logprobs(env, model, current_states, samples):
    if torch.all(current_states[0] == 0.0):
        dist_r, dist_theta = model.to_dist(current_states)
        samples_r = samples[:, 0]
        samples_theta = samples[:, 1]

        step_logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
            )
        
    else:
        step_logprobs = torch.zeros((current_states.shape[0],), device=env.device)
        should_terminate = torch.any(samples == -float("inf"), dim=-1)
        if current_states.shape[0] == should_terminate.sum():
            all_terminate = torch.all(samples == -float("inf"), dim=-1)
            step_logprobs[all_terminate] = -float("inf")
            return step_logprobs
        non_terminal_states = current_states[~should_terminate]
        non_terminal_samples = samples[~should_terminate]
        dist = model.to_dist(non_terminal_states)

        A = torch.where(
            non_terminal_states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - non_terminal_states[:, 0]) / env.delta),
        )
        B = torch.where(
            non_terminal_states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - non_terminal_states[:, 1]) / env.delta),
        )

        non_terminal_step_logprobs = (
            dist.log_prob(non_terminal_samples[:,0])
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )
        step_logprobs[~should_terminate] = non_terminal_step_logprobs
        all_terminate = torch.all(samples == -float("inf"), dim=-1)
        step_logprobs[all_terminate] = -float("inf")

    return step_logprobs

def evaluate_forward_logprobs(env, model, trajectories, sampless):
    logprobs = torch.zeros((trajectories.shape[0],), device=env.device)
    all_logprobs = []
    for i in range(trajectories.shape[1] - 1):
        current_states = trajectories[:, i]
        samples = sampless[:, i]
        step_logprobs = evaluate_forward_step_logprobs(env, model, current_states, samples)
        is_finite = torch.isfinite(step_logprobs)
        logprobs[is_finite] += step_logprobs[is_finite]
        all_logprobs.append(step_logprobs)
    all_logprobs = torch.stack(all_logprobs, dim=1)
    return logprobs, all_logprobs


In [202]:
t[3,:,:]

tensor([[0.0000, 0.0000],
        [0.1169, 0.1327],
        [0.1721, 0.3766],
        [0.4055, 0.4662],
        [0.5486, 0.6712],
        [0.7860, 0.7494],
        [  -inf,   -inf],
        [  -inf,   -inf]], device='cuda:3')

In [195]:
evaluate_forward_logprobs(env, model, t, s)

(tensor([5.9741, 9.7854, 6.0274, 6.5951, 6.7972, 6.2694, 7.1664, 7.7247, 6.6555,
         8.8568], device='cuda:3', grad_fn=<IndexPutBackward0>),
 tensor([[3.1494, 0.9376, 0.8984, 0.9887, 0.0000,   -inf,   -inf],
         [5.2854, 0.9943, 0.7742, 0.9605, 0.9792, 0.7918, 0.0000],
         [3.1336, 0.9710, 0.9796, 0.9432, 0.0000,   -inf,   -inf],
         [2.7636, 0.9170, 0.9648, 0.9814, 0.9683, 0.0000,   -inf],
         [2.9630, 0.9706, 0.9903, 0.9809, 0.8925, 0.0000,   -inf],
         [2.5759, 0.9134, 0.9355, 0.9877, 0.8569, 0.0000,   -inf],
         [3.5160, 0.8519, 0.8581, 0.9440, 0.9964, 0.0000,   -inf],
         [3.1856, 0.7344, 0.9747, 0.9043, 0.9582, 0.9676, 0.0000],
         [3.0356, 0.8622, 0.9374, 0.8929, 0.9274, 0.0000,   -inf],
         [5.1557, 0.9358, 0.8403, 0.9957, 0.9292, 0.0000,   -inf]],
        device='cuda:3', grad_fn=<StackBackward0>))

In [181]:
t[:,4]

tensor([[0.3078, 0.7561],
        [0.4252, 0.5238],
        [0.7631, 0.3748],
        [0.5486, 0.6712],
        [0.5817, 0.6379],
        [0.5750, 0.6237],
        [0.6584, 0.3025],
        [0.4808, 0.5300],
        [0.6701, 0.3931],
        [0.7008, 0.2464]], device='cuda:3')

In [182]:
s[:,4]

tensor([[  -inf, 0.0000],
        [0.6205, 0.0000],
        [  -inf, 0.0000],
        [0.2024, 0.0000],
        [0.0626, 0.0000],
        [0.9027, 0.0000],
        [0.4285, 0.0000],
        [0.1791, 0.0000],
        [0.1079, 0.0000],
        [0.1143, 0.0000]], device='cuda:3')

In [183]:
al[:,4]

tensor([0.0000, 0.9792, 0.0000, 0.9683, 0.8925, 0.8569, 0.9964, 0.9582, 0.9274,
        0.9292], device='cuda:3', grad_fn=<SelectBackward0>)

In [184]:
s[:,6,:]

tensor([[-inf, -inf],
        [-inf, 0.],
        [-inf, -inf],
        [-inf, -inf],
        [-inf, -inf],
        [-inf, -inf],
        [-inf, -inf],
        [-inf, 0.],
        [-inf, -inf],
        [-inf, -inf]], device='cuda:3')

In [185]:
al[:,6]

tensor([-inf, 0., -inf, -inf, -inf, -inf, -inf, 0., -inf, -inf],
       device='cuda:3', grad_fn=<SelectBackward0>)

In [186]:
i = 6
print(evaluate_forward_step_logprobs(env, model, t[:,i, :], s[:,i,:]))
print(al[:,i])

tensor([-inf, 0., -inf, -inf, -inf, -inf, -inf, 0., -inf, -inf],
       device='cuda:3')
tensor([-inf, 0., -inf, -inf, -inf, -inf, -inf, 0., -inf, -inf],
       device='cuda:3', grad_fn=<SelectBackward0>)


In [208]:
t.shape

torch.Size([10, 8, 2])

In [209]:
s.shape

torch.Size([10, 7, 2])

In [210]:
import torch
import numpy as np
import random


class TrajectoryReplayMemory:
    """
    Replay buffer for trajectories with variable lengths.
    
    Stores:
    - trajectories: shape (capacity, max_traj_len, 2)
    - samples: shape (capacity, max_traj_len - 1, 2)
    
    Uses lazy initialization like SAC replay memory - buffers are allocated on first push.
    Automatically handles varying max_traj_len across batches by tracking and updating.
    """
    
    def __init__(self, capacity, seed, device='cpu'):
        """
        Args:
            capacity: Maximum number of trajectories to store
            seed: Random seed for reproducibility
            device: Device to store tensors on
        """
        random.seed(seed)
        np.random.seed(seed)
        self.capacity = capacity
        self.device = device
        self.position = 0
        self.size = 0
        
        # Will be initialized on first push
        self.trajectories = None
        self.samples = None
        self.max_traj_len = 0
        
    def push_batch(self, trajectories, samples):
        """
        Push a batch of trajectories to the replay buffer.
        
        Args:
            trajectories: tensor of shape (batch_size, traj_len, 2)
            samples: tensor of shape (batch_size, traj_len - 1, 2)
        """
        batch_size = trajectories.shape[0]
        traj_len = trajectories.shape[1]
        
        # Move to device if needed
        trajectories = trajectories.to(self.device)
        samples = samples.to(self.device)
        
        # Initialize or resize buffers if needed
        if self.trajectories is None:
            # First push - initialize buffers
            self.max_traj_len = traj_len
            self.trajectories = torch.full(
                (self.capacity, traj_len, 2), 
                -float('inf'), 
                dtype=trajectories.dtype, 
                device=self.device
            )
            self.samples = torch.full(
                (self.capacity, traj_len - 1, 2), 
                -float('inf'), 
                dtype=samples.dtype, 
                device=self.device
            )
        elif traj_len > self.max_traj_len:
            # Need to resize buffers to accommodate longer trajectories
            old_max_len = self.max_traj_len
            self.max_traj_len = traj_len
            
            # Create new larger buffers
            new_trajectories = torch.full(
                (self.capacity, traj_len, 2), 
                -float('inf'), 
                dtype=self.trajectories.dtype, 
                device=self.device
            )
            new_samples = torch.full(
                (self.capacity, traj_len - 1, 2), 
                -float('inf'), 
                dtype=self.samples.dtype, 
                device=self.device
            )
            
            # Copy old data to new buffers
            new_trajectories[:, :old_max_len, :] = self.trajectories
            new_samples[:, :old_max_len - 1, :] = self.samples
            
            self.trajectories = new_trajectories
            self.samples = new_samples
        
        # Calculate indices
        end_pos = self.position + batch_size
        
        if end_pos <= self.capacity:
            # No wrap around
            # Reset to -inf first (to handle varying lengths)
            self.trajectories[self.position:end_pos] = -float('inf')
            self.samples[self.position:end_pos] = -float('inf')
            
            # Write actual data
            self.trajectories[self.position:end_pos, :traj_len, :] = trajectories
            self.samples[self.position:end_pos, :traj_len - 1, :] = samples
        else:
            # Wrap around
            first_part = self.capacity - self.position
            
            # First part
            self.trajectories[self.position:] = -float('inf')
            self.samples[self.position:] = -float('inf')
            self.trajectories[self.position:, :traj_len, :] = trajectories[:first_part]
            self.samples[self.position:, :traj_len - 1, :] = samples[:first_part]
            
            # Second part
            second_part = batch_size - first_part
            self.trajectories[:second_part] = -float('inf')
            self.samples[:second_part] = -float('inf')
            self.trajectories[:second_part, :traj_len, :] = trajectories[first_part:]
            self.samples[:second_part, :traj_len - 1, :] = samples[first_part:]
        
        self.position = end_pos % self.capacity
        self.size = min(self.size + batch_size, self.capacity)
    
    def sample(self, batch_size):
        """
        Sample a random batch from the buffer.
        
        Returns:
            trajectories: tensor of shape (batch_size, max_traj_len, 2)
            samples: tensor of shape (batch_size, max_traj_len - 1, 2)
        """
        indices = np.random.choice(self.size, batch_size, replace=False)
        indices = torch.from_numpy(indices).to(self.device)
        
        return (
            self.trajectories[indices],
            self.samples[indices],
        )
    
    def __len__(self):
        return self.size



In [219]:
mem=TrajectoryReplayMemory(5, 1, device)

In [220]:
mem.push_batch(t,s)

In [227]:
def sample_actions(env, model, states):
    # states is a tensor of shape (n, dim)
    batch_size = states.shape[0]
    out = model.to_dist(states)
    if isinstance(out, tuple):  # s0 input returns (dist_r, dist_theta)
        dist_r, dist_theta = out
        if model.uniform_ratio:
            samples_r = torch.rand(batch_size, device=env.device)
            samples_theta = torch.rand(batch_size, device=env.device)
        else:
            samples_r = dist_r.sample(torch.Size((batch_size,)))
            samples_theta = dist_theta.sample(torch.Size((batch_size,)))

        actions = (
            torch.stack(
                [
                    samples_r * torch.cos(torch.pi / 2.0 * samples_theta),
                    samples_r * torch.sin(torch.pi / 2.0 * samples_theta),
                ],
                dim=1,
            )
            * env.delta
        )

        return actions, samples_r, samples_theta
    else:
        dist = out

        # Automatic termination: check if min_i(1 - state_i) <= env.delta
        # This means at least one dimension is within delta of the boundary
        should_terminate = torch.any(states >= 1 - env.delta, dim=-1)

        A = torch.where(
            states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - states[:, 0]) / env.delta),
        )
        B = torch.where(
            states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - states[:, 1]) / env.delta),
        )
        assert torch.all(
            B[~should_terminate] >= A[~should_terminate]
        )
        if model.uniform_ratio:
            samples = torch.rand(batch_size, device=env.device)
        else:
            samples = dist.sample()

        actions = samples * (B - A) + A
        actions *= torch.pi / 2.0
        actions = (
            torch.stack([torch.cos(actions), torch.sin(actions)], dim=1) * env.delta
        )

        # Set terminal actions and zero logprobs for terminated states
        actions[should_terminate] = -float("inf")
        samples[should_terminate] = -float("inf")
    return actions, samples


def sample_trajectories(env, model, n_trajectories):
    states = torch.zeros((n_trajectories, env.dim), device=env.device)
    actionss = []
    sampless = []
    trajectories = [states]
    first = True
    while not torch.all(states == env.sink_state):
        non_terminal_mask = torch.all(states != env.sink_state, dim=-1)
        actions = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        samples = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        if first:
            first = False
            non_terminal_actions, non_terminal_samples_r, non_terminal_samples_theta = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples_r.unsqueeze(-1), non_terminal_samples_theta.unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        else:
            non_terminal_actions, non_terminal_samples = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples.unsqueeze(-1), torch.zeros_like(non_terminal_samples).unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        
        actions[non_terminal_mask] = non_terminal_actions.reshape(-1, env.dim)
        actionss.append(actions)
        sampless.append(samples)
        states = env.step(states, actions)
        trajectories.append(states)
    trajectories = torch.stack(trajectories, dim=1)
    actionss = torch.stack(actionss, dim=1)
    sampless = torch.stack(sampless, dim=1)
    return trajectories, actionss, sampless


def evaluate_forward_step_logprobs(env, model, current_states, samples):
    if torch.all(current_states[0] == 0.0):
        dist_r, dist_theta = model.to_dist(current_states)
        samples_r = samples[:, 0]
        samples_theta = samples[:, 1]

        step_logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
            )
        
    else:
        step_logprobs = torch.zeros((current_states.shape[0],), device=env.device)
        should_terminate = torch.any(samples == -float("inf"), dim=-1)
        if current_states.shape[0] == should_terminate.sum():
            all_terminate = torch.all(samples == -float("inf"), dim=-1)
            step_logprobs[all_terminate] = -float("inf")
            return step_logprobs
        non_terminal_states = current_states[~should_terminate]
        non_terminal_samples = samples[~should_terminate]
        dist = model.to_dist(non_terminal_states)

        A = torch.where(
            non_terminal_states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - non_terminal_states[:, 0]) / env.delta),
        )
        B = torch.where(
            non_terminal_states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - non_terminal_states[:, 1]) / env.delta),
        )

        non_terminal_step_logprobs = (
            dist.log_prob(non_terminal_samples[:,0])
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )
        step_logprobs[~should_terminate] = non_terminal_step_logprobs
        all_terminate = torch.all(samples == -float("inf"), dim=-1)
        step_logprobs[all_terminate] = -float("inf")

    return step_logprobs

def evaluate_forward_logprobs(env, model, trajectories, sampless):
    logprobs = torch.zeros((trajectories.shape[0],), device=env.device)
    all_logprobs = []
    for i in range(trajectories.shape[1] - 1):
        current_states = trajectories[:, i]
        samples = sampless[:, i]
        step_logprobs = evaluate_forward_step_logprobs(env, model, current_states, samples)
        is_finite = torch.isfinite(step_logprobs)
        logprobs[is_finite] += step_logprobs[is_finite]
        all_logprobs.append(step_logprobs)
    all_logprobs = torch.stack(all_logprobs, dim=1)

    return logprobs, all_logprobs

(tensor([[[0.0000, 0.0000],
          [0.1070, 0.0231],
          [0.1155, 0.2729],
          [0.2335, 0.4933],
          [0.4808, 0.5300],
          [0.7210, 0.5994],
          [0.9604, 0.6714],
          [  -inf,   -inf]]], device='cuda:3'),
 tensor([[[0.4378, 0.1351],
          [0.9784, 0.0000],
          [0.6869, 0.0000],
          [0.0938, 0.0000],
          [0.1791, 0.0000],
          [0.1860, 0.0000],
          [  -inf, 0.0000]]], device='cuda:3'))

In [249]:
t, a, s= sample_trajectories(env, model, 256)

In [250]:
evaluate_forward_logprobs(env, model, t, s)

(tensor([ 6.3049,  8.0581, 10.6425,  5.4570,  6.0205,  6.9582, 11.1457,  8.3006,
          9.4937,  6.6070,  6.1785,  9.1929,  6.5551,  5.3420,  8.7167,  9.6875,
          5.5324,  8.1981, 10.0643,  5.1203,  6.6051,  6.2696,  6.8125,  7.5544,
          8.0624,  7.0409,  7.3496,  5.7086,  8.5025,  7.6846,  6.8814,  8.3864,
          6.3313,  7.9154,  5.5941,  5.5203,  9.9390,  6.5189,  8.5146,  7.2337,
          7.2535,  5.3556,  6.8767,  6.9977,  7.7717,  6.6833,  6.8756,  5.8708,
          9.7061,  5.9830,  6.7616,  7.4976,  6.8159,  5.6410,  6.7580,  8.9903,
          6.8381,  5.7219,  5.6903,  9.8904,  7.5038,  8.8829,  5.9186,  8.1108,
          4.9215,  6.3091,  8.6169,  7.2356,  6.9763,  6.4249,  7.0409,  6.4821,
          6.8562,  5.6809,  8.6254,  6.5991,  5.2656,  6.2546,  7.0892,  8.0111,
          8.0538,  5.9393,  8.9342,  9.1009,  7.3086,  6.0919,  5.0975,  5.3005,
          6.8856,  6.6196,  7.9170,  7.4525,  6.5360,  8.4903,  6.7787,  8.2739,
          5.9961,  8.2345,  

In [251]:
mem.push_batch(t,s)

In [275]:
tj, sa= mem.sample(2)
tj,sa

(tensor([[[0.0000, 0.0000],
          [0.2091, 0.0612],
          [0.4370, 0.1639],
          [0.6056, 0.3485],
          [0.7357, 0.5620],
          [0.9854, 0.5741],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1055, 0.1287],
          [0.3468, 0.1940],
          [0.5263, 0.3680],
          [0.7670, 0.4353],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]]], device='cuda:3'),
 tensor([[[0.8714, 0.1812],
          [0.2696, 0.0000],
          [0.5289, 0.0000],
          [0.6514, 0.0000],
          [0.0310, 0.0000],
          [  -inf, 0.0000],
          [  -inf,   -inf]],
 
         [[0.6655, 0.5629],
          [0.1682, 0.0000],
          [0.4901, 0.0000],
          [0.1736, 0.0000],
          [  -inf, 0.0000],
          [  -inf,   -inf],
          [  -inf,   -inf]]], device='cuda:3'))

In [280]:
tj[:,-3,:]

tensor([[0.7357, 0.5620],
        [0.7670, 0.4353]], device='cuda:3')

In [281]:
torch.all(tj[:,-3,:] == env.sink_state)

tensor(False, device='cuda:3')

In [282]:
while torch.all(tj[:,-2,:] == env.sink_state):
    tj = tj[:,:-1,:]
    sa = sa[:,:-1,:]

In [283]:
tj, sa

(tensor([[[0.0000, 0.0000],
          [0.2091, 0.0612],
          [0.4370, 0.1639],
          [0.6056, 0.3485],
          [0.7357, 0.5620],
          [0.9854, 0.5741],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1055, 0.1287],
          [0.3468, 0.1940],
          [0.5263, 0.3680],
          [0.7670, 0.4353],
          [  -inf,   -inf],
          [  -inf,   -inf]]], device='cuda:3'),
 tensor([[[0.8714, 0.1812],
          [0.2696, 0.0000],
          [0.5289, 0.0000],
          [0.6514, 0.0000],
          [0.0310, 0.0000],
          [  -inf, 0.0000]],
 
         [[0.6655, 0.5629],
          [0.1682, 0.0000],
          [0.4901, 0.0000],
          [0.1736, 0.0000],
          [  -inf, 0.0000],
          [  -inf,   -inf]]], device='cuda:3'))

In [None]:
logprobs, all_logprobs= evaluate_forward_logprobs(env, model, tj, sa)

(tensor([6.2270, 5.7186], device='cuda:3', grad_fn=<IndexPutBackward0>),
 tensor([[2.4423, 0.9734, 0.9939, 0.9756, 0.8418, 0.0000],
         [2.8312, 0.9373, 0.9946, 0.9555, 0.0000,   -inf]], device='cuda:3',
        grad_fn=<StackBackward0>))

In [289]:
last_states = get_last_states(env, tj)
last_states


tensor([[0.9854, 0.5741],
        [0.7670, 0.4353]], device='cuda:3')

In [290]:
logrewards = env.reward(last_states).log()
logrewards


tensor([-6.9078, -6.9078], device='cuda:3')

In [291]:
bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
    env, bw_model, tj
)


In [292]:
bw_logprobs

tensor([6.0479, 3.4188], device='cuda:3')

In [293]:
all_bw_logprobs

tensor([[0.0000, 1.7219, 0.9347, 0.9347, 2.4566],
        [0.0000, 1.5050, 0.9347, 0.9791,   -inf]], device='cuda:3')

In [294]:
loss = torch.mean((logZ + logprobs - bw_logprobs - logrewards) ** 2)


RuntimeError: The size of tensor a (256) must match the size of tensor b (2) at non-singleton dimension 0