# Communication Module

> Model architecture that handles communicaion.

In [None]:
#| default_exp models.comm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CommunicationModuleGRU(nn.Module):
    """
    Handles Agent J's message output. 
    1. During training, outputs the continuous message context (h_t^j).
    2. During execution, outputs the hard-sampled discrete symbol index (m_hard).
    """
    def __init__(self, observation_dim: int, message_context_dim: int, vocab_size: int):
        """
        Args:
            observation_dim (int): Dimension of the encoded observation (z_t^j).
            message_context_dim (int): Dimension of the embedded message/GRU hidden state (h_t^j).
            vocab_size (int): The number of unique symbolic tokens in the vocabulary |V|.
        """
        super().__init__()
        self.message_context_dim = message_context_dim
        self.vocab_size = vocab_size

        # --- 1. Recurrent Core (GRU) ---
        # Computes h_t^j = GRU(z_t^j, h_{t-1}^j)
        self.gru = nn.GRU(
            input_size=observation_dim,
            hidden_size=message_context_dim,
            num_layers=1,
            batch_first=True
        )
        
        # --- 2. Discretization Head Logits ---
        # Learns to map the continuous context h_t^j to a symbol index for execution.
        self.logit_proj = nn.Linear(message_context_dim, vocab_size)


In [None]:
#| export
@patch
def forward(self: CommunicationModuleGRU, z_t_j: torch.Tensor, h_prev_j: torch.Tensor, is_training: bool):
    """
    Calculates the new message context and determines the output based on mode.

    Args:
        z_t_j (torch.Tensor): The encoded observation at time t. Shape: (Batch, observation_dim)
        h_prev_j (torch.Tensor): The previous embedded message context (h_{t-1}^j). 
                                    Shape: (Batch, message_context_dim)
        is_training (bool): If True, returns h_t_j (continuous). 
                            If False, returns the discrete index m_hard.

    Returns:
        - If is_training=True: torch.Tensor (h_t^j), Shape: (Batch, message_context_dim)
        - If is_training=False: torch.Tensor (m_hard), Shape: (Batch)
    """
    B = z_t_j.size(0)

    # 1. Update Recurrent Message Context (h_t^j)
    
    # Input reshape: (B, observation_dim) -> (B, 1, observation_dim)
    z_t_j_seq = z_t_j.unsqueeze(1) 

    # Hidden state reshape: (B, message_context_dim) -> (1, B, message_context_dim)
    h_prev_j_gru = h_prev_j.unsqueeze(0).contiguous()

    # Output h_t_j_gru: (1, B, message_context_dim)
    _, h_t_j_gru = self.gru(z_t_j_seq, h_prev_j_gru)
    
    # Final continuous message context: (B, message_context_dim)
    h_t_j = h_t_j_gru.squeeze(0)

    # 2. Determine Output Based on Mode
    if is_training:
        # --- TRAINING PATH: Pass Continuous Context ---
        # Gradients flow through h_t_j to the GRU, optimizing for usefulness.
        return h_t_j
    else:
        # --- EXECUTION PATH: Hard Discretization ---
        # 2a. Compute Logits
        logits = self.logit_proj(h_t_j) # Shape: (Batch, vocab_size)

        # 2b. Hard Sampling (Argmax): Select the best symbol index
        discrete_token_index = torch.argmax(logits, dim=-1) # Shape: (Batch)
        
        # This is the discrete message sent over the channel.
        return discrete_token_index.long() 


In [None]:
#| hide

if __name__ == '__main__':
    # Define Dimensions
    BATCH_SIZE = 4
    OBSERVATION_DIM = 64
    MESSAGE_CONTEXT_DIM = 128
    VOCAB_SIZE = 10
    
    print(f"--- Agent J Communication Module Initialization ---")
    comm_module_j = CommunicationModuleGRU(OBSERVATION_DIM, MESSAGE_CONTEXT_DIM, VOCAB_SIZE)

    # Mock Inputs
    mock_z_t_j = torch.randn(BATCH_SIZE, OBSERVATION_DIM)
    # h_{t-1}^j is the previous message context (initialize to zero for the first step)
    mock_h_prev_j = torch.zeros(BATCH_SIZE, MESSAGE_CONTEXT_DIM) 
    
    # --- Training Simulation ---
    h_t_j_cont = comm_module_j(mock_z_t_j, mock_h_prev_j, is_training=True)
    print("\nTraining Step:")
    print(f"Output (Continuous Context h_t^j) shape: {h_t_j_cont.shape}")
    
    # --- Execution Simulation ---
    m_hard = comm_module_j(mock_z_t_j, mock_h_prev_j, is_training=False)
    print("\nExecution Step:")
    print(f"Output (Discrete Symbol Indices m_hard): {m_hard.tolist()}")
    print(f"Output shape: {m_hard.shape} (LongTensor of symbol indices)")

--- Agent J Communication Module Initialization ---

Training Step:
Output (Continuous Context h_t^j) shape: torch.Size([4, 128])

Execution Step:
Output (Discrete Symbol Indices m_hard): [7, 7, 7, 6]
Output shape: torch.Size([4]) (LongTensor of symbol indices)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()