In [None]:
import sys
import os

# ABSOLUTE PATH to your experiments directory (edit if needed!)
sys.path.insert(0, "/user_data/sraychau/Storage/ActualProject/openvla-main/experiments")
sys.path.insert(0, "/user_data/sraychau/Storage/ActualProject/Lec14a_camera_sensor")

import numpy as np
import imageio
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Union, Optional
from pathlib import Path
import torch

# --- Project-specific imports ---
from Lec14a_camera_sensor.mujoco_gym_env import MujocoSimpleEnv
from robot.robot_utils import (
    get_model,
    get_action,
    get_image_resize_size,
    set_seed_everywhere,
    invert_gripper_action,
)
from robot.openvla_utils import get_processor


In [None]:
@dataclass
class GenerateConfig:
    # Model parameters
    model_family: str = "openvla"
    pretrained_checkpoint: Union[str, Path] = ""
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    center_crop: bool = False
    unnorm_key: Optional[str] = None
    # Env parameters
    num_episodes: int = 5
    max_steps: int = 100
    image_width: int = 128
    image_height: int = 128
    seed: int = 42
    # Output
    rollout_dir: str = "./rollouts_mujoco_simple"

def save_rollout_video(rollout_images, idx, success, rollout_dir):
    os.makedirs(rollout_dir, exist_ok=True)
    mp4_path = os.path.join(rollout_dir, f"episode_{idx}_success_{success}.mp4")
    video_writer = imageio.get_writer(mp4_path, fps=30)
    for img in rollout_images:
        video_writer.append_data(img)
    video_writer.close()
    print(f"Saved rollout MP4 at path {mp4_path}")
    return mp4_path

def visualize_rollout(rollout_images, episode_idx):
    n_images = len(rollout_images)
    n_cols = min(8, n_images)
    n_rows = int(np.ceil(n_images / n_cols))
    plt.figure(figsize=(2 * n_cols, 2 * n_rows))
    for i, img in enumerate(rollout_images):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(img)
        plt.axis('off')
    plt.suptitle(f"Rollout Visualization - Episode {episode_idx}")
    plt.tight_layout()
    plt.show()


In [None]:
# Set your parameters here
cfg = GenerateConfig(
    pretrained_checkpoint="/user_data/sraychau/Storage/ActualProject/OpenVLA_Model",  # <-- Set this!
    num_episodes=2,
    max_steps=50,
    image_width=128,
    image_height=128,
    seed=42,
    rollout_dir="./rollouts_mujoco_simple"
)
set_seed_everywhere(cfg.seed)


In [None]:
model = get_model(cfg)
processor = get_processor(cfg)
cfg.unnorm_key = "bridge_orig"  # Use bridge_orig dataset statistics for normalization

env = MujocoSimpleEnv(image_width=cfg.image_width, image_height=cfg.image_height)
print("Environment created successfully")


In [None]:
for ep in range(cfg.num_episodes):
    print(f"Starting episode {ep}")
    obs, _ = env.reset()
    print(f"Episode {ep} reset complete, observation shape: {obs['image'].shape}")
    done = False
    t = 0
    rollout_images = []
    sparse_codes_episode = []
    success = False
    while not done and t < cfg.max_steps:
        # Prepare observation for model
        observation = {
            "full_image": obs["image"],
            "state": obs["proprio"],
        }
        # Query model for action
        action, sparse_code = get_action(cfg, model, observation, task_label="move block", processor=processor)
        if sparse_code is not None:
            sparse_codes_episode.append(sparse_code)
        # Invert gripper action if needed (OpenVLA convention)
        action = invert_gripper_action(action)
        # Step environment
        obs, reward, done, _, info = env.step(action)
        rollout_images.append(observation["full_image"])  # Save image for video
        t += 1
    print(f"Episode {ep} finished, saving video and visualizing...")
    save_rollout_video(rollout_images, ep, success=done, rollout_dir=cfg.rollout_dir)
    visualize_rollout(rollout_images, ep)
    if sparse_codes_episode:
        sparse_codes_arr = np.concatenate(sparse_codes_episode, axis=0)
        np.save(os.path.join(cfg.rollout_dir, f"episode_{ep}_sparse_codes.npy"), sparse_codes_arr)
        print(f"Saved sparse codes for episode {ep}, shape: {sparse_codes_arr.shape}")
print("All episodes finished")
env.close()
