<a href="https://colab.research.google.com/github/TahaBinhuraib/Transformers_Notes/blob/main/TranformersTutorial_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
#@title Imports and downloads
%%capture
!pip install transformers
import torch.nn as nn
from transformers import AutoConfig
from transformers import AutoTokenizer
import torch
from math import sqrt
import torch.nn.functional as F


In [7]:
# utilities regarding the model and tokenizer
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

In [3]:
# The text we want to analyze
text = 'time flies like an arrow'
# We want to implement the first step of the tranformer model where we find
# Q, K, V 

inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
# Each token is mapped to an ID in the tokenizer vocabulary.
# We will then use an nn.embedding layer that will transform each token to 
# a 768 dimension vector.                                                   
print(inputs.input_ids)

tensor([[ 2051, 10029,  2066,  2019,  8612]])


In [5]:
config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [8]:
# Batch = 1, 5 token, each token is a 768 dimensional vector
# (batch_size, seq_len, hidden_dim)
input_embeds = token_emb(inputs.input_ids)
input_embeds.size()

torch.Size([1, 5, 768])

In [9]:
query = key = value = input_embeds
# should be equal to 768
dim_k = key.size(-1)

# torch.transpose(input, dim0, dim1)
# The first dimension is the batch, that is why we dont transpose it.
# Usually we have weight matrices Wq,k,v applied to the embeddings.
# (5, 768) * (768, 5) = (5, 5)
scores = torch.bmm(query, key.transpose(1,2))/sqrt(dim_k)
scores.size()

torch.Size([1, 5, 5])

In [14]:
# check wich type of matrix
# output z will be a symmetric matrix
x = torch.rand(3,5)
z = torch.mm(x, x.transpose(0,1))
print(z)

tensor([[3.1543, 1.4546, 1.5078],
        [1.4546, 0.7140, 0.7770],
        [1.5078, 0.7770, 1.1728]])


In [19]:
# Now we have to normalize by applying softmax so the sum over each column
# Equal to one
# the dimension should be the last one as our weights are stored there. 
weights = F.softmax(scores, dim = -1)

# Now we will multiply the attention weights by the values
# The aim is to reach the dimensionality of the input
attention_out = torch.bmm(weights, value)
attention_out.shape

torch.Size([1, 5, 768])

In [21]:
# Let's turn all this into a function we can use later

def scaled_attn_dot_product(q, k, v) -> torch.Tensor:
  dim_k = query.size(-1)
  scores = torch.bmm(q, k.transpose(1,2))/sqrt(dim_k)
  weights = F.softmax(scores, dim=-1)
  out = torch.bmm(weights, v)
  return out

In [22]:
# multiheaded attention will allow the self-attention layer to focus on different 
# semantic aspects of the sequence
# We will apply three independent linear transformations to each embedding(q, k, v)
# They will all carry their own set of learnable parameters
# Don't forget nn.Linear only applies: y = torch.mm(x,transpose(A)) + b
# The A matrix will contain parameters we can learn!

class attention_head(nn.Module):
  def __init__(self,embed_dim, head_dim):
    super().__init__()
    # Embedding dimensions = 768
    self.q = nn.Linear(embed_dim, head_dim)
    self.k = nn.Linear(embed_dim, head_dim)
    self.v = nn.Linear(embed_dim, head_dim)
  
  def forward(self, hidden_state):
    att_out = scaled_attn_dot_product(self.q(hidden_state),self.k(hidden_state),
                                      self.v(hidden_state))
    return att_out




In [50]:
class multiheaded_attention(nn.Module):
  def __init__(self, config):
    super().__init__()
    embed_dims = config.hidden_size
    num_heads = config.num_attention_heads
    # In BERT they used 768/12
    head_dim = embed_dims // num_heads
    self.heads = nn.ModuleList([attention_head(embed_dims, head_dim)
      for _ in range(num_heads)])
  
    self.out_linear = nn.Linear(embed_dims, embed_dims)

  def forward(self, hidden_state):
    x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
    x = self.out_linear(x)
    return x


In [51]:
multiheaded_attention = multiheaded_attention(config)
attn_out = multiheaded_attention(input_embeds)
# 5 tokens each with a dimensionality of config.hidden_size
attn_out.size()

torch.Size([1, 5, 768])