In [1]:
# Partly plagiarized from: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html#:~:text=What%20is%20cross%2Dattention%2C%20and,combine%20two%20different%20input%20sequences.
import torch
import torch.nn.functional as F

# Plotting utils
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

In [2]:
# Source: Oceanic, Greg Egan (Highly recommended) | Slightly edited
sentence = "Mathematics catalogues everything not self-contradictory; within its vast inventory, physics is an island of structures rich enough to contain their own beholders."

In [3]:
# Some clean-up
sentence = sentence.lower()
sentence = sentence.replace(',', '').replace(';', '').replace('.', '')
sentence

'mathematics catalogues everything not self-contradictory within its vast inventory physics is an island of structures rich enough to contain their own beholders'

In [4]:
listofwords = sentence.split()
listofwords, len(listofwords)

(['mathematics',
  'catalogues',
  'everything',
  'not',
  'self-contradictory',
  'within',
  'its',
  'vast',
  'inventory',
  'physics',
  'is',
  'an',
  'island',
  'of',
  'structures',
  'rich',
  'enough',
  'to',
  'contain',
  'their',
  'own',
  'beholders'],
 22)

In [5]:
# Create a vocabulary or dictionary mapping words to integers
vocab = {s:i for i, s in enumerate(sorted(listofwords))}
vocab

{'an': 0,
 'beholders': 1,
 'catalogues': 2,
 'contain': 3,
 'enough': 4,
 'everything': 5,
 'inventory': 6,
 'is': 7,
 'island': 8,
 'its': 9,
 'mathematics': 10,
 'not': 11,
 'of': 12,
 'own': 13,
 'physics': 14,
 'rich': 15,
 'self-contradictory': 16,
 'structures': 17,
 'their': 18,
 'to': 19,
 'vast': 20,
 'within': 21}

In [6]:
# Convert the sentence to an integer vector, using the vocab
import torch

torch.manual_seed(123) # Seed for the pRNG - for reproducibility

sentence_int = torch.tensor([vocab[s] for s in listofwords])
sentence_int

tensor([10,  2,  5, 11, 16, 21,  9, 20,  6, 14,  7,  0,  8, 12, 17, 15,  4, 19,
         3, 18, 13,  1])

### Generate embeddings for each word in the sentence

In [7]:
# torch.nn.Embedding is a linear transformation using weights, sampled from the Normal distribution, to create the embeddings.
# Uncomment and run the following to learn more about torch.nn.Embedding:
# ? torch.nn.Embedding

In [8]:
# Vocab size
n = 22
# Embedding size
d = 64

embed = torch.nn.Embedding(num_embeddings=n, embedding_dim=d)
embedded_sentence = embed(sentence_int).detach() # Ignore detach()

embedded_sentence, embedded_sentence.shape

(tensor([[-1.1065,  1.2682,  0.3147,  ...,  0.4466, -0.8970,  0.1009],
         [-0.2582, -2.0407, -0.8016,  ...,  0.1132,  0.8365,  0.0285],
         [ 0.3277, -0.8331, -1.6179,  ..., -1.7984, -0.6822, -0.5191],
         ...,
         [-1.0693,  0.4660,  0.7012,  ...,  2.7196,  0.4816,  0.2409],
         [-1.4284,  0.5617,  0.7701,  ..., -1.3574, -1.1745, -0.5126],
         [ 0.5146,  0.9938, -0.2587,  ...,  1.2774, -1.4596, -2.1595]]),
 torch.Size([22, 64]))

### Vanilla Attention

Let $X_{n \times d}$ denote our `embedded_sentence` (i.e., matrix storing embeddings for each word), where $n$ is the vocab size and $d$ is the embedding size. And, let $W_{n\times n}$ denote a "similarity matrix", which stores similarity scores between embeddings for each word in the `sentence`. We can use $W$ as a weight matrix to compute attention weights, stored in $Y_{n \times d}$. This gives us "vanilla attention".
$$
\begin{align}
W_{n\times n} &:= XX^T \\
W_{n\times n} &:= X_{n \times d} \times X_{d \times n}^T \\
Y_{n \times d} &:= W_{n \times n} \times X_{n \times d}
\end{align}
$$

In [9]:
X = embedded_sentence

W = X @ X.T # Or, use torch.matmul()
Y = W @ X

W.shape, Y.shape, X.shape

(torch.Size([22, 22]), torch.Size([22, 64]), torch.Size([22, 64]))

