In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import yaml
import pickle

plt.rcParams.update({'font.size': 16})

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

# Create a figure with subplots
# n_agents = [8, 16, 24,32]
n_agents = [8,16, 32, 48, 64]
training_world_dims = [n/4 for n in n_agents]

# Create TWO separate figures: one for training dims, one for eval dims
fig_train, axes_train = plt.subplots(1, len(n_agents), figsize=(18, 7), constrained_layout=True)
fig_eval, axes_eval = plt.subplots(1, len(n_agents), figsize=(18, 7), constrained_layout=True)

# Loop through each n_agents value
for idx, n_agent in enumerate(n_agents):
    # Initialize lists to collect points for this n_agent value
    all_chain_poses_x = []
    all_chain_poses_y = []
    all_target_poses_x = []
    all_target_poses_y = []
    
    # Get the specific subplots for both figures
    ax_train = axes_train[idx]
    ax_eval = axes_eval[idx]
    
    checkpoint_path = Path(f"{config['base_path']}/salp_navigate_8a_ver_1/gcn/0/logs/test_rollouts_info_{n_agent}.dat")
    
    if checkpoint_path.is_file():
        with open(checkpoint_path, "rb") as handle:
            file_data = pickle.load(handle)
            
            # Process each dictionary in the data
            for item in file_data:
                eval_world_x_dim, eval_world_y_dim = (32, 32)

                # Extract chain_pose points
                chain_pose = item['chain_pose'].squeeze(0)
                chain_pose_array = np.array(chain_pose)
                
                # Multiple 2D points as a 2D array
                all_chain_poses_x.extend(chain_pose_array[:, 0])
                all_chain_poses_y.extend(chain_pose_array[:, 1])
                
                # Extract target_pose points
                target_pose = item['target_pose'].squeeze(0)
                target_pose_array = np.array(target_pose)
                
                # Multiple 2D points as a 2D array
                all_target_poses_x.extend(target_pose_array[:, 0])
                all_target_poses_y.extend(target_pose_array[:, 1])
    
    # Plot 1: Training World Dimensions
    ax_train.scatter(all_chain_poses_x, all_chain_poses_y, color='blue', label='Chain Pose', alpha=0.7, marker='o')
    ax_train.scatter(all_target_poses_x, all_target_poses_y, color='red', label='Target Pose', alpha=0.7, marker='^')
    ax_train.set_title(f'{n_agent} Agents (Training: {training_world_dims[idx]}x{training_world_dims[idx]})')
    ax_train.set_xlabel('X Position')
    ax_train.set_ylabel('Y Position')
    ax_train.legend()
    ax_train.grid(True)
    ax_train.set_aspect('equal')
    # Set limits based on training world dimensions
    ax_train.set_xlim(-training_world_dims[idx], training_world_dims[idx])
    ax_train.set_ylim(-training_world_dims[idx], training_world_dims[idx])
    
    # Plot 2: Evaluation World Dimensions
    ax_eval.scatter(all_chain_poses_x, all_chain_poses_y, color='blue', label='Chain Pose', alpha=0.7, marker='o')
    ax_eval.scatter(all_target_poses_x, all_target_poses_y, color='red', label='Target Pose', alpha=0.7, marker='^')
    if eval_world_x_dim is not None:
        ax_eval.set_title(f'{n_agent} Agents (Eval: {eval_world_x_dim}x{eval_world_y_dim})')
        # Set limits based on evaluation world dimensions
        ax_eval.set_xlim(-eval_world_x_dim, eval_world_x_dim)
        ax_eval.set_ylim(-eval_world_y_dim, eval_world_y_dim)
    else:
        ax_eval.set_title(f'{n_agent} Agents (Eval: N/A)')
    ax_eval.set_xlabel('X Position')
    ax_eval.set_ylabel('Y Position')
    ax_eval.legend()
    ax_eval.grid(True)
    ax_eval.set_aspect('equal')

# Set titles for both figures
fig_train.suptitle('State Distribution (Training World Dimensions)', fontsize=20)
fig_eval.suptitle('State Distribution (Evaluation World Dimensions)', fontsize=20)

plt.show()