In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("PyTorch version:", torch.__version__)
if torch.cuda.is_available():
    print("CUDA device name:", torch.cuda.get_device_name(0))

Using device: cuda
PyTorch version: 2.7.1+cu126
CUDA device name: NVIDIA GeForce RTX 4090


# Self Attention

In [70]:
batch_size = 4
sequence_length = 64
embed_dim = 128
x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input tensor:", x.shape)

similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

attention_matrix = similarity.softmax(dim=-1)

output = attention_matrix @ x

print("Shape of output tensor:", output.shape)

Shape of input tensor: torch.Size([4, 64, 128])
Shape of output tensor: torch.Size([4, 64, 128])


In [71]:
class Attention(nn.Module):
    def __init__(self, embedding_dimension):
        super().__init__()
        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1,2)) / self.embed_dim**0.5
        attention = similarity.softmax(axis=-1)
        out = attention @ v

        return out
    
rand = torch.rand(4, 64, 128)
attn = Attention(embedding_dimension=128)
output = attn(rand)
print(output.shape)

torch.Size([4, 64, 128])


## Multi-Head Attention

In [72]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads
        self.head_dim = self.embed_dim //self.num_heads
        self.multihead_qkv = nn.ModuleList()

        for head in range(self.num_heads):
            qkv_proj = nn.ModuleDict([
                ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                ["K", nn.Linear(self.embed_dim, self.head_dim)],
                ["V", nn.Linear(self.embed_dim, self.head_dim)]
            ])

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
    
    
    def forward(self, x):
        heads_out = []
        for head in self.multihead_qkv:
            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)
            
            similarity = (q @ k.transpose(1,2)) / self.head_dim**0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            heads_out.append(output)

        heads_out = torch.cat(heads_out, dim=-1)

        out = self.proj(heads_out)
        return out
    

rand = torch.randn(4,64,128)
attn = MultiHeadAttention(128, 2)
out = attn(rand)
print(out.shape)

torch.Size([4, 64, 128])


In [73]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p=0.0, proj_p=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if self.embed_dim % self.num_heads != 0:
            raise ValueError(f"Embedding dimension {self.embed_dim} must be divisible by number of heads {self.num_heads}")
        self.head_dim = self.embed_dim // self.num_heads

        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        if embed_dim != self.embed_dim:
            raise ValueError(f"Input embedding dimension {embed_dim} does not match model's expected dimension {self.embed_dim}")
        
        # (batch_size, seq_len, embed_dim) reshaped to (batch_size, seq_len, num_heads, head_dim) transposed to (batch_size, num_heads, seq_len, head_dim)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)

        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v

        # (batch_size, num_heads, seq_len, head_dim) transposed to (batch_size, seq_len, num_heads, head_dim) reshaped to (batch_size, seq_len, embed_dim)    reverse the process
        out = out.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out


rand = torch.randn(4, 16, 128)
attn = SelfAttentionEncoder(128, 2)
output = attn(rand)
output.shape

torch.Size([4, 16, 128])

## Padding Masking

In [74]:
### Create an example attention matrix (b x h x n x n) ###
rand_attn = torch.rand(1,2,6,6) # I have 2 heads here!

### Create Attention Mask in the shape (b x n) ###
attention_mask = torch.tensor([1,1,1,1,0,0]).unsqueeze(0).bool()

print("Method 1:")
print("--------")
### Add Two Extra Dimension for the (b x h x n x n) ###
### So unsqueeze mask to be (b x 1 x 1 x n) ###
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)

### Unsqueezed with dummy broadcast dimension ###
print(attention_mask)
print(rand_attn.masked_fill_(~attention_mask, float("-inf")))

print("Method 2:")
print("--------")
### Repeat the Dummy Dimension for seq_len so attention mask is (b 1 x n x n) ###
attention_mask = attention_mask.repeat(1,1,6,1) # repeat dummy middle dim 6 times (for the seq_len) 
print(attention_mask)
print(rand_attn.masked_fill_(~attention_mask, float("-inf")))

Method 1:
--------
tensor([[[[ True,  True,  True,  True, False, False]]]])
tensor([[[[0.9137, 0.6858, 0.8030, 0.1105,   -inf,   -inf],
          [0.5115, 0.3593, 0.1674, 0.7502,   -inf,   -inf],
          [0.6505, 0.0342, 0.9939, 0.2707,   -inf,   -inf],
          [0.1533, 0.9564, 0.7679, 0.0235,   -inf,   -inf],
          [0.5737, 0.1570, 0.0986, 0.3515,   -inf,   -inf],
          [0.7620, 0.3097, 0.0724, 0.0090,   -inf,   -inf]],

         [[0.9492, 0.2097, 0.2889, 0.2773,   -inf,   -inf],
          [0.7308, 0.7691, 0.6744, 0.4852,   -inf,   -inf],
          [0.3991, 0.8991, 0.6009, 0.3711,   -inf,   -inf],
          [0.8267, 0.2496, 0.6943, 0.1098,   -inf,   -inf],
          [0.1869, 0.0369, 0.0927, 0.2785,   -inf,   -inf],
          [0.6996, 0.4437, 0.6242, 0.7508,   -inf,   -inf]]]])
