In [0]:
from transformers import AutoTokenizer, AutoModel, AutoConfig
from torch import nn
import torch
from math import sqrt
import transformers
from typing import List, Tuple



In [0]:
model_ckpt = "bert-base-uncased"

In [0]:

def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    """
    Applies scaled dot product attention mechanism.

    Args:
        query (torch.Tensor): Query tensor with shape `(batch_size, query_len, hidden_dim)`.
        key (torch.Tensor): Key tensor with shape `(batch_size, key_len, hidden_dim)`.
        value (torch.Tensor): Value tensor with shape `(batch_size, value_len, hidden_dim)`.

    Returns:
        torch.Tensor: Attention output tensor with shape `(batch_size, query_len, hidden_dim)`.
    """
    # Get the dimension of the key tensor which the number of laten factors (features) in our case 768
    dim_k = query.size(-1)

    # Calculate the scaled dot product of the query and key tensors, then scale it
    attention_scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)

    # Apply softmax to the attention scores
    weights = torch.nn.functional.softmax(attention_scores, dim=-1)

    # Multiply the value tensor with the weights tensor to get the updated version of your inputs, but now it's included the potential encoder + embedding itself,
    # which means you have now contextual information in your embedding
    attention_output = torch.bmm(weights, value)

    return attention_output


In [0]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        """
        Initializes the AttentionHead module.

        Args:
            embed_dim: The dimensionality of the input embeddings.
            head_dim: The dimensionality of the projected query, key, and value tensors.
        """
        super().__init__()
        # These three projection layers are learned parameters of the model, 
        # initialized randomly, and trained to fit the data during the training process.
       #  Each of these projection layers takes as input a tensor with dimension embed_dim and outputs a tensor with dimension head_dim.
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        # print('=========================== Query looks like =========================')
        # print(self.q)
    def forward(self, hidden_state):
        """
        Computes the attention output of the module given the input embeddings.

        Args:
            hidden_state: A tensor of shape (batch_size, sequence_length, embed_dim) representing the input embeddings.

        Returns:
            A tensor of shape (batch_size, sequence_length, head_dim) representing the attention output.
        """
        # Compute query, key, and value tensors using the projection layers
        # hidden state is our input_embeds of (batch_size, sequence_length, embed_dim)
        # But the linear projection will map this hidden_state of size (batch_size, sequence_length, embed_dim) into (batch_size, sequence_length, head_dim)
        query = self.q(hidden_state)
        key   = self.k(hidden_state)
        value = self.v(hidden_state)

        # Uncomment to see the size is (batch_size, sequence_length, head_dim) because of the projection
        # print(query.size())

        # Compute attention using the query, key, and value tensors
        # Here is the return was before [batch_size, sequence_length, embed_dim], but because of we project this into head_dim it will be [batch_size, sequence_length, head_dim]
        attn_outputs = scaled_dot_product_attention(query, key, value)

        # is the same as hidden_state except embed_dim is now the projection size.
        # In practice this projection size is multiple of the embed_dim which head_dim = embed_dim / number of heads
        # So at the end we will have same embed_dim number but its distributed across different heads
        # For example, BERT has 12 attention heads, so the dimension of each head is 768/12 = 64 for each head
        print(attn_outputs.size())
        return attn_outputs 


In [0]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention layer.

    Args:
    - config: A configuration object with the following attributes:
        * hidden_size: The hidden size of the input tensors.
        * num_attention_heads: The number of attention heads to use.
    
    Attributes:
    - heads: A list of `AttentionHead` instances, each representing a single attention head.
    - output_linear: A linear layer applied to the concatenated outputs of all the attention heads.
    """
    def __init__(self, config):
        super().__init__()
        # Calculate the dimension of each attention head
        embed_dim = config.hidden_size # for beart 768
        num_heads = config.num_attention_heads # for bert 12 attention head
        head_dim = embed_dim // num_heads # which we project the embed_dim to head_dim so from 768 into 64 

        # Create a list of AttentionHead instances, each of these heads are 3 indpendent layers of Q, K, V, each takes embed_dim and return head_dim
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        print(len(self.heads))
        print('='*50)
        print(self.heads)

        # Linear layer applied to the concatenated outputs of all the attention heads
        self.output_linear = nn.Linear(embed_dim, embed_dim)


    def forward(self, hidden_state):
        """
        Apply multi-head attention to the input tensor.

        Args:
        - hidden_state: A tensor of shape `(batch_size, seq_len, hidden_size)` representing the input.

        Returns:
        - A tensor of shape `(batch_size, seq_len, hidden_size)` representing the output of the layer.
        """
        
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        print('='*50)
        print(x.size())
        print(type(x))
        print('='*50)
        # return back to my original input dims
        x = self.output_linear(x)
        print(x.size())
        print(type(x))
        print('='*50)
        return x


In [0]:
model_config = AutoConfig.from_pretrained(model_ckpt)
multihead_attn = MultiHeadAttention(model_config)

12
ModuleList(
  (0): AttentionHead(
    (q): Linear(in_features=768, out_features=64, bias=True)
    (k): Linear(in_features=768, out_features=64, bias=True)
    (v): Linear(in_features=768, out_features=64, bias=True)
  )
  (1): AttentionHead(
    (q): Linear(in_features=768, out_features=64, bias=True)
    (k): Linear(in_features=768, out_features=64, bias=True)
    (v): Linear(in_features=768, out_features=64, bias=True)
  )
  (2): AttentionHead(
    (q): Linear(in_features=768, out_features=64, bias=True)
    (k): Linear(in_features=768, out_features=64, bias=True)
    (v): Linear(in_features=768, out_features=64, bias=True)
  )
  (3): AttentionHead(
    (q): Linear(in_features=768, out_features=64, bias=True)
    (k): Linear(in_features=768, out_features=64, bias=True)
    (v): Linear(in_features=768, out_features=64, bias=True)
  )
  (4): AttentionHead(
    (q): Linear(in_features=768, out_features=64, bias=True)
    (k): Linear(in_features=768, out_features=64, bias=True)
    (

In [0]:
text = ["time flies like an arrow and tokenizer"]
# Call the tokenizer related to this model
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# send the tokenizer and text to get the input_ids and attention_mask from the returned inputs
tokenized_texts = tokenized_texts = tokenizer.batch_encode_plus(
        text,
        return_tensors="pt",  # Return PyTorch tensors
        add_special_tokens=False  # Add special tokens if required
    )


In [0]:
model = AutoModel.from_pretrained(model_ckpt)

input_ids = tokenized_texts["input_ids"]
attention_masks = tokenized_texts["attention_mask"]



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [0]:

with torch.no_grad():
        inputs_embeds = model(input_ids)[0]

attn_output = multihead_attn(inputs_embeds)
attn_output.size()
    

torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 64])
torch.Size([1, 8, 768])
<class 'torch.Tensor'>
torch.Size([1, 8, 768])
<class 'torch.Tensor'>
Out[23]: torch.Size([1, 8, 768])