# Planning Evaluator

> Evaluator for planning performance using the Cross-Entropy Method (CEM) for optimization of discrete action sequences.

In [None]:
# | default_exp evaluators.planning_eval

In [None]:
#| hide
from nbdev.showdoc import *  

In [None]:
#| export 
from fastcore.utils import *
import pandas as pd
import wandb
import os
from mawm.data.utils import base_tf, msg_tf
import numpy as np
import torch
from mawm.planners.cem_planner import CEMPlanner

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
def preprocessor(env, obs, pos=True, get_msg=True):
    obs_transformed = {}
    positions = {}
    goals = {}
    messages = {}
    goal_pos = obs["global"]["goal_pos"]

    agents = [f'agent_{i}' for i in range(env.num_agents)]
    for i, agent_id in enumerate(agents):
        obs_transformed[agent_id] = base_tf(obs[agent_id]['pov'].astype(np.uint8))  # Add batch dimension
        goal = env.get_goal(env.agents[i], goal_pos)[0]
        goals[agent_id] = base_tf(goal.astype(np.uint8))

        if pos:
            positions[agent_id] = torch.from_numpy(obs[agent_id]['selfpos'])
        if get_msg:
            m = msg_tf((obs[agent_id]['pov'], agent_id, False, True))
            messages[agent_id] = m

    
    if pos and get_msg:
        return obs_transformed, positions, goals, messages
    elif pos:
        return obs_transformed, positions, goals
    elif get_msg:
        return obs_transformed, messages
    else:
        return obs_transformed

In [None]:
#| export
from einops import repeat
class FindGoalPlanner:

    def __init__(self, model, msg_enc, comm_module, action_dim = 5, horizon= 10, num_samples= 1000, topk= 100, opt_steps= 10, agents= ['agent_0', 'agent_1'], device='cpu'):
        self.model = model
        self.msg_enc = msg_enc
        self.comm_module = comm_module
        self.agents = agents
        self.device = device
        self.action_dim = action_dim
        self.num_samples = num_samples
        self.topk = topk
        self.opt_steps = opt_steps
        self.horizon = horizon
        self.current_probs = {agent: torch.full((self.horizon, self.action_dim), 1.0/self.action_dim, device=self.device) \
                              for agent in self.agents}
        self.loss = torch.nn.MSELoss(reduction='none')


In [None]:
@patch
def update_dist(self: FindGoalPlanner, costs, samples):
    for agent in self.agents:
        # costs[agent] shape: (num_samples,)
        _, elite_indices = torch.topk(-costs[agent], self.topk)
        elites = samples[agent][elite_indices] # (topk, horizon)

        new_probs = torch.zeros_like(self.current_probs[agent])
        for t in range(self.horizon):
            counts = torch.bincount(elites[:, t], minlength=self.action_dim).float()
            # Add small epsilon to avoid zero probabilities
            new_probs[t] = (counts + 1e-6) / (self.topk + 1e-6 * self.action_dim)

        self.current_probs[agent] = new_probs

In [None]:
# #| export
# @patch
# def Plan(self: FindGoalPlanner, env, preprocessor=preprocessor):
#     obs = env.reset()
#     step = 0
#     plan = {agent: [] for agent in self.agents}
#     _, _, goals, _ = preprocessor(env, obs, pos=True, get_msg=True)

#     goal_pos = obs["global"]["goal_pos"]
#     position= repeat(torch.from_numpy(goal_pos).unsqueeze(0), "b d -> g b d", b=1, g=2)
#     z_goal = self.model.backbone(torch.stack([goals[agent] for agent in self.agents]).to(self.device),
#                                     position=position)
    
#     z_goal = repeat(z_goal, 'b c h w -> (b s) c h w', s=self.num_samples) # TODO
#     z_goal = {agent: z_goal[i:i+1] for i, agent in enumerate(self.agents)}

#     while step < 100:
#         prev_obs, prev_pos, _, msgs = preprocessor(env, obs, pos=True, get_msg=True)
#         for agent in self.agents:
#             self.current_probs[agent] = torch.full((self.horizon, self.action_dim), 1.0/self.action_dim, device=self.device)


#         prev_z = self.model.backbone(torch.stack([prev_obs[agent] for agent in self.agents]).to(self.device),
#                                 position=torch.stack([prev_pos[agent] for agent in self.agents]).to(self.device))
    
#         current_state = {agent: prev_z[i:i+1] for i, agent in enumerate(self.agents)}
        
#         current_state = {agent: repeat(current_state[agent], 'b c h w -> (b s) c h w', s=self.num_samples) \
#                         for agent in self.agents}
        
#         samples = {agent: torch.multinomial(self.current_probs[agent], self.num_samples, replacement=True).T \
#                    for agent in self.agents}
                
#         best_cost = {agent: float("inf") for agent in self.agents}
#         best_plan = {agent: 0 for agent in self.agents}

#         for n in self.opt_steps:
#             for j in range(self.horizon):
#                 next_state = {}
#                 costs = {agent: 0 for agent in self.agents}

