In [1]:
import stanza.runtime
stanza.runtime.setup()

In [2]:
import ipywidgets
import jax
import jax.numpy as jnp
import stanza.datasets.env

env_datasets = [
    "pusht/chi",
    "robomimic/pickplace/can/ph",
    "robomimic/nutassembly/square/ph"
]
default = "robomimic/pickplace/can/ph"
dropdown = ipywidgets.Dropdown(
    options=env_datasets,
    value=default,
    description='Dataset:',
    disabled=False,
)
dropdown

Dropdown(description='Dataset:', index=1, options=('pusht/chi', 'robomimic/pickplace/can/ph', 'robomimic/nutas…

In [3]:
dataset = stanza.datasets.env.datasets.create(dropdown.value)
env = dataset.create_env()

In [4]:
from stanza.env.mujoco.robosuite import ManipulationTaskEEFPose
from stanza.dataclasses import dataclass
from typing import Any

obs_length = 1
action_length = 1
action_config = ManipulationTaskEEFPose()

@dataclass
class Sample:
    state: Any
    observations: jax.Array
    actions: jax.Array

def process_data(env, data):
    def process_element(element):
        return env.full_state(element.reduced_state)
    data = data.map_elements(process_element).cache()
    data = data.chunk(
        action_length + obs_length
    )
    def process_chunk(chunk):
        states = chunk.elements
        actions = jax.vmap(lambda s: env.observe(s, action_config))(states)
        actions = jax.tree.map(lambda x: x[-action_length:], actions)
        obs_states = jax.tree.map(lambda x: x[:obs_length], states)
        curr_state = jax.tree.map(lambda x: x[-1], obs_states)
        obs = jax.vmap(env.observe)(obs_states)
        return Sample(curr_state, obs, actions)
    return data.map(process_chunk)
    
train_data = dataset.splits["train"].slice(0,1)
train_data = process_data(env, train_data).cache()
#jax.debug.print("{s}", s=train_data)

In [5]:
from stanza.policy import PolicyOutput, rollout
from stanza.env import ImageRender

actions = train_data.as_pytree().actions
jax.debug.print("{s}", s=actions[0])
lengths, _ = jax.tree_util.tree_flatten(
            jax.tree.map(lambda x: x.shape[0], actions)
        )
length = lengths[0] + 1
#print(actions.shape)
#print(jax.tree.map(lambda x: x[0], actions).shape)
def actions_policy(input):
    T = input.policy_state if input.policy_state is not None else 0
    action = jax.tree.map(lambda x: x[T], actions)
    #print(action.shape)
    return PolicyOutput(action=action, policy_state=T + 1)

def roll_video(rng_key):
    r = rollout(env.step, train_data[0].state, policy=actions_policy, length=length)
    return jax.vmap(lambda x: env.render(x, ImageRender(128, 128)))(r.states)


[[[-6.20085485e-02 -8.79978538e-02  9.95229006e-01]]

 [[-6.20239601e-02 -9.23529863e-02  9.98414993e-01]]

 [[-6.27634749e-02 -9.58189145e-02  9.99652207e-01]]

 [[-6.23682551e-02 -1.00499585e-01  1.00056314e+00]]

 [[-6.12141564e-02 -1.04648940e-01  1.00070846e+00]]

 [[-5.94969355e-02 -1.09998569e-01  1.00091493e+00]]

 [[-5.73470220e-02 -1.14708036e-01  1.00072885e+00]]

 [[-5.45609184e-02 -1.18481725e-01  9.99664307e-01]]

 [[-5.20782135e-02 -1.24501646e-01  9.98422205e-01]]

 [[-4.84721549e-02 -1.31211549e-01  9.96753693e-01]]

 [[-4.45050187e-02 -1.38112128e-01  9.94727969e-01]]

 [[-3.94437350e-02 -1.45298824e-01  9.91755903e-01]]

 [[-3.36112306e-02 -1.53464556e-01  9.87206995e-01]]

 [[-2.67804824e-02 -1.61174923e-01  9.81981635e-01]]

 [[-1.95862986e-02 -1.68955281e-01  9.76845264e-01]]

 [[-1.26574738e-02 -1.76405922e-01  9.72982943e-01]]

 [[-4.91180504e-03 -1.82722196e-01  9.68753338e-01]]

 [[ 3.70044494e-03 -1.88231483e-01  9.63757932e-01]]

 [[ 1.26765426e-02 -1.933785

In [6]:
from stanza.util.ipython import as_video
as_video(roll_video(jax.random.key(42)))

Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...')