In [1]:
import stanza.runtime
stanza.runtime.setup()

In [2]:
import ipywidgets
import jax
import jax.numpy as jnp
import stanza.datasets.env

env_datasets = [
    "pusht/chi",
    "robomimic/pickplace/can/ph",
    "robomimic/nutassembly/square/ph"
]
default = "robomimic/pickplace/can/ph"
dropdown = ipywidgets.Dropdown(
    options=env_datasets,
    value=default,
    description='Dataset:',
    disabled=False,
)
dropdown

Dropdown(description='Dataset:', index=1, options=('pusht/chi', 'robomimic/pickplace/can/ph', 'robomimic/nutas…

In [3]:
dataset = stanza.datasets.env.datasets.create(dropdown.value)
env = dataset.create_env()

In [4]:
from stanza.env.mujoco.robosuite import ManipulationTaskEEFPose
from stanza.dataclasses import dataclass
from typing import Any

obs_length = 1
action_length = 1
action_config = ManipulationTaskEEFPose()

@dataclass
class Sample:
    state: Any
    observations: jax.Array
    actions: jax.Array

def process_data(env, data):
    def process_element(element):
        return env.full_state(element.reduced_state)
    data = data.map_elements(process_element).cache()
    data = data.chunk(
        action_length + obs_length
    )
    def process_chunk(chunk):
        states = chunk.elements
        actions = jax.vmap(lambda s: env.observe(s, action_config))(states)
        actions = jax.tree.map(lambda x: x[-action_length:], actions)
        obs_states = jax.tree.map(lambda x: x[:obs_length], states)
        curr_state = jax.tree.map(lambda x: x[-1], obs_states)
        obs = jax.vmap(env.observe)(obs_states)
        return Sample(curr_state, obs, actions)
    return data.map(process_chunk)
    
train_data = dataset.splits["train"].slice(2,1)
train_data = process_data(env, train_data).cache()
#jax.debug.print("{s}", s=train_data)
print(train_data.as_pytree().observations.obj_pos)

[[[[-0.05744484 -0.08471354  1.0990576 ]]]


 [[[-0.05798843 -0.08456369  1.0989343 ]]]


 [[[-0.05706813 -0.08962288  1.0984526 ]]]


 [[[-0.05369795 -0.09856503  1.0976169 ]]]


 [[[-0.04845822 -0.10885226  1.0967803 ]]]


 [[[-0.04202615 -0.11960639  1.0968764 ]]]


 [[[-0.03471596 -0.13086478  1.0972005 ]]]


 [[[-0.02615533 -0.14249249  1.0963988 ]]]


 [[[-0.01668859 -0.15434335  1.0945865 ]]]


 [[[-0.00663048 -0.16608125  1.0922397 ]]]


 [[[ 0.00310565 -0.17783043  1.0891715 ]]]


 [[[ 0.01221754 -0.18891025  1.0865681 ]]]


 [[[ 0.02052312 -0.19959556  1.0847088 ]]]


 [[[ 0.02872953 -0.2101968   1.0832103 ]]]


 [[[ 0.03667995 -0.22076967  1.0819963 ]]]


 [[[ 0.04457056 -0.2309938   1.0813384 ]]]


 [[[ 0.05160362 -0.24030407  1.0808965 ]]]


 [[[ 0.05852003 -0.24895504  1.081024  ]]]


 [[[ 0.06550799 -0.2567253   1.0805848 ]]]


 [[[ 0.07246666 -0.26364788  1.0781631 ]]]


 [[[ 0.07902813 -0.2691421   1.0737363 ]]]


 [[[ 0.08512916 -0.2736114   1.0679866 ]]]


 [[[ 0.090

In [5]:
from stanza.policy import PolicyOutput, rollout
from stanza.env import ImageRender

actions = train_data.as_pytree().actions
#jax.debug.print("{s}", s=actions[0])
lengths, _ = jax.tree_util.tree_flatten(
            jax.tree.map(lambda x: x.shape[0], actions)
        )
length = lengths[0] + 1
#print(actions.shape)
#print(jax.tree.map(lambda x: x[0], actions).shape)
def actions_policy(input):
    T = input.policy_state if input.policy_state is not None else 0
    action = jax.tree.map(lambda x: x[T], actions)
    #print(action.shape)
    return PolicyOutput(action=action, policy_state=T + 1)

def roll_video(rng_key):
    r = rollout(env.step, train_data[0].state, policy=actions_policy, length=length)
    return jax.vmap(lambda x: env.render(x, ImageRender(256, 256)))(r.states)


In [6]:
from stanza.util.ipython import as_video
as_video(roll_video(jax.random.key(42)))

Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...')