#                 for rec in self.agents:
#                     for sender in self.agents:
#                         if sender != rec:
#                             if j > 0:
#                                 msg_sender = self.comm_module(current_state[sender])
#                             else:
#                                 msg_sender = msgs[sender].unsqueeze(0).to(self.device)
                    
#                     h_rec = self.msg_enc(msg_sender)
#                     h_rec = repeat(h_rec, 'b t d -> (s b t) d', s=self.num_samples)
#                     z_next = self.model.dynamics(current_state[rec], samples[rec], h_rec)
#                     next_state[rec] = z_next
#                     costs[rec] += self.loss(z_goal[rec], next_state[rec])
#                 current_state = next_state

#             for agent in self.agents:
#                 if costs[agent].min() < best_cost[agent]:
#                     best_cost[agent] = costs[agent].min()
#                     best_plan[agent] = samples[agent][costs[agent].argmin()]
#             self.update_dist(costs, samples)

#         plan_step = {agent: best_plan[agent][0].unsqueeze(0).cpu().numpy() for agent in self.agents}
#         for agent in self.agents:
#             plan[agent].append(plan_step[agent])

#         actions = {agent: int(plan_step[agent]) for agent in self.agents}
#         obs, rewards, done, infos = env.step(actions)
#         print(f"Step: {step}, Actions taken: {actions}, Rewards: {rewards}, Done: {done}")

#         if done['__all__']:
#             break

#         step += 1
        
#     env.close()
#     return plan

In [None]:
#| export
@patch
def Plan(self: FindGoalPlanner, env, preprocessor):
    obs = env.reset()
    step = 0
    plan = {agent: [] for agent in self.agents}
    
    # Goal Latent Preparation
    # Assume preprocessor returns goal images/positions
    _, _, goals, _ = preprocessor(env, obs, pos=True, get_msg=True)

    goal_pos = obs["global"]["goal_pos"]
    position= repeat(torch.from_numpy(goal_pos).unsqueeze(0), "b d -> g b d", b=1, g=2)
    z_goals = self.model.backbone(torch.stack([goals[agent] for agent in self.agents]).to(self.device),
                                    position=position)
    
    z_goals = repeat(z_goals, 'b c h w -> (b s) c h w', s=self.num_samples)
    z_goals = {agent: z_goals[i:i+1] for i, agent in enumerate(self.agents)}

    while step < 100:
        prev_obs, prev_pos, _, msgs = preprocessor(env, obs, pos=True, get_msg=True)
        
        for agent in self.agents:
            self.current_probs[agent] = torch.full((self.horizon, self.action_dim), 1.0/self.action_dim, device=self.device)

        # 2. Optimization Loop (CEM)
        for n in range(self.opt_steps):
            # Sample action sequences for the whole horizon
            # Shape: (num_samples, horizon)
            samples = {agent: torch.multinomial(self.current_probs[agent], self.num_samples, replacement=True).T \
                       for agent in self.agents}
            # make sampled actions one-hot encoded
            samples = {agent: torch.nn.functional.one_hot(samples[agent], num_classes=self.action_dim).float() \
                       for agent in self.agents}
            
            # Initial latent state for this optimization roll-out
            # Shape: (num_samples, latent_dim...)
            start_z = self.model.backbone(torch.stack([prev_obs[a] for a in self.agents]).to(self.device),
                                          position=torch.stack([prev_pos[agent] for agent in self.agents]).to(self.device))
            
            states = {agent: repeat(start_z[i], 'c h w -> s c h w', s=self.num_samples) \
                      for i, agent in enumerate(self.agents)}
            
            total_costs = {agent: torch.zeros(self.num_samples, device=self.device) for agent in self.agents}

            # 3. Trajectory Rollout
            for t in range(self.horizon):
                next_states = {}
                for rec in self.agents:
                    # Get message from the "other" agent
                    sender = [a for a in self.agents if a != rec][0]
                    
                    if t == 0:
                        # Use actual initial message
                        m = msgs[sender].to(self.device).unsqueeze(0)
                        m = repeat(m, 'b c h w -> (s b) c h w', s=self.num_samples, b= 1)
                    else:
                        # Use predicted message from previous latent
                        print(states[sender].shape)
                        m = self.comm_module(states[sender])
                    
                    h_rec = self.msg_enc(m) # Process message
                    
                    # Predict next state: (current_z, action_at_t, message_context)
                    z_next = self.model.dynamics(states[rec], samples[rec][:, t], h_rec.squeeze(1))
                    next_states[rec] = z_next
                    
                    # Calculate MSE and sum over spatial/channel dims
                    step_loss = self.loss(z_next, z_goals[rec]).mean(dim=[1, 2, 3])
                    total_costs[rec] += step_loss
                
                states = next_states

            # 4. Update distribution based on trajectory costs
            self.update_dist(total_costs, samples)

        # 5. Execution: Pick the best action from the final distribution
        # (Alternatively, pick the action from the best trajectory found in the last step)
        executed_actions = {}
        for agent in self.agents:
            # Taking the most likely action at step 0
            act = torch.argmax(self.current_probs[agent][0]).item()
            executed_actions[agent] = act
            plan[agent].append(act)

        obs, rewards, done, infos = env.step(executed_actions)
        print(f"Step: {step} | Actions: {executed_actions} | Rewards: {rewards}")

        if done.get('__all__', False): break
        step += 1
        
    return plan

