In [2]:
import torch
import torch.nn.functional as F
# Set random seed for reproducibility
torch.manual_seed(0)

<torch._C.Generator at 0x121c024d0>

In [3]:
# Define dimensions
batch_size = 2    # number of sequences in the batch
seq_length = 4    # number of tokens in each sequence
d_model = 8       # embedding dimension

In [4]:
# Sample input: a batch of sequences of token embeddings
# Shape: [batch_size, seq_length, d_model]
x = torch.rand(batch_size, seq_length, d_model)

Projection:
Three different weight matrices (W_q, W_k, and W_v) are used to linearly project the input embeddings into queries (Q), keys (K), and values (V).

In [5]:
# Define weight matrices for projecting the inputs to queries, keys, and values.
# In a learned model these would be parameters of nn.Linear layers.
W_q = torch.randn(d_model, d_model)
W_k = torch.randn(d_model, d_model)
W_v = torch.randn(d_model, d_model)

In [6]:
# Project the inputs to Q, K, V
# The resulting shapes: [batch_size, seq_length, d_model]
Q = torch.matmul(x, W_q)
K = torch.matmul(x, W_k)
V = torch.matmul(x, W_v)

Computing Scores:
The dot product between Q and the transpose of K is computed for each sequence, resulting in a score matrix of shape [batch_size, seq_length, seq_length].
These scores are scaled by d_model to maintain stable gradients.

In [7]:
# Compute the scaled dot-product attention scores.
# scores shape: [batch_size, seq_length, seq_length]
# We compute Q * K^T for each element in the batch.
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_model, dtype=torch.float32))


In [8]:
# Apply softmax to obtain attention weights. The softmax is applied on the last dimension.
attn_weights = F.softmax(scores, dim=-1)

In [9]:
# Compute the final attention output as the weighted sum of the values.
# Output shape: [batch_size, seq_length, d_model]
attn_output = torch.matmul(attn_weights, V)

In [10]:
# Print shapes and values for inspection
print("Input x shape:", x.shape)
print("Queries Q shape:", Q.shape)
print("Keys K shape:", K.shape)
print("Values V shape:", V.shape)
print("Scores shape:", scores.shape)
print("Attention Weights shape:", attn_weights.shape)
print("Attention Output shape:", attn_output.shape)

Input x shape: torch.Size([2, 4, 8])
Queries Q shape: torch.Size([2, 4, 8])
Keys K shape: torch.Size([2, 4, 8])
Values V shape: torch.Size([2, 4, 8])
Scores shape: torch.Size([2, 4, 4])
Attention Weights shape: torch.Size([2, 4, 4])
Attention Output shape: torch.Size([2, 4, 8])


In [11]:
# Optionally, print the computed tensors (or a summary) to see the actual values.
print("\nAttention Weights:\n", attn_weights)
print("\nAttention Output:\n", attn_output)


Attention Weights:
 tensor([[[0.1208, 0.4147, 0.0720, 0.3925],
         [0.1173, 0.5753, 0.0603, 0.2471],
         [0.1112, 0.4619, 0.0702, 0.3567],
         [0.1528, 0.4312, 0.1098, 0.3061]],

        [[0.1542, 0.3456, 0.1380, 0.3622],
         [0.0951, 0.4722, 0.2063, 0.2265],
         [0.0795, 0.0594, 0.7971, 0.0640],
         [0.2178, 0.2033, 0.1923, 0.3866]]])

Attention Output:
 tensor([[[-0.1093,  0.1661,  1.3634,  0.7949, -0.6423,  0.6541,  1.7390,
          -1.2077],
         [-0.1311,  0.1290,  1.2045,  0.7495, -0.4215,  0.4963,  1.5629,
          -1.0194],
         [-0.1063,  0.1598,  1.3230,  0.7857, -0.5873,  0.6143,  1.6820,
          -1.1566],
         [-0.2032,  0.1623,  1.2805,  0.7886, -0.5417,  0.6091,  1.7694,
          -1.1958]],

        [[-0.7167,  0.3874,  0.9260,  0.4987, -0.4800,  1.2840,  1.6569,
          -1.6015],
         [-0.6674,  0.2024,  1.0323,  0.4462, -0.3512,  1.2285,  1.5263,
          -1.5770],
         [-0.0712,  0.4308,  0.5893,  1.1798, -0.07