In [7]:
import torch
import torch.nn as nn

class SimpleTransformerEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, num_layers, num_heads):
        """
        Args:
            input_dim (int): Dimensionality of each imaging opportunity input,
                             e.g., 3 for [time, angle, intrinsic utility].
            embed_dim (int): Dimension for embeddings in the transformer.
            num_layers (int): Number of transformer encoder layers.
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        # Project the 3-dimensional input into embed_dim.
        self.input_embedding = nn.Linear(input_dim, embed_dim)
        
        # Learnable [CLS] token to summarize the sequence.
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Build the transformer encoder using the built-in encoder layer.
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
                        where each token is [time, angle, intrinsic utility].
                        
        Returns:
            Tensor: The fixed-length representation from the [CLS] token with shape (batch_size, embed_dim).
        """
        batch_size, seq_len, _ = x.shape
        
        # Embed the input features.
        x = self.input_embedding(x)  # (batch_size, seq_len, embed_dim)
        
        # Prepend the [CLS] token.
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch_size, seq_len+1, embed_dim)
        
        # The transformer expects input shape: (seq_len+1, batch_size, embed_dim).
        x = x.transpose(0, 1)
        
        # Pass through the transformer encoder.
        x = self.transformer_encoder(x)
        
        # Transpose back to (batch_size, seq_len+1, embed_dim) and extract the [CLS] token.
        x = x.transpose(0, 1)
        cls_output = x[:, 0, :]  # (batch_size, embed_dim)
        return cls_output

# ----- Example usage -----
if __name__ == "__main__":
    # Define parameters.
    batch_size = 2
    seq_len = 10  # For example, 10 imaging opportunities.
    input_dim = 3   # Each opportunity is represented by [time, angle, intrinsic utility].
    embed_dim = 32
    num_heads = 4
    num_layers = 2

    # Create the model.
    model = SimpleTransformerEncoder(input_dim, embed_dim, num_layers, num_heads)
    
    # Create a dummy input.
    dummy_input = torch.randn(batch_size, seq_len, input_dim)
    
    # Forward pass.
    cls_output = model(dummy_input)
    print("CLS output shape:", cls_output.shape)  # Expected: (batch_size, embed_dim)


CLS output shape: torch.Size([2, 32])
