In [1]:
import torch.nn as nn
import torch

In [2]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out,context_length,dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
    

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5 , dim = -1
        )

        attn_weights = self.dropout(attn_weights)
        context_vector = attn_weights @ values

        return context_vector

In [3]:
class MultiAttentionWrapper(nn.Module):

    def __init__(self,d_in:int ,d_out:int, context_length, dropout:float,
                 num_heads:int,qkv_bias:bool=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out,context_length,dropout, qkv_bias)
             for _ in range(num_heads)]
        )
    

    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

In [4]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # your      (x^1)
     [0.55, 0.87, 0.66],  # journey   (x^2)
     [0.57, 0.85, 0.64],  # starts    (x^3)
     [0.22, 0.58, 0.34],  # with      (x^4)
     [0.77, 0.25, 0.10],  # one       (x^5)
     [0.05, 0.81, 0.55]]  # step      (x^6)
 )

In [5]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape)

torch.Size([2, 6, 3])


In [6]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in,d_out = 3,2

In [7]:
mha = MultiAttentionWrapper(d_in,d_out, context_length, dropout=0.1,num_heads=2)

context_vector = mha(batch)
context_vector

tensor([[[-0.5940, -0.1169,  0.4356,  0.2951],
         [-0.4542, -0.0819,  0.5659,  0.3906],
         [-0.5925, -0.1201,  0.4935,  0.3742],
         [-0.5896, -0.1197,  0.4859,  0.3689],
         [-0.5222, -0.0913,  0.5650,  0.3837],
         [-0.5083, -0.1602,  0.4977,  0.3310]],

        [[-0.4200, -0.1303,  0.5660,  0.3905],
         [-0.4537, -0.0820,  0.4941,  0.3238],
         [-0.5925, -0.1201,  0.4262,  0.3163],
         [-0.5896, -0.1197,  0.5648,  0.3865],
         [-0.5910, -0.1186,  0.5650,  0.3837],
         [-0.4547, -0.0829,  0.4926,  0.3217]]], grad_fn=<CatBackward0>)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0),\
            "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length,context_length),
                       diagonal=1)
        )

    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # we implicitly split the matrix by adding a num_heads dim
        # unroll last dim: (b,Num_tokens, d_out) -> (b, num_token, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # group matricies by num_head by transpose
        # (b, num_token, num_heads, head_dim) -> (b, num_heads, num_token,head_dim)
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        
        attn_scores = queries @ keys.transpose(2,3)


        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]

        attn_scores.masked_fill(mask_bool, -torch.inf)
        
        #apply softmax
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5 , dim = -1
        )

        # dropout layer
        attn_weights = self.dropout(attn_weights)

        # transpose attn_weights
        context_vec = (attn_weights @ values).transpose(1,2)

        #combine result obtain from multiple heads self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)


        return context_vec        

In [9]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.34, 0.87, 0.66],  # your      (x^1)
     [0.55, 0.87, 0.66, 0.22, 0.58, 0.33],  # journey   (x^2)
     [0.57, 0.85, 0.64, 0.05, 0.80, 0.55],  # starts    (x^3)
     [0.22, 0.58, 0.34, 0.45, 0.57, 0.34],  # with      (x^4)
     [0.77, 0.25, 0.10, 0.32, 0.05, 0.13],  # one       (x^5)
     [0.05, 0.81, 0.55, 0.78, 0.29, 0.15]]  # step      (x^6)
 )

In [10]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape)

torch.Size([2, 6, 6])


In [11]:
batch_size , context_length , d_in = batch.shape
d_out = 6

mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.1, num_heads=2)
context_vectors = mha(batch)

print(context_vectors)
print("context_vector shape:", context_vectors.shape)

tensor([[[-0.4720,  0.1788, -0.1759, -0.0175, -0.4530, -0.1445],
         [-0.4636,  0.2161, -0.1830, -0.0332, -0.4252, -0.1709],
         [-0.4987,  0.2123, -0.1469, -0.0131, -0.4898, -0.1171],
         [-0.4987,  0.2162, -0.1468, -0.0141, -0.4894, -0.1183],
         [-0.4617,  0.2538, -0.1623, -0.0280, -0.4483, -0.1747],
         [-0.4658,  0.1950, -0.1687, -0.0453, -0.4531, -0.1846]],

        [[-0.4603,  0.2349, -0.2065, -0.0163, -0.4055, -0.1467],
         [-0.4990,  0.2125, -0.1470, -0.0132, -0.4898, -0.1167],
         [-0.4987,  0.2123, -0.1469, -0.0131, -0.4898, -0.1171],
         [-0.4987,  0.2162, -0.1468, -0.0141, -0.4894, -0.1183],
         [-0.4689,  0.1997, -0.1567, -0.0247, -0.4651, -0.1644],
         [-0.4964,  0.2169, -0.1730, -0.0168, -0.4563, -0.1140]]],
       grad_fn=<ViewBackward0>)
