In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn


In [3]:
# Example usage
temperature = 0.5
attn_dropout = 0.1

# Create an instance of ScaledDotProductAttention
attention = ScaledDotProductAttention(temperature, attn_dropout)

# Generate random input tensors
batch_size = 16
num_heads = 4
query_length = 10
key_length = 10
value_length = 10
d_model = 32

q = torch.randn(batch_size, num_heads, query_length, d_model)
k = torch.randn(batch_size, num_heads, key_length, d_model)
v = torch.randn(batch_size, num_heads, value_length, d_model)

# Perform attention calculation
output, attn = attention(q, k, v)

# Print the shapes of the output and attention weights
print("Output shape:", output.shape)
print("Attention weights shape:", attn.shape)

Output shape: torch.Size([16, 4, 10, 32])
Attention weights shape: torch.Size([16, 4, 10, 10])


In [19]:
q[15]

tensor([[[ 0.4623,  1.0891,  0.8992,  ..., -1.5840, -0.7746,  0.4222],
         [ 1.9272, -1.4795, -0.8634,  ..., -0.3600, -2.4656, -1.6880],
         [-0.0750, -0.3233, -0.6432,  ...,  0.7246, -0.8531,  0.6246],
         ...,
         [ 1.5696, -0.4765, -0.7247,  ...,  1.1428,  0.3428, -0.4813],
         [ 1.2285, -0.2478, -0.0394,  ..., -0.2524, -1.5987,  1.4104],
         [-0.9387,  0.1707, -0.0802,  ...,  0.4169, -0.1990,  0.3882]],

        [[ 1.1757,  1.3537, -0.4263,  ...,  0.4719, -0.5359,  0.7690],
         [ 2.3273,  0.8985,  1.0585,  ..., -1.2662, -1.9204, -1.9196],
         [ 0.6902, -0.3572, -2.7102,  ...,  0.3293, -0.4928,  0.7565],
         ...,
         [-1.4449,  0.7698, -0.6906,  ..., -0.4847,  1.2833,  1.0203],
         [-0.6540, -0.9430, -1.0137,  ...,  0.0882, -0.2902, -2.0249],
         [-0.2889,  0.7320,  1.0809,  ...,  1.2050, -0.1545,  0.1530]],

        [[-0.6106, -0.6404,  1.2460,  ...,  0.7122, -0.2981,  0.3409],
         [ 1.8062,  0.3735, -2.4609,  ..., -1

In [10]:
k.transpose(2,3).shape

torch.Size([16, 4, 32, 10])

In [4]:
print(q.shape)

torch.Size([16, 4, 10, 32])


In [12]:
torch.matmul(q, k.transpose(2, 3)).shape

torch.Size([16, 4, 10, 10])

In [5]:
print(output)

tensor([[[[ 1.2636e+00,  1.4966e+00,  8.6099e-01,  ..., -8.1101e-01,
           -1.0101e+00,  2.0210e+00],
          [ 1.2327e+00,  6.7414e-01,  1.3825e-02,  ..., -2.9316e-01,
            1.0192e+00, -6.3272e-01],
          [-1.4290e+00, -5.9201e-01,  1.0125e+00,  ...,  2.1902e+00,
           -9.0501e-01,  1.7709e-01],
          ...,
          [ 5.6921e-01,  5.5836e-01,  2.8869e-01,  ...,  1.5721e-01,
            5.0428e-01, -8.5901e-01],
          [ 1.3018e+00,  1.5802e+00,  9.3175e-01,  ..., -8.7202e-01,
           -5.6696e-01,  1.5901e+00],
          [ 5.0998e-01,  3.5238e-01,  1.0520e+00,  ..., -1.0474e+00,
            1.1918e+00, -8.1592e-01]],

         [[-6.8774e-01, -4.4927e-01, -3.4106e-01,  ..., -2.8457e+00,
            1.0816e+00, -4.7335e-01],
          [-6.5265e-01, -4.2920e-01, -3.5713e-01,  ..., -2.7635e+00,
            9.7008e-01, -4.1293e-01],
          [-1.3177e+00, -7.9676e-01, -6.8202e-01,  ...,  2.1684e+00,
           -8.6106e-01, -9.9091e-01],
          ...,
     

In [6]:
print(attn)

tensor([[[[9.4468e-01, 9.7510e-18, 0.0000e+00,  ..., 0.0000e+00,
           1.6643e-01, 4.7429e-20],
          [1.2316e-04, 1.9971e-11, 0.0000e+00,  ..., 1.1100e+00,
           4.7168e-05, 6.2452e-15],
          [5.8909e-03, 7.4559e-09, 9.9872e-06,  ..., 1.6094e-12,
           1.1029e+00, 0.0000e+00],
          ...,
          [2.2999e-05, 3.3986e-01, 8.8623e-10,  ..., 7.4013e-01,
           2.7496e-10, 3.3385e-06],
          [7.8586e-01, 0.0000e+00, 3.1357e-03,  ..., 3.3402e-12,
           1.2683e-16, 0.0000e+00],
          [1.0195e-10, 0.0000e+00, 5.3173e-12,  ..., 4.2602e-11,
           0.0000e+00, 1.1067e+00]],

         [[4.0311e-10, 5.2481e-05, 4.4362e-08,  ..., 2.7263e-08,
           1.1108e+00, 2.6764e-04],
          [3.0444e-07, 1.8827e-10, 1.6555e-11,  ..., 3.3612e-02,
           1.0769e+00, 2.1057e-08],
          [0.0000e+00, 1.0808e+00, 1.4976e-03,  ..., 2.3543e-03,
           4.2537e-12, 3.4160e-09],
          ...,
          [0.0000e+00, 5.2919e-07, 1.0558e+00,  ..., 5.1442