In [None]:
import os
import sys
import numpy as np
import imageio
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Union, Optional
from pathlib import Path
import torch

# Configure paths
sys.path.insert(0, "/user_data/sraychau/Storage/ActualProject/openvla-main/experiments")
sys.path.insert(0, "/user_data/sraychau/Storage/ActualProject/Lec14a_camera_sensor")

# Project imports
from Lec14a_camera_sensor.mujoco_gym_env import MujocoSimpleEnv
from robot.robot_utils import (
    get_model,
    get_action,
    get_image_resize_size,
    set_seed_everywhere,
    invert_gripper_action,
)
from robot.openvla_utils import get_processor
from prismatic.models.vlas.sparse_autoencoder import SparseAutoencoderTorch

In [None]:
# Hardcoded configuration
PRETRAINED_CHECKPOINT = "OpenVLA_Model/"
NUM_EPISODES = 2
MAX_STEPS = 50
IMAGE_WIDTH = 128
IMAGE_HEIGHT = 128
SEED = 42
ROLLOUT_DIR = "./rollouts_mujoco_simple"

In [None]:
@dataclass
class GenerateConfig:
    # Model parameters
    model_family: str = 'openvla'
    pretrained_checkpoint: Union[str, Path] = PRETRAINED_CHECKPOINT
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    center_crop: bool = False
    unnorm_key: Optional[str] = None
    # Env parameters
    num_episodes: int = NUM_EPISODES
    max_steps: int = MAX_STEPS
    image_width: int = IMAGE_WIDTH
    image_height: int = IMAGE_HEIGHT
    seed: int = SEED
    # Output
    rollout_dir: str = ROLLOUT_DIR

def save_rollout_video(rollout_images, idx, success, rollout_dir=ROLLOUT_DIR):
    os.makedirs(rollout_dir, exist_ok=True)
    mp4_path = os.path.join(rollout_dir, f'episode_{idx}_success_{success}.mp4')
    video_writer = imageio.get_writer(mp4_path, fps=30)
    for img in rollout_images:
        video_writer.append_data(img)
    video_writer.close()
    print(f'Saved rollout MP4 at path {mp4_path}')
    return mp4_path

def visualize_rollout(rollout_images, episode_idx):
    n_images = len(rollout_images)
    n_cols = min(8, n_images)
    n_rows = int(np.ceil(n_images / n_cols))
    plt.figure(figsize=(2 * n_cols, 2 * n_rows))
    for i, img in enumerate(rollout_images):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(img)
        plt.axis('off')
    plt.suptitle(f'Rollout Visualization - Episode {episode_idx}')
    plt.tight_layout()
    plt.show()