# Chapter 4 - Lab 4 - Exercise
> Author : Badr TAJINI - Large Language model (LLMs) - ESIEE 2024-2025

> Response by Paul CASCARINO E5-DSIA

# Exercise 4.1: Parameters in the feed forward versus attention module

### How do the parameter counts differ between the `feed-forward` neural network module and `multi-head attention` mechanism in our transformer architecture?**

#### 0. Setup

In [26]:
from importlib.metadata import version

import matplotlib
import tiktoken
import torch
import torch.nn as nn

GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}


class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))
    

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)
    

# Previous lab
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        #assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec



class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"], 
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        return x
    


class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
    

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)

import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

batch = []

txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)
print(batch)

out = model(batch)
print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)


total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])
Input batch:
 tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])

Output shape: torch.Size([2, 4, 50257])
tensor([[[ 0.1381,  0.0077, -0.1963,  ..., -0.0222, -0.1060,  0.1717],
         [ 0.3865, -0.8408, -0.6564,  ..., -0.5163,  0.2369, -0.3357],
         [ 0.6989, -0.1829, -0.1631,  ...,  0.1472, -0.6504, -0.0056],
         [-0.4290,  0.1669, -0.1258,  ...,  1.1579,  0.5303, -0.5549]],

        [[ 0.1094, -0.2894, -0.1467,  ..., -0.0557,  0.2911, -0.2824],
         [ 0.0882, -0.3552, -0.3527,  ...,  1.2930,  0.0053,  0.1898],
         [ 0.6091,  0.4702, -0.4094,  ...,  0.7688,  0.3787, -0.1974],
         [-0.0612, -0.0737,  0.4751,  ...,  1.2463, -0.3834,  0.0609]]],
       grad_fn=<UnsafeViewBackward0>)
Total number of parameters: 163,009,536


#### 1. Enumerate parameters in `feed-forward` module

The `feed-forward` module is composed of 2 linear layers and an intermediate activation function (ReLU or GeLU). We use the code below to implement that part in the lab : 

In [27]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

As we can see, the total number of parameters ($n_{par}$) in `feed-forward` module can be obtained with : 

- First Layer : $n_{par1} = Weights + bias = emb_{dim} * 4 * emb_{dim} + 4 *emb_{dim} = 4*emb_{dim}^2+4*emb_{dim}$

- Second Layer : $n_{par2} = Weights + bias = 4 * emb_{dim} * emb_{dim} + emb_{dim} = 4*emb_{dim}^2+emb_{dim}$

- Total : $n_{par} = 8*emb_{dim}^2+5*emb_{dim}$

We can have $emb_{dim}$ with the configuration code below : 

In [28]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

So, the total number of parameters ($n_{par}$) in `feed-forward` module is 4 722 732:

$n_{par} = 8*768^2+5*768 = 4722732$ 


#### 2. Enumerate parameters in `multi-head attention` module

The `multi-head attention` module is composed of 3 linear layers and a Output Projection layer. We use the code below to implement that part in the lab : 

In [29]:
# class MultiHeadAttention(nn.Module):
#     def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):

#         #...

#         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
#         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
#         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
#         self.out_proj = nn.Linear(d_out, d_out)

#         #...

As we can see, the total number of parameters ($n_{par}$) in `multi-head attention` module can be obtained with : 

- 3 First Layer : $n_{par1} = 3*(d_{in}*d_{out} + d_{out})=3*d_{in}*d_{out}+3*d_{out}$

- Output Projection layer : $n_{par2} = Weights + bias = d_{out}*d_{out}+d_{out}=d_{out}^2+d_{out}$

- Total : $n_{par} = 3*d_{in}*d_{out}+3*d_{out} + d_{out}^2+d_{out} = d_{out}^2 + 3*d_{in}*d_{out} + 4*d_{out}$

We can have $d_{out}$ and $d_{in}$ with the code below : 

In [30]:
# class TransformerBlock(nn.Module):
#     def __init__(self, cfg):
#         super().__init__()
#         self.att = MultiHeadAttention(
#             d_in=cfg["emb_dim"],
#             d_out=cfg["emb_dim"],
#             ...


GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

So, the total number of parameters ($n_{par}$) in `multi-head attention` module is 879 168:

$n_{par} = 768^2 + 3*768*768 + 4*768$ = 4*768^2+4*768 = 879168

#### 3. Perform comparative statistical analysis


To conclude, we have $n_{par feed-forward } = 4 722 732$ $n_{par multi-head attention } = 879 168$.

