In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ---------------------------
# Config
# ---------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 0
np.random.seed(SEED)
torch.manual_seed(SEED)


# Feasible action bounds (per-dimension)
# 예: 환경 action이 [-1, 1]^d 라면
# chunk(flatten dH)인 경우도 동일하게 각 dim에 대해 [-1,1]이라고 두는게 보통 시작점
ELL = None  # shape (act_dim,)
UU  = None  # shape (act_dim,)

# Guard bounds for infeasible sampling: L < ell, u < U
# 논문에서 "100~1000x boundary"도 가능하지만, 실전에서는 우선 2~5배부터
GUARD_SCALE = 3.0  # U = u*GUARD_SCALE, L = ell*GUARD_SCALE
DELTA = 0.5        # infeasible sampling width beyond L/U

GAMMA = 0.99
C_REWARD = 10.0      # reward scaling factor (실험: 1 vs 10 vs 100)
USE_LN = True       # LayerNorm on hidden layers
ALPHA_PA = 1.0      # PA weight (0이면 PA 없음)

# Training (옵션)
DO_TRAIN = True
STEPS = 20000
BATCH = 256
LR = 3e-4

# If minimum reward unknown, use dataset min
RMIN_FROM_DATA = True

# ---------------------------
# Load dataset
# ---------------------------

DATA_PATH = "/home/robros/git/qc-flow-priority-sampling/robomimic/dataset/transport/mh/low_dim_v15.hdf5"  
obs, act, rew, next_obs, done, info = load_offline_dataset(
    DATA_PATH,
    robomimic_obs_keys=None,      # None이면 low-dim key 자동 선택
    exclude_images=True           # 기본 True (이미지는 자동 제외)
)

print(info)
print(obs.shape, act.shape, rew.shape, next_obs.shape, done.shape)
N, obs_dim = obs.shape
act_dim = act.shape[1]

if ELL is None or UU is None:
    # 가장 흔한 기본: [-1, 1]^dim
    ELL = -np.ones((act_dim,), dtype=np.float32)
    UU  =  np.ones((act_dim,), dtype=np.float32)

L = ELL * GUARD_SCALE
U = UU  * GUARD_SCALE

rmin = float(rew.min()) if RMIN_FROM_DATA else -1.0
QMIN = C_REWARD * rmin / (1.0 - GAMMA)

print(f"[INFO] N={N}, obs_dim={obs_dim}, act_dim={act_dim}")
print(f"[INFO] rmin={rmin:.4f}, QMIN={QMIN:.4f}")

# ---------------------------
# Simple replay sampler
# ---------------------------
def sample_batch(batch_size: int):
    idx = np.random.randint(0, N, size=(batch_size,))
    return (
        torch.from_numpy(obs[idx]).to(DEVICE),
        torch.from_numpy(act[idx]).to(DEVICE),
        torch.from_numpy(rew[idx]).to(DEVICE),
        torch.from_numpy(next_obs[idx]).to(DEVICE),
        torch.from_numpy(done[idx]).to(DEVICE),
    )

# ---------------------------
# Infeasible sampler (box-out) : a in A_I
# ---------------------------
def sample_infeasible_actions(batch_size: int):
    """
    Sample actions from A_I:
      pick one dim k, push it to [U_k, U_k+DELTA) or (L_k-DELTA, L_k]
      keep others within feasible [ELL, UU]
    """
    a = np.random.uniform(ELL, UU, size=(batch_size, act_dim)).astype(np.float32)

    ks = np.random.randint(0, act_dim, size=(batch_size,))
    side = np.random.rand(batch_size) < 0.5

    for i in range(batch_size):
        k = ks[i]
        if side[i]:
            a[i, k] = np.random.uniform(U[k], U[k] + DELTA)
        else:
            a[i, k] = np.random.uniform(L[k] - DELTA, L[k])
    return torch.from_numpy(a).to(DEVICE)

# ---------------------------
# Critic network (ReLU MLP + optional LN)
# ---------------------------
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=256, use_ln=True):
        super().__init__()
        self.use_ln = use_ln
        self.fc1 = nn.Linear(obs_dim + act_dim, hidden)
        self.ln1 = nn.LayerNorm(hidden) if use_ln else nn.Identity()
        self.fc2 = nn.Linear(hidden, hidden)
        self.ln2 = nn.LayerNorm(hidden) if use_ln else nn.Identity()
        self.fc3 = nn.Linear(hidden, 1)

    def forward(self, s, a):
        x = torch.cat([s, a], dim=-1)
        x = self.fc1(x)
        x = self.ln1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.ln2(x)
        x = F.relu(x)
        q = self.fc3(x).squeeze(-1)
        return q

critic = Critic(obs_dim, act_dim, use_ln=USE_LN).to(DEVICE)
target = Critic(obs_dim, act_dim, use_ln=USE_LN).to(DEVICE)
target.load_state_dict(critic.state_dict())
opt = torch.optim.Adam(critic.parameters(), lr=LR)

# ---------------------------
# (Optional) Training: offline SARSA-ish 1-step target with same action (minimal)
# For quick diagnostic only: y = c*r + gamma*(1-done)*Q_target(s', a_next_from_dataset)
# Here we use dataset next action by shifting index as a cheap proxy (or sample a' ~ behavior)
# ---------------------------
def soft_update(tgt, src, tau=0.005):
    for p_t, p in zip(tgt.parameters(), src.parameters()):
        p_t.data.mul_(1 - tau).add_(tau * p.data)

