# World Model

> World model (Predictor).

In [None]:
#| default_exp models.worldmodel

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn

In [None]:
#| export
from MAWM.models.dense import DenseModel
class WorldModel(DenseModel):
    def __init__(self, z_dim=32, action_dim=1, prog_dim=32, model_info={ 'layers': 3,'node_size': 256,'activation': nn.ReLU,'dist': None}):
        input_dim = z_dim + action_dim + prog_dim
        output_dim = z_dim + 1  # next state + reward
        self.z_dim = z_dim
        super().__init__((output_dim,), input_dim, model_info)
    
    def forward(self, z_t: torch.FloatTensor, a_t: torch.FloatTensor, prog_emb: torch.FloatTensor):
        x = torch.cat([z_t, a_t, prog_emb], dim=-1)
        out = super().forward(x)
        z_next = out[:, :self.z_dim]
        return z_next

In [None]:
#| hide
z = torch.randn(4, 32)
a = torch.randn(4, 1)
p = torch.randn(4, 32)
wm = WorldModel()
z_next = wm(z, a, p)
eq(z_next.shape, (4, 32))

True

In [None]:
#| export
class RewardModel(DenseModel):
    def __init__(self, z_dim=32, action_dim=1, model_info={ 'layers': 3,'node_size': 256,'activation': nn.ReLU,'dist': 'binary'}):
        input_dim = z_dim + action_dim 
        output_dim = 1
        self.z_dim = z_dim
        super().__init__((output_dim,), input_dim, model_info)
    
    def forward(self, z_t, a_t):
        x = torch.cat([z_t, a_t], dim=-1)
        out = super().forward(x)
        return out

In [None]:
#| hide
z = torch.randn(4, 32)
a = torch.randn(4, 1)
rm = RewardModel()
r = rm(z, a)
r

Independent(Bernoulli(logits: torch.Size([4, 1])), 1)

In [None]:
#|hide
r.sample()

tensor([[1.],
        [1.],
        [0.],
        [1.]])

In [None]:
#| export
class MLPPredictor(nn.Module):
    def __init__(self, latent_dim=32, action_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, z, a):
        return self.net(torch.cat([z, a], dim=-1))


In [None]:
#| hide
z_t = torch.randn(16, 32)
a_t = torch.randn(16, 1)
model = MLPPredictor(latent_dim=32, action_dim=1)
model(z_t, a_t).shape

torch.Size([16, 32])

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Reusing the Multi-Head Attention logic from our previous conversation, 
# ensuring all components for Agent I are in this single module.
class WorldModelAttention(nn.Module):
    """
    World Model with Multi-Head Cross-Attention.
    
    Handles message reception:
    - Training: Accepts continuous message vector (h_t^j).
    - Execution: Accepts discrete symbol index (m_hard) and performs embedding lookup.
    """

    def __init__(self, 
                 state_dim: int, 
                 action_dim: int, 
                 vocab_size: int, 
                 embed_dim: int, # Dimension of the actual message vector
                 predictor_embed_dim: int = 64,
                 total_attention_dim: int = 128,
                 num_heads: int = 3):
        """
        Initializes the World Model and the Symbol Embedding Table.

        Args:
            state_dim (int): Dimension of z_t^i and z_{t+1}^i.
            action_dim (int): Dimension of raw action a_t^i.
            vocab_size (int): Size of the shared symbol vocabulary |V|.
            embed_dim (int): Dimension of the message vector (h_t^j or h_received).
            ... MHA and MLP parameters ...
        """
        super().__init__()
        
        # Critical shared component: Agent I's version of the Symbol Embedding Matrix
        self.symbol_embeddings = nn.Embedding(vocab_size, embed_dim)
        
        # --- Multi-Head Attention Setup ---
        self.num_heads = num_heads
        self.total_attention_dim = total_attention_dim
        self.head_dim = total_attention_dim // num_heads
        self.predictor_embed_dim = predictor_embed_dim

        # MHA Projections: Key/Value input size is the message embedding dimension (embed_dim)
        self.query_proj = nn.Linear(state_dim, total_attention_dim, bias=False)
        self.key_proj = nn.Linear(embed_dim, total_attention_dim, bias=False)
        self.value_proj = nn.Linear(embed_dim, total_attention_dim, bias=False)
        self.output_proj = nn.Linear(total_attention_dim, total_attention_dim)

        # --- Action Encoding Layer ---
        self.action_encoder = nn.Linear(action_dim, predictor_embed_dim, bias=True)

        # --- Final Prediction Network (MLP) ---
        mlp_input_dim = state_dim + total_attention_dim + predictor_embed_dim
        
        self.prediction_mlp = nn.Sequential(
            nn.Linear(mlp_input_dim, state_dim * 2),
            nn.ReLU(),
            nn.Linear(state_dim * 2, state_dim) 
        )

    


