<a href="https://colab.research.google.com/github/JohnPrabhasith/Attention_Transformers/blob/main/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

***Attention*** is an essential component of neural network Transformers, which are driving the current excitement in Large Language Models and AI. Specifically, a Decoder-Only Transformer, illustrated below, is the foundation for the popular model **ChatGPT**.

In [None]:
class MaskedSelfAttention(nn.Module):

  def __init__(self, d_model = 2, row_dim = 0, col_dim = 1):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    self.row_dim = row_dim
    self.col_dim = col_dim



  def forward(self, token_encoding, mask = None):
    q = self.W_q(token_encoding)
    k = self.W_k(token_encoding)
    v = self.W_v(token_encoding)

    sims = torch.matmul(q, k.transpose(dim0 = self.row_dim, dim1 = self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask = mask, value = -1e9)

    attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
    attention_scores = torch.matmul(attention_percents, v)

    return attention_scores


In [None]:
encoding_matrix = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

In [None]:
torch.manual_seed(42)

<torch._C.Generator at 0x7e62bce9d170>

In [None]:
masked_self_attention = MaskedSelfAttention(d_model = 2, row_dim = 0, col_dim = 1)

In [None]:
mask = torch.tril(torch.ones(3,3))
mask = (mask == 0)

In [None]:
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [None]:
masked_self_attention(encoding_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [None]:
masked_self_attention.W_q.weight.transpose(0,1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [None]:
## Encoder-Decoder Attention

In [None]:
class Attention(nn.Module):
  def __init__(self, d_model = 2, row_dim = 0, col_dim = 1):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    self.row_dim = row_dim
    self.col_dim = col_dim

  def forward(self, encoding_for_q, encoding_for_v, encoding_for_k, mask=None):
    q = self.W_q(encoding_for_q)
    k = self.W_k(encoding_for_k)
    v = self.W_v(encoding_for_v)

    sims = torch.matmul(q, k.transpose(dim0 = self.row_dim, dim1 = self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask = mask, value = -1e9)

    attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
    attention_scores = torch.matmul(attention_percents, v)

    return attention_scores


In [None]:
encoding_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])
encoding_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])
encoding_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

torch.manual_seed(42)

<torch._C.Generator at 0x7e62bce9d170>

In [None]:
attention = Attention(d_model = 2, row_dim = 0, col_dim = 1)

In [None]:
attention(encoding_for_q, encoding_for_v, encoding_for_k)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [None]:
## MultiHead_Attention

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model = 2, row_dim = 0, col_dim = 1, num_heads = 1):
    super().__init__()
    self.heads = nn.ModuleList(
        [Attention(d_model, row_dim, col_dim) for _ in range(num_heads)]
    )

    self.col_dim = col_dim
    self.row_dim = row_dim

  def forward(self, encoding_for_q, encoding_for_v, emcoding_for_k, mask = None):
    return torch.cat([head(encoding_for_q, encoding_for_v, emcoding_for_k) for head in self.heads], dim = self.col_dim)


In [None]:
torch.manual_seed(42)

<torch._C.Generator at 0x7e62bce9d170>

In [None]:
multi_head_attention = MultiHeadAttention(2,0,1,1)

In [None]:
multi_head_attention(encoding_for_q, encoding_for_v, encoding_for_k)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)