In [1]:
import torch
from modules import MultiHeadAttentionWithBias

In [4]:
# Create layer
layer = MultiHeadAttentionWithBias(
    embed_dim=32, 
    num_heads=4,
    proj_bias=False,
    attention_dropout=0.1,
    kv_dim=None, # specify this argument if the input dims of query and key are different
    )
# For demonstration purposes, we will run the layer in inference mode to avoid the effect of dropout. 
# When initiate this layer, the following line is not necessary.
layer.training = False

In [5]:
# Example inputs
query = torch.rand((5, 10, 32)) # (bsz, q_len, embed_dim)
key = torch.rand((5, 15, 32)) # (bsz, k_len, embed_dim)
attention_bias = torch.rand((5, 10, 15)) # (bsz, q_len, k_len)

# Use cases
## Self-attention
Self-attention is executed when both `key_value` and `keys_values_cache` are not provided (think of the self-attention in the Encoder of Transformer); or `key_value` is not provided and `update_cache` is set to `True` (think of the self-attention in the Decoder of Transformer). In the case of the latter, the projected query will be returned to be reused if `return_current_keys_values` is set to `True`. 

### Self-attention without caching projected query:

In [6]:
# Simplest usage
hidden_query, = layer(query)
print(hidden_query.shape) # (bsz, q_len, embed_dim)

torch.Size([5, 10, 32])


In [7]:
# Getting the attention weight as output
hidden_query, attn_weight = layer(query, return_attn_weights=True)
print("hidden_query shape: ", hidden_query.shape) # (bsz, q_len, embed_dim)
print("attn_weight shape: ", attn_weight.shape) # (bsz, num_heads, q_len, q_len)
print("Sum of attn_weight is one along the last dimension: \n", attn_weight.sum(-1))

hidden_query shape:  torch.Size([5, 10, 32])
attn_weight shape:  torch.Size([5, 4, 10, 10])
Sum of attn_weight is one along the last dimension: 
 tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 

### Self-attention with caching for auto-regressive decoder
In the scenario of self-attention in Transformer Decoder during inference, the query is updated auto-regressively, and self-attention is performed at every time-step. To avoid the recalculation of the projection of the past queries, we cache it into `keys_values_cache`, and update it at each time-step by setting `update_keys_values_cache` to `True`.

In [9]:
# Running the layer for 5 iterations
keys_values_cache = None # None initialization
for _ in range(5):
    hidden_query, keys_values_cache = layer(query, 
        keys_values_cache=keys_values_cache,
        update_keys_values_cache=True,
        return_current_keys_values=True)

past_keys, past_values = keys_values_cache
print(past_keys.shape) # (bsz, 5*q_len, embed_dim)
print(past_values.shape) # (bsz, 5*q_len, embed_dim)

torch.Size([5, 50, 32])
torch.Size([5, 50, 32])


## Cross-attention
### Cross-attention without caching

In [10]:
# Simplest usage
hidden_query, = layer(query, key)
print(hidden_query.shape) # (bsz, q_len, embed_dim)

torch.Size([5, 10, 32])


In [11]:
# Getting the attention weight as output
hidden_query, attn_weight = layer(query, key, return_attn_weights=True)
print("hidden_query shape: ", hidden_query.shape) # (bsz, q_len, embed_dim)
print("attn-weight shape: ", attn_weight.shape) # (bsz, num_heads, q_len, k_len)
print("Sum of attn_weight is one along the last dimension: \n", attn_weight.sum(-1))

hidden_query shape:  torch.Size([5, 10, 32])
attn-weight shape:  torch.Size([5, 4, 10, 15])
Sum of attn_weight is one along the last dimension: 
 tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 

### Self-attention with caching for auto-regressive decoder
In this case, we want to perform attention from different queries to the same key (think of the cross-attention in the decoder of Transformer during inference). Therefore, to avoid recaculating the projection of the key, we cache it into `keys_values_cache`. Different from self-attention, we will not update `keys_values_cache`, since the key is fixed, by setting `update_keys_values_cache` to `False`. In such case, if both `key` and `keys_values_cache` are provided, `key` will be ignored.

In [12]:
# Running the layer for 5 iterations
keys_values_cache = None # None initialization
# Only at the first iteration, the :key: argument is used. Afterward, the layer uses :keys_values_cache: instead, and ignores :key:.
for _ in range(5):
    hidden_query, keys_values_cache = layer(query, key,
        keys_values_cache=keys_values_cache,
        # update_keys_values_cache=False,
        return_current_keys_values=True)

past_keys, past_values = keys_values_cache
print(past_keys.shape) # (bsz, k_len, embed_dim)
print(past_values.shape) # (bsz, k_len, embed_dim)

torch.Size([5, 15, 32])
torch.Size([5, 15, 32])


# Including attention bias

In [13]:
hidden_query, = layer(query, key,
    attention_bias=attention_bias)
print(hidden_query.shape)

torch.Size([5, 10, 32])


# Including attention mask

In [14]:
# define a dummy attention mask. Let's say we want to mask out the last 3 tokens of query and last 5 tokens of key for every batch
attention_mask = torch.zeros((5, 10, 15))
for i in range(7):
    for j in range(10):
        attention_mask[:, i, j] = torch.ones((5))
hidden_query, attn_weight = layer(query, key,
    attention_bias=attention_bias,
    attention_mask=attention_mask,
    return_attn_weights=True)

print("The last five columns and the last three rows are masked: \n", attn_weight[0, 0])

The last five columns and the last three rows are masked: 
 tensor([[0.0591, 0.1158, 0.0931, 0.1535, 0.1061, 0.0920, 0.1145, 0.0699, 0.1090,
         0.0871, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0962, 0.0664, 0.1532, 0.0722, 0.0799, 0.0647, 0.0907, 0.1626, 0.1158,
         0.0982, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0832, 0.1229, 0.1075, 0.1069, 0.0832, 0.0537, 0.1680, 0.0825, 0.0646,
         0.1275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0752, 0.0747, 0.0772, 0.1410, 0.1029, 0.0584, 0.1310, 0.1429, 0.0967,
         0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1234, 0.0695, 0.0999, 0.0698, 0.0975, 0.0803, 0.1269, 0.1555, 0.0720,
         0.1051, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0684, 0.1184, 0.0678, 0.1099, 0.0745, 0.1369, 0.0762, 0.0945, 0.1356,
         0.1178, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0686, 0.0740, 0.1078, 0.1014, 0.1097, 0.0935, 0.1322, 0.1344, 0.0752,
         0.1032, 0.0000, 