# Communication Module

> Model architecture that handles communicaion.

In [None]:
#| default_exp models.comm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

class CommunicationModule(nn.Module):
    """
    Handles Agent J's message output. 
    1. During training, outputs the continuous message context (h_t^j).
    2. During execution, outputs the hard-sampled discrete symbol index (m_hard).
    """
    def __init__(self, state_dim: int, message_length: int, vocab_size: int, embed_dim: int, ffn_dim: int, num_heads: int = 4):
        """
        Args:
            state_dim (int): Dimension of the encoded observation (z_t^j).
            message_length (int): Length of the discrete message sequence L.
            vocab_size (int): The number of unique symbolic tokens in the vocabulary |V|.
            embed_dim (int): The dimension of the continuous symbol embeddings (also the K/V input dim).
            ffn_dim (int): Dimension of the intermediate Feed-Forward Network layer.
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.message_length = message_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        
        # --- 1. Symbol Embedding Matrix (E) ---
        self.symbol_embeddings = nn.Embedding(vocab_size, embed_dim)
        
        # --- 2. Multi-Head Attention (MHA) ---
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.total_attention_dim = self.head_dim * num_heads 
        
        if self.total_attention_dim != embed_dim:
             raise ValueError("embed_dim must be divisible by num_heads.")

        # Q, K, V Projections
        self.query_proj = nn.Linear(state_dim, self.total_attention_dim, bias=False) # Q from z_t^j
        self.key_proj = nn.Linear(embed_dim, self.total_attention_dim, bias=False)  # K from message embedding
        self.value_proj = nn.Linear(embed_dim, self.total_attention_dim, bias=False) # V from message embedding
        self.output_proj = nn.Linear(self.total_attention_dim, self.total_attention_dim) # MHA output

        # --- 3. Feed-Forward Network (FFN) ---
        # Processes the MHA context vector
        self.feed_forward = nn.Sequential(
            nn.Linear(self.total_attention_dim, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, embed_dim) # Output dimension matches the K/V dim for potential stacking, but here is just the input to the Logit Head
        )
        
        # --- 4. Prediction Head (Logits h_t^j) ---
        self.logit_proj = nn.Linear(embed_dim, vocab_size)



In [None]:
#| export
@patch
def _split_heads(self: CommunicationModule, x: torch.Tensor):
        """ Reshape (B, D) to (B, H, D_h) """
        new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
        x = x.view(*new_shape)
        return x.unsqueeze(2) 

In [None]:
#| export
@patch
def _combine_heads(self: CommunicationModule, x: torch.Tensor):
        """ Reshape (B, H, 1, D_h) back to (B, D) """
        x = x.squeeze(2) 
        new_shape = x.size()[:-2] + (self.total_attention_dim,)
        return x.view(*new_shape)

In [None]:
#| export
@patch
def _multi_head_attention(self: CommunicationModule, Q_input: torch.Tensor, KV_input: torch.Tensor):
        """ Computes MHA where Q is the state and K/V is the message context. """

        Q_proj = self.query_proj(Q_input)
        K_proj = self.key_proj(KV_input)
        V_proj = self.value_proj(KV_input)

        Q = self._split_heads(Q_proj)
        K = self._split_heads(K_proj)
        V = self._split_heads(V_proj)

        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        weights = F.softmax(scores, dim=-1) 
        context_head_output = torch.matmul(weights, V)

        context_vector_combined = self._combine_heads(context_head_output)
        context_vector = self.output_proj(context_vector_combined)

        return context_vector


In [None]:
#| hide
import torch
vocab_size = 100
L = 10
inp = torch.randint(1, vocab_size, (L,))
print(inp, inp.dtype, inp.shape)


tensor([71, 18, 85, 93, 17, 44, 66, 97, 63, 17]) torch.int64 torch.Size([10])


In [None]:
#| hide
embed_size = 32
embedding_layer = nn.Embedding(vocab_size, embed_size)
embedded_inp = embedding_layer(inp)
print(embedded_inp.dtype, embedded_inp.shape)

torch.float32 torch.Size([10, 32])


In [None]:
#| export
@patch
def embed_discrete_message(self: CommunicationModule, m_j: torch.Tensor) -> torch.Tensor:
        """ Embeds the discrete message sequence m_j into a single continuous vector (mean). """
        embedded_tokens = self.symbol_embeddings(m_j.long())
        # Averaging the embeddings across the message length L 
        return embedded_tokens.mean(dim=1)

In [None]:
#| export
@patch
def forward(self: CommunicationModule, 
            z_t_j: torch.Tensor, 
            m_prev_j: torch.Tensor, 
            is_training: bool):
    """
    Calculates the Logits based on the current state and previous message.

    Args:
        z_t_j (torch.Tensor): Agent J's current state. Shape: (Batch, state_dim)
        m_prev_j (torch.Tensor): Agent J's discrete message from t-1. Shape: (Batch, message_length)
        is_training (bool): Flag to determine output mode.

    Returns:
        - If is_training=True: 
            - torch.Tensor (h_t_j_logits), Continuous logits. Shape: (Batch, vocab_size)
        - If is_training=False: 
            - torch.Tensor (m_t_j), New discrete message sequence. Shape: (Batch, message_length)
    """
    
    # 1. Embed Previous Message (m_{t-1}^j)
    prev_message_embed = self.embed_discrete_message(m_prev_j) # (B, embed_dim)
    
    # 2. Multi-Head Cross-Attention (Context Vector)
    context_vector = self._multi_head_attention(
        Q_input=z_t_j, 
        KV_input=prev_message_embed
    )
    print(context_vector.shape)
    # 3. Feed-Forward Processing
    ffn_output = self.feed_forward(context_vector)
    print(ffn_output.shape)

    # 4. Logit Prediction (h_t^j)
    h_t_j_logits = self.logit_proj(ffn_output) # (B, vocab_size)

    # 5. Output Determination
    if is_training:
        # --- TRAINING PATH: Send Logits ---
        # Agent I receives these continuous logits for its loss calculation.
        return h_t_j_logits
    else:
        # --- EXECUTION PATH: Update Discrete Message Sequence ---
        
        # Predict the next symbol index (Argmax/Hard Decoding)
        predicted_token = torch.argmax(h_t_j_logits, dim=-1).unsqueeze(1) # (B, 1)

        # Drop the first symbol (oldest) and append the predicted token
        m_truncated = m_prev_j[:, 1:] # (B, L-1)
        m_t_j = torch.cat([m_truncated, predicted_token], dim=1) # (B, L)
        
        # The message sent is the full discrete sequence m_t^j
        return m_t_j

In [None]:
#| hide

if __name__ == '__main__':
    # Define Dimensions
    BATCH_SIZE = 4
    STATE_DIM = 64
    MESSAGE_LENGTH = 3 
    VOCAB_SIZE = 10    
    EMBED_DIM = 32     
    FFN_DIM = 128
    NUM_HEADS = 4

    print(f"--- Agent J Non-Recurrent MHA Communication Module Initialization ---")
    comm_module_j = CommunicationModule(
        STATE_DIM, MESSAGE_LENGTH, VOCAB_SIZE, EMBED_DIM, FFN_DIM, NUM_HEADS
    )

    # Mock Inputs 
    mock_z_t_j = torch.randn(BATCH_SIZE, STATE_DIM)
    m_t_minus_1_j = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MESSAGE_LENGTH)).long()
    
    # ====================================================================
    #                  SIMULATION: TRAINING STEP
    # ====================================================================
    print("\n=============================================")
    print("SIMULATION: TRAINING STEP (Logits Output)")
    print("=============================================")
    
    h_t_j_logits = comm_module_j(
        z_t_j=mock_z_t_j, 
        m_prev_j=m_t_minus_1_j, 
        is_training=True
    )
    
    print(f"Input message m_prev_j (symbols): {m_t_minus_1_j[0].tolist()}")
    print(f"Output Message (Logits) shape: {h_t_j_logits.shape}")
    print(f"Expected Logit Dimension: {VOCAB_SIZE}. Matches: {h_t_j_logits.shape[1] == VOCAB_SIZE}")
    
    # ====================================================================
    #                  SIMULATION: EXECUTION STEP
    # ====================================================================
    print("\n=============================================")
    print("SIMULATION: EXECUTION STEP (Discrete Message Update)")
    print("=============================================")
    
    m_t_j = comm_module_j(
        z_t_j=mock_z_t_j, 
        m_prev_j=m_t_minus_1_j, 
        is_training=False
    )
    
    print(f"Output to Agent I: New discrete message m_t^j (symbols): {m_t_j[0].tolist()}")
    print(f"Verification: Message shifted and new token appended.")
    print(f"Old first token: {m_t_minus_1_j[0, 0]}")
    print(f"New first token: {m_t_j[0, 0]}")

    

--- Agent J Non-Recurrent MHA Communication Module Initialization ---

SIMULATION: TRAINING STEP (Logits Output)
torch.Size([4, 32])
torch.Size([4, 32])
Input message m_prev_j (symbols): [4, 2, 7]
Output Message (Logits) shape: torch.Size([4, 10])
Expected Logit Dimension: 10. Matches: True

SIMULATION: EXECUTION STEP (Discrete Message Update)
torch.Size([4, 32])
torch.Size([4, 32])
Output to Agent I: New discrete message m_t^j (symbols): [2, 7, 9]
Verification: Message shifted and new token appended.
Old first token: 4
New first token: 2


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()