In [1]:
import json
import os
from socket import if_indextoname
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

import argparse

from env import Box, get_last_states
from sac_model import CirclePB, Uniform
from sac_sampling import (
    sample_trajectories,
    evaluate_backward_logprobs,
)
from sac import SAC
from sac_replay_memory import ReplayMemory, trajectories_to_transitions

from utils import (
    fit_kde,
    plot_reward,
    sample_from_reward,
    plot_samples,
    estimate_jsd,
    plot_trajectories,
    plot_termination_probabilities,
)

import config
import sac_config



parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default=sac_config.DEVICE)
parser.add_argument("--dim", type=int, default=config.DIM)
parser.add_argument("--delta", type=float, default=config.DELTA)
parser.add_argument("--epsilon", type=float, default=config.EPSILON)
parser.add_argument("--R0", type=float, default=config.R0, help="Baseline reward value")
parser.add_argument("--R1", type=float, default=config.R1, help="Medium reward value (e.g., outer square)")
parser.add_argument("--R2", type=float, default=config.R2, help="High reward value (e.g., inner square)")
parser.add_argument("--reward_debug", action="store_true", default=config.REWARD_DEBUG)
parser.add_argument(
    "--reward_type",
    type=str,
    choices=["baseline", "ring", "angular_ring", "multi_ring", "curve", "gaussian_mixture"],
    default=config.REWARD_TYPE,
    help="Type of reward function to use. To modify reward-specific parameters (radius, sigma, etc.), edit rewards.py"
)
parser.add_argument(
    "--beta_min",
    type=float,
    default=config.BETA_MIN,
    help="Minimum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--beta_max",
    type=float, 
    default=config.BETA_MAX,
    help="Maximum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--PB",
    type=str,
    choices=["learnable", "tied", "uniform"],
    default=config.PB,
    help="Backward policy type",
)
parser.add_argument("--gamma_scheduler", type=float, default=config.GAMMA_SCHEDULER)
parser.add_argument("--scheduler_milestone", type=int, default=config.SCHEDULER_MILESTONE)
parser.add_argument("--seed", type=int, default=config.SEED)
parser.add_argument("--lr", type=float, default=config.LR, help="Learning rate for SAC")
parser.add_argument("--BS", type=int, default=config.BS)
parser.add_argument("--n_iterations", type=int, default=config.N_ITERATIONS)
parser.add_argument("--n_evaluation_interval", type=int, default=config.N_EVALUATION_INTERVAL)
parser.add_argument("--n_logging_interval", type=int, default=config.N_LOGGING_INTERVAL)
parser.add_argument("--hidden_dim", type=int, default=config.HIDDEN_DIM)
parser.add_argument("--n_hidden", type=int, default=config.N_HIDDEN)
parser.add_argument("--n_evaluation_trajectories", type=int, default=config.N_EVALUATION_TRAJECTORIES)
parser.add_argument("--no_plot", action="store_true", default=config.NO_PLOT)
parser.add_argument("--no_wandb", action="store_true", default=config.NO_WANDB)
parser.add_argument("--wandb_project", type=str, default=config.WANDB_PROJECT)
parser.add_argument("--uniform_ratio", type=float, default=config.UNIFORM_RATIO, help="Ratio of uniform policy")


# SAC-specific arguments
parser.add_argument("--tau", type=float, default=sac_config.TAU, help="Tau for soft update")
parser.add_argument("--target_update_interval", type=int, default=sac_config.TARGET_UPDATE_INTERVAL, help="Target network update interval")
parser.add_argument("--Critic_hidden_size", type=int, default=sac_config.CRITIC_HIDDEN_SIZE, help="Hidden size for SAC critic networks")
parser.add_argument("--replay_size", type=int, default=sac_config.REPLAY_SIZE, help="Replay buffer size")
parser.add_argument("--sac_batch_size", type=int, default=sac_config.SAC_BATCH_SIZE, help="SAC batch size")
parser.add_argument("--updates_per_step", type=int, default=sac_config.UPDATES_PER_STEP, help="SAC updates per step")
parser.add_argument("--without_backward_model", type=bool, default=sac_config.WITHOUT_BACKWARD_MODEL, help="Whether to use backward model")
args = parser.parse_args([])


device = args.device
dim = args.dim
delta = args.delta
epsilon = args.epsilon
seed = args.seed
lr = args.lr
n_iterations = args.n_iterations
BS = args.BS

