In [1]:
import argon.runtime
argon.runtime.setup()

In [2]:
import ipywidgets
import jax
import argon.numpy as jnp
import argon.datasets.env

env_datasets = [
    "pusht/chi",
    "robomimic/pickplace/can/ph",
    "robomimic/nutassembly/square/ph"
]
default = "robomimic/nutassembly/square/ph"
dropdown = ipywidgets.Dropdown(
    options=env_datasets,
    value=default,
    description='Dataset:',
    disabled=False,
)
dropdown

Dropdown(description='Dataset:', index=2, options=('pusht/chi', 'robomimic/pickplace/can/ph', 'robomimic/nutasâ€¦

In [3]:
dataset = argon.datasets.env.datasets.create(dropdown.value)
env = dataset.create_env()

In [4]:
from argon.env.mujoco.robosuite import EEfPose
from argon.core.dataclasses import dataclass
from typing import Any

obs_length = 1
action_length = 1
action_config = EEfPose()

@dataclass
class Sample:
    state: Any
    observations: jax.Array
    actions: jax.Array

def process_data(env, data):
    def process_element(element):
        return env.full_state(element.reduced_state)
    data = data.map_elements(process_element).cache()
    data = data.chunk(
        action_length + obs_length
    )
    def process_chunk(chunk):
        states = chunk.elements
        actions = jax.vmap(lambda s: env.observe(s, action_config))(states)
        actions = jax.tree.map(lambda x: x[-action_length:], actions)
        obs_states = jax.tree.map(lambda x: x[:obs_length], states)
        curr_state = jax.tree.map(lambda x: x[-1], obs_states)
        obs = jax.vmap(env.observe)(obs_states)
        return Sample(curr_state, obs, actions)
    return data.map(process_chunk)
    
train_data = dataset.splits["train"].slice(0,1)
train_data = process_data(env, train_data).cache()
#jax.debug.print("{s}", s=train_data)
#print(train_data.as_pytree().observations.obj_pos)

In [5]:
from argon.policy import PolicyOutput, rollout
from argon.env import ImageRender

actions = train_data.as_pytree().actions
#jax.debug.print("{s}", s=actions[0])
lengths, _ = jax.tree_util.tree_flatten(
            jax.tree.map(lambda x: x.shape[0], actions)
        )
length = lengths[0] + 1
#print(actions.shape)
#print(jax.tree.map(lambda x: x[0], actions).shape)
def actions_policy(input):
    T = input.policy_state if input.policy_state is not None else 0
    action = jax.tree.map(lambda x: x[T], actions)
    #print(action.shape)
    return PolicyOutput(action=action, policy_state=T + 1)

def roll_video(rng_key):
    r = rollout(env.step, train_data[0].state, policy=actions_policy, length=length)
    return jax.vmap(lambda x: env.render(x, ImageRender(256, 256)))(r.states)


In [6]:
from argon.util.ipython import as_video
as_video(roll_video(jax.random.key(42)))

Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...')