In [1]:
from pathlib import Path

import gym_pusht  # noqa: F401
import gymnasium as gym
import imageio
import numpy as np
import torch
from huggingface_hub import snapshot_download

from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#  Create a directory to store the video of the evaluation
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/2024-11-26/18-20-16_pusht_diffusion_default")

policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Device set to:", device)
else:
    device = torch.device("cpu")
    print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
    # Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
    policy.diffusion.num_inference_steps = 10

policy.to(device)



Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 73002.13it/s]


Loading weights from local directory
GPU is available. Device set to: cuda


DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (diffusion): Diff

In [6]:
# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
# also automatically stops running after 300 interactions/steps.
BATCH_SIZE = 100
env = gym.vector.make(
    "gym_pusht/PushT-v0",
    num_envs=BATCH_SIZE,
    obs_type="pixels_agent_pos",
    max_episode_steps= 200,
)

  gym.logger.warn(


In [None]:
# Reset the policy and environments
BATCH_SIZE = 100
env = gym.vector.make(
    "gym_pusht/PushT-v0",
    num_envs=BATCH_SIZE,
    obs_type="pixels_agent_pos",
    max_episode_steps= 200,
)
policy.reset()
batch_observations, info = env.reset(seed=42)

# Prepare to collect rewards and frames for each environment in the batch
rewards = [[] for _ in range(BATCH_SIZE)]
frames = [[] for _ in range(BATCH_SIZE)]

# Render the initial frames
initial_frames = env.call("render")
for i in range(BATCH_SIZE):
    frames[i].append(initial_frames[i])

done = [False] * BATCH_SIZE
steps = [0] * BATCH_SIZE

while not all(done):
    # Prepare batched observations
    states = torch.from_numpy(batch_observations["agent_pos"]).to(torch.float32).to(device)
    images = torch.from_numpy(batch_observations["pixels"]).to(torch.float32) / 255
    images = images.permute(0, 3, 1, 2).to(device)  # Convert to channel-first format

    # Create batched input dictionary for the policy
    observations = {
        "observation.state": states,
        "observation.image": images,
    }

    # Predict actions in batch
    with torch.inference_mode():
        actions = policy.select_action(observations)

    numpy_actions = actions.cpu().numpy()

    # Step through all environments in the batch
    batch_observations, batch_rewards, batch_terminated, batch_truncated, infos = env.step(numpy_actions)

    # Track rewards, frames, and done status for each environment
    frames_ = env.call("render")
    for i in range(BATCH_SIZE):
        if not done[i]:
            rewards[i].append(batch_rewards[i])
            frames[i].append(frames_[i])
            done[i] = batch_terminated[i] or batch_truncated[i]
            steps[i] += 1
            print(f"Environment {i} has completed {steps[i]} steps. reward={batch_rewards[i]}")

# Encode videos for each environment
for i in range(BATCH_SIZE):
    video_path = output_directory / f"rollout_env_{i}.mp4"
    imageio.mimsave(str(video_path), np.stack(frames[i]), fps=env.metadata["render_fps"])
    print(f"Video for environment {i} is available at '{video_path}'.")

print("Batch evaluation complete.")

Environment 0 has completed 1 steps. reward=0.0
Environment 1 has completed 1 steps. reward=0.0
Environment 2 has completed 1 steps. reward=0.03055885971007947
Environment 3 has completed 1 steps. reward=0.017677395962375143
Environment 4 has completed 1 steps. reward=0.0002236514599414307
Environment 5 has completed 1 steps. reward=0.0
Environment 6 has completed 1 steps. reward=0.0
Environment 7 has completed 1 steps. reward=0.010031352981896939
Environment 8 has completed 1 steps. reward=0.0
Environment 9 has completed 1 steps. reward=0.0
Environment 10 has completed 1 steps. reward=0.0
Environment 11 has completed 1 steps. reward=0.0
Environment 12 has completed 1 steps. reward=0.0
Environment 13 has completed 1 steps. reward=0.0
Environment 14 has completed 1 steps. reward=0.0
Environment 15 has completed 1 steps. reward=0.24366805933367677
Environment 16 has completed 1 steps. reward=0.0
Environment 17 has completed 1 steps. reward=0.0904629963138023
Environment 18 has completed 



Video for environment 0 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_0.mp4'.




Video for environment 1 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_1.mp4'.




Video for environment 2 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_2.mp4'.




Video for environment 3 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_3.mp4'.




Video for environment 4 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_4.mp4'.




Video for environment 5 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_5.mp4'.




Video for environment 6 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_6.mp4'.




Video for environment 7 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_7.mp4'.




Video for environment 8 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_8.mp4'.




Video for environment 9 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_9.mp4'.




Video for environment 10 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_10.mp4'.




Video for environment 11 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_11.mp4'.




Video for environment 12 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_12.mp4'.




Video for environment 13 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_13.mp4'.




Video for environment 14 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_14.mp4'.




Video for environment 15 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_15.mp4'.




Video for environment 16 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_16.mp4'.




Video for environment 17 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_17.mp4'.




Video for environment 18 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_18.mp4'.




Video for environment 19 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_19.mp4'.




Video for environment 20 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_20.mp4'.




Video for environment 21 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_21.mp4'.




Video for environment 22 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_22.mp4'.




Video for environment 23 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_23.mp4'.




Video for environment 24 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_24.mp4'.




Video for environment 25 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_25.mp4'.




Video for environment 26 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_26.mp4'.




Video for environment 27 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_27.mp4'.




Video for environment 28 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_28.mp4'.




Video for environment 29 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_29.mp4'.




Video for environment 30 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_30.mp4'.




Video for environment 31 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_31.mp4'.




Video for environment 32 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_32.mp4'.




Video for environment 33 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_33.mp4'.




Video for environment 34 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_34.mp4'.




Video for environment 35 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_35.mp4'.




Video for environment 36 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_36.mp4'.




Video for environment 37 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_37.mp4'.




Video for environment 38 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_38.mp4'.




Video for environment 39 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_39.mp4'.




Video for environment 40 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_40.mp4'.




Video for environment 41 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_41.mp4'.




Video for environment 42 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_42.mp4'.




Video for environment 43 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_43.mp4'.




Video for environment 44 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_44.mp4'.




Video for environment 45 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_45.mp4'.




Video for environment 46 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_46.mp4'.




Video for environment 47 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_47.mp4'.




Video for environment 48 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_48.mp4'.




Video for environment 49 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_49.mp4'.




Video for environment 50 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_50.mp4'.




Video for environment 51 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_51.mp4'.




Video for environment 52 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_52.mp4'.




Video for environment 53 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_53.mp4'.




Video for environment 54 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_54.mp4'.




Video for environment 55 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_55.mp4'.




Video for environment 56 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_56.mp4'.




Video for environment 57 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_57.mp4'.




Video for environment 58 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_58.mp4'.




Video for environment 59 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_59.mp4'.




Video for environment 60 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_60.mp4'.




Video for environment 61 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_61.mp4'.




Video for environment 62 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_62.mp4'.




Video for environment 63 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_63.mp4'.




Video for environment 64 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_64.mp4'.




Video for environment 65 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_65.mp4'.




Video for environment 66 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_66.mp4'.




Video for environment 67 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_67.mp4'.




Video for environment 68 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_68.mp4'.




Video for environment 69 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_69.mp4'.




Video for environment 70 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_70.mp4'.




Video for environment 71 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_71.mp4'.




Video for environment 72 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_72.mp4'.




Video for environment 73 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_73.mp4'.




Video for environment 74 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_74.mp4'.




Video for environment 75 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_75.mp4'.




Video for environment 76 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_76.mp4'.




Video for environment 77 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_77.mp4'.




Video for environment 78 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_78.mp4'.




Video for environment 79 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_79.mp4'.




Video for environment 80 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_80.mp4'.




Video for environment 81 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_81.mp4'.




Video for environment 82 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_82.mp4'.




Video for environment 83 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_83.mp4'.




Video for environment 84 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_84.mp4'.




Video for environment 85 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_85.mp4'.




Video for environment 86 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_86.mp4'.




Video for environment 87 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_87.mp4'.




Video for environment 88 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_88.mp4'.




Video for environment 89 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_89.mp4'.




Video for environment 90 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_90.mp4'.




Video for environment 91 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_91.mp4'.




Video for environment 92 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_92.mp4'.




Video for environment 93 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_93.mp4'.




Video for environment 94 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_94.mp4'.




Video for environment 95 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_95.mp4'.




Video for environment 96 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_96.mp4'.




Video for environment 97 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_97.mp4'.




Video for environment 98 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_98.mp4'.




Video for environment 99 is available at 'outputs/eval/example_pusht_diffusion/rollout_env_99.mp4'.
Batch evaluation complete.


In [18]:
# Reset the policy and environmens to prepare for rollout

numpy_observation, info = env.reset(seed=1234)

# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []

# Render frame of the initial state
frames.append(env.render())

step = 0
done = False
while not done:
    # Prepare observation for the policy running in Pytorch
    state = torch.from_numpy(numpy_observation["agent_pos"])
    image = torch.from_numpy(numpy_observation["pixels"])

    # Convert to float32 with image from channel first in [0,255]
    # to channel last in [0,1]
    state = state.to(torch.float32)
    image = image.to(torch.float32) / 255
    image = image.permute(2, 0, 1)

    # Send data tensors from CPU to GPU
    state = state.to(device, non_blocking=True)
    image = image.to(device, non_blocking=True)

    # Add extra (empty) batch dimension, required to forward the policy
    state = state.unsqueeze(0)
    image = image.unsqueeze(0)

    # Create the policy input dictionary
    observation = {
        "observation.state": state,
        "observation.image": image,
    }
    
    # Predict the next action with respect to the current observation
    with torch.inference_mode():
        action = policy.select_action(observation)

    # Prepare the action for the environment
    numpy_action = action.squeeze(0).to("cpu").numpy()
    # print(numpy_action)
    # Step through the environment and receive a new observation
    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
    print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    frames.append(env.render())

    # The rollout is considered done when the success state is reach (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated | truncated | done
    step += 1



step=0 reward=0.3549526479034709 terminated=False
step=1 reward=0.3549526479034709 terminated=False
step=2 reward=0.3549526479034709 terminated=False
step=3 reward=0.3549526479034709 terminated=False
step=4 reward=0.3549526479034709 terminated=False
step=5 reward=0.3549526479034709 terminated=False
step=6 reward=0.3549526479034709 terminated=False
step=7 reward=0.3549526479034709 terminated=False
step=8 reward=0.3549526479034709 terminated=False
step=9 reward=0.3549526479034709 terminated=False
step=10 reward=0.3549526479034709 terminated=False
step=11 reward=0.3549526479034709 terminated=False
step=12 reward=0.3549526479034709 terminated=False
step=13 reward=0.3583105184393292 terminated=False
step=14 reward=0.37652863787608654 terminated=False
step=15 reward=0.39283125939387686 terminated=False
step=16 reward=0.36670874447322566 terminated=False
step=17 reward=0.3135768316292473 terminated=False
step=18 reward=0.26680185903854986 terminated=False
step=19 reward=0.22731800821590223 te

In [19]:
if terminated:
    print("Success!")
else:
    print("Failure!")

# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]

# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

print(f"Video of the evaluation is available in '{video_path}'.")

Failure!




Video of the evaluation is available in 'outputs/eval/example_pusht_diffusion/rollout.mp4'.