In [10]:
# Plotting Similarity Heatmap
fig = px.imshow(W,
                labels=dict(x="Words", y="Words", color="Similarity"),
                x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Similarity Heatmap")

fig.show()

In [11]:
# Plotting the attention weights
fig = px.imshow(Y,
                labels=dict(x="Embedding dimension", y="Words", color="Attention Weights"),
                # x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Attention Weights")

fig.show()

In [12]:
# Rescaling W and then applying SoftMax to turn the values into probabilties.
# Uncomment and run the following to learn more.
# ? F.softmax

In [13]:
W2 = W / torch.sqrt(torch.tensor(d).float()) # W / sqrt(k)
W2 = F.softmax(W2, dim=-1) 

# Plotting Similarity Heatmap
fig = px.imshow(W2,
                labels=dict(x="Words", y="Words", color="Similarity"),
                x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Similarity Heatmap")

fig.show()

In [14]:
Y2 = W2 @ X

# Plotting the attention weights
fig = px.imshow(Y2,
                labels=dict(x="Embedding dimension", y="Words", color="Attention Weights"),
                # x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Attention Weights")

fig.show()

## Self-Attention (AKA Scaled Dot-Product Attention)

### Define the Weight Matrices for Self-Attention

In this case, we have three weight matrices $W^Q_{d \times d_q}$, $W^K_{d \times d_k}$ ($d_q = d_k$), and $W^V_{d \times d_v}$. Here $d_{\{i\}}, i \in [\text{q, k, v}]$ is the output dimension.

$$
\begin{align}
Q_{n \times d_q} &:= X_{n \times d} \times W_{d \times d_q}^Q \\
K_{n \times d_k} &:= X_{n \times d} \times W_{d \times d_k}^K \\
V_{n \times d_v} &:= X_{n \times d} \times W_{d \times d_v}^V
\end{align}
$$

$$
\begin{align}
W_{n\times n} &:= Q_{n\times d_q} \times K_{d_k \times n}^T \\
W_{n\times n} &:= \text{softmax}\left(\frac{W_{n\times n}}{\sqrt{d_q}}\right) \\
Y_{n \times d_v} &:= W_{n\times n} \times V_{n \times d_v}
\end{align}
$$

In [15]:
# d = 64
d_q = d_k = 32 # To satisfy the dimension requirements for multiplying Q & K
d_v = 48 # d_v can be arbitrary, though usually kept the same as for Q & K

W_Q = torch.nn.Parameter(torch.rand(d, d_q))
W_K = torch.nn.Parameter(torch.rand(d, d_k))
W_V = torch.nn.Parameter(torch.rand(d, d_v))

W_Q.shape, W_K.shape, W_V.shape

(torch.Size([64, 32]), torch.Size([64, 32]), torch.Size([64, 48]))

In [16]:
W_Q

Parameter containing:
tensor([[0.0605, 0.2699, 0.3683,  ..., 0.3112, 0.9992, 0.9132],
        [0.0440, 0.0074, 0.2083,  ..., 0.8483, 0.9896, 0.1457],
        [0.3154, 0.6381, 0.6555,  ..., 0.6815, 0.6295, 0.5264],
        ...,
        [0.7914, 0.1441, 0.5487,  ..., 0.1384, 0.0780, 0.5441],
        [0.0409, 0.8966, 0.1534,  ..., 0.7153, 0.6619, 0.6170],
        [0.7900, 0.6138, 0.2550,  ..., 0.8380, 0.2511, 0.0179]],
       requires_grad=True)

In [17]:
Q = X @ W_Q
K = X @ W_K
V = X @ W_V

W = Q @ K.T
W = W / torch.sqrt(torch.tensor(d_q).float())
W = F.softmax(W, dim=-1)
Y = W @ V

W.shape, Y.shape

(torch.Size([22, 22]), torch.Size([22, 48]))

In [18]:
# Plotting Similarity Heatmap - NOT TRAINED
fig = px.imshow(W.detach(),
                labels=dict(x="Words", y="Words", color="Similarity"),
                x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Similarity Heatmap (Untrained)")

fig.show()

In [19]:
# Plotting the attention weights - NOT TRAINED
fig = px.imshow(Y.detach(),
                labels=dict(x="Values (V) Dimension", y="Words", color="Attention Weights"),
                # x=list(vocab.keys()),
                y=list(vocab.keys()),
                color_continuous_scale="Viridis")

fig.update_layout(title="Attention Weights (Untrained)")

fig.show()

### MultiHead Attention

In [20]:
# Number of heads
h = 4
MW_Q = torch.nn.Parameter(torch.rand(h, d_q, d))
MW_K = torch.nn.Parameter(torch.rand(h, d_k, d))
MW_V = torch.nn.Parameter(torch.rand(h, d_v, d))

In [21]:
x_1 = embedded_sentence[1] # Second word's embedding
x_1, x_1.shape

(tensor([-0.2582, -2.0407, -0.8016, -0.8183, -1.1820, -0.2877, -0.6043,  0.6002,
         -1.4053, -0.5922, -0.2548,  1.1517, -0.0179,  0.4264, -0.7657, -0.0545,
         -1.2743,  0.4513, -0.2280,  0.9224,  0.2056, -0.4970,  0.5821,  0.2053,
         -0.3018, -0.6703, -0.6171, -0.8334,  0.4839, -0.1349,  0.2119, -0.8714,
          0.6851,  2.0024, -0.5469,  1.6014, -2.2577, -1.8009,  0.7015,  0.5703,
         -1.1766, -2.0524,  0.1132,  1.4353,  0.0883, -1.2037,  1.0964,  2.4210,
          0.1538, -0.4452,  0.5503,  0.0658,  0.6805,  1.2064,  1.6250,  0.3459,
          0.1343,  0.7662,  2.2760, -1.3255, -0.8970,  0.1132,  0.8365,  0.0285]),
 torch.Size([64]))

In [22]:
# Multihead query vectors for the second word
mq2 = MW_Q @ x_1
mq2.shape # h=4, d_q=32

torch.Size([4, 32])

In [23]:
# To compute for all the heads
stacked_inputs = embedded_sentence.T.repeat(h, 1, 1)
stacked_inputs, stacked_inputs.shape

(tensor([[[-1.1065, -0.2582,  0.3277,  ..., -1.0693, -1.4284,  0.5146],
          [ 1.2682, -2.0407, -0.8331,  ...,  0.4660,  0.5617,  0.9938],
          [ 0.3147, -0.8016, -1.6179,  ...,  0.7012,  0.7701, -0.2587],
          ...,
          [ 0.4466,  0.1132, -1.7984,  ...,  2.7196, -1.3574,  1.2774],
          [-0.8970,  0.8365, -0.6822,  ...,  0.4816, -1.1745, -1.4596],
          [ 0.1009,  0.0285, -0.5191,  ...,  0.2409, -0.5126, -2.1595]],
 
         [[-1.1065, -0.2582,  0.3277,  ..., -1.0693, -1.4284,  0.5146],
          [ 1.2682, -2.0407, -0.8331,  ...,  0.4660,  0.5617,  0.9938],
          [ 0.3147, -0.8016, -1.6179,  ...,  0.7012,  0.7701, -0.2587],
          ...,
          [ 0.4466,  0.1132, -1.7984,  ...,  2.7196, -1.3574,  1.2774],
          [-0.8970,  0.8365, -0.6822,  ...,  0.4816, -1.1745, -1.4596],
          [ 0.1009,  0.0285, -0.5191,  ...,  0.2409, -0.5126, -2.1595]],
 
         [[-1.1065, -0.2582,  0.3277,  ..., -1.0693, -1.4284,  0.5146],
          [ 1.2682, -2.0407,

In [24]:
M_Q = torch.bmm(MW_Q, stacked_inputs)
M_K = torch.bmm(MW_K, stacked_inputs)
M_V= torch.bmm(MW_V, stacked_inputs)

M_Q.shape, M_K.shape, M_V.shape

(torch.Size([4, 32, 22]), torch.Size([4, 32, 22]), torch.Size([4, 48, 22]))

In [25]:
M_Q = M_Q.permute(0, 2, 1)
M_K = M_K.permute(0, 2, 1)
M_V = M_V.permute(0, 2, 1)

M_Q.shape, M_K.shape, M_V.shape

(torch.Size([4, 22, 32]), torch.Size([4, 22, 32]), torch.Size([4, 22, 48]))

### Try implementing the attention weight calculation for MHA

In [26]:
# your code here

### Bonus: SoftMax Function

In [27]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x))  # Subtracting np.max(x) for numerical stability

    return e_x / e_x.sum(axis=0)

# Generate some input values
x_values = np.arange(-5, 5, 0.1)
y_values = softmax(x_values) # Try both functions
# y_values = F.softmax(torch.Tensor(x_values), dim=-1)


fig = px.line(x=x_values, y=y_values, labels={'x': 'Input Values', 'y': 'Probability'},
              title='Softmax Function', line_shape='linear')

fig.update_xaxes(title_text='Input Values')
fig.update_yaxes(title_text='Probability')
fig.update_layout(showlegend=False, xaxis=dict(showgrid=True, zeroline=False),
                  yaxis=dict(showgrid=True, zeroline=False))

fig.show()