In [1]:
import torch
import torch.nn.functional as F

In [12]:
batch, seq_len, embedding_dim = 10, 20, 30  # Example dimensions
x = torch.randn(batch, seq_len, embedding_dim)  # Example tensor of shape (batch_size, seq_len, embedding_dim)

raw_weights = torch.bmm(x, x.transpose(1, 2))  # Compute raw attention weights
normalized_weights = F.softmax(raw_weights, dim=-1)  # Normalize weights using softmax

raw_weights.shape, normalized_weights.shape  # Check shapes of the computed weights

(torch.Size([10, 20, 20]), torch.Size([10, 20, 20]))

In [13]:
y = torch.bmm(normalized_weights, x) # Apply attention weights to the input tensor

This is all we need to get self-attention working in PyTorch.

# Additional tricks
Actual self-attention relies on three additional tricks.

### 1. Query, Key, Value
In self-attention, we compute three different representations of the input: Query, Key, and Value. These are derived from the input tensor and are used to compute attention weights.

We can implement this by defining three linear transformations for the input tensor.
We will use PyTorch's `nn.Linear` to create these transformations, namely:
$W_q$ for Query, $W_k$ for Key, and $W_v$ for Value.

We then compute the raw attention weights by taking the dot product of the Query and Key matrices, followed by a softmax operation to normalize the weights:

$q_i = W_q(x_i), \qquad k_i = W_k(x_i), \qquad v_i = W_v(x_i)$

Then, the raw attention weights are computed as:
$ w'_{ij} = q_i ^T \cdot k_j $

We then normalize these weights using softmax:
$ w_{ij} = \text{softmax}(w'_{ij}) $

Finally, we apply the attention weights to the Value matrix:
$ y_i = \sum_j w_{ij} v_j $

In [14]:
W_q = torch.randn(embedding_dim, embedding_dim)  # Query weight matrix
W_k = torch.randn(embedding_dim, embedding_dim)  # Key weight matrix
W_v = torch.randn(embedding_dim, embedding_dim)  # Value weight matrix

In [17]:
W_q.unsqueeze(0).shape, W_q.unsqueeze(0).expand(batch, -1, -1).shape

(torch.Size([1, 30, 30]), torch.Size([10, 30, 30]))

In [None]:
q = torch.bmm(x, W_q.unsqueeze(0).expand(batch, -1, -1))  # Compute queries

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, k, num_heads=1, mask=False):
        super.__init__()
        assert k % num_heads == 0, "k must be divisible by num_heads"
        self.k, self.num_heads = k, num_heads

        # These compute the queries, keys and values for all
        # heads
        self.W_q = torch.nn.Linear(k, k, bias=False)
        self.W_k = torch.nn.Linear(k, k, bias=False)
        self.W_v = torch.nn.Linear(k, k, bias=False)

        # Applied after the multi-head self-attention operation

    def forward(self, x):
        b, t, k = x.size()
        h = self.num_heads

        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        # to divide queries, keys, values into chunks
        s = k//h

        queries = queries.view(b, t, h, s)
        keys = keys.view(b, t, h, s)
        values = values.view(b, t, h, s)
