In [1]:
import json
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

import argparse

from env import Box, get_last_states
from model import CirclePF, CirclePB, NeuralNet
from sampling import (
    sample_trajectories,
    evaluate_backward_logprobs,
)

from utils import (
    fit_kde,
    plot_reward,
    sample_from_reward,
    plot_samples,
    estimate_jsd,
    plot_trajectories,
)

import config

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default=config.DEVICE)
parser.add_argument("--dim", type=int, default=config.DIM)
parser.add_argument("--delta", type=float, default=config.DELTA)
parser.add_argument(
    "--n_components",
    type=int,
    default=config.N_COMPONENTS,
    help="Number of components in Mixture Of Betas",
)

parser.add_argument("--reward_debug", action="store_true", default=config.REWARD_DEBUG)
parser.add_argument(
    "--reward_type",
    type=str,
    choices=["baseline", "ring", "angular_ring", "multi_ring", "curve", "gaussian_mixture", "corner_squares", "two_corners", "edge_boxes", "edge_boxes_corner_squares"],
    default=config.REWARD_TYPE,
    help="Type of reward function to use. To modify reward-specific parameters (radius, sigma, etc.), edit rewards.py"
)
parser.add_argument("--R0", type=float, default=config.R0, help="Baseline reward value")
parser.add_argument("--R1", type=float, default=config.R1, help="Medium reward value (e.g., outer square)")
parser.add_argument("--R2", type=float, default=config.R2, help="High reward value (e.g., inner square)")
parser.add_argument(
    "--n_components_s0",
    type=int,
    default=config.N_COMPONENTS_S0,
    help="Number of components in Mixture Of Betas",
)
parser.add_argument(
    "--beta_min",
    type=float,
    default=config.BETA_MIN,
    help="Minimum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--beta_max",
    type=float,
    default=config.BETA_MAX,
    help="Maximum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--PB",
    type=str,
    choices=["learnable", "tied", "uniform"],
    default=config.PB,
)
parser.add_argument("--gamma_scheduler", type=float, default=config.GAMMA_SCHEDULER)
parser.add_argument("--scheduler_milestone", type=int, default=config.SCHEDULER_MILESTONE)
parser.add_argument("--seed", type=int, default=config.SEED)
parser.add_argument("--lr", type=float, default=config.LR)
parser.add_argument("--lr_Z", type=float, default=config.LR_Z)
parser.add_argument("--lr_F", type=float, default=config.LR_F)
parser.add_argument("--tie_F", action="store_true", default=config.TIE_F)
parser.add_argument("--BS", type=int, default=config.BS)
parser.add_argument("--n_iterations", type=int, default=config.N_ITERATIONS)
parser.add_argument("--hidden_dim", type=int, default=config.HIDDEN_DIM)
parser.add_argument("--n_hidden", type=int, default=config.N_HIDDEN)
parser.add_argument("--n_evaluation_trajectories", type=int, default=config.N_EVALUATION_TRAJECTORIES)
parser.add_argument("--no_plot", action="store_true", default=config.NO_PLOT)
parser.add_argument("--no_wandb", action="store_true", default=config.NO_WANDB)
parser.add_argument("--wandb_project", type=str, default=config.WANDB_PROJECT)

# Use parse_args([]) in Jupyter to avoid conflicts with Jupyter's kernel arguments
args = parser.parse_args([])


In [3]:
device = args.device
dim = args.dim
delta = args.delta
seed = args.seed
lr = args.lr
lr_Z = args.lr_Z
lr_F = args.lr_F
n_iterations = args.n_iterations
BS = args.BS
n_components = args.n_components
n_components_s0 = args.n_components_s0

torch.manual_seed(seed)
np.random.seed(seed)

print(f"Using device: {device}")

env = Box(
    dim=dim,
    delta=delta,
    device_str=device,
    reward_type=args.reward_type,
    reward_debug=args.reward_debug,
    R0=args.R0,
    R1=args.R1,
    R2=args.R2,
)

# Get the true KDE
samples = sample_from_reward(env, n_samples=10000)
true_kde, fig1 = fit_kde(samples, plot=True)


model = CirclePF(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    n_components=n_components,
    n_components_s0=n_components_s0,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)

bw_model = CirclePB(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    torso=model.torso if args.PB == "tied" else None,
    uniform=args.PB == "uniform",
    n_components=n_components,
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)


logZ = torch.zeros(1, requires_grad=True, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if args.PB != "uniform":
    optimizer.add_param_group(
        {
            "params": bw_model.output_layer.parameters()
            if args.PB == "tied"
            else bw_model.parameters(),
            "lr": lr,
        }
    )
optimizer.add_param_group({"params": [logZ], "lr": lr_Z})


scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[i * args.scheduler_milestone for i in range(1, 10)],
    gamma=args.gamma_scheduler,
)

jsd = float("inf")

Using device: cuda:3