if DO_TRAIN:
    for step in range(1, STEPS + 1):
        s, a, r, s2, d = sample_batch(BATCH)

        # reward scaling
        r_scaled = C_REWARD * r

        # cheap "behavior next action": sample random batch action (proxy)
        # (원하면: 같은 idx의 next action이 있으면 그걸 쓰세요)
        a2 = sample_batch(BATCH)[1].detach()

        with torch.no_grad():
            q2 = target(s2, a2)
            y = r_scaled + GAMMA * (1.0 - d) * q2

        q = critic(s, a)
        td_loss = F.mse_loss(q, y)

        # PA loss: force Q(s, a_I) -> QMIN
        if ALPHA_PA > 0:
            aI = sample_infeasible_actions(BATCH)
            qI = critic(s, aI)
            pa_loss = F.mse_loss(qI, torch.full_like(qI, float(QMIN)))
            loss = td_loss + ALPHA_PA * pa_loss
        else:
            pa_loss = torch.tensor(0.0, device=DEVICE)
            loss = td_loss

        opt.zero_grad()
        loss.backward()
        opt.step()
        soft_update(target, critic)

        if step % 2000 == 0:
            print(f"[step {step}] td={td_loss.item():.4f} pa={pa_loss.item():.4f} total={loss.item():.4f}")

# ---------------------------
# Evaluation & Visualization
# ---------------------------
critic.eval()

# Choose states for evaluation
M = 1024
idx = np.random.randint(0, N, size=(M,))
S_eval = torch.from_numpy(obs[idx]).to(DEVICE)

# Sample dataset actions + infeasible actions
A_data = torch.from_numpy(act[idx]).to(DEVICE)
A_inf  = sample_infeasible_actions(M)

with torch.no_grad():
    Q_data = critic(S_eval, A_data).cpu().numpy()
    Q_inf  = critic(S_eval, A_inf).cpu().numpy()

# (A) Histogram: Q on dataset vs infeasible
plt.figure()
plt.hist(Q_data, bins=60, alpha=0.7, label="Q on dataset actions")
plt.hist(Q_inf,  bins=60, alpha=0.7, label="Q on infeasible actions")
plt.axvline(QMIN, linestyle="--", label="Q_min")
plt.legend()
plt.title("Q distribution: dataset vs infeasible")
plt.xlabel("Q")
plt.ylabel("count")
plt.show()

# (B) Scatter: action norm vs Q
A_data_np = A_data.cpu().numpy()
A_inf_np  = A_inf.cpu().numpy()
norm_data = np.linalg.norm(A_data_np, axis=1) / (np.linalg.norm(UU) + 1e-8)
norm_inf  = np.linalg.norm(A_inf_np,  axis=1) / (np.linalg.norm(UU * GUARD_SCALE) + 1e-8)

plt.figure()
plt.scatter(norm_data, Q_data, s=10, alpha=0.6, label="dataset")
plt.scatter(norm_inf,  Q_inf,  s=10, alpha=0.6, label="infeasible")
plt.axhline(QMIN, linestyle="--", label="Q_min")
plt.legend()
plt.title("Action norm vs Q (should suppress infeasible)")
plt.xlabel("normalized action norm")
plt.ylabel("Q")
plt.show()

# (C) 2D slice heatmap around a fixed state: vary two dims (i,j), keep others at a dataset action
# Choose a single evaluation state/action anchor
s0 = torch.from_numpy(obs[idx[0:1]]).to(DEVICE)
a0 = torch.from_numpy(act[idx[0:1]]).to(DEVICE)  # anchor action
i, j = 0, min(1, act_dim-1)  # choose two dims (change if needed)

grid = 101
x = np.linspace(L[i] - DELTA, U[i] + DELTA, grid).astype(np.float32)  # includes infeasible zone
y = np.linspace(L[j] - DELTA, U[j] + DELTA, grid).astype(np.float32)

A_grid = np.repeat(a0.cpu().numpy(), grid * grid, axis=0)
A_grid[:, i] = np.tile(x, grid)
A_grid[:, j] = np.repeat(y, grid)

S_grid = np.repeat(s0.cpu().numpy(), grid * grid, axis=0)

with torch.no_grad():
    Q_grid = critic(torch.from_numpy(S_grid).to(DEVICE),
                    torch.from_numpy(A_grid).to(DEVICE)).cpu().numpy()
Q_grid = Q_grid.reshape(grid, grid)

plt.figure()
plt.imshow(Q_grid, origin="lower",
           extent=[x.min(), x.max(), y.min(), y.max()],
           aspect="auto")
plt.colorbar(label="Q")
plt.title(f"2D slice of Q(s,a): dims ({i},{j})")
plt.xlabel(f"a[{i}]")
plt.ylabel(f"a[{j}]")
plt.show()


{'source': 'robomimic_hdf5', 'exclude_images': True, 'num_demos': 300}
(195800, 127) (195800, 14) (195800,) (195800, 127) (195800,)
[INFO] N=195800, obs_dim=127, act_dim=14
[INFO] rmin=0.0000, QMIN=0.0000


AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
