In [3]:
import json
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

import argparse

from env import Box, get_last_states
from sac_model import CirclePB, Uniform
from sac_sampling import (
    sample_trajectories,
    evaluate_backward_logprobs,
)
from sac import SAC
from sac_replay_memory import ReplayMemory, trajectories_to_transitions

from utils import (
    fit_kde,
    plot_reward,
    sample_from_reward,
    plot_samples,
    estimate_jsd,
    plot_trajectories,
)

import config
import sac_config

parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default=sac_config.DEVICE)
parser.add_argument("--dim", type=int, default=config.DIM)
parser.add_argument("--delta", type=float, default=config.DELTA)
parser.add_argument("--epsilon", type=float, default=config.EPSILON)
parser.add_argument("--R0", type=float, default=config.R0, help="Baseline reward value")
parser.add_argument("--R1", type=float, default=config.R1, help="Medium reward value (e.g., outer square)")
parser.add_argument("--R2", type=float, default=config.R2, help="High reward value (e.g., inner square)")
parser.add_argument("--reward_debug", action="store_true", default=config.REWARD_DEBUG)
parser.add_argument(
    "--reward_type",
    type=str,
    choices=["baseline", "ring", "angular_ring", "multi_ring", "curve", "gaussian_mixture"],
    default=config.REWARD_TYPE,
    help="Type of reward function to use. To modify reward-specific parameters (radius, sigma, etc.), edit rewards.py"
)
parser.add_argument(
    "--beta_min",
    type=float,
    default=config.BETA_MIN,
    help="Minimum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--beta_max",
    type=float, 
    default=config.BETA_MAX,
    help="Maximum value for the concentration parameters of the Beta distribution",
)
parser.add_argument(
    "--PB",
    type=str,
    choices=["learnable", "tied", "uniform"],
    default=config.PB,
    help="Backward policy type",
)
parser.add_argument("--gamma_scheduler", type=float, default=config.GAMMA_SCHEDULER)
parser.add_argument("--scheduler_milestone", type=int, default=config.SCHEDULER_MILESTONE)
parser.add_argument("--seed", type=int, default=config.SEED)
parser.add_argument("--lr", type=float, default=config.LR, help="Learning rate for SAC")
parser.add_argument("--BS", type=int, default=config.BS)
parser.add_argument("--n_iterations", type=int, default=config.N_ITERATIONS)
parser.add_argument("--n_evaluation_interval", type=int, default=config.N_EVALUATION_INTERVAL)
parser.add_argument("--n_logging_interval", type=int, default=config.N_LOGGING_INTERVAL)
parser.add_argument("--hidden_dim", type=int, default=config.HIDDEN_DIM)
parser.add_argument("--n_hidden", type=int, default=config.N_HIDDEN)
parser.add_argument("--n_evaluation_trajectories", type=int, default=config.N_EVALUATION_TRAJECTORIES)
parser.add_argument("--no_plot", action="store_true", default=config.NO_PLOT)
parser.add_argument("--no_wandb", action="store_true", default=config.NO_WANDB)
parser.add_argument("--wandb_project", type=str, default=config.WANDB_PROJECT)
parser.add_argument("--uniform_ratio", type=float, default=config.UNIFORM_RATIO, help="Ratio of uniform policy")


# SAC-specific arguments
parser.add_argument("--tau", type=float, default=sac_config.TAU, help="Tau for soft update")
parser.add_argument("--target_update_interval", type=int, default=sac_config.TARGET_UPDATE_INTERVAL, help="Target network update interval")
parser.add_argument("--policy_update_interval", type=int, default=sac_config.POLICY_UPDATE_INTERVAL, help="Policy update interval")
parser.add_argument("--Critic_hidden_size", type=int, default=sac_config.CRITIC_HIDDEN_SIZE, help="Hidden size for SAC critic networks")
parser.add_argument("--replay_size", type=int, default=sac_config.REPLAY_SIZE, help="Replay buffer size")
parser.add_argument("--sac_batch_size", type=int, default=sac_config.SAC_BATCH_SIZE, help="SAC batch size")
parser.add_argument("--updates_per_step", type=int, default=sac_config.UPDATES_PER_STEP, help="SAC updates per step")
parser.add_argument("--without_backward_model", type=bool, default=sac_config.WITHOUT_BACKWARD_MODEL, help="Whether to use backward model")
args = parser.parse_args([])


In [4]:
device = args.device
dim = args.dim
delta = args.delta
epsilon = args.epsilon
seed = args.seed
lr = args.lr
n_iterations = args.n_iterations
BS = args.BS

if seed == 0:
    seed = np.random.randint(int(1e6))

run_name = f"SAC_d{delta}_{args.reward_type}_lr{lr}_sd{seed}"
if args.without_backward_model:
    run_name += f"_without_backward_model"