In [5]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )


    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    print(trajectories.shape)
    print(actionss.shape)
    print(logprobs.shape)
    print(all_logprobs.shape)
    break
    # TB (Trajectory Balance) loss
    loss = torch.mean((logZ + logprobs - bw_logprobs - logrewards) ** 2)

    if torch.isinf(loss):
        raise ValueError("Infinite loss")
    loss.backward()
    # clip the gradients for bw_model
    for p in bw_model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    for p in model.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
    optimizer.step()
    scheduler.step()

    if any(
        [
            torch.isnan(list(model.parameters())[i]).any()
            for i in range(len(list(model.parameters())))
        ]
    ):
        raise ValueError("NaN in model parameters")

    if i % 100 == 0:
        log_dict = {
            "loss": loss.item(),
            "sqrt(logZdiff**2)": np.sqrt((np.log(env.Z) - logZ.item())**2),
            "states_visited": (i + 1) * BS,
        }

        # Evaluate JSD every 500 iterations and add to the same log
        if i % 500 == 0:
            trajectories, _, _, _ = sample_trajectories(
                env, model, args.n_evaluation_trajectories
            )
            last_states = get_last_states(env, trajectories)
            kde, fig4 = fit_kde(last_states, plot=True)
            jsd = estimate_jsd(kde, true_kde)

            log_dict["JSD"] = jsd

            if not NO_PLOT:
                colors = plt.cm.rainbow(np.linspace(0, 1, 10))
                fig1 = plot_samples(last_states[:2000].detach().cpu().numpy())
                fig2 = plot_trajectories(trajectories.detach().cpu().numpy()[:20])

                log_dict["last_states"] = wandb.Image(fig1)
                log_dict["trajectories"] = wandb.Image(fig2)
                log_dict["kde"] = wandb.Image(fig4)

        if USE_WANDB:
            wandb.log(log_dict, step=i)

        tqdm.write(
            # Loss with 3 digits of precision, logZ with 2 digits of precision, true logZ with 2 digits of precision
            # Last computed JSD with 4 digits of precision
            f"States: {(i + 1) * BS}, Loss: {loss.item():.3f}, logZ: {logZ.item():.2f}, true logZ: {np.log(env.Z):.2f}, JSD: {jsd:.4f}"
        )


# if USE_WANDB:
#     wandb.finish()

# # Save model and arguments as JSON
# save_path = os.path.join("saved_models", run_name)
# if not os.path.exists(save_path):
#     os.makedirs(save_path)
#     torch.save(model.state_dict(), os.path.join(save_path, "model.pt"))
#     torch.save(bw_model.state_dict(), os.path.join(save_path, "bw_model.pt"))
#     torch.save(logZ, os.path.join(save_path, "logZ.pt"))
#     with open(os.path.join(save_path, "args.json"), "w") as f:
#         json.dump(vars(args), f)


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([256, 8, 2])
torch.Size([256, 7, 2])
torch.Size([256])
torch.Size([256, 7])





In [210]:

def trajectories_to_transitions(trajectories, actionss, all_bw_logprobs, logrewards, env):
    """
    Convert trajectories to transitions for replay buffer.

    Args:
        trajectories: tensor of shape (batch_size, trajectory_length, dim)
        actionss: tensor of shape (batch_size, trajectory_length, dim)
        all_bw_logprobs: tensor of shape (batch_size, trajectory_length)
        last_states: tensor of shape (batch_size, dim)
        logrewards: tensor of shape (batch_size,)
        env: environment object

    Returns:
        Tuple of (states, actions, rewards, next_states, dones) as tensors
    """

    # Extract states and next_states for intermediate transitions
    # Match the length to all_bw_logprobs

    states = trajectories[:, :-1, :]  
    next_states = trajectories[:, 1:, :] 
    is_not_sink = torch.all(states != env.sink_state, dim=-1)
    is_next_sink = torch.all(next_states == env.sink_state, dim=-1)
    last_state = is_not_sink & is_next_sink
    dones = torch.zeros_like(last_state, dtype=torch.float32)  # (batch_size, bw_length)
    dones[last_state] = 1.0
    dones = dones[:, 1:]
    rewards = all_bw_logprobs
    rewards = torch.where(last_state[:,1:], rewards + logrewards.unsqueeze(1), rewards)
    states = states[:, :-1, :]  
    next_states = next_states[:, :-1, :] 
    actions = actionss[:, :-1, :] 
    # Check which rewards are valid (not inf/nan)
    is_valid = torch.isfinite(rewards)  # (batch_size, bw_length)

    # Flatten batch and time dimensions for transitions
    states_flat = states[is_valid]
    actions_flat = actions[is_valid]
    rewards_flat = rewards[is_valid]
    next_states_flat = next_states[is_valid]
    dones_flat = dones[is_valid]

    return states_flat, actions_flat, rewards_flat, next_states_flat, dones_flat



In [211]:

for i in trange(n_iterations):
    optimizer.zero_grad()
    trajectories, actionss, logprobs, all_logprobs = sample_trajectories(
        env,
        model,
        BS,
    )

    last_states = get_last_states(env, trajectories)
    logrewards = env.reward(last_states).log()
    bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
        env, bw_model, trajectories
    )
    all_states, all_actions, all_rewards, all_next_states, all_dones = trajectories_to_transitions(
        trajectories, actionss, all_bw_logprobs, logrewards, env
    )


    for i in range(15):
        print(f"==={i}===")
        print(all_states[i])
        print(all_actions[i])
        print(torch.exp(all_rewards[i]))
        print(all_next_states[i])
        print(all_dones[i])
    break

  0%|          | 0/20000 [00:00<?, ?it/s]

