In [1]:
import torch
from modules import MultiHeadAttentionWithBias

In [15]:
# Create layer
layer = MultiHeadAttentionWithBias(
    embed_dim=32, 
    num_heads=4,
    proj_bias=False,
    attention_dropout=0.1,
    kv_dim=None, # specify this argument if the input dims of query and key are different
    )
# For demonstration purposes, we will run the layer in inference mode to avoid the effect of dropout. 
# When initiate this layer, the following line is not necessary.
layer.training = False

In [3]:
# Example inputs
query = torch.rand((5, 10, 32)) # (bsz, q_len, embed_dim)
key = torch.rand((5, 15, 32)) # (bsz, k_len, embed_dim)
attention_bias = torch.rand((5, 10, 15)) # (bsz, q_len, k_len)

# Use cases
## Self-attention
Self-attention is executed when both `key_value` and `past_keys_values` are not provided (think of the self-attention in the Encoder of Transformer); or `key_value` is not provided and `is_autoregressive` is set to `True` (think of the self-attention in the Decoder of Transformer). In the case of the latter, the projected query will be returned to be reused if `return_current_keys_values` is set to `True`. 

### Self-attention without caching projected query:

In [7]:
# Simplest usage
hidden_query, = layer(query)
print(hidden_query.shape) # (bsz, q_len, embed_dim)

torch.Size([5, 10, 32])


In [18]:
# Getting the attention weight as output
hidden_query, attn_weight = layer(query, return_attn_weights=True)
print("hidden_query shape: ", hidden_query.shape) # (bsz, q_len, embed_dim)
print("attn-weight shape: ", attn_weight.shape) # (bsz, num_heads, q_len, q_len)
print("Sum of attn_weight is one along the last dimension: \n", attn_weight.sum(-1))

hidden_query shape:  torch.Size([5, 10, 32])
attN-weight shape:  torch.Size([5, 4, 10, 10])
Sum of attn_weight is one along the last dimension: 
 tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 

### Self-attention with caching for auto-regressive decoder

In [35]:
# Running the layer for 5 iteration
past_keys_values = None # None initialization
for _ in range(5):
    hidden_query, past_keys_values = layer(query, 
        past_keys_values=past_keys_values,
        is_autoregressive=True,
        return_current_keys_values=True)

past_keys, past_values = past_keys_values
print(past_keys.shape) # (bsz, 5*q_len, embed_dim)
print(past_values.shape) # (bsz, 5*q_len, embed_dim)

torch.Size([5, 50, 32])
torch.Size([5, 50, 32])


## Cross-attention
### Cross-attention without caching

In [25]:
# Simplest usage
hidden_query, = layer(query, key)
print(hidden_query.shape) # (bsz, q_len, embed_dim)

torch.Size([5, 10, 32])


In [26]:
# Getting the attention weight as output
hidden_query, attn_weight = layer(query, key, return_attn_weights=True)
print("hidden_query shape: ", hidden_query.shape) # (bsz, q_len, embed_dim)
print("attn-weight shape: ", attn_weight.shape) # (bsz, num_heads, q_len, k_len)
print("Sum of attn_weight is one along the last dimension: \n", attn_weight.sum(-1))

hidden_query shape:  torch.Size([5, 10, 32])
attn-weight shape:  torch.Size([5, 4, 10, 15])
Sum of attn_weight is one along the last dimension: 
 tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 

### Self-attention with caching for auto-regressive decoder

In [34]:
# Running the layer for 5 iteration
past_keys_values = None # None initialization
for _ in range(5):
    hidden_query, past_keys_values = layer(query, key,
        past_keys_values=past_keys_values,
        is_autoregressive=True,
        return_current_keys_values=True)

past_keys, past_values = past_keys_values
print(past_keys.shape) # (bsz, 5*k_len, embed_dim)
print(past_values.shape) # (bsz, 5*k_len, embed_dim)

torch.Size([5, 75, 32])
torch.Size([5, 75, 32])


# Including attention bias

In [28]:
hidden_query, = layer(query, key,
    attention_bias=attention_bias)
print(hidden_query.shape)

torch.Size([5, 10, 32])


# Including attention mask

In [33]:
# define a dummy attention mask. Let's say we want to mask out the last 3 tokens of query and last 5 tokens of key for every batch
attention_mask = torch.zeros((5, 10, 15))
for i in range(7):
    for j in range(10):
        attention_mask[:, i, j] = torch.ones((5))
hidden_query, attn_weight = layer(query, key,
    attention_bias=attention_bias,
    attention_mask=attention_mask,
    return_attn_weights=True)

print("The last five columns and the last three rows are masked: \n", attn_weight[0, 0])

The last five columns and the last three rows are masked: 
 tensor([[0.1614, 0.0603, 0.0999, 0.1353, 0.0606, 0.1235, 0.0741, 0.0832, 0.0803,
         0.1214, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1369, 0.0959, 0.1065, 0.1270, 0.0598, 0.1096, 0.0756, 0.1253, 0.0661,
         0.0973, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0844, 0.0825, 0.1149, 0.0698, 0.1559, 0.1376, 0.0790, 0.0837, 0.0958,
         0.0964, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0806, 0.1615, 0.0681, 0.0829, 0.1467, 0.0608, 0.1125, 0.0659, 0.1229,
         0.0981, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0829, 0.0906, 0.1431, 0.1379, 0.0601, 0.0558, 0.1146, 0.0677, 0.0878,
         0.1597, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0995, 0.1261, 0.1343, 0.0611, 0.0678, 0.1146, 0.0955, 0.1047, 0.1005,
         0.0959, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1163, 0.1233, 0.0981, 0.0767, 0.1030, 0.1043, 0.1125, 0.0504, 0.1219,
         0.0935, 0.0000, 