In [None]:
#| export
# --- Utility methods for MHA (as before) ---
@patch
def _split_heads(self: WorldModelAttention, x: torch.Tensor):
    new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
    x = x.view(*new_shape)
    return x.unsqueeze(2) # Add sequence length L=1: (B, num_heads, 1, head_dim)

@patch
def _combine_heads(self: WorldModelAttention, x: torch.Tensor):
    x = x.squeeze(2) 
    new_shape = x.size()[:-2] + (self.total_attention_dim,)
    return x.view(*new_shape)

# --- Core Attention Logic ---

@patch
def _multi_head_attention(self: WorldModelAttention, Q_input: torch.Tensor, KV_input: torch.Tensor):
    """ Computes Multi-Head Attention given the Q and KV sources. """
    
    # Compute Q, K, V Projections
    Q_proj = self.query_proj(Q_input)   # Q_input is z_t^i
    K_proj = self.key_proj(KV_input)    # KV_input is the continuous message (h_t^j or h_received)
    V_proj = self.value_proj(KV_input)  

    # Split into Multiple Heads
    Q = self._split_heads(Q_proj) # (B, num_heads, 1, head_dim)
    K = self._split_heads(K_proj) 
    V = self._split_heads(V_proj) 
    
    # Scaled Dot-Product
    K_T = K.transpose(-1, -2)
    scores = torch.matmul(Q, K_T) / math.sqrt(self.head_dim)
    
    weights = F.softmax(scores, dim=-1) 
    context_head_output = torch.matmul(weights, V)

    # Combine Heads and Final Projection
    context_vector_combined = self._combine_heads(context_head_output)
    context_vector = self.output_proj(context_vector_combined) # (B, total_attention_dim)
    
    return context_vector


In [None]:
#| export
@patch
def forward(self: WorldModelAttention, 
            z_t_i, 
            message_input, 
            action_t,
            is_training):
    """
    Predicts the next state z_{t+1}^i.

    Args:
        z_t_i (torch.Tensor): Agent i's current state. Shape: (Batch, state_dim)
        message_input (torch.Tensor): 
            - TRAINING: Continuous vector h_t^j. Shape: (Batch, embed_dim)
            - EXECUTION: Discrete symbol index m_hard. Shape: (Batch)
        action_t (torch.Tensor): Agent i's action. Shape: (Batch, action_dim)
        is_training (bool): Flag to switch message processing mode.
    """
    
    # --- 1. Message Processing (The Key Difference) ---
    if is_training:
        # Training: Use the continuous message directly
        h_message = message_input # (B, embed_dim)
    else:
        # Execution: Perform non-differentiable lookup on the received index
        # message_input is the discrete index m_hard (LongTensor)
        h_message = self.symbol_embeddings(message_input.long()) # (B, embed_dim)
    
    
    # --- 2. Multi-Head Attention ---
    # Q comes from z_t_i, K/V comes from the processed message h_message
    context_vector = self._multi_head_attention(
        Q_input=z_t_i,
        KV_input=h_message
    )

    # --- 3. Encode Action ---
    a = self.action_encoder(action_t) # (B, predictor_embed_dim)

    # --- 4. Final Prediction ---
    combined_features = torch.cat([z_t_i, context_vector, a], dim=1) 
    z_t_plus_1_i = self.prediction_mlp(combined_features) 

    return z_t_plus_1_i


