In [2]:
import torch.nn as nn

In [3]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out,context_length,dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
    

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5 , dim = -1
        )

        attn_weights = self.dropout(attn_weights)
        context_vector = attn_weights @ values

        return context_vector

In [4]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self,d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in=d_in,d_out=d_out,context_length=context_length,dropout=dropout,qkv_bias=qkv_bias)
             for _ in range(num_heads)]
        )


    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim =-1)

In [5]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # your      (x^1)
     [0.55, 0.87, 0.66],  # journey   (x^2)
     [0.57, 0.85, 0.64],  # starts    (x^3)
     [0.22, 0.58, 0.34],  # with      (x^4)
     [0.77, 0.25, 0.10],  # one       (x^5)
     [0.05, 0.81, 0.55]]  # step      (x^6)
 )

In [6]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape)

torch.Size([2, 6, 3])


In [7]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in,d_out = 3,2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.2, num_heads=2)

context_vector = mha(batch)

print("context vector shape:", context_vector.shape)
print("context vector:", context_vector)

context vector shape: torch.Size([2, 6, 4])
context vector: tensor([[[-0.5110, -0.0883,  0.6368,  0.4393],
         [-0.5787, -0.1032,  0.5558,  0.3643],
         [-0.6665, -0.1351,  0.4901,  0.3319],
         [-0.6633, -0.1346,  0.6355,  0.4348],
         [-0.6649, -0.1334,  0.4611,  0.3365],
         [-0.6635, -0.1353,  0.6357,  0.4376]],

        [[-0.5694, -0.0991,  0.3843,  0.2654],
         [-0.4338, -0.0615,  0.4903,  0.3317],
         [-0.1747, -0.0626,  0.6366,  0.4391],
         [-0.5744, -0.1023,  0.6355,  0.4348],
         [-0.6649, -0.1334,  0.6356,  0.4316],
         [-0.4856, -0.0715,  0.3453,  0.2246]]], grad_fn=<CatBackward0>)


In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0),\
            "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length,context_length),
                       diagonal=1)
        )

    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # we implicitly split the matrix by adding a num_heads dim
        # unroll last dim: (b,Num_tokens, d_out) -> (b, num_token, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # group matricies by num_head by transpose
        # (b, num_token, num_heads, head_dim) -> (b, num_heads, num_token,head_dim)
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        
        attn_scores = queries @ keys.transpose(2,3)


        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]

        attn_scores.masked_fill(mask_bool, -torch.inf)
        
        #apply softmax
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5 , dim = -1
        )

        # dropout layer
        attn_weights = self.dropout(attn_weights)

        # transpose attn_weights
        context_vec = (attn_weights @ values).transpose(1,2)

        #combine result obtain from multiple heads self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)


        return context_vec        

In [13]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.34, 0.87, 0.66],  # your      (x^1)
     [0.55, 0.87, 0.66, 0.22, 0.58, 0.33],  # journey   (x^2)
     [0.57, 0.85, 0.64, 0.05, 0.80, 0.55],  # starts    (x^3)
     [0.22, 0.58, 0.34, 0.45, 0.57, 0.34],  # with      (x^4)
     [0.77, 0.25, 0.10, 0.32, 0.05, 0.13],  # one       (x^5)
     [0.05, 0.81, 0.55, 0.78, 0.29, 0.15]]  # step      (x^6)
 )

In [14]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape)

torch.Size([2, 6, 6])


In [15]:
batch_size , context_length , d_in = batch.shape
d_out = 6

mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.1, num_heads=2)
context_vectors = mha(batch)

print(context_vectors)
print("context_vector shape:", context_vectors.shape)

tensor([[[ 0.3980, -0.1468,  0.0311, -0.2454, -0.4149,  0.2771],
         [ 0.3973, -0.1466,  0.0313, -0.2445, -0.4155,  0.2776],
         [ 0.3976, -0.1468,  0.0320, -0.2450, -0.4162,  0.2772],
         [ 0.3948, -0.1291,  0.0339, -0.2582, -0.3931,  0.2916],
         [ 0.3659, -0.1050,  0.0071, -0.2283, -0.3692,  0.3204],
         [ 0.3884, -0.1324,  0.0384, -0.2502, -0.4000,  0.2874]],

        [[ 0.3477, -0.1118, -0.0356, -0.1733, -0.3515,  0.3293],
         [ 0.3894, -0.1286,  0.0118, -0.2411, -0.3704,  0.3039],
         [ 0.3945, -0.1281,  0.0122, -0.2509, -0.3808,  0.2917],
         [ 0.3962, -0.1452,  0.0288, -0.2430, -0.4126,  0.2798],
         [ 0.3949, -0.1289,  0.0330, -0.2582, -0.3919,  0.2917],
         [ 0.3750, -0.1207,  0.0155, -0.2324, -0.3839,  0.3038]]],
       grad_fn=<ViewBackward0>)
context_vector shape: torch.Size([2, 6, 6])