It implies that  ($n_{par}$) in our example multi-head attention` is around 5.3 times less important than ($n_{par}$) in `feed-forward` and require around 5.3 times less ressources to perform


#### 4. Interpret parametric distribution characteristics

We have : 

- $n_{par feed-forward } = 8*emb_{dim}^2+5*emb_{dim}$ whose have 
  - a quadratic term 8*emb_{dim}^2
  - a linear term 5*emb_{dim}

- $n_{par multi-head attention } = d_{out}^2 + 3*d_{in}*d_{out} + 4*d_{out}$ with $d_{out}=d_{in}=emb_{dim}$
  - a quadratic term 4*emb_{dim}^2
  - a linear term 4*emb_{dim}

The ratio is $(8*emb_{dim}^2+5*emb_{dim}) / (4*emb_{dim}^2+4*emb_{dim}) \simeq 8/4= 2$

`feed-forward` modules are more parameter-heavy, implying they require more memory and computational resources for training and inference.

# Exercise 4.2: Initialize larger GPT models

### Can you systematically scale the GPT-2 model architecture from the small configuration to medium, large, and XL variants by exclusively modifying the configuration parameters?

#### 1. GPT-2 Small (Current Implementation)


In [31]:
GPT_CONFIG_small = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

model = GPTModel(GPT_CONFIG_small)


total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 163,009,536


$n_{tot} = vocab_{size}*context_{length} + n_{layers}*((4+8)*emb_{dim}^2+(4+5)*emb_{dim}) + vocab_{size}*context_{length}$


- With counting the of output layer : 

  $n_{tot} = 2*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 2*50257*768 + 12*(12*768^2+9*768) = 162212352  \simeq 162M$

- Without counting the of output layer : 

  $n_{tot} = 1*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 1*50257*768 + 12*(12*768^2+9*768) = 123614976  \simeq 123M$

  The number of parameter of GPT2 small is 162M counting the of output layer 123M otherwise

#### 2. GPT-2 Medium

In [32]:
GPT_CONFIG_medium= {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 1024,         # Embedding dimension
    "n_heads": 24,          # Number of attention heads
    "n_layers": 16,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

model = GPTModel(GPT_CONFIG_medium)


total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 305,467,392


- With counting the of output layer : 

  $n_{tot} = 2*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 2*50257*1024 + 24*(12*1024^2+9*1024) = 405 137 408  \simeq 405M$

- Without counting the of output layer : 

  $n_{tot} = 1*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 1*50257*1024 + 24*(12*1024^2+9*1024) = 353 674 240  \simeq 353M$

  The number of parameter of GPT2 medium is 405M counting the of output layer 353M otherwise


#### 3. GPT-2 Large


In [33]:
GPT_CONFIG_large = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 1280,         # Embedding dimension
    "n_heads": 36,          # Number of attention heads
    "n_layers": 20,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}


model = GPTModel(GPT_CONFIG_large)


total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")


Total number of parameters: 523,443,200


- With counting the of output layer : 

  $n_{tot} = 2*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 2*50257*1280 + 36*(12*1280^2+9*1280) = 836 861 440  \simeq 836M$

- Without counting the of output layer : 

  $n_{tot} = 1*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 1*50257*1280 + 36*(12*1280^2+9*1280) = 772532480  \simeq 772M$

  The number of parameter of GPT2 large is 836M counting the of output layer 772M otherwise

#### 4. GPT-2 XL

In [34]:
GPT_CONFIG_xl = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 1600,         # Embedding dimension
    "n_heads": 48,          # Number of attention heads
    "n_layers": 25,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}


model = GPTModel(GPT_CONFIG_xl)


total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 930,864,000


- With counting the of output layer : 

  $n_{tot} = 2*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 2*50257*1600 + 48*(12*1600^2+9*1600) = 1 636 073 600  \simeq 1636M$

- Without counting the of output layer : 

  $n_{tot} = 1*vocab_{size}*context_{length} + n_{layers}*(12*emb_{dim}^2+9*emb_{dim})$

  $n_{tot small} = 1*50257*1600 + 48*(12*1600^2+9*1600) = 1555662400  \simeq 1555M$

  The number of parameter of GPT2 extra large is 1636M counting the of output layer 1555M otherwise


  To conclude :
  
  - $n{par M} > 2.5*n{par S}$ 
  - $n{par L} > 2*n{par M}$ 
  - $n{par XL} > 2*n{par L}$ 


# Exercise 4.3: Using separate dropout parameters

### How can we enhance the dropout configuration of the GPT model by implementing layer-specific dropout rates?

#### 1. Replace the monolithic `drop_rate` parameter

In the current implementation, a single dropout rate is used for the entire model via the *drop_rate* variable. 

However, this approach does not allow customization of the regularizations applied to the different components of the architecture.

We will add in our config : 

- drop_emb : This will control the dropout rate for the embedding layer.

- drop_residual : This will handle the dropout applied to the residual (shortcut) connections.

- drop_attention : This will manage the dropout used specifically in the multi-head attention layer.

In [None]:
GPT_CONFIG_124M_new = {
    "vocab_size": 50257,        # Vocabulary size
    "context_length": 1024,     # Context length
    "emb_dim": 768,             # Embedding dimension
    "n_heads": 12,              # Number of attention heads
    "n_layers": 12,             # Number of layers
    "drop_emb": 0.1,            # Embedding dropout rate
    "drop_residual": 0.1,       # Shortcut dropout rate
    "drop_attention": 0.1,      # Attention dropout rate
    "qkv_bias": False           # Query-Key-Value bias
}


#### 2. Introduce a hierarchical dropout configuration and Maintain the overall structural integrity of the model architecture

*drop_emb* applied in the embedding layers (tok_emb and pos_emb) in the GPTModel. 

We also pass he *drop_residual* and *drop_attention* parameters to the respective components in the TransformerBlock and MultiHeadAttention.

In [36]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_emb"])  # Embedding dropout

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)         # Apply embedding dropout
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


*drop_residual*: Applied to the residual (shortcut) connections in the TransformerBlock.

we need to pass the *drop_residual* parameter for the residual connections and *drop_attention* for the attention dropout.

In [37]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"], 
            dropout=cfg["drop_attention"],  # Apply attention-specific dropout
            qkv_bias=cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_residual = nn.Dropout(cfg["drop_residual"])  # Apply residual dropout

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_residual(x)  # Apply residual dropout
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_residual(x)  # Apply residual dropout
        x = x + shortcut  # Add the original input back

        return x

*drop_attention:* Applied to the attention weights in the MultiHeadAttention.

We will pass *drop_attention* parameter for the attention-specific dropout within the attention mechanism.

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  # Apply attention dropout

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec