In [None]:
import mediapy as media
import mujoco
import mujoco.viewer
import numpy as np
import tqdm
from dm_control import composer

from aloha_sim import task_suite
from aloha_sim.aloha import Aloha
from aloha_sim.motion_planner import Planner
from aloha_sim.tasks.base.aloha2_task import (
    right_qpos_to_ctrl,
)
from aloha_sim.utils.visualization import add_frame_to_renderer
from aloha_sim.scripted_policies.marker import get_pick_pose
from aloha_sim.scripted_policies.marker import Policy

In [None]:
task_name = "MarkerRemoveLid"

In [None]:
_DT = 0.02
_IMAGE_SIZE = (480, 848)
_ALOHA_CAMERAS = {
    "overhead_cam": _IMAGE_SIZE,
    "worms_eye_cam": _IMAGE_SIZE,
    "wrist_cam_left": _IMAGE_SIZE,
    "wrist_cam_right": _IMAGE_SIZE,
}

In [None]:
task_class, kwargs = task_suite.TASK_FACTORIES[task_name]
task = task_class(
    cameras=_ALOHA_CAMERAS,
    control_timestep=_DT,
    update_interval=1,
    image_observation_delay_secs=0.0,
    **kwargs,
)
env = composer.Environment(
    task=task,
    time_limit=float("inf"),  # No explicit time limit from the environment
    random_state=np.random.RandomState(0),  # For reproducibility
    recompile_mjcf_every_episode=False,
    strip_singleton_obs_buffer_dim=True,
    delayed_observation_padding=composer.ObservationPadding.INITIAL_VALUE,
)
time_step = env.reset()

In [None]:
# random_qpos = env.random_state.uniform(joint_limits[:, 0], joint_limits[:, 1])

In [None]:
media.show_image(env.physics.render())

In [None]:
policy = Policy(env)

In [None]:
from collections import defaultdict

frames = defaultdict(list)


def add_cameras(timestep):
    for camera_name in _ALOHA_CAMERAS:
        frames[camera_name].append(timestep.observation[camera_name])


timestep = env.reset()
add_cameras(timestep)
# policy.setup()
policy.reset()

while not policy.is_done():
    action = policy.step(timestep.observation)
    timestep = env.step(action)
    add_cameras(timestep)

In [None]:
media.show_videos(frames, fps=1.0 / 0.01)

In [None]:
# _ = env.reset()

# Visual options
camera = -1

with mujoco.Renderer(env.physics.model.ptr, *_IMAGE_SIZE) as renderer:
    # Run to make the scene
    renderer.update_scene(env.physics.data.ptr, camera)
    add_frame_to_renderer(renderer, env.physics.model.ptr, env.physics.data.ptr)
    # Render the scene
    media.show_image(renderer.render())