run_name += f"_R0,R1,R2_{args.R0},{args.R1},{args.R2}"
run_name += f"_tau{args.tau}"
run_name += f"_BS{BS}"
run_name += f"_replay_size{args.replay_size}"
run_name += f"_sac_batch_size{args.sac_batch_size}"
run_name += f"_update_per_step{args.updates_per_step}"
run_name += f"_target_update_interval{args.target_update_interval}"
run_name += f"_UR{args.uniform_ratio}"
run_name += f"_device{device}"
print(run_name)

torch.manual_seed(seed)
np.random.seed(seed)

print(f"Using device: {device}")

env = Box(
    dim=dim,
    delta=delta,
    epsilon=epsilon,
    device_str=device,
    reward_type=args.reward_type,
    reward_debug=args.reward_debug,
    R0=args.R0,
    R1=args.R1,
    R2=args.R2,
)


# Create SAC agent (includes CirclePF as policy)
sac_agent = SAC(args, env)
Uniform_model = Uniform()
memory = ReplayMemory(args.replay_size, seed, device=device)

bw_model = CirclePB(
    hidden_dim=args.hidden_dim,
    n_hidden=args.n_hidden,
    torso=sac_agent.policy.torso if args.PB == "tied" else None,
    uniform=args.PB == "uniform",
    beta_min=args.beta_min,
    beta_max=args.beta_max,
).to(device)


SAC_d0.25_ring_lr0.001_sd483492_R0,R1,R2_0.1,0.5,2_tau0.1_BS256_replay_size1000000_sac_batch_size256_update_per_step5_target_update_interval5_UR0.2_devicecuda:2
Using device: cuda:2


In [5]:
sac_updates = 0  # Track SAC update steps
# Initialize loss tracking variables
qf1_loss = 0.0
qf2_loss = 0.0
policy_loss = None

for i in trange(1, n_iterations + 1):
    with torch.no_grad():   # ★ 여기 추가
        if np.random.rand() < args.uniform_ratio:
            trajectories, actionss, _, _  = sample_trajectories(
                env,
                Uniform_model,
                BS,
            )
        else:
            trajectories, actionss, _, _  = sample_trajectories(
                env,
                sac_agent.policy,
                BS,
            )

        last_states = get_last_states(env, trajectories)
        logrewards = env.reward(last_states).log()
        
        bw_logprobs, all_bw_logprobs = evaluate_backward_logprobs(
            env, bw_model, trajectories
        )

        if args.without_backward_model:
            intermediate_rewards = torch.where(
                all_bw_logprobs != -float("inf"),
                torch.zeros_like(all_bw_logprobs),
                all_bw_logprobs,
                )
        else:
            intermediate_rewards = all_bw_logprobs

        # Convert trajectories to transitions and push to replay memory
        all_states, all_actions, all_rewards, all_next_states, all_dones = trajectories_to_transitions(
            trajectories, actionss, intermediate_rewards, logrewards, env
        )
    memory.push_batch(all_states, all_actions, all_rewards, all_next_states, all_dones)

    if len(memory) > args.sac_batch_size:
        # Accumulate losses over multiple updates
        qf1_losses = []
        qf2_losses = []
        policy_losses = []
        
        for _ in range(args.updates_per_step):
            qf1_loss_step, qf2_loss_step, policy_loss_step = sac_agent.update_parameters(memory, args.sac_batch_size, sac_updates)
            qf1_losses.append(qf1_loss_step)
            qf2_losses.append(qf2_loss_step)
            if policy_loss_step is not None:
                policy_losses.append(policy_loss_step)
            sac_updates += 1
        
        # Average the losses
        qf1_loss = sum(qf1_losses) / len(qf1_losses)
        qf2_loss = sum(qf2_losses) / len(qf2_losses)
        policy_loss = sum(policy_losses) / len(policy_losses) if len(policy_losses) > 0 else None

    if any(
        [
            torch.isnan(list(sac_agent.policy.parameters())[i]).any()
            for i in range(len(list(sac_agent.policy.parameters())))
        ]
    ):
        raise ValueError("NaN in model parameters")
    break


  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]


In [26]:
trajectories[99]

tensor([[0.0000, 0.0000],
        [0.1174, 0.0928],
        [0.2980, 0.2656],
        [  -inf,   -inf],
        [  -inf,   -inf],
        [  -inf,   -inf],
        [  -inf,   -inf],
        [  -inf,   -inf]], device='cuda:2')

In [24]:
last_states[:100]

