In [1]:
import torch
import torch.nn as nn

In [2]:
# First I'll implement the self-attention block then i'll implement the Multi-head-attention as a wrapper on top of a self-attention
# I'll use B = batch_dim, d_model = number of embedding dimensions, T = input sequence length, d_k, d_v as dimensions of key and values matrix

# Symbols
# X --> Input embeddings = embeddings + positional embeddings --> (B, T, d_model)
# W_Q --> Query projection --> (d_model, h * d_k) [assuming h number of heads]
# W_K --> Key projection --> (d_model, h * d_k) [assuming h number of heads]
# W_V --> Value projection --> (d_model, h * d_v) [assuming h number of heads]
# Q, K, V --> (B, h, d_model)


In [3]:
class DummySelfAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, debugger: bool = False):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.debugger = debugger

        self.W_Q = torch.nn.Linear(d_model, d_k, bias = False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias = False)
        self.W_V = torch.nn.Linear(d_model, d_v, bias = False)
    
    def forward(self, X: torch.Tensor):
        B, T, d_model = X.shape
        assert self.d_model == d_model, "self.d_model must be equal to d_model from x.shape"
        # X.shape = (B, T, d_model)
        Q = self.W_Q(X) # Q.shape = (B, T, d_k)
        K = self.W_K(X) # K.shape = (B, T, d_k)
        V = self.W_V(X) # V.shape = (B, T, d_v)
        S = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5) # S.shape = (B, T, T)
        A = torch.softmax(S, dim = -1)
        Z = A @ V # Z.shape = (B, T, d_v)
        if self.debugger:
            print(f"X.shape  = (B, T, d_model) = {X.shape}")
            print(f"B = {B}")
            print(f"T = {T}")
            print(f"d_model = {d_model}")
            print("-" * 100)
            print(f"Q.shape = {Q.shape}")
            print(f"K.shape = {K.shape}")
            print(f"V.shape = {V.shape}")
            print(f"S = Q @ K.T")
            print(f"A = softmax(S) | Attention must normalize over keys.")
            print(f"Z = A @ V")
            print(f"so, Z = (softmax(Q @ K.T)/d_k ** 0.5) @ V")
            print(A.sum(dim = -1))
            print(f"S.shape = {S.shape}")
            print(f"A.shape = {A.shape}")
            print(f"Z.shape = {Z.shape}")
        
        return Z
        

In [4]:
dummy_d_model = 4
dummy_d_k = 3
dummy_d_v = 5
dummy_seq_length = 20
dummy_batch_size = 1
X = torch.rand(dummy_batch_size, dummy_seq_length, dummy_d_model)
dummy_sf = DummySelfAttention(d_model = dummy_d_model, d_k = dummy_d_k, d_v = dummy_d_v, debugger = True)
dummy_sf(X)

X.shape  = (B, T, d_model) = torch.Size([1, 20, 4])
B = 1
T = 20
d_model = 4
----------------------------------------------------------------------------------------------------
Q.shape = torch.Size([1, 20, 3])
K.shape = torch.Size([1, 20, 3])
V.shape = torch.Size([1, 20, 5])
S = Q @ K.T
A = softmax(S) | Attention must normalize over keys.
Z = A @ V
so, Z = (softmax(Q @ K.T)/d_k ** 0.5) @ V
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000]], grad_fn=<SumBackward1>)
S.shape = torch.Size([1, 20, 20])
A.shape = torch.Size([1, 20, 20])
Z.shape = torch.Size([1, 20, 5])