===0===
tensor([0., 0.], device='cuda:0')
tensor([0.0515, 0.0231], device='cuda:0')
tensor(1., device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.0515, 0.0231], device='cuda:0')
tensor(0., device='cuda:0')
===1===
tensor([0.0515, 0.0231], device='cuda:0')
tensor([0.1983, 0.1522], device='cuda:0')
tensor(3.4034, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.2498, 0.1753], device='cuda:0')
tensor(0., device='cuda:0')
===2===
tensor([0.2498, 0.1753], device='cuda:0')
tensor([0.1348, 0.2105], device='cuda:0')
tensor(3.8133, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.3846, 0.3859], device='cuda:0')
tensor(0., device='cuda:0')
===3===
tensor([0.3846, 0.3859], device='cuda:0')
tensor([0.1413, 0.2063], device='cuda:0')
tensor(3.9550, device='cuda:0', grad_fn=<ExpBackward0>)
tensor([0.5259, 0.5921], device='cuda:0')
tensor(0., device='cuda:0')
===4===
tensor([0.5259, 0.5921], device='cuda:0')
tensor([0.2142, 0.1288], device='cuda:0')
tensor(3.8797, device='cuda:0', grad_fn=<ExpBack




In [216]:
from torch.distributions import Beta


class CirclePF_Uniform():
    def __init__(self):
        pass
    
    def to_dist(self, x):
        if torch.all(x[0] == 0.0):
            assert torch.all(
                x == 0.0
            )  # If one of the states is s0, all of them must be
            alpha = torch.ones(x.shape[0], device=x.device)
            beta = torch.ones(x.shape[0], device=x.device)
            
            return Beta(alpha, beta), Beta(alpha, beta)
        
        alpha = torch.ones(x.shape[0], device=x.device)
        beta = torch.ones(x.shape[0], device=x.device)
        return Beta(alpha, beta)

In [12]:
states_ex = torch.rand(10, 2, device=env.device)
states_ex_0 = torch.zeros(10, 2, device=env.device)

In [19]:
ations, logprobs, samples = sample_actions(env, model, states_ex)
ations_0, logprobs_0, samples_r_0, samples_theta_0 = sample_actions(env, model, states_ex_0)
samples_0 = torch.cat([samples_r_0.unsqueeze(-1), samples_theta_0.unsqueeze(-1)], dim=-1)



In [20]:
print(ations.shape)
print(samples.shape)
print(samples_r_0.shape)
print(samples_theta_0.shape)
print(samples_0.shape)

torch.Size([10, 2])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10, 2])


In [440]:
def sample_actions(env, model, states):
    # states is a tensor of shape (n, dim)
    batch_size = states.shape[0]
    out = model.to_dist(states)
    if isinstance(out, tuple):  # s0 input returns (dist_r, dist_theta)
        dist_r, dist_theta = out
        if model.uniform_ratio:
            samples_r = torch.rand(batch_size, device=env.device)
            samples_theta = torch.rand(batch_size, device=env.device)
        else:
            samples_r = dist_r.sample(torch.Size((batch_size,)))
            samples_theta = dist_theta.sample(torch.Size((batch_size,)))

        actions = (
            torch.stack(
                [
                    samples_r * torch.cos(torch.pi / 2.0 * samples_theta),
                    samples_r * torch.sin(torch.pi / 2.0 * samples_theta),
                ],
                dim=1,
            )
            * env.delta
        )
        logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
        )
        return actions, logprobs, samples_r, samples_theta
    else:
        dist = out

        # Automatic termination: check if min_i(1 - state_i) <= env.delta
        # This means at least one dimension is within delta of the boundary
        should_terminate = torch.any(states >= 1 - env.delta, dim=-1)

        A = torch.where(
            states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - states[:, 0]) / env.delta),
        )
        B = torch.where(
            states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - states[:, 1]) / env.delta),
        )
        assert torch.all(
            B[~should_terminate] >= A[~should_terminate]
        )
        if model.uniform_ratio:
            samples = torch.rand(batch_size, device=env.device)
        else:
            samples = dist.sample()

        actions = samples * (B - A) + A
        actions *= torch.pi / 2.0
        actions = (
            torch.stack([torch.cos(actions), torch.sin(actions)], dim=1) * env.delta
        )

        logprobs = (
            dist.log_prob(samples)
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )

        # Set terminal actions and zero logprobs for terminated states
        actions[should_terminate] = -float("inf")
        logprobs[should_terminate] = 0.0
        samples[should_terminate] = -float("inf")
    return actions, logprobs, samples