if seed == 0:
    seed = np.random.randint(int(1e6))

torch.manual_seed(seed)
np.random.seed(seed)

print(f"Using device: {device}")



Using device: cuda:3


In [2]:
import random
import numpy as np
import torch
import os
import pickle


def trajectories_to_transitions(trajectories, actionss, all_bw_logprobs, logrewards, env):
    """
    Convert trajectories to transitions for replay buffer.

    Args:
        trajectories: tensor of shape (batch_size, trajectory_length, dim)
        actionss: tensor of shape (batch_size, trajectory_length, dim)
        all_bw_logprobs: tensor of shape (batch_size, trajectory_length)
        last_states: tensor of shape (batch_size, dim)
        logrewards: tensor of shape (batch_size,)
        env: environment object

    Returns:
        Tuple of (states, actions, rewards, next_states, dones) as tensors
    """

    # Extract states and next_states for intermediate transitions
    # Match the length to all_bw_logprobs


    # Extract states and next_states for intermediate transitions
    # Match the length to all_bw_logprobs

    states = trajectories[:, :-1, :]  
    next_states = trajectories[:, 1:, :] 
    is_not_sink = torch.all(states != env.sink_state, dim=-1)
    is_next_sink = torch.all(next_states == env.sink_state, dim=-1)
    last_state = is_not_sink & is_next_sink
    dones = torch.zeros_like(last_state, dtype=torch.float32)  # (batch_size, bw_length)
    dones[last_state] = 1.0
    rewards = torch.cat([all_bw_logprobs, torch.full((all_bw_logprobs.shape[0], 1), float('-inf'), device=all_bw_logprobs.device)], dim=1)
    rewards[dones == 1] = logrewards

    # Check which rewards are valid (not inf/nan)
    is_valid = torch.isfinite(rewards)  # (batch_size, bw_length)

    # Flatten batch and time dimensions for transitions
    states_flat = states[is_valid]
    actions_flat = actionss[is_valid]
    rewards_flat = rewards[is_valid]
    next_states_flat = next_states[is_valid]
    dones_flat = dones[is_valid]

    return states_flat, actions_flat, rewards_flat, next_states_flat, dones_flat


