In [1]:
import torch
import torch.nn.functional as F

In [2]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Computes scaled dot-product attention.
    
    Args:
        query: Tensor of shape (batch_size, seq_len_q, d_k)
        key: Tensor of shape (batch_size, seq_len_k, d_k)
        value: Tensor of shape (batch_size, seq_len_v, d_v)
        mask: Tensor of shape (batch_size, 1, seq_len_k) or (batch_size, seq_len_q, seq_len_k)

    Returns:
        attention_output: Tensor of shape (batch_size, seq_len_q, d_v)
        attention_weights: Tensor of shape (batch_size, seq_len_q, seq_len_k)
    """
    # Compute scores (batch_size, seq_len_q, seq_len_k)
    d_k = query.size(-1) 
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Compute attention weights
    attention_weights = F.softmax(scores, dim=-1)  # Normalize over seq_len_k

    # Compute weighted sum of values
    attention_output = torch.matmul(attention_weights, value)  # (batch_size, seq_len_q, d_v)

    return attention_output, attention_weights


In [5]:
input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process
 
queries = torch.rand((batch_size, input_seq_length, d_k))
keys = torch.rand((batch_size, input_seq_length, d_k))
values = torch.rand((batch_size, input_seq_length, d_v))
 
print(scaled_dot_product_attention(queries, keys, values))

(tensor([[[0.4433, 0.6196, 0.4616,  ..., 0.7070, 0.6236, 0.3755],
         [0.4315, 0.6178, 0.4547,  ..., 0.6820, 0.6223, 0.3787],
         [0.4412, 0.6139, 0.4564,  ..., 0.7093, 0.6219, 0.3751],
         [0.4459, 0.6204, 0.4533,  ..., 0.6982, 0.6300, 0.3748],
         [0.4320, 0.6125, 0.4729,  ..., 0.7098, 0.6106, 0.3825]],

        [[0.4680, 0.4281, 0.5978,  ..., 0.2877, 0.4766, 0.3701],
         [0.4713, 0.4420, 0.5896,  ..., 0.2842, 0.4769, 0.3599],
         [0.4525, 0.4220, 0.6075,  ..., 0.2948, 0.4650, 0.3747],
         [0.4465, 0.4084, 0.6245,  ..., 0.2980, 0.4618, 0.3819],
         [0.4430, 0.4244, 0.6148,  ..., 0.2979, 0.4555, 0.3687]],

        [[0.5673, 0.5779, 0.6215,  ..., 0.6010, 0.6999, 0.5187],
         [0.5501, 0.5612, 0.5859,  ..., 0.6188, 0.7274, 0.5206],
         [0.5861, 0.5580, 0.6105,  ..., 0.6236, 0.7126, 0.5222],
         [0.5401, 0.5768, 0.6143,  ..., 0.5989, 0.7181, 0.5197],
         [0.5711, 0.5860, 0.6324,  ..., 0.5944, 0.6856, 0.5192]],

        ...,

    

In [4]:
queries

tensor([[[6.3756e-01, 9.0932e-01, 7.0029e-01,  ..., 9.5509e-01,
          6.1598e-01, 1.6296e-01],
         [6.5664e-01, 1.2814e-01, 6.8085e-01,  ..., 4.1152e-01,
          6.7366e-01, 6.2232e-01],
         [9.2248e-01, 4.1807e-01, 9.1756e-01,  ..., 2.8463e-01,
          1.6877e-01, 9.0261e-01],
         [2.2053e-01, 3.3198e-01, 4.3760e-01,  ..., 6.3471e-01,
          3.5757e-01, 7.4478e-01],
         [7.3801e-01, 6.1371e-01, 1.5627e-02,  ..., 9.9245e-01,
          7.0801e-01, 1.7650e-01]],

        [[6.5800e-01, 6.7176e-01, 2.5619e-01,  ..., 7.4651e-01,
          6.8219e-01, 1.8525e-01],
         [2.9137e-01, 3.1204e-01, 6.5756e-01,  ..., 8.5335e-01,
          9.5485e-01, 1.8461e-01],
         [6.1591e-01, 4.8785e-01, 3.4270e-01,  ..., 7.7645e-01,
          5.2495e-01, 7.0066e-01],
         [3.8391e-01, 6.0739e-01, 4.7553e-01,  ..., 7.3965e-01,
          6.7372e-01, 1.7814e-01],
         [6.9567e-01, 6.3398e-02, 2.3861e-01,  ..., 9.9667e-01,
          1.8843e-01, 8.2774e-01]],

      