### minmal exploration into mean field using petting zoo

In [3]:
import numpy as np
from pettingzoo.mpe import simple_spread_v3

def make_env(seed: int = 0):
    env = simple_spread_v3.parallel_env(
        N = 5,
        local_ratio = 0.5,
        max_cycles = 25,
        continuous_actions = False,
        render_mode = None
    )

    obs, info = env.reset()
    return env, obs


def entropy(pos: np.ndarray, eps: float=1e-8):
    """
    Discrete Entropy H(p) = -sum(pi*log(pi))
    """
    p_safe = pos[pos>0]
    return float(-np.sum(p_safe*np.log(p_safe+eps)))


def get_pos_from_obs(obs: dict, agents):
    pos = []

    for a in agents:
        o = obs[a]
        p = o[2:4] # x,y pos on map [-1,1]
        pos.append(p)
    return np.stack(pos, axis=0)


def pos_to_histogram(pos: np.ndarray, grid_size:int = 10, xlim=(-1.0, 1.0), ylim=(-1.0, 1.0)):
    N = pos.shape[0]

    if N == 0:
        return np.ones(grid_size * grid_size) / (grid_size * grid_size)
    
    xs = pos[:, 0]
    ys = pos[:, 1]

    #map contin coords to integer grid indicies

    gx = np.clip(
        ((xs-xlim[0]) / (xlim[1] - xlim[0]) * grid_size).astype(int),
        0, grid_size -1 
    )
    gy = np.clip(
        ((ys-ylim[0]) / (ylim[1] - ylim[0]) * grid_size).astype(int),
        0, grid_size-1
    )

    counts = np.zeros((grid_size, grid_size), dtype=np.float32)
    for i in range(N):
        counts[gx[i], gy[i]] += 1

    p = counts.flatten()
    p /= p.sum()

    return p





In [7]:
def run_random_rollout():
    env, obs = make_env(seed=42)

    episode_entropy = []

    while env.agents:
        pos = get_pos_from_obs(obs, env.agents)
        p = pos_to_histogram(pos, grid_size=10)
        H = entropy(p)

        episode_entropy.append(H)

        actions = {
            agent: env.action_space(agent).sample() 
            for agent in env.agents
        }

        obs, rewards, term, trunc, infos = env.step(actions)
        if all(term.values()) or all(trunc.values()):
            break

    env.close
    print("Episode entropies:", episode_entropy)


run_random_rollout()
        

Episode entropies: [1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828, 1.6094379425048828]