class ReplayMemory:
    def __init__(self, capacity, seed, device='cpu'):
        random.seed(seed)
        np.random.seed(seed)
        self.capacity = capacity
        self.device = device
        self.position = 0
        self.size = 0

        # Will be initialized on first push
        self.states = None
        self.actions = None
        self.rewards = None
        self.next_states = None
        self.dones = None

    def push_batch(self, states, actions, rewards, next_states, dones):
        """
        Push a batch of transitions to the replay buffer.

        Args:
            states: tensor of shape (batch_size, state_dim)
            actions: tensor of shape (batch_size, action_dim)
            rewards: tensor of shape (batch_size,)
            next_states: tensor of shape (batch_size, state_dim)
            dones: tensor of shape (batch_size,)
        """
        batch_size = states.shape[0]

        # Initialize buffers on first push
        if self.states is None:
            state_dim = states.shape[1]
            action_dim = actions.shape[1]
            self.states = torch.zeros((self.capacity, state_dim), dtype=states.dtype, device=self.device)
            self.actions = torch.zeros((self.capacity, action_dim), dtype=actions.dtype, device=self.device)
            self.rewards = torch.zeros(self.capacity, dtype=rewards.dtype, device=self.device)
            self.next_states = torch.zeros((self.capacity, state_dim), dtype=next_states.dtype, device=self.device)
            self.dones = torch.zeros(self.capacity, dtype=dones.dtype, device=self.device)

        # Move to device if needed
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        # Calculate indices
        end_pos = self.position + batch_size

        if end_pos <= self.capacity:
            # No wrap around
            self.states[self.position:end_pos] = states
            self.actions[self.position:end_pos] = actions
            self.rewards[self.position:end_pos] = rewards
            self.next_states[self.position:end_pos] = next_states
            self.dones[self.position:end_pos] = dones
        else:
            # Wrap around
            first_part = self.capacity - self.position
            self.states[self.position:] = states[:first_part]
            self.actions[self.position:] = actions[:first_part]
            self.rewards[self.position:] = rewards[:first_part]
            self.next_states[self.position:] = next_states[:first_part]
            self.dones[self.position:] = dones[:first_part]

            second_part = batch_size - first_part
            self.states[:second_part] = states[first_part:]
            self.actions[:second_part] = actions[first_part:]
            self.rewards[:second_part] = rewards[first_part:]
            self.next_states[:second_part] = next_states[first_part:]
            self.dones[:second_part] = dones[first_part:]

        self.position = end_pos % self.capacity
        self.size = min(self.size + batch_size, self.capacity)

    def sample(self, batch_size):
        """Sample a random batch from the buffer and return as numpy arrays for compatibility."""
        indices = np.random.choice(self.size, batch_size, replace=False)
        indices = torch.from_numpy(indices).to(self.device)

        return (
            self.states[indices],
            self.actions[indices],
            self.rewards[indices].unsqueeze(-1),  # (batch_size,) -> (batch_size, 1)
            self.next_states[indices],
            self.dones[indices].unsqueeze(-1),  # (batch_size,) -> (batch_size, 1)
        )

    def __len__(self):
        return self.size

    def save_buffer(self, env_name, suffix="", save_path=None):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')

        if save_path is None:
            save_path = "checkpoints/sac_buffer_{}_{}".format(env_name, suffix)
        print('Saving buffer to {}'.format(save_path))

        buffer_dict = {
            'states': self.states[:self.size].cpu() if self.states is not None else None,
            'actions': self.actions[:self.size].cpu() if self.actions is not None else None,
            'rewards': self.rewards[:self.size].cpu() if self.rewards is not None else None,
            'next_states': self.next_states[:self.size].cpu() if self.next_states is not None else None,
            'dones': self.dones[:self.size].cpu() if self.dones is not None else None,
            'position': self.position,
            'size': self.size,
        }

        with open(save_path, 'wb') as f:
            pickle.dump(buffer_dict, f)

    def load_buffer(self, save_path):
        print('Loading buffer from {}'.format(save_path))

        with open(save_path, "rb") as f:
            buffer_dict = pickle.load(f)

        self.position = buffer_dict['position']
        self.size = buffer_dict['size']

        if buffer_dict['states'] is not None:
            state_dim = buffer_dict['states'].shape[1]
            action_dim = buffer_dict['actions'].shape[1]

            self.states = torch.zeros((self.capacity, state_dim), device=self.device)
            self.actions = torch.zeros((self.capacity, action_dim), device=self.device)
            self.rewards = torch.zeros(self.capacity, device=self.device)
            self.next_states = torch.zeros((self.capacity, state_dim), device=self.device)
            self.dones = torch.zeros(self.capacity, device=self.device)

            self.states[:self.size] = buffer_dict['states'].to(self.device)
            self.actions[:self.size] = buffer_dict['actions'].to(self.device)
            self.rewards[:self.size] = buffer_dict['rewards'].to(self.device)
            self.next_states[:self.size] = buffer_dict['next_states'].to(self.device)
            self.dones[:self.size] = buffer_dict['dones'].to(self.device)


In [3]:
import torch
from torch.distributions import Distribution, Beta
import numpy as np


