In [2]:
import torch
import torch.nn.functional as F

In [3]:


inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [4]:
seq_len , d_model = inputs.shape

In [5]:
d_k = 2

W_q = torch.nn.Parameter(torch.randn(d_model, d_k))
W_k = torch.nn.Parameter(torch.randn(d_model, d_k))
W_v = torch.nn.Parameter(torch.randn(d_model, d_k))

In [6]:
def self_attention(X, W_q, W_k, W_v):
  seq_len = X.shape[0]
  d_k = W_q.shape[1]

  Q = X @ W_q
  K = X @ W_k
  V = X @ W_v

  scores = (Q @ K.T) / (d_k ** 0.5)

  mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
  scores = scores.masked_fill(mask == 1, float('-inf'))

  attn_weights = F.softmax(scores, dim =-1)
  output = attn_weights @V

  return output, attn_weights

In [7]:
output, attn_weighs = self_attention(inputs, W_q, W_k, W_v)

print(output)
print(attn_weighs)

tensor([[0.4509, 0.6372],
        [0.3008, 1.3833],
        [0.2660, 1.6186],
        [0.1883, 1.4945],
        [0.3701, 1.4193],
        [0.2607, 1.3759]], grad_fn=<MmBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4744, 0.5256, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3029, 0.3433, 0.3539, 0.0000, 0.0000, 0.0000],
        [0.2553, 0.2492, 0.2529, 0.2426, 0.0000, 0.0000],
        [0.1134, 0.1870, 0.1915, 0.2028, 0.3053, 0.0000],
        [0.2000, 0.1578, 0.1601, 0.1462, 0.2088, 0.1271]],
       grad_fn=<SoftmaxBackward0>)


In [8]:
#  implementing trainable weights with backprop

In [9]:
import torch.nn as nn

In [14]:
class SelfAttention(nn.Module):
  def __init__(self, d_model, d_k):
    super().__init__()
    self.W_q = nn.Linear(d_model, d_k, bias=False)
    self.W_k = nn.Linear(d_model, d_k, bias=False)
    self.W_v = nn.Linear(d_model, d_k, bias=False)

  def forward(self, X):
    seq_len = X.shape[0]

    Q = self.W_q(X)
    K = self.W_k(X)
    V = self.W_v(X)

    scores = (Q @ K.T) / (d_k ** 0.5)

    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    scores = scores.masked_fill(mask ==1 , float('-inf'))

    attn_weights = F.softmax(scores, dim =-1)
    output = attn_weights @ V

    return output, attn_weights



In [15]:
d_model = 3
d_k = 2
attn = SelfAttention(d_model, d_k)

target = inputs.clone().detach()

In [16]:
optimizer = torch.optim.Adam(attn.parameters(), lr=0.01)

In [21]:
for step in range(10):
  optimizer.zero_grad()

  output, _ = attn(inputs)
  loss = F.mse_loss(output, target[:, :2])

  loss.backward()
  optimizer.step()



In [22]:
# multi head attenation

In [28]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()

    assert d_model % num_heads == 0

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_head = d_model // num_heads
    # Fix: The qkv_proj should project to 3 * d_model to account for Q, K, V concatenated
    self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
    self.out_proj = nn.Linear(d_model, d_model, bias=False)

  def forward(self, X):
    seq_len = X.shape[0]

    qkv = self.qkv_proj(X) # Now qkv will have shape (seq_len, 3 * d_model)
    Q, K, V = qkv.chunk(3, dim =-1) # Each of Q, K, V will have shape (seq_len, d_model)

    # Reshape Q, K, V to (seq_len, num_heads, d_head)
    # Since d_model = num_heads * d_head, reshaping (seq_len, d_model) to (seq_len, num_heads, d_head) is valid
    Q = Q.view(seq_len, self.num_heads, self.d_head)
    K = K.view(seq_len, self.num_heads, self.d_head)
    V = V.view(seq_len, self.num_heads, self.d_head)

    Q = Q.transpose(0,1)
    K = K.transpose(0,1)
    V = V.transpose(0,1)

    scores = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)

    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    scores = scores.masked_fill(mask == 1, float('-inf'))

    attn = F.softmax(scores, dim =-1)
    out = attn @ V

    out = out.transpose(0,1).contiguous()
    out = out.view(seq_len, self.d_model)

    return self.out_proj(out)

In [29]:
# Input embeddings
X = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]],
    dtype=torch.float32
)

# Model
d_model = 3
num_heads = 1  # must divide d_model
mhsa = MultiHeadSelfAttention(d_model, num_heads)

# Forward pass
output = mhsa(X)
print(output.shape)  # (6, 3)


torch.Size([6, 3])
