In [6]:
def number_to_words(n):
    if not 0 <= n < 20:
        return "Number out of range"

    words = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", 
             "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", 
             "sixteen", "seventeen", "eighteen", "nineteen"]

    return words[n]

# 测试
for i in range(20):
    print(f"{i}: {number_to_words(i)}")


0: zero
1: one
2: two
3: three
4: four
5: five
6: six
7: seven
8: eight
9: nine
10: ten
11: eleven
12: twelve
13: thirteen
14: fourteen
15: fifteen
16: sixteen
17: seventeen
18: eighteen
19: nineteen


In [None]:
class CausalCrossAttention(nn.Module):
    def __init__(self, observation_dim, embed_dim, num_heads):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim, num_heads, batch_first=False)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.Mish(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.linear = nn.Linear(observation_dim, embed_dim)

    def forward(self, x, context):
        # Ensure context has the correct shape for the linear layer
        context = context.unsqueeze(2)  # Add sequence length dimension

        # Rearrange x and context for compatibility with MultiheadAttention
        x = rearrange(x, 'b t c -> c b t')
        context = rearrange(context, 'b s c -> c b s')
        
        # Transform context to the embedding space
        context = self.linear(context)
        
        # Create a causal mask
        # Assuming x's shape is (embed_dim, batch_size, sequence_length)
        # We create a lower triangular mask of size (sequence_length, sequence_length)
        seq_len = x.shape[2]
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(x.device).bool()

        # Apply causal mask in the attention mechanism
        attn_output, _ = self.multihead_attn(x, context, context, attn_mask=causal_mask)
        
        # Residual connection and layer normalization
        x = x + attn_output
        x = self.layer_norm(x)
        
        # Feed-forward network
        x = x + self.ffn(x)
        
        # Return to original shape
        x = rearrange(x, 'c b t -> b t c')
        return x