def sample_actions(env, model, states):
    # states is a tensor of shape (n, dim)
    batch_size = states.shape[0]
    out = model.to_dist(states)
    if isinstance(out[0], Distribution):  # s0 input
        dist_r, dist_theta = out
        samples_r = dist_r.rsample(torch.Size((batch_size,)))
        samples_theta = dist_theta.rsample(torch.Size((batch_size,)))

        actions = (
            torch.stack(
                [
                    samples_r * torch.cos(torch.pi / 2.0 * samples_theta),
                    samples_r * torch.sin(torch.pi / 2.0 * samples_theta),
                ],
                dim=1,
            )
            * env.delta
        )

        logprobs = (
            dist_r.log_prob(samples_r)
            + dist_theta.log_prob(samples_theta)
            - torch.log(samples_r * env.delta)
            - np.log(np.pi / 2)
            - np.log(env.delta)  # why ?
        )

        exit_proba_naive = torch.zeros((batch_size,),device=env.device)
        actions_naive = actions
        logprobs_naive = logprobs

    else:
        exit_proba, dist = out
        exit_proba_naive = exit_proba.clone()
        exit = torch.bernoulli(exit_proba).bool()
        near_goal_mask = torch.norm(1 - states, dim=1) <= env.delta
        boundary_mask  = torch.any(states >= 1 - env.epsilon, dim=-1)
        exit = near_goal_mask | boundary_mask
        A = torch.where(
            states[:, 0] <= 1 - env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((1 - states[:, 0]) / env.delta),
        )
        B = torch.where(
            states[:, 1] <= 1 - env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((1 - states[:, 1]) / env.delta),
        )
        assert torch.all(
            B[~torch.any(states >= 1 - env.delta, dim=-1)]
            >= A[~torch.any(states >= 1 - env.delta, dim=-1)]
        )
        samples = dist.rsample()

        actions = samples * (B - A) + A
        actions *= torch.pi / 2.0
        actions = (
            torch.stack([torch.cos(actions), torch.sin(actions)], dim=1) * env.delta
        )

        actions_naive = actions.clone()
        
        logprobs = (
            dist.log_prob(samples)
            + torch.log(1 - exit_proba)
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )

        logprobs_naive = logprobs.clone()

        actions[exit] = -float("inf")
        logprobs[exit] = torch.log(exit_proba[exit])
        logprobs[near_goal_mask] = 0.0
        logprobs[boundary_mask] = 0.0

        exit_proba_naive[near_goal_mask] = 1.0
        exit_proba_naive[boundary_mask] = 1.0
        actions_naive[near_goal_mask] = -float("inf")
        actions_naive[boundary_mask] = -float("inf")
        logprobs_naive[near_goal_mask] = 0.0
        logprobs_naive[boundary_mask] = 0.0

    return actions, logprobs, exit_proba_naive, actions_naive, logprobs_naive


def sample_trajectories(env, model, n_trajectories):
    step = 0
    states = torch.zeros((n_trajectories, env.dim), device=env.device)
    actionss = []
    trajectories = [states]
    trajectories_logprobs = torch.zeros((n_trajectories,), device=env.device)
    all_logprobs = []
    while not torch.all(states == env.sink_state):
        step_logprobs = torch.full((n_trajectories,), -float("inf"), device=env.device)
        non_terminal_mask = torch.all(states != env.sink_state, dim=-1)
        actions = torch.full(
            (n_trajectories, env.dim), -float("inf"), device=env.device
        )
        non_terminal_actions, logprobs, _, _, _ = sample_actions(
            env,
            model,
            states[non_terminal_mask],
        )
        actions[non_terminal_mask] = non_terminal_actions.reshape(-1, env.dim)
        actionss.append(actions)
        states = env.step(states, actions)
        trajectories.append(states)
        trajectories_logprobs[non_terminal_mask] += logprobs
        step_logprobs[non_terminal_mask] = logprobs
        all_logprobs.append(step_logprobs)
        step += 1
    trajectories = torch.stack(trajectories, dim=1)
    actionss = torch.stack(actionss, dim=1)
    all_logprobs = torch.stack(all_logprobs, dim=1)
    return trajectories, actionss, trajectories_logprobs, all_logprobs


def evaluate_backward_logprobs(env, model, trajectories):
    logprobs = torch.zeros((trajectories.shape[0],), device=env.device)
    all_logprobs = []
    for i in range(trajectories.shape[1] - 2, 1, -1):
        all_step_logprobs = torch.full(
            (trajectories.shape[0],), -float("inf"), device=env.device
        )
        non_sink_mask = torch.all(trajectories[:, i] != env.sink_state, dim=-1)
        current_states = trajectories[:, i][non_sink_mask]
        previous_states = trajectories[:, i - 1][non_sink_mask]
        difference_1 = current_states[:, 0] - previous_states[:, 0]
        difference_1.clamp_(
            min=0.0, max=env.delta
        )  # Should be the case already - just to avoid numerical issues
        A = torch.where(
            current_states[:, 0] >= env.delta,
            0.0,
            2.0 / torch.pi * torch.arccos((current_states[:, 0]) / env.delta),
        )
        B = torch.where(
            current_states[:, 1] >= env.delta,
            1.0,
            2.0 / torch.pi * torch.arcsin((current_states[:, 1]) / env.delta),
        )

        dist = model.to_dist(current_states)

        step_logprobs = (
            dist.log_prob(
                (
                    1.0
                    / (B - A)
                    * (2.0 / torch.pi * torch.acos(difference_1 / env.delta) - A)
                ).clamp(1e-4, 1 - 1e-4)
            ).clamp_max(100)
            - np.log(env.delta)
            - np.log(np.pi / 2)
            - torch.log(B - A)
        )

        if torch.any(torch.isnan(step_logprobs)):
            raise ValueError("NaN in backward logprobs")

        if torch.any(torch.isinf(step_logprobs)):
            raise ValueError("Inf in backward logprobs")

        logprobs[non_sink_mask] += step_logprobs
        all_step_logprobs[non_sink_mask] = step_logprobs

        all_logprobs.append(all_step_logprobs)

    all_logprobs.append(torch.zeros((trajectories.shape[0],), device=env.device))
    all_logprobs = torch.stack(all_logprobs, dim=1)

    return logprobs, all_logprobs.flip(1)



