In [2]:
import re
import numpy as np
from gensim.models import Word2Vec

text = "The animal did not cross the street because it was tired."
tokens = re.findall(r"\w+", text.lower())

sentences = [tokens]

word2vec_model = Word2Vec(
    sentences,
    vector_size=64,
    window=3,
    min_count=1,
    sg=1
)

X = np.array([word2vec_model.wv[word] for word in tokens])  # (seq_len, d_model)

def positional_encoding(seq_len, d_model):
    PE = np.zeros((seq_len, d_model))
    for pos in range(seq_len):
        for i in range(0, d_model, 2):
            PE[pos, i] = np.sin(pos / (10000 ** (i/d_model)))
            PE[pos, i+1] = np.cos(pos / (10000 ** (i/d_model)))
    return PE

X = X + positional_encoding(len(tokens), 64)

In [9]:
X.shape

(11, 64)

In [3]:
d_model = 64
num_heads = 8

d_k = d_v = d_model // num_heads
seq_len = X.shape[0]

In [6]:
W_Q = np.random.randn(d_model, d_model) / np.sqrt(d_model)
W_K = np.random.randn(d_model, d_model) / np.sqrt(d_model)
W_V = np.random.randn(d_model, d_model) / np.sqrt(d_model)

W_O = np.random.randn(d_model, d_model) / np.sqrt(d_model)

In [7]:
Q = np.matmul(X, W_Q)
K = np.matmul(X, W_K)
V = np.matmul(X, W_V)

In [None]:
Q = Q.reshape(seq_len, num_heads, d_k).transpose(1,0,2)
K = K.reshape(seq_len, num_heads, d_k).transpose(1,0,2)
V = V.reshape(seq_len, num_heads, d_v).transpose(1,0,2)

In [13]:
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / e_x.sum(axis=-1, keepdims=True)

In [24]:
# Attention per head

heads = []
attention_weights = []

for h in range(num_heads):
    scores = np.matmul(Q[h], np.transpose(K[h]))
    scores = scores / np.sqrt(d_k)
    weights = softmax(scores)
    attention_weights.append(weights)
    head = np.matmul(weights, V[h])
    heads.append(head)

attention_weights = np.array(attention_weights)

In [25]:
heads[0].shape

(11, 8)

In [26]:
concat = np.concatenate(heads, axis=-1)

In [27]:
concat.shape

(11, 64)

In [28]:
output = np.matmul(concat, W_O)

In [22]:
output.shape

(11, 64)

In [29]:
attention_weights.shape

(8, 11, 11)

In [30]:
idx_it = tokens.index("it")
idx_animal = tokens.index("animal")

for h in range(num_heads):
    print(f"Head {h}: it → animal = {attention_weights[h, idx_it, idx_animal]}")


Head 0: it → animal = 0.10873981931134619
Head 1: it → animal = 0.08316267920094819
Head 2: it → animal = 0.10025777849543988
Head 3: it → animal = 0.06540342354556555
Head 4: it → animal = 0.06514002725719846
Head 5: it → animal = 0.08499400186878146
Head 6: it → animal = 0.12383233719946195
Head 7: it → animal = 0.1018221892779414


In [31]:
for h in range(num_heads):
    row = attention_weights[h, idx_it]
    top = np.argsort(row)[::-1][:5]
    print(f"\nHead {h} top attention for 'it':")
    for i in top:
        print(tokens[i], row[i])



Head 0 top attention for 'it':
the 0.13772885023266387
because 0.11307310648916988
animal 0.10873981931134619
street 0.10566136931353416
it 0.09870851230010597

Head 1 top attention for 'it':
the 0.11290876025590166
cross 0.10897842499740756
street 0.10232231752265228
not 0.09477035922187571
the 0.08795316236034434

Head 2 top attention for 'it':
not 0.12133748410945817
did 0.11548234694473876
cross 0.11414343649127572
animal 0.10025777849543988
the 0.09801103478954434

Head 3 top attention for 'it':
it 0.15088632208246516
because 0.13793642197045644
was 0.13356362087063622
tired 0.10935559517151136
street 0.10238877138965315

Head 4 top attention for 'it':
tired 0.11149779657639766
the 0.10503659010827242
was 0.10457976283216136
street 0.10219177253429736
because 0.09888239090563967

Head 5 top attention for 'it':
it 0.11082552437357568
because 0.10339021745940916
did 0.10338384097839297
not 0.10197611642765242
was 0.0959354275705214

Head 6 top attention for 'it':
the 0.135149303166