In [1]:
import torch
import torch.nn as nn

In [2]:
# First I'll implement the self-attention block then i'll implement the Multi-head-attention as a wrapper on top of a self-attention
# I'll use B = batch_dim, d_model = number of embedding dimensions, T = input sequence length, d_k, d_v as dimensions of key and values matrix

# Symbols
# X --> Input embeddings = embeddings + positional embeddings --> (B, T, d_model)
# W_Q --> Query projection --> (d_model, h * d_k) [assuming h number of heads]
# W_K --> Key projection --> (d_model, h * d_k) [assuming h number of heads]
# W_V --> Value projection --> (d_model, h * d_v) [assuming h number of heads]
# Q, K, V --> (B, h, d_model)


In [136]:
class DummySelfAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, debugger: bool = False):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.debugger = debugger

        self.W_Q = torch.nn.Linear(d_model, d_k, bias = False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias = False)
        self.W_V = torch.nn.Linear(d_model, d_v, bias = False)
    
    def forward(self, X: torch.Tensor):
        B, T, d_model = X.shape
        assert self.d_model == d_model, "self.d_model must be equal to d_model from x.shape"
        # X.shape = (B, T, d_model)
        Q = self.W_Q(X) # Q.shape = (B, T, d_k)
        K = self.W_K(X) # K.shape = (B, T, d_k)
        V = self.W_V(X) # V.shape = (B, T, d_v)
        S = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5) # S.shape = (B, T, T)
        A = torch.softmax(S, dim = -1)
        Z = A @ V # Z.shape = (B, T, d_v)
        if self.debugger:
            print(f"X.shape  = (B, T, d_model) = {X.shape}")
            print(f"B = {B}")
            print(f"T = {T}")
            print(f"d_model = {d_model}")
            print("-" * 100)
            print(f"Q.shape = {Q.shape}")
            print(f"K.shape = {K.shape}")
            print(f"V.shape = {V.shape}")
            print(f"S = Q @ K.T")
            print(f"A = softmax(S) | Attention must normalize over keys.")
            print(f"Z = A @ V")
            print(f"so, Z = (softmax(Q @ K.T)/d_k ** 0.5) @ V")
            print(A.sum(dim = -1))
            print(f"S.shape = {S.shape}")
            print(f"A.shape = {A.shape}")
            print(f"Z.shape = {Z.shape}")
        
        return Z
        

In [139]:
dummy_d_model = 4
dummy_d_k = 3
dummy_d_v = 5
dummy_seq_length = 20
dummy_batch_size = 1
X = torch.rand(dummy_batch_size, dummy_seq_length, dummy_d_model)
dummy_sf = DummySelfAttention(d_model = dummy_d_model, d_k = dummy_d_k, d_v = dummy_d_v, debugger = True)
dummy_sf(X)

X.shape  = (B, T, d_model) = torch.Size([1, 20, 4])
B = 1
T = 20
d_model = 4
----------------------------------------------------------------------------------------------------
Q.shape = torch.Size([1, 20, 3])
K.shape = torch.Size([1, 20, 3])
V.shape = torch.Size([1, 20, 5])
S = Q @ K.T
A = softmax(S) | Attention must normalize over keys.
Z = A @ V
so, Z = (softmax(Q @ K.T)/d_k ** 0.5) @ V
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000]], grad_fn=<SumBackward1>)
S.shape = torch.Size([1, 20, 20])
A.shape = torch.Size([1, 20, 20])
Z.shape = torch.Size([1, 20, 5])


