In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module): #inherits from nn.Module
    def __init__(self, d_in, d_out): # contructor of the class
        super().__init__() # intialize the parent class
        # keyword self in a classs refers to the instance of the class
        self.d_in = d_in
        self.d_out = d_out
        # create a layer that applies an affine transformation to the input
        # y = Ax + b, where A is a weight matrix and b is a bias vector
        # Weights intialized with a uniform distribution
        # its weights and biases are stored as torch.nn.Parameter objects.
        # This makes them part of the model’s .parameters() 
        # returns the parameters of the model when called
        self.Q = nn.Linear(d_in, d_out) 
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x) # apply the affine transformation to the input x
        keys = self.K(x)
        values = self.V(x)
        # Compute the attention scores, bmm is batch matrix multiplication
        # scores = queries * keys^T / sqrt(d_out)
        scores = torch.bmm(queries, keys.transpose(1, 2)) 
        # keys.transpose(1, 2) transposes the last two dimensions
        # (batch_size, seq_len, d_out) -> (batch_size, d_out, seq_len)
        scores = scores / (self.d_out ** 0.5)
        attention = F.softmax(scores, dim=2)
        # converts the attention scores into probabilities along the last dimension, 
        # so each set of scores sums to 1 for every query in the batch.
        hidden_states = torch.bmm(attention, values)
        return hidden_states


##### Q. Why do we have muliple attention heads ?

To attend to information from different representation subspaces at differen positions. Module computes several attention in parallel, each with it own learn projection.

Token embedding is identicial for the same sequence for all heads, for each head, model applies a different linear transformation to the embedding to produce Q, K, V. Each head computes attention using Q, K, V as each head sees the input differently. The output are concatenated and mixed, allowing the model to combine information.

##### Q. If all projection matrices start randomly and see the same input, why don't all attention heads learn the same thing ?

W_Q, W_K, W_V - Start with random values. Do not necessarily converge to same solution. Updated during training independently, During BackPropagation, gradient for each head's parameters depend on - head's own output, loss function, interaction with other heads. Hence each head to receive different gradient updates.

- Random Intialization -- Different staring points
- Independent Weights -- Unique learning paths
- Gradient Updates -- Driven by head-specific outputs
- Loss Optimization -- Encourages diversity for better results


In [4]:
# MultiheadAttention class that uses multiple Attention heads
# What is hidden_size and num_heads? -
# refer to the dimensiaonlity of input & output vectors
# if model is used for NLP tasks, hidden_size is the size of the word embeddings
# num_heads is the number of attention heads to use -
# each head will learn different representations of the input data

class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size # size of input & outpu vectors
        self.num_heads = num_heads # number of attention heads
        self.out = nn.Linear(hidden_size, hidden_size) # linear layer
        self.head = nn.ModuleList([
            Attention(hidden_size, hidden_size // num_heads)
            for _ in range(num_heads)
        ]) # create a list of Attention heads # each head has its own set of weights and biases
        # The hidden size is divided by the number of heads to ensure 
        # that each head has a smaller dimensionality, allowing the model 
        # to learn different representations of the input data.
    
    def foward(self, x):
        outputs = [head(x) for head in self.head]
        outputs = torch.cat(outputs, dim=-1)
        hidden_states = self.out(outputs)
        return hidden_states

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.qkv_linear = nn.Linear(hidden_size, hidden_size * 3)
        self.out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, hidden_size = x.size()

        qkv = self.qkv_linear(x) # [batch_size, seq_length, hidden_size * 3]
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim) # [batch_size, seq_length, num_heads, head_dim * 3]
        qkv = qkv.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim * 3]
        queries, keys, values = qkv.chunk(3, dim=-1)# [batch_size, num_heads, seq_length, head_dim]

        scores = torch.matmul(queries, keys.transpose(2, 3)) # [batch_size, num_heads, seq_lenght, seq_length]
        scores = scores / (self.head_dim ** 0.5)# [batch_size, num_heads, seq_lenght, seq_length]
        attention = F.softmax(scores, dim=-1)# [batch_size, num_heads, seq_lenght, seq_length]
        context = torch.matmul(attention, values) #[batch_size, num_heads, seq_lenght, seq_length]
        
        context = context.transpose(1, 2) # [bs, sl, num_heads, sl]
        context = context.reshape(batch_size, seq_length, hidden_size) # [batch_size, seq_length, hidden_size]
        output = self.out(context) # [batch_size, seq_length, hidden_size]

        return output

In [12]:
import torch
import torch.nn as nn

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout=0.5, qkv_bias=False):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
                        diagonal=1)
        )
        # buffers are automatically moved to appropiate device, along with our model
        # Hence no need to manually ensure that tensor on same device as model params

    def forward(self, x):
        b, num_tokens, d_in = x.shape # Keeping the batch at dimension  0
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        attention_scores = queries @ keys.transpose(1, 2) # Transposing dim 1 and 2
        attention_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )# in PyTorch, Operation with trailing _ is performed in-place. 
        # Avoiding unnecessary memory copies

        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        attention_weights = self.dropout(attention_weights)
        context_vec = attention_weights @ values

        return context_vec


torch.manual_seed(123)

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

batch = torch.stack((inputs, inputs), dim=0) 
print(batch.shape)

d_in, d_out = batch.shape[-1], 2
context_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length, dropout=0.0)
print(ca(batch))

torch.Size([2, 6, 3])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


#### Stacking multiple single-head attention


In [13]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False, num_heads=2):
        super().__init__()
        self.heads = torch.nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)
            ])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

torch.manual_seed(123)
d_in, d_out = batch.shape[-1], 2
context_length = batch.shape[1]

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0)
print(mha(batch))

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads=4, qkv_bias=False):
        super().__init__()

        assert(d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim


        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # Shape -> (3, 4)
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias) # Shape -> (3, 4)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias) # Shape -> (3, 4)
        self.out_proj = nn.Liner(d_in, d_out) # layer to combine output heads
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape # x Shape -> (2, 6, 3)

        queries = self.W_query(x) # queries Shape -> (2, 6, 4) - (b, num_token, d_out)
        keys = self.W_keys(x)
        values = self.W_values(x)

        # unroll last dim -> (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose to (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2) 
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

