<a href="https://colab.research.google.com/github/Pythonista7/deeply-learning/blob/main/Transofrmer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a transformer


In [2]:
import torch
import torch.nn as nn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Multi Head Attention
Starting off with the core matter at hand, lets implemet attention block.

In [4]:
# @title
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,heads = 6):
        super().__init__()
        self.d_model = d_model
        self.heads = heads

        if d_model % heads != 0:
          raise ValueError("d_model must be divisible by heads")

        self.d_head = d_model // heads

        # So its really of shape [d_model, H * d_head]
        # Think of each of these as H blocks concatenated horizontally
        # Like WQ​=[WQ(1)​∣WQ(2)​∣…∣WQ(H)​] where each WQ(i) is of shape [d_model,d_head] stacked H times.
        self.Q_linear_projection_layer = nn.Linear(self.d_model,self.d_model,device=device)
        self.K_linear_projection_layer = nn.Linear(self.d_model,self.d_model,device=device)
        self.V_linear_projection_layer = nn.Linear(self.d_model,self.d_model,device=device)

        self.softmax = nn.Softmax(dim = -1)
        self.W_o = nn.Linear(d_model,d_model,device=device)

    def forward(self,Q,K,V,mask=None):
      """
      We assumme to get matrices of dimensions [ B, T, d_model].
      We calculate based on number of heads the dimension of each head as d_head = d_model/H where  where H is the number of heads.
      We use additive mask, assumming the input mask to be containing [0 or 1], 1 to preserve and 0 to hide, this can be easily generated using torch.tril()
      """
      # Linear Projections
      q_linear_projections = self.Q_linear_projection_layer(Q)
      k_linear_projections = self.K_linear_projection_layer(K)
      v_linear_projections = self.V_linear_projection_layer(V)

      # Reshape output for heads
      # [B,T, H * d_head]
      B,T,D = q_linear_projections.shape
      q_linear_projections = q_linear_projections.reshape(B,T,self.heads,self.d_head).transpose(1,2) # [B,H,T,d_head]
      assert q_linear_projections.shape == (B,self.heads,T,self.d_head)

      B,T,D = k_linear_projections.shape
      k_linear_projections = k_linear_projections.reshape(B,T,self.heads,self.d_head).transpose(1,2) # [B,H,T,d_head]
      assert k_linear_projections.shape == (B,self.heads,T,self.d_head)

      B,T,D = v_linear_projections.shape
      v_linear_projections = v_linear_projections.reshape(B,T,self.heads,self.d_head).transpose(1,2) # [B,H,T,d_head]
      assert v_linear_projections.shape == (B,self.heads,T,self.d_head)

      # Scaled Attention
      scores = (q_linear_projections @ k_linear_projections.transpose(-2,-1) ) / (q_linear_projections.shape[-1] ** 0.5)

      # note this might not work in case of cross attention we will need to specifically separate T_q and T_k , rn they are equal.
      assert scores.shape == (B,self.heads,T,T)

      if mask is not None:
        # Doing this will cause the softmax -> 0 where numbers are close to -inf hence masking them
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        scores = scores + mask

      scores = self.softmax(scores)

      # This catches wrong softmax dim and mask weirdness instantly , basically tries to sum the dim on which softmax was applied and checks if its close to 1 which it should be.
      assert torch.allclose(scores.sum(dim=-1), torch.ones_like(scores.sum(dim=-1)), atol=1e-4, rtol=1e-4)

      scaled_attn_result = scores @ v_linear_projections

      # Concat , but this is really doing the reverse of what we did in the reshape previous to scaled attention.
      scaled_attn_result = scaled_attn_result.transpose(1,2).reshape(B,T,D)

      # Output Projection
      result = self.W_o(scaled_attn_result)
      return result


In [5]:
mha = MultiHeadAttention(840)

In [6]:
random_input = torch.rand((2,10,840))

In [7]:
res = mha.forward(random_input,random_input,random_input)

In [8]:
res.shape

torch.Size([2, 10, 840])

Now that we have a basic attention block setup lets focus on a couple other fundamental blocks required to piece together an Encoder and a Decoder! Next we will required 2 things : 1. an embedding layer and 2. positional encoder. `1.` tells the attention block "what" and `2.` convey's "where" in the sequence, both of which are crucial.