def sample_trajectories(env, model, n_trajectories):
    step = 0
    states = torch.zeros((n_trajectories, env.dim), device=env.device)
    actionss = []
    sampless = []
    trajectories = [states]
    trajectories_logprobs = torch.zeros((n_trajectories,), device=env.device)
    all_logprobs = []
    first = True
    while not torch.all(states == env.sink_state):
        step_logprobs = torch.full((n_trajectories,), -float("inf"), device=env.device)
        non_terminal_mask = torch.all(states != env.sink_state, dim=-1)
        actions = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        samples = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        if first:
            first = False
            non_terminal_actions, logprobs, non_terminal_samples_r, non_terminal_samples_theta = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples_r.unsqueeze(-1), non_terminal_samples_theta.unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        else:
            non_terminal_actions, logprobs, non_terminal_samples = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples.unsqueeze(-1), torch.zeros_like(non_terminal_samples).unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        
        actions[non_terminal_mask] = non_terminal_actions.reshape(-1, env.dim)
        actionss.append(actions)
        sampless.append(samples)
        states = env.step(states, actions)
        trajectories.append(states)
        trajectories_logprobs[non_terminal_mask] += logprobs
        step_logprobs[non_terminal_mask] = logprobs
        all_logprobs.append(step_logprobs)
        step += 1
    trajectories = torch.stack(trajectories, dim=1)
    actionss = torch.stack(actionss, dim=1)
    sampless = torch.stack(sampless, dim=1)
    all_logprobs = torch.stack(all_logprobs, dim=1)
    return trajectories, actionss, trajectories_logprobs, all_logprobs, sampless


In [441]:
t, a, l, al, s = sample_trajectories(env, model, 10)

In [444]:
def evaluate_forward_step_logprobs(env, model, current_states, samples):
    if torch.all(current_states[0] == 0.0):
        dist_r, dist_theta = model.to_dist(current_states)
        samples_r = samples[:, 0]
        samples_theta = samples[:, 1]

        step_logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
            )
        
    else:
        step_logprobs = torch.zeros((current_states.shape[0],), device=env.device)
        should_terminate = torch.any(samples == -float("inf"), dim=-1)
        if current_states.shape[0] == should_terminate.sum():
            all_terminate = torch.all(samples == -float("inf"), dim=-1)
            step_logprobs[all_terminate] = -float("inf")
            return step_logprobs
        non_terminal_states = current_states[~should_terminate]
        non_terminal_samples = samples[~should_terminate]
        dist = model.to_dist(non_terminal_states)

        A = torch.where(
            non_terminal_states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - non_terminal_states[:, 0]) / env.delta),
        )
        B = torch.where(
            non_terminal_states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - non_terminal_states[:, 1]) / env.delta),
        )

        non_terminal_step_logprobs = (
            dist.log_prob(non_terminal_samples[:,0])
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )
        step_logprobs[~should_terminate] = non_terminal_step_logprobs
        all_terminate = torch.all(samples == -float("inf"), dim=-1)
        step_logprobs[all_terminate] = -float("inf")

    return step_logprobs

def evaluate_forward_logprobs(env, model, trajectories, sampless):
    logprobs = torch.zeros((trajectories.shape[0],), device=env.device)
    all_logprobs = []
    for i in range(trajectories.shape[1] - 1):
        current_states = trajectories[:, i]
        samples = sampless[:, i]
        step_logprobs = evaluate_forward_step_logprobs(env, model, current_states, samples)
        is_finite = torch.isfinite(step_logprobs)
        logprobs[is_finite] += step_logprobs[is_finite]
        all_logprobs.append(step_logprobs)
    all_logprobs = torch.stack(all_logprobs, dim=1)
    return logprobs, all_logprobs


In [445]:
evaluate_forward_logprobs(env, model, t, s)

(tensor([7.4568, 9.5071, 8.3924, 9.0600, 6.0962, 7.6606, 5.3401, 6.5564, 6.2242,
         5.8879], device='cuda:3', grad_fn=<IndexPutBackward0>),
 tensor([[3.0008, 0.9719, 0.6390, 0.8684, 0.9920, 0.9846, 0.0000],
         [4.6895, 0.9936, 0.9179, 0.9814, 0.9347, 0.9901, 0.0000],
         [4.5363, 0.9946, 0.9075, 0.9828, 0.9712, 0.0000,   -inf],
         [5.0931, 0.9923, 0.9913, 0.9907, 0.9926, 0.0000,   -inf],
         [2.2358, 0.9778, 0.9930, 0.9683, 0.9213, 0.0000,   -inf],
         [3.7149, 0.9903, 0.9890, 0.9782, 0.9881, 0.0000,   -inf],
         [2.4656, 0.9727, 0.9071, 0.9947, 0.0000,   -inf,   -inf],
         [2.6280, 0.9865, 0.9925, 0.9558, 0.9936, 0.0000,   -inf],
         [2.3661, 0.9700, 0.9829, 0.9284, 0.9769, 0.0000,   -inf],
         [2.4897, 0.9052, 0.9395, 0.7858, 0.7677, 0.0000,   -inf]],
        device='cuda:3', grad_fn=<StackBackward0>))

In [455]:
i = 6
print(evaluate_forward_step_logprobs(env, model, t[:,i, :], s[:,i,:]))
print(al[:,i])