In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta
from torch.optim import Adam
from sac_utils import soft_update, hard_update
from sac_model import QNetwork
from sac_model import CirclePF
# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        
        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2

class NeuralNet(nn.Module):
    def __init__(self, dim=2, hidden_dim=64, n_hidden=2, torso=None, output_dim=3):
        super().__init__()
        self.dim = dim
        self.n_hidden = n_hidden
        self.output_dim = output_dim
        if torso is not None:
            self.torso = torso
        else:
            self.torso = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.ELU(),
                *[
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ELU(),
                    )
                    for _ in range(n_hidden)
                ],
            )
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.output_layer(self.torso(x))
        return out


class CirclePF(NeuralNet):
    def __init__(
        self,
        hidden_dim=64,
        n_hidden=2,
        beta_min=0.1,
        beta_max=2.0,
    ):
        output_dim = 3 # Only alpha and beta for single Beta distribution
        super().__init__(
            dim=2, hidden_dim=hidden_dim, n_hidden=n_hidden, output_dim=output_dim
        )

        # The following parameters are for PF(. | s0)
        self.PFs0 = nn.ParameterDict(
            {
                "log_alpha_r": nn.Parameter(torch.zeros(1)),
                "log_alpha_theta": nn.Parameter(torch.zeros(1)),
                "log_beta_r": nn.Parameter(torch.zeros(1)),
                "log_beta_theta": nn.Parameter(torch.zeros(1)),
            }
        )

        self.beta_min = beta_min
        self.beta_max = beta_max
        self.beta_min = beta_min
        self.beta_max = beta_max

    def forward(self, x):
        out = super().forward(x)
        pre_sigmoid_exit = out[..., 0]
        log_alpha = out[..., 1]
        log_beta = out[..., 2]

        exit_proba = torch.sigmoid(pre_sigmoid_exit)
        return (
            exit_proba,
            self.beta_max * torch.sigmoid(log_alpha) + self.beta_min,
            self.beta_max * torch.sigmoid(log_beta) + self.beta_min,
        )

    def to_dist(self, x):
        if torch.all(x[0] == 0.0):
            assert torch.all(
                x == 0.0
            )  # If one of the states is s0, all of them must be
            alpha_r = self.PFs0["log_alpha_r"]
            alpha_r = self.beta_max * torch.sigmoid(alpha_r) + self.beta_min
            alpha_theta = self.PFs0["log_alpha_theta"]
            alpha_theta = self.beta_max * torch.sigmoid(alpha_theta) + self.beta_min
            beta_r = self.PFs0["log_beta_r"]
            beta_r = self.beta_max * torch.sigmoid(beta_r) + self.beta_min
            beta_theta = self.PFs0["log_beta_theta"]
            beta_theta = self.beta_max * torch.sigmoid(beta_theta) + self.beta_min

            dist_r = Beta(alpha_r[0], beta_r[0])
            dist_theta = Beta(alpha_theta[0], beta_theta[0])
            return dist_r, dist_theta

        # Otherwise, we use the neural network
        exit_proba, alpha, beta = self.forward(x)
        dist = Beta(alpha, beta)

        return exit_proba, dist

class Uniform():
    def __init__(self):
        pass

    def to_dist(self, x):
        # Set device to match x (input tensor)
        device = x.device
        if torch.all(x[0] == 0.0):
            assert torch.all(
                x == 0.0
            )  # If one of the states is s0, all of them must be
            return Beta(torch.tensor(1., device=device), torch.tensor(1., device=device)), Beta(torch.tensor(1., device=device), torch.tensor(1., device=device))
        return Beta(torch.tensor(1., device=device), torch.tensor(1., device=device))