## Positional Encoding
We need to calculate a P.E value of dims [T,d_model] as per the paper and we do a add to the input embedding, broadcasting should automatically handle adding PE to batches since PE will be same across the batch for a given position.

$PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}} )$

$PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}} )$

where `pos` is the position and `i` is the dimension

Notice that the input angles to both sin and cos are the same. so lets generate that that first and then apply sin and cos alternatively across the range.

In [9]:
def get_positional_encoding(T,d_model):
  """
  T: sequence length
  d_model: model dimentions , also same as embedding dims
  """
  # Alot of things from this function can be precomputed ,stored and reused for better performance.
  i = torch.arange(0,d_model,2)
  # this is written as ^-1 so it can be multiplied insted of div
  div = torch.exp(-torch.log(torch.tensor(10000,device=device))*i/d_model).unsqueeze(0) # (1, d_model//2)
  pos = torch.arange(0,T,1,device=device).unsqueeze(1) # [T,1]
  # print(f"pos shape {pos.shape} , div shape {div.shape}")
  angles = pos * div
  positional_encodings = torch.zeros((T,d_model))
  positional_encodings[:,0::2] = torch.sin(angles)
  positional_encodings[:,1::2] = torch.cos(angles)
  return positional_encodings

In [10]:
test_pos_enc_data = torch.rand((10,840))
get_positional_encoding(10,840)

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2955e-01,  ...,  1.0000e+00,
          1.0222e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.2649e-01,  ...,  1.0000e+00,
          2.0443e-04,  1.0000e+00],
        ...,
        [ 6.5699e-01,  7.5390e-01,  5.3540e-01,  ...,  1.0000e+00,
          7.1552e-04,  1.0000e+00],
        [ 9.8936e-01, -1.4550e-01,  9.9962e-01,  ...,  1.0000e+00,
          8.1774e-04,  1.0000e+00],
        [ 4.1212e-01, -9.1113e-01,  5.8103e-01,  ...,  1.0000e+00,
          9.1995e-04,  1.0000e+00]])

Since a lot of data can be cache and reused instead of recomputing the positional encodings for each token, we can use `register_buffer` as a device aware cache to populate a tensor outside the computation graph to store such values.

In [11]:
class PositionalEncoder(nn.Module):
  def __init__(self, T, d_model , **kwargs) -> None:
    super().__init__( **kwargs)
    self.sequence_len = T
    self.d_model = d_model

    pe = self.get_positional_encoding(self.sequence_len,self.d_model)
    self.register_buffer('pos_enc',pe)


  def get_positional_encoding(self,T,d_model):
    """
    T: sequence length
    d_model: model dimentions , also same as embedding dims
    """
    # Alot of things from this function can be precomputed ,stored and reused for better performance.
    i = torch.arange(0,d_model,2)
    # this is written as ^-1 so it can be multiplied insted of div
    div = torch.exp(-torch.log(torch.tensor(10000,device=device))*i/d_model).unsqueeze(0) # (1, d_model//2)
    pos = torch.arange(0,T,1,device=device).unsqueeze(1) # [T,1]
    # print(f"pos shape {pos.shape} , div shape {div.shape}")
    angles = pos * div
    positional_encodings = torch.zeros((T,d_model))
    positional_encodings[:,0::2] = torch.sin(angles)
    positional_encodings[:,1::2] = torch.cos(angles)
    return positional_encodings

  def forward(self,X):
    input_seq_len = X.shape[1]
    assert input_seq_len <= self.sequence_len , f"Input sequence length {input_seq_len} is greater than positional encoder sequence length {self.sequence_len}"
    return X + self.pos_enc[:input_seq_len,:]

In [12]:
test_pos_enc_data = torch.rand((4,10,840))

In [13]:
test_pe = PositionalEncoder(10,840)
test_pe(test_pos_enc_data)

