In [9]:
from stable_baselines3 import PPO
import gymnasium as gym
from robot_animation.data_processing import process_raw_robot_data, robot_data_to_qpos_qvel

import numpy as np
# Load the trained model from the zip file
model = PPO.load("../models/ppo_robot_animation_qkqryqza/model.zip")
csv_path = "../data/kuka_2.csv"

frame_rate = 153

animation_df = process_raw_robot_data(csv_path)
target_qpos, _ = robot_data_to_qpos_qvel(animation_df, num_q=7)

target_qvel = np.zeros_like(target_qpos)
target_qvel[1:] = (target_qpos[1:] - target_qpos[:-1]) * frame_rate # TODO: shift this upstream
target_qvel[0] = np.zeros(target_qpos.shape[1])  


eval_env = gym.make(
    "RobotAnimationEnv-kuka",
    animation_frame_rate=frame_rate,
    target_qpos=target_qpos,
    target_qvel=target_qvel,
    num_q=7,
    reset_noise_scale=0.01,
)


eval_env.reset()

def evaluate_policy(model: PPO, env: gym.Env, num_episodes: int = 1) -> list[np.ndarray]:
    """
    Evaluate the trained policy and return frames for visualization.
    
    Args:
        model: Trained PPO model
        env: Environment to evaluate in
        num_episodes: Number of episodes to run
        
    Returns:
        List of frames from all episodes
    """
    all_frames = []
    
    for _ in range(num_episodes):
        obs, _ = env.reset()
        print(env.unwrapped.target_base_rotation)
        target_joint_0 = np.array([2.9])
        done = False
        episode_frames = []
        
        while not done:
            # obs = np.concatenate([obs, target_joint_0])
            action, _ = model.predict(obs, deterministic=True)
            
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            frame = env.render()
            
            if frame is not None:
                episode_frames.append(frame)
        
        all_frames.extend(episode_frames)
    
    return all_frames


frames = evaluate_policy(model, eval_env, num_episodes=1)

eval_env.close()


-2.019330049813105