class CirclePB(NeuralNet):
    def __init__(
        self,
        hidden_dim=64,
        n_hidden=2,
        torso=None,
        uniform=False,
        beta_min=0.1,
        beta_max=2.0,
    ):
        output_dim = 2  # Only alpha and beta for single Beta distribution
        super().__init__(
            dim=2, hidden_dim=hidden_dim, n_hidden=n_hidden, output_dim=output_dim
        )
        if torso is not None:
            self.torso = torso
        self.uniform = uniform
        self.beta_min = beta_min
        self.beta_max = beta_max

    def forward(self, x):
        # x is a batch of states, a tensor of shape (batch_size, dim) with dim == 2
        out = super().forward(x)
        log_alpha = out[:, 0]
        log_beta = out[:, 1]
        return (
            self.beta_max * torch.sigmoid(log_alpha) + self.beta_min,
            self.beta_max * torch.sigmoid(log_beta) + self.beta_min,
        )

    def to_dist(self, x):
        if self.uniform:
            return Beta(torch.ones(x.shape[0], device=x.device), torch.ones(x.shape[0], device=x.device))
        alpha, beta = self.forward(x)
        dist = Beta(alpha, beta)
        return dist




class SAC(object):
    def __init__(self, args, env):

        self.gamma = 1.0  # Fixed to 1.0 for GFlowNet
        self.target_update_interval = args.target_update_interval
        self.device = env.device
        self.tau = args.tau
        self.env = env  # Store env for sample_actions
        self.critic = QNetwork(env.dim, env.dim, args.Critic_hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
        self.critic_target = QNetwork(env.dim, env.dim, args.Critic_hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)
        self.policy = CirclePF(
            hidden_dim=args.hidden_dim,
            n_hidden=args.n_hidden,
            beta_min=args.beta_min,
            beta_max=args.beta_max,
        ).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
        # Add scheduler for policy optimizer
        self.policy_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.policy_optim,
            milestones=[i * args.scheduler_milestone for i in range(1, 10)],
            gamma=args.gamma_scheduler,
        )

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = memory.sample(batch_size=batch_size)

        # ★ 여기에서 그래프를 확실히 끊어주기
        state_batch      = state_batch.detach().to(self.device)
        action_batch     = action_batch.detach().to(self.device)
        reward_batch     = reward_batch.detach().to(self.device)
        next_state_batch = next_state_batch.detach().to(self.device)
        done_batch       = done_batch.detach().to(self.device)

        with torch.no_grad():
            # Separate next_state_batch based on done flag
            done_mask = (done_batch == 1).squeeze(-1)  # (batch_size,)
            # print(done_mask)
            next_state_not_done = next_state_batch[~done_mask]  # States where done == 0
            # print(next_state_not_done)
            target_q_value = reward_batch[~done_mask].squeeze(-1)
            # print(target_q_value)           

            _, _, next_state_exit_proba, next_state_action_naive, next_state_log_pi_naive = sample_actions(self.env, self.policy, next_state_not_done)

            is_inf_mask = torch.all(torch.isinf(next_state_action_naive), dim=-1)
        #     # 1. 반드시 terminal state에 도달하는 경우

            target_q_value += next_state_exit_proba * (self.env.reward(next_state_not_done) - next_state_exit_proba.log())
        #     # 2. 반드시 terminal state에 도달하지 않아도 되는 경우 

            qf1_next_target, qf2_next_target = self.critic_target(next_state_not_done[~is_inf_mask], next_state_action_naive[~is_inf_mask])
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target).squeeze(-1) 

            target_q_value[~is_inf_mask] += (1 - next_state_exit_proba[~is_inf_mask]) * (min_qf_next_target - next_state_log_pi_naive[~is_inf_mask])

        qf1, qf2 = self.critic(state_batch[~done_mask], action_batch[~done_mask])  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1 = qf1.squeeze(-1)
        qf2 = qf2.squeeze(-1)
        qf1_loss = F.mse_loss(qf1, target_q_value)  
        qf2_loss = F.mse_loss(qf2, target_q_value)  

        qf_loss = qf1_loss + qf2_loss
        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        s0_mask = torch.all(state_batch == 0, dim=-1)
        non_s0_mask = ~s0_mask

        policy_loss = torch.zeros_like(reward_batch).squeeze(-1)

        if s0_mask.any():
            _, _, exit_proba_s0, action_naive_s0, log_pi_naive_s0 = sample_actions(self.env, self.policy, state_batch[s0_mask])

        if non_s0_mask.any():
            _, _, exit_proba_non_s0, action_naive_non_s0, log_pi_naive_non_s0 = sample_actions(self.env, self.policy, state_batch[non_s0_mask])

        exit_proba = torch.cat([exit_proba_s0, exit_proba_non_s0], dim=0)
        action_naive = torch.cat([action_naive_s0, action_naive_non_s0], dim=0)
        log_pi_naive = torch.cat([log_pi_naive_s0, log_pi_naive_non_s0], dim=0)
        state_batch_reordered = torch.cat([state_batch[s0_mask], state_batch[non_s0_mask]], dim=0)
        is_inf_mask = torch.all(torch.isinf(action_naive), dim=-1)
        is_non_zero_mask = torch.where(exit_proba != 0, True, False)

        policy_loss[is_non_zero_mask] += exit_proba[is_non_zero_mask] * (exit_proba[is_non_zero_mask].log() - self.env.reward(state_batch_reordered[is_non_zero_mask]))


        qf1_pi, qf2_pi = self.critic(state_batch_reordered[~is_inf_mask], action_naive[~is_inf_mask])

        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        min_qf_pi = min_qf_pi.squeeze(-1)
        policy_loss[~is_inf_mask] += (1 - exit_proba[~is_inf_mask]) * (log_pi_naive[~is_inf_mask] - min_qf_pi)
        policy_loss = policy_loss.mean()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item()
        
    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')
        if ckpt_path is None:
            ckpt_path = "checkpoints/sac_checkpoint_{}_{}".format(env_name, suffix)
        print('Saving models to {}'.format(ckpt_path))
        torch.save({'policy_state_dict': self.policy.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_target.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optim.state_dict()}, ckpt_path)

    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])

            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()