Method 2:
--------
tensor([[[[ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True, 

In [75]:
class SelfAttentionPadding(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p=0.0, proj_p=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if self.embed_dim % self.num_heads != 0:
            raise ValueError(f"Embedding dimension {self.embed_dim} must be divisible by number of heads {self.num_heads}")
        self.head_dim = self.embed_dim // self.num_heads

        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask=None):
        
        batch_size, seq_len, embed_dim = x.shape
        if embed_dim != self.embed_dim:
            raise ValueError(f"Input embedding dimension {embed_dim} does not match model's expected dimension {self.embed_dim}")
        
        # (batch_size, seq_len, embed_dim) reshaped to (batch_size, seq_len, num_heads, head_dim) transposed to (batch_size, num_heads, seq_len, head_dim)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)

        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        

        #################################################################################
        if attention_mask is not None:
            # attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)   # (batch_size, 1, 1, seq_len)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, seq_len, 1)  # repeat dummy middle dim for seq_len (above line also works, but this is better for flash attention)
            attn = attn.masked_fill(~attention_mask, float("-inf"))  # Apply the attention mask
        #################################################################################

        
        attn = attn.softmax(dim=-1)
        # print("Attention Matrix after softmax:\n", attn)
        attn = self.attn_drop(attn)
        out = attn @ v

        # (batch_size, num_heads, seq_len, head_dim) transposed to (batch_size, seq_len, num_heads, head_dim) reshaped to (batch_size, seq_len, embed_dim)    reverse the process
        out = out.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out


### We will now have sequences of different lengths, identify the number of tokens in each sequence ###
seq_lens = [3,5,4]
embed_dim = 9
num_heads = 3
a = SelfAttentionPadding(embed_dim, num_heads)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
### This will be a tensor upto the max(seq_lens) ###
rand = torch.randn(len(seq_lens),max(seq_lens),embed_dim)

### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention Mask:")
print(masks)

### Pass through MHA ###
output = a(rand, attention_mask=masks)
print("Final Output:", output.shape)

Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
Final Output: torch.Size([3, 5, 9])


## Causal Masking

In [76]:
seq_len = 8
ones = torch.ones(seq_len, seq_len)
causal_mask = torch.tril(ones).bool()
print("Causal Mask:\n", causal_mask)

padding_mask = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0]).bool()
padding_mask = padding_mask.unsqueeze(0).repeat(seq_len, 1)
print("Padding Mask:\n", padding_mask)

combined_mask = causal_mask.masked_fill(~padding_mask, False)
print("Combined Mask:\n", combined_mask)

Causal Mask:
 tensor([[ True, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True]])
Padding Mask:
 tensor([[ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, Fa

In [77]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, causal=True, attn_p=0.0, proj_p=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if self.embed_dim % self.num_heads != 0:
            raise ValueError(f"Embedding dimension {self.embed_dim} must be divisible by number of heads {self.num_heads}")
        self.head_dim = self.embed_dim // self.num_heads
        self.causal = causal

        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask=None):
        
        batch_size, seq_len, embed_dim = x.shape
        if embed_dim != self.embed_dim:
            raise ValueError(f"Input embedding dimension {embed_dim} does not match model's expected dimension {self.embed_dim}")
        
        # (batch_size, seq_len, embed_dim) reshaped to (batch_size, seq_len, num_heads, head_dim) transposed to (batch_size, num_heads, seq_len, head_dim)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)


        #################################################################################
        if self.causal:
            causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
            causal_mask = causal_mask.reshape(1, 1, seq_len, seq_len)
            if attention_mask is not None:
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, seq_len, 1)
                causal_mask = causal_mask.repeat(batch_size, 1, 1, 1)  # repeat for batch size
                causal_mask = causal_mask.masked_fill(~attention_mask, False)  # Apply padding mask to causal mask
            
            attn = attn.masked_fill(~causal_mask, float("-inf"))  # Apply the causal masks
        else:
            if attention_mask is not None:
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, seq_len, 1)
                attn = attn.masked_fill(~attention_mask, float("-inf"))  #
        #################################################################################


        attn = attn.softmax(dim=-1)
        # print("Attention Matrix after softmax:\n", attn)
        attn = self.attn_drop(attn)
        out = attn @ v

        # (batch_size, num_heads, seq_len, head_dim) transposed to (batch_size, seq_len, num_heads, head_dim) reshaped to (batch_size, seq_len, embed_dim)    reverse the process
        out = out.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out


### We will now have sequences of different lengths, identify the number of tokens in each sequence ###
seq_lens = [3,5,4]
embed_dim = 9
num_heads = 3
a = SelfAttention(embed_dim, num_heads)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
### This will be a tensor upto the max(seq_lens) ###
rand = torch.randn(len(seq_lens),max(seq_lens),embed_dim)

### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention Mask:")
print(masks)

### Pass through MHA ###
output = a(rand, attention_mask=masks)
print("Final Output:", output.shape)

Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
Final Output: torch.Size([3, 5, 9])


# Cross Attention

In cross attention, the **query** comes from the decoder, while the **key** and **value** come from the encoder. This mechanism is essential in models like the Transformer, where the decoder needs to attend to the encoder's output to generate each token in the target sequence.

**Example:**  
When translating a sentence from English to French, the query is the French sentence being generated (decoder), and the key and value are the encoded representations of the English sentence (encoder).

Unlike self-attention in the decoder, **causal masking is not used in cross attention**. This is because, during translation, the decoder should be able to attend to all tokens in the encoder's output at each step, as the entire source sentence is available. But, **padding masking** is applied to ignore the padded tokens in the encoder's output, ensuring that attention is only paid to valid (non-padded) tokens.

However, **causal masking is used in the decoder's self-attention** when generating the French sentence. This ensures that each position can only attend to previous (or current) French words, preventing the model from "seeing the future" during generation.


In [78]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p=0.0, proj_p=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if self.embed_dim % self.num_heads != 0:
            raise ValueError(f"Embedding dimension {self.embed_dim} must be divisible by number of heads {self.num_heads}")
        self.head_dim = self.embed_dim // self.num_heads

        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias = bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, src, tgt, attention_mask=None):

        batch_size, src_seq_len, embed_dim = src.shape
        _, tgt_seq_len, _ = tgt.shape
        if embed_dim != self.embed_dim:
            raise ValueError(f"Input embedding dimension {embed_dim} does not match model's expected dimension {self.embed_dim}")
        
        q = self.query(tgt).reshape(batch_size, tgt_seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k = self.key(src).reshape(batch_size, src_seq_len, self.num_heads, self.head_dim).transpose(1,2)
        v = self.value(src).reshape(batch_size, src_seq_len, self.num_heads, self.head_dim).transpose(1,2)

        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        ### NOTE: 
        ### attn.shape - (Batch x num_heads x french_seq_len x english_seq_len)
        ### mask.shape - (Batch x english_seq_len)
        ### Need to expand mask (Batch x english_seq_len) -> (Batch x 1 x 1 x english_seq_len) -> (Batch x 1 x french_seq_len x english_seq_len)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, tgt_seq_len, 1)
            attn = attn.masked_fill(~attention_mask, float("-inf"))  # Apply the attention mask

        attn = attn.softmax(dim=-1)
        # print("Attention Matrix after softmax:\n", attn)
        attn = self.attn_drop(attn)
        out = attn @ v

        out = out.transpose(1, 2).reshape(batch_size, tgt_seq_len, self.embed_dim)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out


### We will now have sequences of different lengths, identify the number of tokens in each sequence ###
english_seq_lens = [3,5,4]
french_seq_lens = [7,6,2]

embed_dim = 18
num_heads = 3
a = CrossAttention(embed_dim, num_heads)

### Create random tensor in the shape (Batch x Seq Len x Embed Dim) for French and English ###
### This will be a tensor upto the max(seq_lens) ###
rand_english = torch.randn(len(english_seq_lens),max(english_seq_lens),embed_dim)
rand_french = torch.randn(len(french_seq_lens),max(french_seq_lens),embed_dim)


### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
english_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in english_seq_lens], batch_first=True, padding_value=0).bool()
french_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in french_seq_lens], batch_first=True, padding_value=0).bool()

