In [None]:
import furniture_bench  # noqa: F401

from src.behavior.base import Actor
from src.eval.load_model import load_bc_actor

import wandb
import numpy as np
import torch

from omegaconf import OmegaConf

from furniture_bench.envs.observation import DEFAULT_STATE_OBS
import hydra
from src.gym import turn_off_april_tags
from src.gym.env_rl_wrapper import ResidualPolicyEnvWrapper
from src.gym.furniture_sim_env import FurnitureRLSimEnv
from src.models.residual import ResidualPolicy

from tqdm import trange


turn_off_april_tags()

api = wandb.Api()

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
residual_run = api.run("residual-ppo-2/runs/sdoolikr")
# residual_run = api.run("residual-ppo-2/runs/3iom50to")

cfg = OmegaConf.create(
    {
        **residual_run.config,
        "env": {"randomness": "low"},
        "base_bc_poliy": "ankile/one_leg-diffusion-state-1/runs/7623y5vn",
    }
)

cfg

In [None]:
env: FurnitureRLSimEnv = FurnitureRLSimEnv(
    act_rot_repr=cfg.act_rot_repr,
    action_type=cfg.action_type,
    april_tags=False,
    concat_robot_state=True,
    ctrl_mode="diffik",
    obs_keys=DEFAULT_STATE_OBS,
    furniture="one_leg",
    gpu_id=0,
    headless=True,  # cfg.headless,
    num_envs=128,  # cfg.num_envs,
    observation_space="state",
    randomness=cfg.env.randomness,
    max_env_steps=100_000_000,
)

env.max_force_magnitude = 0.05
env.max_torque_magnitude = 0.0025

# Load the behavior cloning actor
bc_actor: Actor = load_bc_actor(cfg.base_bc_poliy)

env: ResidualPolicyEnvWrapper = ResidualPolicyEnvWrapper(
    env,
    max_env_steps=cfg.num_env_steps,
    reset_on_success=cfg.reset_on_success,
    reset_on_failure=cfg.reset_on_failure,
)
env.set_normalizer(bc_actor.normalizer)

In [None]:

# Residual policy setup
residual_policy: ResidualPolicy = hydra.utils.instantiate(
    cfg.residual_policy,
    obs_shape=env.observation_space.shape,
    action_shape=env.action_space.shape,
)

# Load the residual policy weights
wts = [f for f in residual_run.files() if ".pt" in f.name][0]
wts.download(replace=True)

residual_policy.load_state_dict(torch.load(wts.name)["model_state_dict"])

residual_policy.to(device)

In [None]:
next_obs = env.reset()
bc_actor.reset()

total_reward = 0


for step in trange(0, 850):

    # Get the base normalized action
    base_naction = bc_actor.action_normalized(next_obs)

    # Process the obs for the residual policy
    next_obs = env.process_obs(next_obs)
    next_residual_obs = torch.cat([next_obs, base_naction], dim=-1)

    with torch.no_grad():
        residual_naction_samp, logprob, _, value, naction_mean = (
            residual_policy.get_action_and_value(next_residual_obs)
        )

    residual_naction = naction_mean
    naction = base_naction + residual_naction * cfg.residual_policy.action_scale

    next_obs, reward, next_done, truncated, infos = env.step(naction)

    total_reward += reward.sum()


# Calculate the success rate
total_reward / env.num_envs

In [None]:
# Only base policy: 51%
# With ok residual: 87%
# With better residual: 95%