In [None]:
#| hide
from mawm.envs.marl_grid import make_env
from mawm.envs.marl_grid.cfg import config
import copy
import numpy as np

seed = np.random.randint(0, 10000)
cfg = copy.deepcopy(config)
cfg.env_cfg.seed = int(seed)
cfg.env_cfg.max_steps = 512

env = make_env(cfg.env_cfg)

In [None]:
#| hide 
from mawm.models.jepa import JEPA
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/findgoal/mawm/main/mawm-seq-40.yaml")
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=5)

from mawm.models.comm import CommModule, MSGEnc
msg_enc = MSGEnc()
comm_module = CommModule()

In [None]:
#| hide
planner = FindGoalPlanner(model, msg_enc, comm_module, horizon= 5, num_samples = 20, opt_steps= 5)

In [None]:
# planner.Plan(env, preprocessor)

torch.Size([20, 32, 15, 15])


  return F.mse_loss(input, target, reduction=self.reduction)


UnboundLocalError: cannot access local variable 't' where it is not associated with a value

In [None]:
#| hide

obs = env.reset()
items = preprocessor(env, obs, pos=True, get_msg=True)

In [None]:
#| hide
obs_transformed, positions, goals, messages = items
obs_transformed['agent_0'].shape, positions['agent_0'].shape, goals['agent_0'].shape, messages['agent_0'].shape


(torch.Size([3, 42, 42]),
 torch.Size([2]),
 torch.Size([3, 42, 42]),
 torch.Size([5, 7, 7]))

In [None]:
#| hide
device = "cpu"
agents = ['agent_0', 'agent_1']
prev_obs, prev_pos, goals, msgs = preprocessor(env, obs, pos=True, get_msg=True)
prev_z = model.backbone(torch.stack([prev_obs[agent] for agent in agents]).to(device),
                                position=torch.stack([prev_pos[agent] for agent in agents]).to(device))
    

In [None]:
prev_z.shape, prev_z[0:1].shape, prev_z[1:2].shape

(torch.Size([2, 32, 15, 15]),
 torch.Size([1, 32, 15, 15]),
 torch.Size([1, 32, 15, 15]))

In [None]:
#| hide
current_state = {agent: prev_z[i:i+1] for i, agent in enumerate(agents)}
current_state['agent_0'].shape, current_state['agent_1'].shape

(torch.Size([1, 32, 15, 15]), torch.Size([1, 32, 15, 15]))

In [None]:
#| hide
goal_pos = obs["global"]["goal_pos"]
torch.from_numpy(goal_pos).unsqueeze(0).shape

torch.Size([1, 2])

In [None]:
#| hide
goals['agent_0'].shape, torch.stack([goals[agent] for agent in agents]).shape

(torch.Size([3, 42, 42]), torch.Size([2, 3, 42, 42]))

In [None]:
#| hide
from einops import repeat
position= repeat(torch.from_numpy(goal_pos).unsqueeze(0), "b d -> g b d", b=1, g=2)
prev_z = model.backbone(torch.stack([goals[agent] for agent in agents]).to(device),
                                position=position)

prev_z.shape

torch.Size([2, 32, 15, 15])

In [None]:
# #| hide
# import torch
# ckpt = torch.load("./models/best.pth", map_location='cpu')
# ckpt.keys()

  ckpt = torch.load("./models/best.pth", map_location='cpu')


dict_keys(['epoch', 'jepa', 'msg_encoder', 'msg_predictor', 'obs_predictor', 'train_loss', 'val_loss', 'optimizer', 'lr'])

In [None]:
# model.load_state_dict(ckpt['jepa'])
# msg_encoder.load_state_dict(ckpt['msg_encoder'])
# msg_pred.load_state_dict(ckpt['msg_predictor'])
# obs_pred.load_state_dict(ckpt['obs_predictor'])

In [None]:
# model.backbone

MeNet6(
  (layers): Sequential(
    (0): Identity()
    (1): Sequential(
      (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
      (1): GroupNorm(4, 16, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
      (4): GroupNorm(8, 32, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (7): GroupNorm(8, 32, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): GroupNorm(8, 32, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

In [None]:
# #| hide
# planner = CEMPlanner(model=model, msg_enc= msg_encoder, msg_pred=msg_pred, obs_pred=obs_pred)

In [None]:
#| hide
# evaluator = PlanEvaluator(model, msg_encoder, msg_pred, planner)


In [None]:
# #| hide
# evaluator.eval_all_agents(env, preprocessor, negotiation_rounds=3)

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # type: ignore  # noqa: E702