In [9]:
env = Box(
    dim=dim,
    delta=delta,
    epsilon=epsilon,
    device_str=device,
    reward_type=args.reward_type,
    reward_debug=args.reward_debug,
    R0=args.R0,
    R1=args.R1,
    R2=args.R2,
)

# Create SAC agent (includes CirclePF as policy)
sac_agent = SAC(args, env)
Uniform_model = Uniform()

# Create replay memory
memory = ReplayMemory(args.replay_size, seed, device=device)

bw_model = CirclePB(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    torso=sac_agent.policy.torso if args.PB == "tied" else None,
    uniform=args.PB == "uniform",
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)

jsd = float("inf")
sac_updates = 0  # Track SAC update steps

for i in trange(1, n_iterations + 1):
    with torch.no_grad():   # ★ 여기 추가
        if np.random.rand() < args.uniform_ratio:
            trajectories, actionss, _, _  = sample_trajectories(
                env,
                Uniform_model,
                BS,
            )
        else:
            trajectories, actionss, _, _  = sample_trajectories(
                env,
                sac_agent.policy,
                BS,
            )

        last_states = get_last_states(env, trajectories)
        logrewards = env.reward(last_states).log()
        
        bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
            env, bw_model, trajectories
        )

        if args.without_backward_model:
            intermediate_rewards = torch.where(
                all_bw_logprobs != -float("inf"),
                torch.zeros_like(all_bw_logprobs),
                all_bw_logprobs,
                )
        else:
            intermediate_rewards = all_bw_logprobs

        # Convert trajectories to transitions and push to replay memory
        all_states, all_actions, all_rewards, all_next_states, all_dones = trajectories_to_transitions(
            trajectories, actionss, intermediate_rewards, logrewards, env
        )
        
    memory.push_batch(all_states, all_actions, all_rewards, all_next_states, all_dones)
    if len(memory) > args.sac_batch_size:
        for _ in range(args.updates_per_step):
            qf1_loss, qf2_loss, policy_loss = sac_agent.update_parameters(memory, args.sac_batch_size, sac_updates)
            sac_updates += 1
        # Step the scheduler once per iteration (not per update)
        sac_agent.policy_scheduler.step()

    if any(
        [
            torch.isnan(list(sac_agent.policy.parameters())[i]).any()
            for i in range(len(list(sac_agent.policy.parameters())))
        ]
    ):
        raise ValueError("NaN in model parameters")

        


  1%|▏         | 63/5000 [00:14<19:12,  4.28it/s]


KeyboardInterrupt: 