context_vector shape: torch.Size([2, 6, 6])


### MULTI HEAD ATTENTION WITH WEIGHT SPLITS

In [30]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_in:int,d_out:int, num_of_heads:int,
                 context_length:int,dropout:float,
                 qkv_bias:bool = False):
        
        super().__init__()
        assert (d_out % num_of_heads == 0),\
            "d_out must be divisible by num_heads"        
        self.d_out = d_out
        self.num_of_heads  = num_of_heads
        self.head_dim = self.d_out // self.num_of_heads
        self.w_query = nn.Linear(d_in,d_in, bias = qkv_bias)
        self.w_key = nn.Linear(d_in,d_in,bias = qkv_bias)
        self.w_value = nn.Linear(d_in,d_in,bias = qkv_bias)
        self.proj = nn.Linear(d_out,d_out) # linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length,context_length),
                       diagonal=1)
        )
    
    def forward(self, x):
        b,num_of_token, d_in = x.shape
        
        # claculate query, keys value matrix
        query = self.w_query(x)
        keys = self.w_key(x)
        value = self.w_value(x)

        # transforming to shape (batch, num_of_token, d_in) ---> (batch, num_of_token, num_of_heads, head_dim)
        query = query.view(b, num_of_token, self.num_of_heads, self.head_dim)
        keys = keys.view(b, num_of_token, self.num_of_heads, self.head_dim)
        value = value.view(b, num_of_token, self.num_of_heads, self.head_dim)

        #shape Transform (batch, num_of_token, num_of_heads, head_dim) ---> (batch, num_of_heads,num_of_token, head_dim)
        query = query.transpose(1,2)
        keys = keys.transpose(1,2)
        value = value.transpose(1,2)

        #find attention scores
        attention_score = query @ keys.transpose(2,3)

        # create Mask
        mask_bool = self.mask.bool()[:num_of_token,:num_of_token]

        # apply mask
        attention_score.masked_fill(mask_bool,-torch.inf)

        #caluclate attention weight
        attention_weight = torch.softmax(
            attention_score / keys.shape[-1]**0.5,
            dim = -1
        )

        #shape : (b,num_of_heads ,num_of_token,head_dim) ---> (b,num_of_token,num_of_heads,head_dim)
        context_vector =  (attention_weight @ value).transpose(1,2)
        
        # combine all heads
        context_vector = context_vector.contiguous().view(b,num_of_token,self.d_out)

        #output projection
        context_vector = self.proj(context_vector)

        #return context vector after dropout
        return self.dropout(context_vector)

In [27]:
batch_size, context_length,d_in = batch.shape

In [31]:
mha = MultiHeadAttention(d_in,d_out, num_of_heads=2,
                         context_length= 6, 
                         dropout=0.1, qkv_bias=False)

In [32]:
context_vector = mha(batch)
context_vector

tensor([[[ 0.0287,  0.0996,  0.5098, -0.4397, -0.1321,  0.0645],
         [ 0.0300,  0.1010,  0.5131, -0.4423, -0.1324,  0.0652],
         [ 0.0296,  0.1010,  0.5125, -0.4421, -0.1323,  0.0650],
         [ 0.0304,  0.1011,  0.5135, -0.4432, -0.1323,  0.0666],
         [ 0.0307,  0.1007,  0.5135, -0.4417, -0.1327,  0.0663],
         [ 0.0310,  0.1013,  0.5146, -0.4438, -0.1325,  0.0669]],

        [[ 0.0000,  0.0996,  0.5098, -0.4397, -0.1321,  0.0000],
         [ 0.0300,  0.1010,  0.5131, -0.4423, -0.1324,  0.0000],
         [ 0.0296,  0.1010,  0.5125, -0.4421, -0.1323,  0.0650],
         [ 0.0304,  0.1011,  0.5135, -0.4432, -0.1323,  0.0666],
         [ 0.0000,  0.1007,  0.5135, -0.4417, -0.1327,  0.0663],
         [ 0.0310,  0.1013,  0.5146, -0.4438, -0.1325,  0.0669]]],
       grad_fn=<MulBackward0>)

In [3]:
import torch
import torch.nn as nn