print("English Attention Mask:")
print(english_masks)
print("French Attention Mask:")
print(french_masks)

### Pass through MHA ###
output = a(src=rand_english, tgt=rand_french, attention_mask=english_masks)
print("Final Output:", output.shape)

English Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
French Attention Mask:
tensor([[ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False],
        [ True,  True, False, False, False, False, False]])
Final Output: torch.Size([3, 7, 18])


# Flash Attention

[Flash Attention](https://github.com/Dao-AILab/flash-attention) is a highly efficient implementation of the attention mechanism that fuses matrix multiplication, scaling, and softmax into a single CUDA kernel. This reduces memory overhead and speeds up computation, especially for long sequences.

PyTorch provides access to Flash Attention via [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). This function expects:

- **Queries**: `(B x H x L x E)`
- **Keys**: `(B x H x S x E)`
- **Values**: `(B x H x S x E)`
- **attn_mask**: `(B x 1 x L x S)`, where `False` indicates positions to mask.

It also has an `is_causal` flag to automatically apply causal masking.

Flash Attention supports both self-attention (`L = S`) and cross-attention (`L ≠ S`), when queries are different from keys/values as we saw in cross-attention. The only extra step on our end is to make sure our ```attn_mask``` is of shape (B x 1 x L x S), which means we have to do the extra repeat, as it wont automatically broadcast along the $L$ dimension if we left it as (B x 1 x 1 x S)

By merging operations and minimizing memory transfers, Flash Attention achieves significant speedups over standard attention implementations.

In [79]:
class Attention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, attn_drop=0.0, proj_drop=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop

        assert self.embed_dim % self.num_heads == 0, f"Embedding dimension {self.embed_dim} must be divisible by number of heads {self.num_heads}"
        self.head_dim = self.embed_dim // self.num_heads

        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)


    def forward(self, 
                src,
                tgt=None,
                attention_mask=None,
                causal=False):
        
        batch, src_len, embed_dim = src.shape
        
        ### Self Attention ###
        if tgt is None:
            q = self.q_proj(src).reshape(batch, src_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(src).reshape(batch, src_len, self.num_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(src).reshape(batch, src_len, self.num_heads, self.head_dim).transpose(1, 2)

            if attention_mask is not None:
                attention_mask = attention_mask.bool()
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, src_len, 1)
            
        ### Cross Attention ###
        else:
            tgt_len = tgt.shape[1]
            q = self.q_proj(tgt).reshape(batch, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(src).reshape(batch, src_len, self.num_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(src).reshape(batch, src_len, self.num_heads, self.head_dim).transpose(1, 2)

            if attention_mask is not None:
                attention_mask = attention_mask.bool()
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 1, tgt_len, 1)

            causal = False # Causal masking is not used in cross attention


        ##################################################################
        attention_out = F.scaled_dot_product_attention(
            q, k, v, 
            attn_mask=attention_mask,
            dropout_p=self.attn_drop if self.train else 0.0,
            is_causal=causal
        )
        ##################################################################


        attention_out = attention_out.transpose(1, 2).flatten(2)  # (batch, num_heads, seq_len, head_dim) -> (batch, seq_len, num_heads * head_dim)
        attention_out = self.out_proj(attention_out)
        attention_out = F.dropout(attention_out, p=self.proj_drop, training=self.training)

        return attention_out


Testing Flash Attention

In [80]:
### Test Out Self-Attention!! ###
print("TESTING SELF-ATTENTION!!!")
print("-------------------------")
seq_lens = [3,5,4]
embed_dim = 9
num_heads = 3
a = Attention(embed_dim, num_heads)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
### This will be a tensor upto the max(seq_lens) ###
rand = torch.randn(len(seq_lens),max(seq_lens),embed_dim)

### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention Mask:")
print(masks)

### Pass through MHA ###
output = a(rand, attention_mask=masks, causal=True)
print("Final Output:", output.shape, "\n")


print("TESTING CROSS-ATTENTION!!!")
print("-------------------------")
### Test out Cross Attention 
### We will now have sequences of different lengths, identify the number of tokens in each sequence ###
english_seq_lens = [3,5,4]
french_seq_lens = [7,6,2]

embed_dim = 9
num_heads = 3
a = Attention(embed_dim, num_heads)

### Create random tensor in the shape (Batch x Seq Len x Embed Dim) for French and English ###
### This will be a tensor upto the max(seq_lens) ###
rand_english = torch.randn(len(english_seq_lens),max(english_seq_lens),embed_dim)
rand_french = torch.randn(len(french_seq_lens),max(french_seq_lens),embed_dim)


### Create Attention Mask from the seq_lens (shortest sequences padded to the longest ###
english_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in english_seq_lens], batch_first=True, padding_value=0).bool()
french_masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in french_seq_lens], batch_first=True, padding_value=0).bool()

print("English Attention Mask:")
print(english_masks)
print("French Attention Mask:")
print(french_masks)

### Pass through MHA ###
output = a(src=rand_english, tgt=rand_french, attention_mask=english_masks)
print("Final Output:", output.shape)

TESTING SELF-ATTENTION!!!
-------------------------
Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
Final Output: torch.Size([3, 5, 9]) 

TESTING CROSS-ATTENTION!!!
-------------------------
English Attention Mask:
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
French Attention Mask:
tensor([[ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False],
        [ True,  True, False, False, False, False, False]])
Final Output: torch.Size([3, 7, 9])
