<!-- # Planning Evaluator

> Evaluator for planning performance using the Cross-Entropy Method (CEM) for optimization of discrete action sequences. -->

In [None]:
# | default_exp evaluators.planning_eval

In [None]:
#| hide
from nbdev.showdoc import *  

In [None]:
#| export 
from fastcore.utils import *
import pandas as pd
import wandb
import os
from mawm.data.utils import base_tf, msg_tf
import numpy as np
import torch
from mawm.planners.cem_planner import CEMPlanner

In [None]:
#| export
def preprocessor(env, obs, pos=True, get_msg=True):
    obs_transformed = {}
    positions = {}
    goals = {}
    messages = {}
    goal_pos = obs["global"]["goal_pos"]

    agents = [f'agent_{i}' for i in range(env.num_agents)]
    for i, agent_id in enumerate(agents):
        obs_transformed[agent_id] = base_tf(obs[agent_id]['pov'].astype(np.uint8))  # Add batch dimension
        goal = env.get_goal(env.agents[i], goal_pos)[0]
        goals[agent_id] = base_tf(goal.astype(np.uint8))

        if pos:
            positions[agent_id] = torch.from_numpy(obs[agent_id]['selfpos'])
        if get_msg:
            m = msg_tf((obs[agent_id]['pov'], agent_id, False))
            messages[agent_id] = m

    
    if pos and get_msg:
        return obs_transformed, positions, goals, messages
    elif pos:
        return obs_transformed, positions, goals
    elif get_msg:
        return obs_transformed, messages
    else:
        return obs_transformed

In [None]:
#| export
class PlanEvaluator:
    "Evaluator for planning performance using the Cross-Entropy Method (CEM) for optimization of discrete action sequences."
    def __init__(self, planner, agents= ['agent_0', 'agent_1'], device='cpu'):
        self.agents = agents
        self.device = device
        self.planners = {agent: planner for agent in agents}

In [None]:
#| export
@patch
def eval_all_agents(self: PlanEvaluator, env, preprocessor=preprocessor, negotiation_rounds=3):
    obs = env.reset()
    step = 0
    agents = self.agents
    horizon = self.planners[agents[0]].horizon

    # 1. Initialize "Intents" (The Draft Plans)
    # At t=0, we start with zeros (staying still)
    lst_intents = []
    intents = {agent: torch.zeros(horizon, dtype=torch.long) for agent in agents}

    while step < 100:

        obs_transformed, pos, goals, msgs = preprocessor(env, obs, pos=True, get_msg=True)

        for r in range(negotiation_rounds):
            new_intents = {}
            print("Negotiation Round:", r+1)
            for agent in self.agents:
                # Find the other agent
                other_agent = [a for a in agents if a != agent][0]
                
                # Plan based on the OTHER agent's intent from the previous negotiation round
                # This grounds the "imagination" in reality
                new_intents[agent] = self.planners[agent].plan(
                    o_t=obs_transformed[agent], 
                    pos_t=pos[agent], 
                    o_g=goals[agent], 
                    m_other=msgs[other_agent],
                    other_actions=intents[other_agent] # This is the "Anchor"
                )
            
            # Update intents for the next negotiation round
            intents = new_intents

        
        lst_intents.append(intents.copy())
        # After negotiation rounds, take the FIRST action of the final best plan
        actions = {agent: intents[agent][0] for agent in agents}
        actions = {agent: np.int64(actions[agent].item()) for agent in agents}
        obs, rewards, done, infos = env.step(actions)
        print(f"Step: {step}, Actions taken: {actions}, Rewards: {rewards}, Done: {done}")
        # Shift the remaining plan forward by 1 and pad with a 0 (Stay)
        for agent in agents:
            shifted_plan = torch.cat([intents[agent][1:], torch.zeros(1, dtype=torch.long)])
            intents[agent] = shifted_plan

        if done['__all__']:
            break

        step += 1
        
    env.close()
    return lst_intents # Returning the final sequences

In [None]:
#| hide
from mawm.envs.marl_grid import make_env
from mawm.envs.marl_grid.cfg import config
import copy
import numpy as np

seed = np.random.randint(0, 10000)
cfg = copy.deepcopy(config)
cfg.env_cfg.seed = int(seed)
cfg.env_cfg.max_steps = 512

env = make_env(cfg.env_cfg)

In [None]:
#| hide

obs = env.reset()
items = preprocessor(env, obs, pos=True, get_msg=True)

In [None]:
#| hide

obs_transformed, positions, goals, messages = items
obs_transformed['agent_0'].shape, positions['agent_0'].shape, goals['agent_0'].shape, messages['agent_0'].shape


(torch.Size([3, 42, 42]),
 torch.Size([2]),
 torch.Size([3, 42, 42]),
 torch.Size([5, 7, 7]))

In [None]:
#| hide 
from mawm.models.jepa import JEPA
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=1)

#| 
from mawm.models.misc import ObsPred, MsgPred
from mawm.models.vision import SemanticEncoder
obs_pred = ObsPred(h_dim=32, out_channels=18)
msg_pred = MsgPred(h_dim=32, in_channels=16)
msg_encoder = SemanticEncoder(latent_dim=32)

In [None]:
#| hide
import torch
ckpt = torch.load("./models/best.pth", map_location='cpu')
ckpt.keys()

  ckpt = torch.load("./models/best.pth", map_location='cpu')


dict_keys(['epoch', 'jepa', 'msg_encoder', 'msg_predictor', 'obs_predictor', 'train_loss', 'val_loss', 'optimizer', 'lr'])

In [None]:
model.load_state_dict(ckpt['jepa'])
msg_encoder.load_state_dict(ckpt['msg_encoder'])
msg_pred.load_state_dict(ckpt['msg_predictor'])
obs_pred.load_state_dict(ckpt['obs_predictor'])

<All keys matched successfully>

In [None]:
model.backbone

MeNet6(
  (layers): Sequential(
    (0): Identity()
    (1): Sequential(
      (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
      (1): GroupNorm(4, 16, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
      (4): GroupNorm(8, 32, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (7): GroupNorm(8, 32, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): GroupNorm(8, 32, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (position_encoder): Expander2D()
)

In [None]:
#| hide
planner = CEMPlanner(model=model, msg_enc= msg_encoder, msg_pred=msg_pred, obs_pred=obs_pred)

In [None]:
#| hide
# evaluator = PlanEvaluator(model, msg_encoder, msg_pred, planner)


In [None]:
# #| hide
# evaluator.eval_all_agents(env, preprocessor, negotiation_rounds=3)

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # type: ignore  # noqa: E702