In [5]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.34, 0.87, 0.66],  # your      (x^1)
     [0.55, 0.87, 0.66, 0.22, 0.58, 0.33],  # journey   (x^2)
     [0.57, 0.85, 0.64, 0.05, 0.80, 0.55],  # starts    (x^3)
     [0.22, 0.58, 0.34, 0.45, 0.57, 0.34],  # with      (x^4)
     [0.77, 0.25, 0.10, 0.32, 0.05, 0.13],  # one       (x^5)
     [0.05, 0.81, 0.55, 0.78, 0.29, 0.15]]  # step      (x^6)
 )

In [6]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape)

torch.Size([2, 6, 6])


In [7]:
batch_size , context_length , d_in = batch.shape
d_out = 6

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in:int, d_out:int, num_of_heads:int,
                 context_length:int,dropout:float,qkv_bias:bool=False):
        super().__init__()
        assert (d_out % num_of_heads == 0)\
            , "d_out is not divisible by num_of_heads"
        self.d_out = d_out
        self.num_of_heads = num_of_heads
        self.context_length = context_length
        self.head_dim = d_out // num_of_heads
        self.W_Query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_Key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_Value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.projLayer = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length,context_length),
                       diagonal=1)
        )
    

    def forward(self,x):
        b,num_of_token,d_in = x.shape
        Query = self.W_Query(x)
        Key   = self.W_Key(x)
        Value = self.W_Value(x)

        # reshape: (b,Num_of_token,d_in) ---> (b,num_of_token,num_of_heads,head_dim)
        Query = Query.view(b,num_of_token,self.num_of_heads,self.head_dim)
        Key   = Key.view(b,num_of_token,self.num_of_heads,self.head_dim)
        Value = Value.view(b,num_of_token,self.num_of_heads,self.head_dim)

        #grouping by Heads ( Transpose with respect to 1 and 2 shape)
        Query = Query.transpose(1,2)
        Key   = Key.transpose(1,2)
        Value = Value.transpose(1,2)

        #calculate attention score and reshape (b,num_of_heads, num_of_token,,head_dim) ---> (b,num_of_token,num_of_heads,head_dim)
        Atten_Score = Query @ Key.transpose(2,3)

        # create Mask 
        mask_bool = self.mask.bool()[:num_of_token,:num_of_token]

        # apply mask
        Atten_Score.masked_fill(mask_bool,-torch.inf)

        #calculate attention weight (softmax and devide by sqrt of head_dim)
        Atten_Weight = torch.softmax(
            Atten_Score/ self.head_dim ** 0.5 ,
            dim = -1
        )

        context_vector = (Atten_Weight @ Value).transpose(1,2)

        context_vector = context_vector.contiguous().view(b, num_of_token,self.d_out)

        context_vector = self.projLayer(context_vector)

        return self.dropout(context_vector)

In [26]:
mha = MultiHeadAttention(d_in = d_in, d_out = d_out, num_of_heads=2,context_length=context_length,
                         dropout=0.1,qkv_bias=False)

mha(batch)

tensor([[[ 1.2440e-03, -4.3129e-02,  2.7337e-01,  2.1775e-01,  2.2385e-01,
          -5.3225e-01],
         [ 1.3242e-04, -4.3032e-02,  2.7282e-01,  2.1738e-01,  2.2316e-01,
          -5.3243e-01],
         [-7.7254e-05, -4.3099e-02,  2.7289e-01,  2.1734e-01,  2.2295e-01,
          -5.3218e-01],
         [-4.8952e-04, -4.3376e-02,  2.7316e-01,  2.1807e-01,  0.0000e+00,
          -5.3218e-01],
         [ 1.9017e-03, -4.2268e-02,  2.7227e-01,  2.1845e-01,  2.2485e-01,
          -0.0000e+00],
         [-4.1028e-04, -4.3252e-02,  2.7321e-01,  2.1785e-01,  2.2255e-01,
          -5.3176e-01]],

        [[ 1.2440e-03, -0.0000e+00,  2.7337e-01,  0.0000e+00,  2.2385e-01,
          -5.3225e-01],
         [ 1.3242e-04, -4.3032e-02,  2.7282e-01,  2.1738e-01,  2.2316e-01,
          -5.3243e-01],
         [-7.7254e-05, -4.3099e-02,  2.7289e-01,  2.1734e-01,  2.2295e-01,
          -5.3218e-01],
         [-4.8952e-04, -4.3376e-02,  2.7316e-01,  2.1807e-01,  2.2285e-01,
          -5.3218e-01],
        