tensor([[[ 0.3819,  0.0240,  0.1881,  0.4650, -0.6376],
         [ 0.3804,  0.0230,  0.1929,  0.4665, -0.6356],
         [ 0.3816,  0.0251,  0.1833,  0.4621, -0.6370],
         [ 0.3817,  0.0222,  0.1955,  0.4691, -0.6373],
         [ 0.3826,  0.0232,  0.1898,  0.4669, -0.6381],
         [ 0.3840,  0.0244,  0.1858,  0.4654, -0.6403],
         [ 0.3793,  0.0232,  0.1918,  0.4651, -0.6340],
         [ 0.3780,  0.0238,  0.1878,  0.4621, -0.6316],
         [ 0.3833,  0.0229,  0.1913,  0.4682, -0.6392],
         [ 0.3800,  0.0234,  0.1880,  0.4640, -0.6342],
         [ 0.3828,  0.0239,  0.1888,  0.4661, -0.6389],
         [ 0.3800,  0.0236,  0.1910,  0.4650, -0.6352],
         [ 0.3789,  0.0247,  0.1841,  0.4608, -0.6330],
         [ 0.3799,  0.0238,  0.1863,  0.4630, -0.6340],
         [ 0.3788,  0.0231,  0.1901,  0.4642, -0.6327],
         [ 0.3815,  0.0242,  0.1867,  0.4640, -0.6369],
         [ 0.3822,  0.0233,  0.1894,  0.4663, -0.6376],
         [ 0.3823,  0.0229,  0.1912,  0.4674, -0

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, d_model:int , d_k: int, d_v: int):
        """
        > Implements single-head scaled dot-product self-attention.

        > This class is only responsible for Self-Attention Mechanism and not for transforming the Attention's Output's Value space to d_model's space
        > Explanation: Self-Attention's output lives in d_v vector space,
                       this class is not responsible for projecting output back to d_model's vector space
        """
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        # Initializing Learned Linear Projection Matrix
        # The original Transformer paper uses no bias for Q/K/V, so I'm also not using them by setting bias = False
        self.W_Q = nn.Linear(d_model, d_k, bias = False)
        self.W_K = nn.Linear(d_model, d_k, bias = False)
        self.W_V = nn.Linear(d_model, d_v, bias = False)
    
    def forward(self, x):
        _, _, d_model = x.shape # x.shape = (B, T, d_model)
        assert self.d_model == d_model, "self.d_model must be equal to d_model from x.shape"
        # Calculate Query, Key and Value Matrix
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        # Get the Updated Information and update each token's representation
        S = (Q @ K.transpose(-2, -1) / (self.d_k ** 0.5)) # to ensure that standard deviation is consistent after transformation 
        A = torch.softmax(S, dim = -1)
        Z = A @ V # Right now, Z.shape != x.shape because attention outputs are in value space we project Z back into,
                  # d_model so residual connections and multiple Transformer layers can be stacked
        return Z

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        """
        This class is responsible for Performing the Multi-head-Attention and,
        this will also project back into d_model with correct basis vector system
        the Architecture's Output is also in d_model's vector space but in wrong basis vector because of concatenation
        so we need to Project the output back to d_model to fix the vecto's basis system

        Also this supports the cross attention, so q, k, v are forwarded explicitly
        becaue in deoder q comes from hidden states, k and v comes from encoders output
        so: forward(self, q, k, v) --> supports the encoder-decoder attention
        so: forward(self, x) --> supports the self-MHA
        """
        super().__init__()
        assert d_model % h == 0, "self.d_model must be a multiple of h"
        self.d_model = d_model
        self.h = h
        self.d_k = self.d_model // self.h
        self.d_v = self.d_model // self.h
        self.dropout = nn.Dropout(dropout)

        self.W_Q = nn.Linear(self.d_model, self.h * self.d_k, bias = False)
        self.W_K = nn.Linear(self.d_model, self.h * self.d_k, bias = False)
        self.W_V = nn.Linear(self.d_model, self.h * self.d_v, bias = False)
        self.W_O = nn.Linear(self.d_model, self.d_model, bias = False) # This is the output Projection 
        # The output of This Multi-head Attention is in d_model space we need an output projection layer so that it must learn,
        # the importance of Each head and how each head are relevent to each other
        # It is required to enable Stacking, Residual Conncetions
        # mathematically,
                # The Dimensions are numerically equal but the basis systems of vectors are different
                # changes the basis,  not the dimensionality.
        # W_O does three things simultaneously:
            # Mixes information across heads
            # Learns head importance
            # Re-embeds into a shared semantic space
    
    def forward_selfMHA(self, x):
        """
        This forward method is not acceptable for cross-attention (encoder-decoder transformers)
        because cross-attention expects,
        # q from decoders hidden's state
        # k, v from encoders output
        """
        B, T, _ = x.shape
        Q = self.W_Q(x).view(B, T, self.h, self.d_k).transpose(1, 2) # We split this and transpose this --> shape: [B, h, T, d_k]
        K = self.W_K(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        # We split this and transpose this
        # After Splitting --> shape: [B, T, h, d_k]
        # Now Transpose: --> shape: [B, h, T, d_k] --> each head now have its' completely own sequence
        print("Q.shape", Q.shape)
        print("K.shape", K.shape)
        print("V.shape", V.shape)

        S = (Q @ K.transpose(-1, -2))/(self.d_k ** 0.5) # or we can do, K.transpose(-2, -1)
        S = torch.softmax(S, dim = -1)
        
        # Update the information of each of the tokens for each of the heads
        # Z.shape = (B, h, T, d_v)
        Z = (S @ V) # This will update the information and this will project into the d_v = d_k = d_model/h dimentional space

        # Now we concatenate this to get the Representation into our embedding space i.e. d_model's Vector Space
        print("Before concatenation: output.Shape = ", Z.shape)
        Z = Z.transpose(1, 2).contiguous().view(B, T, self.d_model)
        Z = self.W_O(Z) # This is the output projection we need this, compulsorily
        print("After concatenation: output.Shape = ", Z.shape)
        # we want to view this as, (B, T, d_model)
        return Z
    
    def forward(self, q, k, v, mask = None):
        # Shape of q: (B, T, d_model) --> (B, T, h, d_model/h) --> --> (B, h, T, d_model/h)
        B, T, _ = q.shape
        query = self.W_Q(q).view(B, T, self.h, self.d_k).transpose(1, 2)
        key = self.W_K(k).view(B, T, self.h, self.d_k).transpose(1, 2)
        value = self.W_V(v).view(B, T, self.h, self.d_v).transpose(1, 2) # (B, h, T, d_model/h)

        S = ((query @ key.transpose(-2, -1)) / (self.d_k ** 0.5)) # S.Shape: (B, h, T, T)
        if mask is not None:
            S = S.masked_fill(mask == 0, -1e9)
        S = torch.softmax(S, dim = -1)
        S = self.dropout(S)
        Z = S @ value
        # (B, h, T, T) @ (B, h, T, d_model/h) = (B, h, T, d_model/h) --> (B, T, h, d_model/h)
        return self.dropout(self.W_O(Z.transpose(1, 2).contiguous().view(B, T, -1)))


In [7]:
x = torch.rand(2, 10, 64)
MultiHeadAttention(d_model = 64, h = 8, dropout = 0.1)(x, x, x, mask = None)

tensor([[[-0.0567, -0.0278,  0.0725,  ...,  0.0000, -0.0933,  0.0224],
         [-0.0315, -0.0000,  0.0219,  ...,  0.1418, -0.1089,  0.0003],
         [ 0.0059, -0.0657,  0.0499,  ...,  0.0000, -0.1144,  0.0086],
         ...,
         [-0.0141, -0.0465, -0.0024,  ...,  0.0966, -0.1304, -0.0165],
         [-0.0496, -0.0172,  0.0437,  ...,  0.1461, -0.1158,  0.0000],
         [ 0.0021, -0.0493,  0.0219,  ...,  0.1515, -0.0959,  0.0083]],

        [[-0.1196, -0.0000,  0.0462,  ...,  0.1545, -0.1757,  0.0294],
         [-0.1239, -0.0412,  0.0602,  ...,  0.1473, -0.1611,  0.0597],
         [-0.0890, -0.0603,  0.0585,  ...,  0.0865, -0.1498,  0.0127],
         ...,
         [-0.0826, -0.0387,  0.0564,  ...,  0.1146, -0.1289,  0.0003],
         [-0.1009, -0.0204,  0.0352,  ...,  0.1145, -0.1133,  0.0081],
         [-0.1303, -0.0231,  0.0386,  ...,  0.1276, -0.1231, -0.0049]]],
       grad_fn=<MulBackward0>)

torch.Size([2, 10, 64]) --> we have 2 examples of each 10, 64 --> meaning 2 examples having seq_length = 10 and each char is represented as 64-Dimensional vector

torch.Size([2, 10, 8, 8]) --> we have 2 examples of each containing seq_length = 10 having 8 different heads and each heads are 8-Dimensional vectors

torch.Size([2, 8, 10, 8]) ---> we have 2 examples of each containing 8 heads and each head will be seq_length = 10 long and each will opearate on 8-D vectors

In [8]:
SelfAttention(d_model = 4, d_k = 3, d_v = 5)(X)

tensor([[[ 0.2105, -0.0680, -0.0706,  0.0245, -0.5060],
         [ 0.2104, -0.0675, -0.0711,  0.0248, -0.5056],
         [ 0.2104, -0.0677, -0.0705,  0.0246, -0.5059],
         [ 0.2108, -0.0679, -0.0717,  0.0250, -0.5054],
         [ 0.2107, -0.0683, -0.0705,  0.0243, -0.5062],
         [ 0.2111, -0.0681, -0.0725,  0.0255, -0.5050],
         [ 0.2103, -0.0664, -0.0732,  0.0263, -0.5043],
         [ 0.2094, -0.0671, -0.0685,  0.0236, -0.5070],
         [ 0.2111, -0.0681, -0.0725,  0.0254, -0.5050],
         [ 0.2104, -0.0664, -0.0734,  0.0264, -0.5042],
         [ 0.2110, -0.0676, -0.0732,  0.0260, -0.5045],
         [ 0.2103, -0.0670, -0.0719,  0.0254, -0.5051],
         [ 0.2097, -0.0669, -0.0697,  0.0244, -0.5062],
         [ 0.2102, -0.0665, -0.0725,  0.0259, -0.5046],
         [ 0.2099, -0.0668, -0.0708,  0.0248, -0.5057],
         [ 0.2104, -0.0679, -0.0701,  0.0243, -0.5062],
         [ 0.2109, -0.0674, -0.0732,  0.0260, -0.5045],
         [ 0.2108, -0.0680, -0.0716,  0.0250, -0

### Now that i've implemented the Transformer Architecture in src/transformer

- Let's see an Inference time optimization technique called as KV Cache

In [9]:
text = open("tiny_shakespeare.txt", 'r', encoding='utf-8').read()
input_text = "Hi, I'm Himanshu Sing"

chars = sorted(list(set(text)))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
vocab_size = len(chars)
d_model = 2

# Let's Create an Emebedding
embeddings = torch.rand(vocab_size, d_model)

In [10]:
x_idx = torch.tensor([stoi[x] for x in input_text]).unsqueeze(dim = 0)

In [11]:
x_enc = embeddings[x_idx]

In [12]:
x_enc

tensor([[[0.8190, 0.0844],
         [0.6219, 0.3023],
         [0.5767, 0.7185],
         [0.0599, 0.1221],
         [0.3400, 0.0207],
         [0.4904, 0.4405],
         [0.6588, 0.4593],
         [0.0599, 0.1221],
         [0.8190, 0.0844],
         [0.6219, 0.3023],
         [0.6588, 0.4593],
         [0.0848, 0.1006],
         [0.6150, 0.9959],
         [0.3872, 0.2297],
         [0.2240, 0.5658],
         [0.2021, 0.2924],
         [0.0599, 0.1221],
         [0.2628, 0.3887],
         [0.6219, 0.3023],
         [0.6150, 0.9959],
         [0.7015, 0.8609]]])

In [14]:
multi_head = MultiHeadAttention(d_model = d_model, h = 2, dropout = 0.1)
output = multi_head(x_enc, x_enc, x_enc, mask = None)

In [15]:
output

tensor([[[ 0.1882, -0.0909],
         [ 0.2497, -0.1380],
         [ 0.2690, -0.1536],
         [ 0.2712, -0.1551],
         [ 0.2561, -0.1431],
         [ 0.2238, -0.1202],
         [ 0.2396, -0.1343],
         [ 0.2431, -0.1425],
         [ 0.2376, -0.1347],
         [ 0.0000, -0.1095],
         [ 0.0000, -0.1369],
         [ 0.0000, -0.1517],
         [ 0.2649, -0.1569],
         [ 0.0000, -0.1494],
         [ 0.0000, -0.1077],
         [ 0.2175, -0.1115],
         [ 0.2347, -0.1369],
         [ 0.2684, -0.1530],
         [ 0.2368, -0.1449],
         [ 0.2339, -0.1362],
         [ 0.2294, -0.1240]]], grad_fn=<MulBackward0>)

In [59]:
char_1 = "h"
char_2 = "hi"
char_3 = "him"
char_4 = "hima"
char_5 = "himan"
char_6 = "himans"
char_7 = "himansh"
char_8 = "himanshu"
scores = {}

# For char = h
char_enc = embeddings[[stoi[x] for x in char_1]]
query_h = multi_head.W_Q(char_enc[-1])
key_h = multi_head.W_K(char_enc[-1])
value_h = multi_head.W_V(char_enc[-1])
scores['for_h'] = {}
scores['for_h']['query_h'] = query_h
scores['for_h']['key_h'] = key_h
scores['for_h']['value_h'] = value_h

# For char = i
char_enc = embeddings[[stoi[x] for x in char_2]]
query_i = multi_head.W_Q(char_enc[-1])
key_i = multi_head.W_K(char_enc[-1])
value_i = multi_head.W_V(char_enc[-1])
scores['for_i'] = {}
scores['for_i']['query_i'] = query_i
scores['for_i']['key_i'] = key_i
scores['for_i']['value_i'] = value_i

# For char = m
char_enc = embeddings[[stoi[x] for x in char_3]]
query_m = multi_head.W_Q(char_enc[-1])
key_m = multi_head.W_K(char_enc[-1])
value_m = multi_head.W_V(char_enc[-1])
scores['for_m'] = {}
scores['for_m']['query_m'] = query_m
scores['for_m']['key_m'] = key_m
scores['for_m']['value_m'] = value_m

scores


# -------------- For char_3 | char_3 = "him" -----------------------
x_enc_3 = embeddings[[stoi[x] for x in char_3]]
x_enc_3
h_enc = x_enc_3[0]
i_enc = x_enc_3[1]
m_enc = x_enc_3[2]
dict_scores_char_3 = {}
dict_scores_char_3['h_enc'] = {}
dict_scores_char_3['h_enc']['query_h'] = {}
dict_scores_char_3['h_enc']['key_h'] = {}
dict_scores_char_3['h_enc']['value_h'] = {}
query_h = multi_head.W_Q(h_enc)
key_h = multi_head.W_K(h_enc)
value_h = multi_head.W_V(h_enc)
dict_scores_char_3['h_enc']['query_h'] = query_h
dict_scores_char_3['h_enc']['key_h'] = key_h
dict_scores_char_3['h_enc']['value_h'] = value_h

dict_scores_char_3['i_enc'] = {}
dict_scores_char_3['i_enc']['query_i'] = {}
dict_scores_char_3['i_enc']['key_i'] = {}
dict_scores_char_3['i_enc']['value_i'] = {}
query_i = multi_head.W_Q(i_enc)
key_i = multi_head.W_K(i_enc)
value_i = multi_head.W_V(i_enc)
dict_scores_char_3['i_enc']['query_i'] = query_i
dict_scores_char_3['i_enc']['key_i'] = key_i
dict_scores_char_3['i_enc']['value_i'] = value_i

dict_scores_char_3['m_enc'] = {}
dict_scores_char_3['m_enc']['query_m'] = {}
dict_scores_char_3['m_enc']['key_m'] = {}
dict_scores_char_3['m_enc']['value_m'] = {}
query_m = multi_head.W_Q(m_enc)
key_m = multi_head.W_K(m_enc)
value_m = multi_head.W_V(m_enc)
dict_scores_char_3['m_enc']['query_m'] = query_m
dict_scores_char_3['m_enc']['key_m'] = key_m
dict_scores_char_3['m_enc']['value_m'] = value_m

dict_scores_char_3

{'h_enc': {'query_h': tensor([-0.1649, -0.3664], grad_fn=<SqueezeBackward4>),
  'key_h': tensor([-0.4249,  0.3146], grad_fn=<SqueezeBackward4>),
  'value_h': tensor([-0.0890,  0.1684], grad_fn=<SqueezeBackward4>)},
 'i_enc': {'query_i': tensor([-0.0372, -0.3314], grad_fn=<SqueezeBackward4>),
  'key_i': tensor([-0.3966,  0.0626], grad_fn=<SqueezeBackward4>),
  'value_i': tensor([0.2937, 0.3658], grad_fn=<SqueezeBackward4>)},
 'm_enc': {'query_m': tensor([-0.0855, -0.4262], grad_fn=<SqueezeBackward4>),
  'key_m': tensor([-0.5059,  0.1552], grad_fn=<SqueezeBackward4>),
  'value_m': tensor([0.2518, 0.3986], grad_fn=<SqueezeBackward4>)}}

In [57]:
scores

{'for_h': {'query_h': tensor([-0.1649, -0.3664], grad_fn=<SqueezeBackward4>),
  'key_h': tensor([-0.4249,  0.3146], grad_fn=<SqueezeBackward4>),
  'value_h': tensor([-0.0890,  0.1684], grad_fn=<SqueezeBackward4>)},
 'for_i': {'query_i': tensor([-0.0372, -0.3314], grad_fn=<SqueezeBackward4>),
  'key_i': tensor([-0.3966,  0.0626], grad_fn=<SqueezeBackward4>),
  'value_i': tensor([0.2937, 0.3658], grad_fn=<SqueezeBackward4>)},
 'for_m': {'query_m': tensor([-0.0855, -0.4262], grad_fn=<SqueezeBackward4>),
  'key_m': tensor([-0.5059,  0.1552], grad_fn=<SqueezeBackward4>),
  'value_m': tensor([0.2518, 0.3986], grad_fn=<SqueezeBackward4>)}}

In [23]:
query_h, key_h, value_h

(tensor([ 0.0551, -0.2668], grad_fn=<SqueezeBackward4>),
 tensor([-0.3285, -0.1181], grad_fn=<SqueezeBackward4>),
 tensor([0.5205, 0.4565], grad_fn=<SqueezeBackward4>))