tensor([[[ 0.7473,  1.3762,  0.2759,  ...,  1.4673,  0.0939,  1.8608],
         [ 0.8826,  1.4445,  1.7402,  ...,  1.4413,  0.2259,  1.8707],
         [ 1.0566,  0.0261,  0.9311,  ...,  1.4470,  0.8691,  1.4051],
         ...,
         [ 0.9384,  1.2002,  0.6811,  ...,  1.9732,  0.3912,  1.7361],
         [ 1.9576,  0.2992,  1.6011,  ...,  1.3092,  0.0315,  1.9008],
         [ 0.7229, -0.3882,  1.5355,  ...,  1.5789,  0.5343,  1.9023]],

        [[ 0.4694,  1.9663,  0.7528,  ...,  1.9207,  0.2482,  1.0108],
         [ 1.0322,  0.7881,  1.2624,  ...,  1.9114,  0.6086,  1.8982],
         [ 1.0248, -0.1682,  1.7162,  ...,  1.4582,  0.9976,  1.2365],
         ...,
         [ 1.0778,  1.7237,  0.7035,  ...,  1.2487,  0.2576,  1.0321],
         [ 1.0873,  0.5317,  1.6854,  ...,  1.5681,  0.5731,  1.1959],
         [ 1.1321, -0.8077,  0.8261,  ...,  1.4930,  0.0783,  1.0357]],

        [[ 0.0038,  1.4287,  0.1158,  ...,  1.8288,  0.4352,  1.6492],
         [ 1.5294,  0.8382,  0.9557,  ...,  1

## Token Embedding
What this does is take the token input and convert it into representational embeddings encoding data per token incliding both "what the token is" with the embedding layer and also "where the token is" with the positional encoding.

In [22]:
class InputEmbedding(nn.Module):
  def __init__(self,vocab_size,d_model,max_seq_len = 1024):
    super().__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.embedding_layer = nn.Embedding(self.vocab_size,self.d_model)
    self.pos_encoder = PositionalEncoder(max_seq_len,self.d_model)

  def forward(self,X):
    """
    X: input vector of shape (batch_size,seq_len) with values in range [0,vocab_size]
    """
    input_embedding = self.embedding_layer(X) * torch.sqrt(torch.tensor(self.d_model,device=device)) # scaling this by sqrt(d_model) as suggested in the paper.
    pos_encoded_input = self.pos_encoder(input_embedding)
    return pos_encoded_input


In [24]:
# A small test to see if this works
B, T = 2, 5
vocab_size = 100
d_model = 32

X = torch.randint(0, vocab_size, (B, T))   # must be long
embed = InputEmbedding(vocab_size, d_model, max_seq_len=10)

Y = embed(X)

print("X:", X.shape, X.dtype)
print("Y:", Y.shape, Y.dtype)


X: torch.Size([2, 5]) torch.int64
Y: torch.Size([2, 5, 32]) torch.float32


## Encoder


In [31]:
class Encoder(nn.Module):
  def __init__(self, d_model, vocab_size, no_of_heads,max_seq_len=1024):
    super().__init__()
    self.d_ff = 2048
    # The input embedding is going to be generated outside the Encoder since we want to "stack" up encoders
    # stack = feeding the output of one EncoderLayer into the next, repeatedly, with new learnable parameters each time.
    # self.input_embedding = InputEmbedding(d_model=d_model, vocab_size=vocab_size,max_seq_len=max_seq_len)
    self.MHA = MultiHeadAttention(d_model = d_model, heads = no_of_heads)
    self.norm1 = nn.LayerNorm(d_model)
    self.linear1 = nn.Linear(d_model,self.d_ff)
    self.relu1 = nn.ReLU()
    self.linear2 = nn.Linear(self.d_ff,d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self,X):
    # The input is going to be of dims [B,T,d_model]
    attn = self.MHA(Q=X,K=X,V=X)
    layer_norm1 = self.norm1(X + attn) # Note we also add the residual skip conn here
    feed_forward1 = self.linear1(layer_norm1)
    a1 = self.relu1(feed_forward1)
    feed_forward2 = self.linear2(a1)
    layer_norm2 = self.norm2(feed_forward2 + layer_norm1) # residual skip connection
    return layer_norm2

## Decoder


In [None]:
class Decoder(nn.Module):
  def __init__(self) -> None:
    super().__init__()

  def forward(self,encoder_output):
    pass