tensor([[[-0.5317,  0.2370, -0.7122,  0.0348, -0.0190],
         [-0.5251,  0.2359, -0.7033,  0.0320, -0.0138],
         [-0.5250,  0.2336, -0.7029,  0.0324, -0.0150],
         [-0.5316,  0.2372, -0.7121,  0.0347, -0.0188],
         [-0.5253,  0.2363, -0.7036,  0.0320, -0.0137],
         [-0.5183,  0.2320, -0.6938,  0.0299, -0.0104],
         [-0.5209,  0.2328, -0.6973,  0.0308, -0.0120],
         [-0.5212,  0.2346, -0.6980,  0.0305, -0.0112],
         [-0.5255,  0.2347, -0.7037,  0.0324, -0.0149],
         [-0.5253,  0.2362, -0.7036,  0.0321, -0.0139],
         [-0.5228,  0.2332, -0.7000,  0.0316, -0.0135],
         [-0.5259,  0.2351, -0.7043,  0.0324, -0.0148],
         [-0.5277,  0.2359, -0.7067,  0.0332, -0.0161],
         [-0.5219,  0.2332, -0.6987,  0.0311, -0.0126],
         [-0.5245,  0.2354, -0.7023,  0.0319, -0.0136],
         [-0.5251,  0.2342, -0.7031,  0.0325, -0.0150],
         [-0.5239,  0.2350, -0.7016,  0.0317, -0.0133],
         [-0.5300,  0.2374, -0.7100,  0.0339, -0

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d_model:int , d_k: int, d_v: int):
        """
        > Implements single-head scaled dot-product self-attention.

        > This class is only responsible for Self-Attention Mechanism and not for transforming the Attention's Output's Value space to d_model's space
        > Explanation: Self-Attention's output lives in d_v vector space,
                       this class is not responsible for projecting output back to d_model's vector space
        """
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        # Initializing Learned Linear Projection Matrix
        # The original Transformer paper uses no bias for Q/K/V, so I'm also not using them by setting bias = False
        self.W_Q = nn.Linear(d_model, d_k, bias = False)
        self.W_K = nn.Linear(d_model, d_k, bias = False)
        self.W_V = nn.Linear(d_model, d_v, bias = False)
    
    def forward(self, x):
        _, _, d_model = x.shape # x.shape = (B, T, d_model)
        assert self.d_model == d_model, "self.d_model must be equal to d_model from x.shape"
        # Calculate Query, Key and Value Matrix
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        # Get the Updated Information and update each token's representation
        S = (Q @ K.transpose(-2, -1) / (self.d_k ** 0.5)) # to ensure that standard deviation is consistent after transformation 
        A = torch.softmax(S, dim = -1)
        Z = A @ V # Right now, Z.shape != x.shape because attention outputs are in value space we project Z back into,
                  # d_model so residual connections and multiple Transformer layers can be stacked
        return Z

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int = 8):
        """
        This class is responsible for Performing the Multi-head-Attention and,
        this will also project back into d_model with correct basis vector system
        the Architecture's Output is also in d_model's vector space but in wrong basis vector because of concatenation
        so we need to Project the output back to d_model to fix the vecto's basis system
        """
        super().__init__()
        assert d_model % h == 0, "self.d_model must be a multiple of h"
        self.d_model = d_model
        self.h = h
        self.d_k = self.d_model // self.h
        self.d_v = self.d_model // self.h

        self.W_Q = nn.Linear(self.d_model, self.h * self.d_k, bias = False)
        self.W_K = nn.Linear(self.d_model, self.h * self.d_k, bias = False)
        self.W_V = nn.Linear(self.d_model, self.h * self.d_v, bias = False)
        self.W_O = nn.Linear(self.d_model, self.d_model, bias = False) # This is the output Projection 
        # The output of This Multi-head Attention is in d_model space we need an output projection layer so that it must learn,
        # the importance of Each head and how each head are relevent to each other
        # It is required to enable Stacking, Residual Conncetions
        # mathematically,
                # The Dimensions are numerically equal but the basis systems of vectors are different
                # changes the basis,  not the dimensionality.
    
    def forward(self, x):
        B, T, _ = x.shape
        Q = self.W_Q(x).view(B, T, self.h, self.d_k).transpose(1, 2) # We split this and transpose this --> shape: [B, T, h, d_k]
        K = self.W_K(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        # We split this and transpose this
        # After Splitting --> shape: [B, T, h, d_k]
        # Now Transpose: --> shape: [B, h, T, d_k] --> each head now have its' completely own sequence
        print("Q.shape", Q.shape)
        print("K.shape", K.shape)
        print("V.shape", V.shape)

        S = (Q @ K.transpose(-1, -2))/(self.d_k ** 0.5) # or we can do, K.transpose(-2, -1)
        S = torch.softmax(S, dim = -1)
        
        # Update the information of each of the tokens for each of the heads
        # Z.shape = (B, h, T, d_v)
        Z = (S @ V) # This will update the information and this will project into the d_v = d_k = d_model/h dimentional space

        # Now we concatenate this to get the Representation into our embedding space i.e. d_model's Vector Space
        print("Before concatenation: output.Shape = ", Z.shape)
        Z = Z.transpose(1, 2).contiguous().view(B, T, self.d_model)
        Z = self.W_O(Z)
        print("After concatenation: output.Shape = ", Z.shape)
        # we want to view this as, (B, T, d_model)
        return Z

In [209]:
MultiHeadAttention(d_model = 64, h = 8)(torch.rand(2, 10, 64))

Q.shape torch.Size([2, 8, 10, 8])
K.shape torch.Size([2, 8, 10, 8])
V.shape torch.Size([2, 8, 10, 8])
Before concatenation: output.Shape =  torch.Size([2, 8, 10, 8])
After concatenation: output.Shape =  torch.Size([2, 10, 64])


tensor([[[-0.0062, -0.1860,  0.0581,  ..., -0.0554,  0.4067,  0.1713],
         [-0.0062, -0.1853,  0.0566,  ..., -0.0565,  0.4077,  0.1714],
         [-0.0083, -0.1852,  0.0561,  ..., -0.0572,  0.4082,  0.1701],
         ...,
         [-0.0078, -0.1853,  0.0570,  ..., -0.0569,  0.4103,  0.1712],
         [-0.0055, -0.1842,  0.0575,  ..., -0.0578,  0.4099,  0.1710],
         [-0.0065, -0.1850,  0.0569,  ..., -0.0566,  0.4092,  0.1714]],

        [[ 0.0648, -0.1600,  0.0764,  ...,  0.0227,  0.3762,  0.1466],
         [ 0.0635, -0.1603,  0.0777,  ...,  0.0238,  0.3758,  0.1477],
         [ 0.0654, -0.1598,  0.0779,  ...,  0.0246,  0.3755,  0.1452],
         ...,
         [ 0.0639, -0.1601,  0.0773,  ...,  0.0246,  0.3746,  0.1472],
         [ 0.0646, -0.1603,  0.0773,  ...,  0.0242,  0.3767,  0.1477],
         [ 0.0644, -0.1595,  0.0768,  ...,  0.0244,  0.3746,  0.1464]]],
       grad_fn=<UnsafeViewBackward0>)

torch.Size([2, 10, 64]) --> we have 2 examples of each 10, 64 --> meaning 2 examples having seq_length = 10 and each char is represented as 64-Dimensional vector

torch.Size([2, 10, 8, 8]) --> we have 2 examples of each containing seq_length = 10 having 8 different heads and each heads are 8-Dimensional vectors

torch.Size([2, 8, 10, 8]) ---> we have 2 examples of each containing 8 heads and each head will be seq_length = 10 long and each will opearate on 8-D vectors

In [161]:
SelfAttention(d_model = 4, d_k = 3, d_v = 5)(X)

tensor([[[-0.2190,  0.2350,  0.4183,  0.1907,  0.0375],
         [-0.2196,  0.2394,  0.4180,  0.1888,  0.0397],
         [-0.2183,  0.2350,  0.4203,  0.1861,  0.0359],
         [-0.2192,  0.2357,  0.4179,  0.1913,  0.0382],
         [-0.2197,  0.2405,  0.4182,  0.1876,  0.0400],
         [-0.2191,  0.2363,  0.4186,  0.1894,  0.0379],
         [-0.2189,  0.2374,  0.4195,  0.1865,  0.0376],
         [-0.2196,  0.2393,  0.4181,  0.1886,  0.0396],
         [-0.2187,  0.2351,  0.4191,  0.1888,  0.0369],
         [-0.2200,  0.2389,  0.4169,  0.1921,  0.0405],
         [-0.2190,  0.2368,  0.4192,  0.1877,  0.0377],
         [-0.2188,  0.2393,  0.4203,  0.1833,  0.0377],
         [-0.2193,  0.2367,  0.4181,  0.1902,  0.0384],
         [-0.2189,  0.2370,  0.4193,  0.1871,  0.0376],
         [-0.2197,  0.2371,  0.4172,  0.1924,  0.0394],
         [-0.2190,  0.2355,  0.4186,  0.1898,  0.0375],
         [-0.2195,  0.2386,  0.4183,  0.1887,  0.0392],
         [-0.2193,  0.2387,  0.4187,  0.1874,  0