tensor([[0.1684, 0.4574],
        [0.0992, 0.2525],
        [0.0882, 0.1209],
        [0.0309, 0.1111],
        [0.0275, 0.0436],
        [0.4047, 0.3297],
        [0.0886, 0.0286],
        [0.3834, 0.2544],
        [0.1001, 0.4296],
        [0.0021, 0.0030],
        [0.3249, 0.1248],
        [0.0597, 0.0507],
        [0.2873, 0.8435],
        [0.0445, 0.1881],
        [0.2105, 0.4049],
        [0.1433, 0.0220],
        [0.0712, 0.0571],
        [0.0688, 0.0115],
        [0.2533, 0.0743],
        [0.1627, 0.0208],
        [0.0208, 0.0543],
        [0.0034, 0.0294],
        [0.0816, 0.0659],
        [0.0639, 0.1013],
        [0.0246, 0.2014],
        [0.6037, 0.3071],
        [0.5743, 0.1599],
        [0.5024, 0.4289],
        [0.0909, 0.0981],
        [0.0485, 0.0140],
        [0.0941, 0.0775],
        [0.8887, 0.8427],
        [0.1823, 0.0887],
        [0.0364, 0.1095],
        [0.0800, 0.0409],
        [0.8634, 0.9678],
        [0.5422, 0.5688],
        [0.2534, 0.2848],
        [0.0

In [36]:
logrewards[:10]

tensor([ 0.1319, -2.3026, -2.3026, -2.3026, -2.3026, -2.2589, -2.3026,  0.3275,
        -2.2647, -2.3026], device='cuda:2')

In [35]:
trajectories[5]

tensor([[0.0000, 0.0000],
        [0.0858, 0.0235],
        [0.1644, 0.2608],
        [0.4047, 0.3297],
        [  -inf,   -inf],
        [  -inf,   -inf],
        [  -inf,   -inf],
        [  -inf,   -inf]], device='cuda:2')

In [31]:
all_bw_logprobs[:10]

tensor([[0.0000, 1.6884,   -inf,   -inf,   -inf,   -inf],
        [0.0000, 2.2829,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000, 1.7183, 0.9347,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000, 0.9347,   -inf,   -inf,   -inf,   -inf],
        [0.0000, 2.2734,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf]], device='cuda:2')

In [37]:
logrewards[:10]

tensor([ 0.1319, -2.3026, -2.3026, -2.3026, -2.3026, -2.2589, -2.3026,  0.3275,
        -2.2647, -2.3026], device='cuda:2')

In [38]:
intermediate_rewards

tensor([[0.0000, 1.6884,   -inf,   -inf,   -inf,   -inf],
        [0.0000, 2.2829,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        ...,
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf,   -inf]], device='cuda:2')

In [43]:
all_states[:10],  all_actions[:10], all_rewards[:10], all_next_states[:10], all_dones[:10]

(tensor([[0.0000, 0.0000],
         [0.0417, 0.2419],
         [0.1684, 0.4574],
         [0.0000, 0.0000],
         [0.0205, 0.0151],
         [0.0992, 0.2525],
         [0.0000, 0.0000],
         [0.0882, 0.1209],
         [0.0000, 0.0000],
         [0.0309, 0.1111]], device='cuda:2'),
 tensor([[0.0417, 0.2419],
         [0.1268, 0.2155],
         [  -inf,   -inf],
         [0.0205, 0.0151],
         [0.0786, 0.2373],
         [  -inf,   -inf],
         [0.0882, 0.1209],
         [  -inf,   -inf],
         [0.0309, 0.1111],
         [  -inf,   -inf]], device='cuda:2'),
 tensor([ 0.0000,  1.6884,  0.1319,  0.0000,  2.2829, -2.3026,  0.0000, -2.3026,
          0.0000, -2.3026], device='cuda:2'),
 tensor([[0.0417, 0.2419],
         [0.1684, 0.4574],
         [  -inf,   -inf],
         [0.0205, 0.0151],
         [0.0992, 0.2525],
         [  -inf,   -inf],
         [0.0882, 0.1209],
         [  -inf,   -inf],
         [0.0309, 0.1111],
         [  -inf,   -inf]], device='cuda:2'),
 tenso

In [39]:
all_states, all_actions, all_rewards, all_next_states, all_dones 

(tensor([[0.0000, 0.0000],
         [0.0417, 0.2419],
         [0.1684, 0.4574],
         ...,
         [0.1730, 0.0357],
         [0.0000, 0.0000],
         [0.0156, 0.0259]], device='cuda:2'),
 tensor([[0.0417, 0.2419],
         [0.1268, 0.2155],
         [  -inf,   -inf],
         ...,
         [  -inf,   -inf],
         [0.0156, 0.0259],
         [  -inf,   -inf]], device='cuda:2'),
 tensor([ 0.0000,  1.6884,  0.1319,  0.0000,  2.2829, -2.3026,  0.0000, -2.3026,
          0.0000, -2.3026,  0.0000, -2.3026,  0.0000,  1.7183,  0.9347, -2.2589,
          0.0000, -2.3026,  0.0000,  0.9347,  0.3275,  0.0000,  2.2734, -2.2647,
          0.0000, -2.3026,  0.0000,  2.0349, -2.2881,  0.0000, -2.3026,  0.0000,
          1.5620,  1.4669,  0.9347, -2.2545,  0.0000, -2.3026,  0.0000,  1.3853,
          0.7301,  0.0000, -2.3026,  0.0000, -2.3026,  0.0000, -2.3026,  0.0000,
          2.5851, -2.3026,  0.0000, -2.3026,  0.0000, -2.3026,  0.0000, -2.3026,
          0.0000, -2.3026,  0.0000, -2.3026