In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import yaml
import pickle

plt.rcParams.update({'font.size': 16})

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

# Create a figure with subplots
n_agents = [8, 16, 24,32]

fig, axes = plt.subplots(1, len(n_agents), figsize=(18, 7), constrained_layout=True)


# Loop through each n_agents value
for idx, n_agent in enumerate(n_agents):
    # Initialize lists to collect points for this n_agent value
    all_chain_poses_x = []
    all_chain_poses_y = []
    all_target_poses_x = []
    all_target_poses_y = []
    
    # Get the specific subplot
    ax = axes[idx]
    
    checkpoint_path = Path(f"{config['base_path']}/salp_local_8a/gcn/test/logs/test_rollouts_info_{n_agent}.dat")
    
    if checkpoint_path.is_file():
        with open(checkpoint_path, "rb") as handle:
            file_data = pickle.load(handle)
            
            # Process each dictionary in the data
            for item in file_data:
                # Extract chain_pose points
                chain_pose = item['chain_pose'].squeeze(0)
                chain_pose_array = np.array(chain_pose)
                
                # Multiple 2D points as a 2D array
                all_chain_poses_x.extend(chain_pose_array[:, 0])
                all_chain_poses_y.extend(chain_pose_array[:, 1])
                
                # Extract target_pose points
                target_pose = item['target_pose'].squeeze(0)
                target_pose_array = np.array(target_pose)
                
                # Multiple 2D points as a 2D array
                all_target_poses_x.extend(target_pose_array[:, 0])
                all_target_poses_y.extend(target_pose_array[:, 1])
    
    # Create the scatter plot on this subplot
    ax.scatter(all_chain_poses_x, all_chain_poses_y, color='blue', label='Chain Pose', alpha=0.7, marker='o')
    ax.scatter(all_target_poses_x, all_target_poses_y, color='red', label='Target Pose', alpha=0.7, marker='^')
    ax.set_title(f'Distribution for {n_agent} Agents')
    ax.set_xlabel('X Position')
    ax.set_ylabel('Y Position')
    ax.legend()
    ax.grid(True)
    ax.set_aspect('equal')  # Equal scaling for both axes

plt.suptitle('Distribution of Chain Pose and Target Pose', fontsize=20)
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Make room for the suptitle