tensor([0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:3')
tensor([0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:3', grad_fn=<SelectBackward0>)


In [456]:
t.shape

torch.Size([10, 8, 2])

In [457]:
s.shape

torch.Size([10, 7, 2])

In [329]:
import torch
import numpy as np
import random


class TrajectoryReplayMemory:
    """
    Replay buffer for trajectories with variable lengths.
    
    Stores:
    - trajectories: shape (capacity, max_traj_len, 2)
    - samples: shape (capacity, max_traj_len - 1, 2)
    
    Uses lazy initialization like SAC replay memory - buffers are allocated on first push.
    Automatically handles varying max_traj_len across batches by tracking and updating.
    """
    
    def __init__(self, capacity, seed, device='cpu'):
        """
        Args:
            capacity: Maximum number of trajectories to store
            seed: Random seed for reproducibility
            device: Device to store tensors on
        """
        random.seed(seed)
        np.random.seed(seed)
        self.capacity = capacity
        self.device = device
        self.position = 0
        self.size = 0
        
        # Will be initialized on first push
        self.trajectories = None
        self.samples = None
        self.max_traj_len = 0
        
    def push_batch(self, trajectories, samples):
        """
        Push a batch of trajectories to the replay buffer.
        
        Args:
            trajectories: tensor of shape (batch_size, traj_len, 2)
            samples: tensor of shape (batch_size, traj_len - 1, 2)
        """
        batch_size = trajectories.shape[0]
        traj_len = trajectories.shape[1]
        
        # Move to device if needed
        trajectories = trajectories.to(self.device)
        samples = samples.to(self.device)
        
        # Initialize or resize buffers if needed
        if self.trajectories is None:
            # First push - initialize buffers
            self.max_traj_len = traj_len
            self.trajectories = torch.full(
                (self.capacity, traj_len, 2), 
                -float('inf'), 
                dtype=trajectories.dtype, 
                device=self.device
            )
            self.samples = torch.full(
                (self.capacity, traj_len - 1, 2), 
                -float('inf'), 
                dtype=samples.dtype, 
                device=self.device
            )
        elif traj_len > self.max_traj_len:
            # Need to resize buffers to accommodate longer trajectories
            old_max_len = self.max_traj_len
            self.max_traj_len = traj_len
            
            # Create new larger buffers
            new_trajectories = torch.full(
                (self.capacity, traj_len, 2), 
                -float('inf'), 
                dtype=self.trajectories.dtype, 
                device=self.device
            )
            new_samples = torch.full(
                (self.capacity, traj_len - 1, 2), 
                -float('inf'), 
                dtype=self.samples.dtype, 
                device=self.device
            )
            
            # Copy old data to new buffers
            new_trajectories[:, :old_max_len, :] = self.trajectories
            new_samples[:, :old_max_len - 1, :] = self.samples
            
            self.trajectories = new_trajectories
            self.samples = new_samples
        
        # Calculate indices
        end_pos = self.position + batch_size
        
        if end_pos <= self.capacity:
            # No wrap around
            # Reset to -inf first (to handle varying lengths)
            self.trajectories[self.position:end_pos] = -float('inf')
            self.samples[self.position:end_pos] = -float('inf')
            
            # Write actual data
            self.trajectories[self.position:end_pos, :traj_len, :] = trajectories
            self.samples[self.position:end_pos, :traj_len - 1, :] = samples
        else:
            # Wrap around
            first_part = self.capacity - self.position
            
            # First part
            self.trajectories[self.position:] = -float('inf')
            self.samples[self.position:] = -float('inf')
            self.trajectories[self.position:, :traj_len, :] = trajectories[:first_part]
            self.samples[self.position:, :traj_len - 1, :] = samples[:first_part]
            
            # Second part
            second_part = batch_size - first_part
            self.trajectories[:second_part] = -float('inf')
            self.samples[:second_part] = -float('inf')
            self.trajectories[:second_part, :traj_len, :] = trajectories[first_part:]
            self.samples[:second_part, :traj_len - 1, :] = samples[first_part:]
        
        self.position = end_pos % self.capacity
        self.size = min(self.size + batch_size, self.capacity)
    
    def sample(self, batch_size):
        """
        Sample a random batch from the buffer.
        
        Returns:
            trajectories: tensor of shape (batch_size, max_traj_len, 2)
            samples: tensor of shape (batch_size, max_traj_len - 1, 2)
        """
        indices = np.random.choice(self.size, batch_size, replace=False)
        indices = torch.from_numpy(indices).to(self.device)
        
        return (
            self.trajectories[indices],
            self.samples[indices],
        )
    
    def __len__(self):
        return self.size



In [343]:
t.shape
s.shape

torch.Size([256, 7, 2])

In [365]:
mem=TrajectoryReplayMemory(300, 1, device)

In [366]:
mem.push_batch(t,s)

In [367]:
mem.sample(10)

(tensor([[[0.0000, 0.0000],
          [0.1631, 0.0285],
          [0.3407, 0.2044],
          [0.5881, 0.2406],
          [0.8025, 0.3691],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0037, 0.0870],
          [0.0910, 0.3213],
          [0.1967, 0.5478],
          [0.3667, 0.7311],
          [0.5479, 0.9034],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1948, 0.0543],
          [0.2764, 0.2906],
          [0.4698, 0.4490],
          [0.6908, 0.5660],
          [0.8798, 0.7295],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0251, 0.1095],
          [0.2633, 0.1853],
          [0.4738, 0.3202],
          [0.7218, 0.3514],
          [0.9509, 0.4517],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [

In [351]:
def sample_actions(env, model, states):
    # states is a tensor of shape (n, dim)
    batch_size = states.shape[0]
    out = model.to_dist(states)
    if isinstance(out, tuple):  # s0 input returns (dist_r, dist_theta)
        dist_r, dist_theta = out
        if model.uniform_ratio:
            samples_r = torch.rand(batch_size, device=env.device)
            samples_theta = torch.rand(batch_size, device=env.device)
        else:
            samples_r = dist_r.sample(torch.Size((batch_size,)))
            samples_theta = dist_theta.sample(torch.Size((batch_size,)))

        actions = (
            torch.stack(
                [
                    samples_r * torch.cos(torch.pi / 2.0 * samples_theta),
                    samples_r * torch.sin(torch.pi / 2.0 * samples_theta),
                ],
                dim=1,
            )
            * env.delta
        )

        return actions, samples_r, samples_theta
    else:
        dist = out

        # Automatic termination: check if min_i(1 - state_i) <= env.delta
        # This means at least one dimension is within delta of the boundary
        should_terminate = torch.any(states >= 1 - env.delta, dim=-1)

        A = torch.where(
            states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - states[:, 0]) / env.delta),
        )
        B = torch.where(
            states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - states[:, 1]) / env.delta),
        )
        assert torch.all(
            B[~should_terminate] >= A[~should_terminate]
        )
        if model.uniform_ratio:
            samples = torch.rand(batch_size, device=env.device)
        else:
            samples = dist.sample()

        actions = samples * (B - A) + A
        actions *= torch.pi / 2.0
        actions = (
            torch.stack([torch.cos(actions), torch.sin(actions)], dim=1) * env.delta
        )

        # Set terminal actions and zero logprobs for terminated states
        actions[should_terminate] = -float("inf")
        samples[should_terminate] = -float("inf")
    return actions, samples


def sample_trajectories(env, model, n_trajectories):
    states = torch.zeros((n_trajectories, env.dim), device=env.device)
    actionss = []
    sampless = []
    trajectories = [states]
    first = True
    while not torch.all(states == env.sink_state):
        non_terminal_mask = torch.all(states != env.sink_state, dim=-1)
        actions = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        samples = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        if first:
            first = False
            non_terminal_actions, non_terminal_samples_r, non_terminal_samples_theta = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples_r.unsqueeze(-1), non_terminal_samples_theta.unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        else:
            non_terminal_actions, non_terminal_samples = sample_actions(
                env,
                model,
                states[non_terminal_mask],
            )
            non_terminal_samples = torch.cat([non_terminal_samples.unsqueeze(-1), torch.zeros_like(non_terminal_samples).unsqueeze(-1)], dim=-1)
            samples[non_terminal_mask] = non_terminal_samples.reshape(-1, env.dim)
        
        actions[non_terminal_mask] = non_terminal_actions.reshape(-1, env.dim)
        actionss.append(actions)
        sampless.append(samples)
        states = env.step(states, actions)
        trajectories.append(states)
    trajectories = torch.stack(trajectories, dim=1)
    actionss = torch.stack(actionss, dim=1)
    sampless = torch.stack(sampless, dim=1)
    return trajectories, actionss, sampless


def evaluate_forward_step_logprobs(env, model, current_states, samples):
    if torch.all(current_states[0] == 0.0):
        dist_r, dist_theta = model.to_dist(current_states)
        samples_r = samples[:, 0]
        samples_theta = samples[:, 1]

        step_logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
            )
        
    else:
        step_logprobs = torch.zeros((current_states.shape[0],), device=env.device)
        should_terminate = torch.any(samples == -float("inf"), dim=-1)
        if current_states.shape[0] == should_terminate.sum():
            all_terminate = torch.all(samples == -float("inf"), dim=-1)
            step_logprobs[all_terminate] = -float("inf")
            return step_logprobs
        non_terminal_states = current_states[~should_terminate]
        non_terminal_samples = samples[~should_terminate]
        dist = model.to_dist(non_terminal_states)

        A = torch.where(
            non_terminal_states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - non_terminal_states[:, 0]) / env.delta),
        )
        B = torch.where(
            non_terminal_states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - non_terminal_states[:, 1]) / env.delta),
        )

        non_terminal_step_logprobs = (
            dist.log_prob(non_terminal_samples[:,0])
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )
        step_logprobs[~should_terminate] = non_terminal_step_logprobs
        all_terminate = torch.all(samples == -float("inf"), dim=-1)
        step_logprobs[all_terminate] = -float("inf")

    return step_logprobs