In [None]:
# # # | hide
# from MAWM.models.comm import CommunicationModuleGRU
# # # --- Example Usage (Simulation of one step) ---

# # Define Shared Dimensions
# BATCH_SIZE = 16
# STATE_DIM = 100       
# ACTION_DIM = 8        
# VOCAB_SIZE = 10       
# EMBED_DIM = 64        # Message vector dimension (h_t^j and h_received)
# PREDICTOR_EMBED_DIM = 32
# NUM_HEADS = 4
# TOTAL_ATTENTION_DIM = 64
# MESSAGE_CONTEXT_DIM = EMBED_DIM  # For Agent J's GRU output dimension

# # --- Setup ---
# # Need the two modules we defined
# sender = CommunicationModuleGRU(observation_dim=EMBED_DIM, message_context_dim=EMBED_DIM, vocab_size=VOCAB_SIZE)
# receiver_wm = WorldModelAttention(
#     state_dim=STATE_DIM,
#     action_dim=ACTION_DIM,
#     vocab_size=VOCAB_SIZE,
#     embed_dim=EMBED_DIM,
#     predictor_embed_dim=PREDICTOR_EMBED_DIM,
#     total_attention_dim=TOTAL_ATTENTION_DIM,
#     num_heads=NUM_HEADS
# )

# # CRITICAL: In a real system, the symbol embeddings must be synchronized/shared.
# # Here, we copy Agent I's learned embeddings to Agent J's output projection for demonstration.
# # In training, Agent I's embedding matrix E is trained by the World Model loss.
# # Agent J's Logit Projection (W_L) is also trained.

# #     # We won't copy/sync here, but note that the final message dimension must match (EMBED_DIM).

# #     # Mock Inputs
# mock_z_t_i = torch.randn(BATCH_SIZE, STATE_DIM)
# mock_z_t_j_GRU_output = torch.randn(BATCH_SIZE, EMBED_DIM) # The continuous GRU output h_t^j
# mock_action_t = torch.randn(BATCH_SIZE, ACTION_DIM)
# mock_h_prev_j = torch.zeros(BATCH_SIZE, MESSAGE_CONTEXT_DIM) 

# print("--- Simulation: Training Step (Continuous Channel) ---")

# # 1. Agent J Output (Training)
# h_t_j_cont = sender(mock_z_t_j_GRU_output, mock_h_prev_j, is_training=True)

# # 2. Agent I Prediction (Training)
# z_t_plus_1_i_train = receiver_wm(
#     z_t_i=mock_z_t_i,
#     message_input=h_t_j_cont,
#     action_t=mock_action_t,
#     is_training=True
# )
# print(f"Agent J sent: Continuous vector (shape {h_t_j_cont.shape})")
# print(f"Predicted state z_t+1_i (Training) shape: {z_t_plus_1_i_train.shape}")


# print("\n--- Simulation: Execution Step (Discrete Channel) ---")

# # 1. Agent J Output (Execution) - Sends DISCRETE INDEX
# m_hard = sender(mock_z_t_j_GRU_output, mock_h_prev_j, is_training=False)

# # 2. Agent I Prediction (Execution) - Performs lookup using the index
# z_t_plus_1_i_exec = receiver_wm(
#     z_t_i=mock_z_t_i,
#     message_input=m_hard, # Passes the discrete index (LongTensor)
#     action_t=mock_action_t,
#     is_training=False
# )

# print(f"Agent J sent: Discrete Index (example indices: {m_hard[0:4].tolist()})")
# print(f"Agent I performed lookup to get message vector.")
# print(f"Predicted state z_t+1_i (Execution) shape: {z_t_plus_1_i_exec.shape}")



--- Simulation: Training Step (Continuous Channel) ---
Agent J sent: Continuous vector (shape torch.Size([16, 64]))
Predicted state z_t+1_i (Training) shape: torch.Size([16, 100])

--- Simulation: Execution Step (Discrete Channel) ---
Agent J sent: Discrete Index (example indices: [5, 3, 5, 9])
Agent I performed lookup to get message vector.
Predicted state z_t+1_i (Execution) shape: torch.Size([16, 100])


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()