In [1]:
import imageio
import numpy
import gym_pusht  # noqa: F401
import gymnasium as gym
import matplotlib.pyplot as plt
from IPython.display import clear_output
from collections import deque
from cvae_utilities import *
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(
    project="difftop-training",
    config={
        "horizon": 7,
        "action_dim": 2,
        "state_dim": 512*3*3 + 2,
        "batch_size": 16,
        "latent_dim_state": 50,
        "latent_dim_action": 2,  # Same as action_dim
        "posterior_dim": 64,
        "beta": 1.0,
        "training_steps": 5000,
        "learning_rate": 3e-4,
        "max_iterations": 50,
        "step_size": 1e-4,
        "damping": 1e-3,
        "min_std": 1e-4,
        "discount": 0.99,
        "log_freq": 1,
        "eval_freq": 100,
        "output_directory": "./output",
        "device": torch.device("cuda:0")
    },
    resume=True  # Enable resuming previous run
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlawrence-rs-lin[0m ([33mlawrence-rs-lin-university-of-toronto[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:

policy = CVAEWithTrajectoryOptimization(wandb.config)

In [6]:
torch.load(f"{wandb.config.output_directory}/best_model_/model.pt",map_location=wandb.config.device)

  torch.load(f"{wandb.config.output_directory}/best_model_/model.pt",map_location=wandb.config.device)


{'state_encoder': OrderedDict([('obs_encoder.conv1.weight',
               tensor([[[[ 2.5355e-02,  5.3796e-02, -3.6716e-02,  ...,  7.8015e-02,
                           1.4165e-02, -8.3887e-03],
                         [-2.0371e-02,  5.3505e-02, -5.8239e-02,  ...,  7.4255e-02,
                          -3.9345e-02,  3.0756e-03],
                         [-2.4187e-02,  1.5720e-01, -7.5871e-02,  ...,  8.4658e-02,
                          -4.2633e-02,  3.6720e-02],
                         ...,
                         [ 4.5052e-02,  1.6086e-01, -3.6101e-01,  ...,  7.1354e-02,
                          -1.6139e-02,  2.0667e-02],
                         [ 6.6252e-02,  1.4736e-02, -2.2778e-01,  ...,  7.2878e-03,
                           3.4059e-02,  1.3289e-02],
                         [ 2.8422e-02, -2.4875e-03, -1.1260e-01,  ..., -2.7386e-02,
                           3.6710e-02, -1.9755e-02]],
               
                        [[-4.4240e-02,  1.0913e-01,  5.4083e-02,  ..., 

In [None]:


env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
    render_mode="rgb_array"
)

# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []
numpy_observation, info = env.reset(seed=123)

# Initialize action queue
action_queue = deque(maxlen=3)

# Render frame of the initial state
img = env.render()
frames.append(img)

# Display initial frame
plt.figure(figsize=(8,6))
plt.imshow(img)
plt.axis('off')
plt.show()

step = 0
done = False
while not done:
    # Generate new actions when queue is empty
    if len(action_queue) == 0:
        # Prepare observation for the policy running in Pytorch
        state = torch.from_numpy(numpy_observation["agent_pos"])
        image = torch.from_numpy(numpy_observation["pixels"])

        # Convert to float32 with image from channel first in [0,255]
        # to channel last in [0,1]
        state = state.to(torch.float32)
        image = image.to(torch.float32) / 255
        image = image.permute(2, 0, 1)

        # Send data tensors from CPU to GPU
        state = state.to(device, non_blocking=True)
        image = image.to(device, non_blocking=True)
        action = torch.randn(1,28).cuda()

        # Add extra (empty) batch dimension, required to forward the policy
        state = state.unsqueeze(0).unsqueeze(0)/512 * 2 - 1
        obs = image.unsqueeze(0).unsqueeze(0)

        # Predict the next actions with respect to the current observation
        with torch.no_grad():
            output_dict = policy.plan_with_theseus_update(obs, state, action, 7, cfg.discount, cfg, eval_mode=True)

        print(output_dict['best_actions'].shape)
        
        # Get first 3 actions and add to queue
        actions = output_dict['best_actions'][0][0].reshape(-1, 2)[:3]
        print(actions)
        for act in actions:
            numpy_act = (act.numpy() + 1)/2 * 512
            action_queue.append(numpy_act)

    # Get next action from queue
    numpy_action = action_queue.popleft()

    # Step through the environment and receive a new observation
    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
    print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    img = env.render()
    frames.append(img)
    
    # Display frame
    clear_output(wait=True)
    plt.figure(figsize=(8,6))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

    # The rollout is considered done when the success state is reach (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated | truncated | done
    step += 1

if terminated:
    print("Success!")
else:
    print("Failure!")

# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]

# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

print(f"Video of the evaluation is available in '{video_path}'.")