def evaluate_forward_logprobs(env, model, trajectories, sampless):
    logprobs = torch.zeros((trajectories.shape[0],), device=env.device)
    all_logprobs = []
    for i in range(trajectories.shape[1] - 1):
        current_states = trajectories[:, i]
        samples = sampless[:, i]
        step_logprobs = evaluate_forward_step_logprobs(env, model, current_states, samples)
        is_finite = torch.isfinite(step_logprobs)
        logprobs[is_finite] += step_logprobs[is_finite]
        all_logprobs.append(step_logprobs)
    all_logprobs = torch.stack(all_logprobs, dim=1)

    return logprobs, all_logprobs

In [352]:
t, a, s= sample_trajectories(env, model, 256)

In [353]:
evaluate_forward_logprobs(env, model, t, s)

(tensor([ 7.2703,  6.7745,  7.0488,  9.2890,  6.2542,  7.5546,  6.0784,  9.3962,
          6.7712,  6.5864,  7.2523,  7.8289,  6.8927,  5.8318,  7.4712,  9.8495,
          7.1471,  7.5509,  6.0035,  6.3502,  7.1690,  6.4905,  6.0889,  6.4749,
          9.4738,  7.7401,  9.5404,  5.5503,  5.6407,  7.1850,  7.6787,  5.4053,
          6.4027,  6.9018,  7.1598,  6.8983,  7.7725,  6.4910,  6.8409,  7.4630,
          8.4053,  8.0030,  6.4169,  7.8157,  8.4612,  6.7934,  6.9402,  5.4510,
          8.1682,  6.4215,  7.1978,  5.6409,  5.9790,  6.8391,  6.7521,  9.4780,
          6.5982,  5.3990,  9.0980,  6.6128,  6.3377,  8.3008,  6.6721,  7.1374,
          6.3204,  9.2481,  5.5821,  6.4647,  6.4622,  6.2567,  7.3263,  5.1845,
          5.3985,  7.1261,  6.1677,  7.7498,  6.7811,  6.8985,  6.3412,  5.2898,
          6.3936,  7.7283,  6.5746,  5.1075,  6.8604,  6.9429,  7.0991,  6.2015,
          6.3545,  8.3286,  6.9783,  6.3119,  6.8757,  7.1931,  6.5056,  5.4936,
          6.0764,  9.4744,  

In [371]:
mem.push_batch(t,s)

In [370]:
tj, sa= mem.sample(10)
tj,sa

(tensor([[[0.0000, 0.0000],
          [0.1489, 0.1100],
          [0.3772, 0.2118],
          [0.3888, 0.4615],
          [0.5637, 0.6401],
          [0.6486, 0.8753],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0909, 0.0076],
          [0.1256, 0.2552],
          [0.3711, 0.3022],
          [0.6205, 0.3196],
          [0.7845, 0.5082],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1175, 0.1339],
          [0.2130, 0.3650],
          [0.3746, 0.5557],
          [0.5897, 0.6830],
          [0.8266, 0.7630],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0341, 0.1356],
          [0.2495, 0.2624],
          [0.4929, 0.3197],
          [0.5544, 0.5620],
          [0.7269, 0.7430],
          [0.9661, 0.8154],
          [  -inf,   -inf],
          [

In [372]:
while torch.all(tj[:,-2,:] == env.sink_state):
    tj = tj[:,:-1,:]
    sa = sa[:,:-1,:]
tj, sa

(tensor([[[0.0000, 0.0000],
          [0.1489, 0.1100],
          [0.3772, 0.2118],
          [0.3888, 0.4615],
          [0.5637, 0.6401],
          [0.6486, 0.8753],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0909, 0.0076],
          [0.1256, 0.2552],
          [0.3711, 0.3022],
          [0.6205, 0.3196],
          [0.7845, 0.5082],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1175, 0.1339],
          [0.2130, 0.3650],
          [0.3746, 0.5557],
          [0.5897, 0.6830],
          [0.8266, 0.7630],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0341, 0.1356],
          [0.2495, 0.2624],
          [0.4929, 0.3197],
          [0.5544, 0.5620],
          [0.7269, 0.7430],
          [0.9661, 0.8154],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1003, 0.1028],
          [0.2599, 0.2952],
        

In [357]:
tj, sa

(tensor([[[0.0000, 0.0000],
          [0.0161, 0.0266],
          [0.1754, 0.2192],
          [0.3531, 0.3950],
          [0.4934, 0.6020],
          [0.5379, 0.8480],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0440, 0.1575],
          [0.2931, 0.1783],
          [0.5409, 0.2112],
          [0.7909, 0.2159],
          [  -inf,   -inf],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0442, 0.1397],
          [0.1832, 0.3475],
          [0.4177, 0.4340],
          [0.5797, 0.6245],
          [0.7494, 0.8080],
          [  -inf,   -inf],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.0052, 0.1514],
          [0.2551, 0.1554],
          [0.5007, 0.2023],
          [0.5159, 0.4518],
          [0.6165, 0.6807],
          [0.7437, 0.8959],
          [  -inf,   -inf]],
 
         [[0.0000, 0.0000],
          [0.1550, 0.1933],
          [0.3593, 0.3374],
        

In [358]:
logprobs, all_logprobs = evaluate_forward_logprobs(env, model, tj, sa)
logprobs, all_logprobs

(tensor([8.2738, 5.2398, 6.8477, 6.9588, 5.8801, 5.6194, 5.8480, 6.6392, 8.2814,
         6.6549], device='cuda:3', grad_fn=<IndexPutBackward0>),
 tensor([[4.4295, 0.9932, 0.9931, 0.9816, 0.8764, 0.0000,   -inf],
         [2.7961, 0.8342, 0.8867, 0.7228, 0.0000,   -inf,   -inf],
         [2.9219, 0.9865, 0.9625, 0.9884, 0.9883, 0.0000,   -inf],
         [2.6865, 0.6700, 0.9152, 0.7643, 0.9542, 0.9686, 0.0000],
         [2.1087, 0.9897, 0.8781, 0.9666, 0.9370, 0.0000,   -inf],
         [2.8767, 0.9910, 0.7727, 0.9791, 0.0000,   -inf,   -inf],
         [2.0117, 0.9875, 0.9597, 0.9007, 0.9884, 0.0000,   -inf],
         [3.0426, 0.7359, 0.9778, 0.9930, 0.8898, 0.0000,   -inf],
         [4.4501, 0.9354, 0.9300, 0.9718, 0.9941, 0.0000,   -inf],
         [2.7169, 0.9926, 0.9775, 0.9892, 0.9787, 0.0000,   -inf]],
        device='cuda:3', grad_fn=<StackBackward0>))

In [359]:
last_states = get_last_states(env, tj)
last_states


tensor([[0.5379, 0.8480],
        [0.7909, 0.2159],
        [0.7494, 0.8080],
        [0.7437, 0.8959],
        [0.9704, 0.6295],
        [0.3154, 0.7861],
        [0.7670, 0.8963],
        [0.7576, 0.6421],
        [0.7663, 0.5644],
        [0.6567, 0.9331]], device='cuda:3')

In [360]:
logrewards = env.reward(last_states).log()
logrewards


tensor([-6.9078, -0.6931, -6.9078, -6.9078, -6.9078, -6.9078, -0.6931, -6.9078,
        -6.9078, -6.9078], device='cuda:3')

In [361]:
bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
    env, bw_model, tj
)


In [362]:
bw_logprobs

tensor([5.7731, 4.3418, 4.5475, 6.5405, 4.9009, 4.3858, 4.4536, 6.0224, 5.3782,
        4.9392], device='cuda:3')

In [363]:
all_bw_logprobs

tensor([[0.0000, 2.6728, 0.9347, 0.9347, 1.2309,   -inf],
        [0.0000, 1.6166, 1.3801, 1.3451,   -inf,   -inf],
        [0.0000, 1.5819, 0.9347, 0.9347, 1.0962,   -inf],
        [0.0000, 1.7856, 1.4452, 0.9347, 0.9347, 1.4402],
        [0.0000, 0.9347, 0.9347, 0.9347, 2.0968,   -inf],
        [0.0000, 1.7235, 1.6308, 1.0315,   -inf,   -inf],
        [0.0000, 1.0672, 0.9347, 0.9347, 1.5170,   -inf],
        [0.0000, 3.1987, 0.9347, 0.9347, 0.9543,   -inf],
        [0.0000, 2.5316, 0.9347, 0.9347, 0.9772,   -inf],
        [0.0000, 1.3946, 0.9347, 0.9347, 1.6752,   -inf]], device='cuda:3')

tensor(58.9589, device='cuda:3', grad_fn=<MeanBackward0>)

In [391]:
mem = TrajectoryReplayMemory(3, 1, device)

In [437]:
t, a, s = sample_trajectories(env, model, 3)
t, s

(tensor([[[0.0000e+00, 0.0000e+00],
          [8.0978e-02, 4.0087e-04],
          [3.0653e-01, 1.0822e-01],
          [3.6991e-01, 3.5005e-01],
          [5.8475e-01, 4.7790e-01],
          [8.1528e-01, 5.7463e-01],
          [      -inf,       -inf]],
 
         [[0.0000e+00, 0.0000e+00],
          [3.5976e-02, 6.5828e-02],
          [2.1181e-01, 2.4354e-01],
          [2.6638e-01, 4.8751e-01],
          [4.9330e-01, 5.9242e-01],
          [6.3604e-01, 7.9767e-01],
          [      -inf,       -inf]],
 
         [[0.0000e+00, 0.0000e+00],
          [5.5799e-02, 4.3961e-02],
          [2.2473e-01, 2.2825e-01],
          [4.7125e-01, 2.6983e-01],
          [5.0480e-01, 5.1757e-01],
          [5.2800e-01, 7.6649e-01],
          [      -inf,       -inf]]], device='cuda:3'),
 tensor([[[0.3239, 0.0032],
          [0.2839, 0.0000],
          [0.8368, 0.0000],
          [0.3417, 0.0000],
          [0.2529, 0.0000],
          [  -inf, 0.0000]],
 
         [[0.3001, 0.6816],
          [0.5034, 

In [438]:
mem.push_batch(t,s)

In [439]:
print(mem.trajectories, mem.samples)

tensor([[[0.0000e+00, 0.0000e+00],
         [8.0978e-02, 4.0087e-04],
         [3.0653e-01, 1.0822e-01],
         [3.6991e-01, 3.5005e-01],
         [5.8475e-01, 4.7790e-01],
         [8.1528e-01, 5.7463e-01],
         [      -inf,       -inf]],

        [[0.0000e+00, 0.0000e+00],
         [3.5976e-02, 6.5828e-02],
         [2.1181e-01, 2.4354e-01],
         [2.6638e-01, 4.8751e-01],
         [4.9330e-01, 5.9242e-01],
         [6.3604e-01, 7.9767e-01],
         [      -inf,       -inf]],

        [[0.0000e+00, 0.0000e+00],
         [5.5799e-02, 4.3961e-02],
         [2.2473e-01, 2.2825e-01],
         [4.7125e-01, 2.6983e-01],
         [5.0480e-01, 5.1757e-01],
         [5.2800e-01, 7.6649e-01],
         [      -inf,       -inf]]], device='cuda:3') tensor([[[0.3239, 0.0032],
         [0.2839, 0.0000],
         [0.8368, 0.0000],
         [0.3417, 0.0000],
         [0.2529, 0.0000],
         [  -inf, 0.0000]],

        [[0.3001, 0.6816],
         [0.5034, 0.0000],
         [0.8599, 0.0000