Transformer from scratch using pyTorch

*torch* for Core PyTorch library — tensors, models, autograd, etc.

*torchvision* , *torchaudio* can be imported for audio/image processing.

In [1]:
!pip3 install torch tor

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

Multi-Head Attention Class

*Self-attention focuses on finding the realtions among the words in the sequence. Multihead self-attention does the same but in multiple ways to get info like grammatical structure, semantic meaning, etc,.*

1. word -> token -> q,k,v

  A word is converted to tokens(smallest unit of text), then converted to vectors. Each vector is then is divided into query, key, value. Query is what is needed. Key is what we have. Value is the actual data.
2. split heads -> what is a head and why head?

  An head is an attention unit. Each head focuses on each aspect of the sentence. Number of heads === Number of ways different ways relationship between tokens are calculated.
3. masking

  Masking blocks attention to specific tokens (e.g. padding or future tokens) by setting their scores to -inf before softmax.
4. softmax?

  Softmax is used to convert attnetion scores from matmul to probabilities.
5. what is linear projection here?

  Linear projection nn.Linear is used to convert input vectors into desired dimension.

In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self, dim_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.dim_model = dim_model
    self.num_heads = num_heads

    self.dim_head = dim_model // num_heads  # dim_model must be divisible by num_heads

    self.query = nn.Linear(dim_model, dim_model)
    self.key = nn.Linear(dim_model, dim_model)
    self.value = nn.Linear(dim_model, dim_model)
    self.output = nn.Linear(dim_model, dim_model)

  def split_heads(self, x):
    batch_size, seq_length, dim_model = x.size()
    return x.view(batch_size, seq_length, self.num_heads, self.dim_head).transpose(1, 2)

  def combine_heads(self, x):
    batch_size, _, seq_length, dim_head = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.dim_model)

  def forward(self, q, k, v, mask=None):
    query = self.split_heads(self.query(q))
    key = self.split_heads(self.key(k))
    value = self.split_heads(self.value(v))

    # score calculation
    attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.dim_head)

    # apply mask for decoders/padding
    if mask is not None:
          attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

    # softmax
    attn_probs = torch.softmax(attn_scores, dim=-1)

    # multiply probalilities with value vector and combine the heads
    return self.output(self.combine_heads(torch.matmul(attn_probs, value)))

Feed Forward Network

*It is an MLP(Multi-layer perceptron) that is used to understand and refine each token's features individually.*
1. What does the layers do actually?

  Each hidden layer is used to expand the i/p token's dimension

2. ReLu is originally used in Transformers - necessary? Any alternative?

  ReLu is used to introduce non-linearity, so that the model can learn powerful mappings. GeLu(Gaussian Error Linear Unit) is smooth, probabilistic	and used in BERT, GPT


In [10]:
class PositionWiseFeedForward(nn.Module):  # can you give multiple layers here? - yes but increases time and compute
  def __init__(self, dim_model, dim_ff):
    super(PositionWiseFeedForward, self).__init__()
    self.fc1 = nn.Linear(dim_model, dim_ff) # dim_ff is usually higher than dim_model so that it can go deep and understand the token
    self.fc2 = nn.Linear(dim_ff, dim_model)
    self.gelu = nn.GELU()

  def forward(self, x):
    return self.fc2(self.gelu(self.fc1(x)))

Positional encoding

*Positional encoding is used to add positional data (i.e what lies where in the sequence/sentence) as transformers process tokens parallely.*
1. why sine and cosine ?

  Sine is applied for even and cosine for odd positions to distinguish adjacent tokens
2. how will this help for learning position ?

  By this, each position will have a unique pattern across dimnesions. Without this "She ate a cupcake after the dinner" and "After the dinner, she ate a cupcake" will be similar to the model. Their semantic meaning is same but they are not identical.




In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pos_enc = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pos_enc', pos_enc.unsqueeze(0)) # just saved and used, not backpropagated.

    def forward(self, x):
        return x + self.pos_enc[:, :x.size(1)]

Encoder

  One of the main components in Transformers where input is processed

  1. Why layer normalization?

    Stabalizes and speed-up learning

  2. How q=k=v here?

    Because every word/token in the input sentence requires everyother word to understand each other better. So, what we want, what we have and the actual data we have are all one and the same - the input.

  3. What is a dropout?

    Regularizes by randomly adding zeros in the input to avoid overfitting.

  4. Residue? What & why ?

    Adding back the original input to the output of a sublayer.
    Used to maintain the input signal, helps in preserving the input pattern while transforming the same.

  5. Why 2 norm and 2 dropout? Can't I reuse the same for ffn and multi-head attention?

    Should not reuse that because each normalization layer maintains it's own scale and shift. If we reuse ffn and mulit-head attention layer will have to share these params which limits their learning. Same follows for dropout also - let masking be layer-specific for better learning







In [6]:
class Encoder(nn.Module):
    def __init__(self, dim_model, num_heads, dim_ff, dropout=0.1):
        super(Encoder, self).__init__()
        self.attn = MultiHeadAttention(dim_model, num_heads)
        self.ffn = PositionWiseFeedForward(dim_model, dim_ff)

        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(dim_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.attn(x, x, x, mask)  # multi-head attention, q=k=v=x
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        ff_output = self.ffn(x) #feed-forward
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)

        return x

Decoder

  *They are crucial part of the transformers which generates output with the contextual data provided by the encoders and the data from previous layer.*

  1. Masking in decoders - 2 types?

    Masking is used to prevent attention to unnecessary parts - for eg. it's enough for the decoders to just know the output of it's previous layer.
    target_mask -> look-ahead mask in self-attn
    src-mask -> padding mask in cross-attention
    
  2. What is cross-attention?

    Mulit-head attention applied with q= decoder's o/p and k,v = encoder's output is called cross-attention.
  

In [7]:
class Decoder(nn.Module):
  def __init__(self, dim_model, num_heads, dim_ff, dropout=0.1):
    super(Decoder, self).__init__()
    self.attn = MultiHeadAttention(dim_model, num_heads)
    self.cross_attn = MultiHeadAttention(dim_model, num_heads)
    self.ffn = PositionWiseFeedForward(dim_model, dim_ff)

    self.norm1 = nn.LayerNorm(dim_model)
    self.norm2 = nn.LayerNorm(dim_model)
    self.norm3 = nn.LayerNorm(dim_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

  def forward(self, x, encoder_output, src_mask, target_mask):
    attn_output = self.attn(x, x, x, target_mask)
    x = x + self.dropout1(attn_output)
    x = self.norm1(x)

    cros_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask) # cross-attention
    x = x + self.dropout2(cros_attn_output)
    x = self.norm2(x)

    ff_output = self.ffn(x)
    x = x + self.dropout3(ff_output)
    x = self.norm3(x)

    return x


Transformer

1. Token → Embedding → Positional Encoding
2. Encoder → Self-attention + FFN
3. Decoder → Masked Self-attn → Cross-attn → FFN
4. Output → Linear → Logits for prediction


In [8]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, dim_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()

        self.encoder_embedding = nn.Embedding(src_vocab_size, dim_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, dim_model)
        self.positional_encoding = PositionalEncoding(dim_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([Encoder(dim_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([Decoder(dim_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(dim_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        # Padding mask: 1 for non-padding tokens, 0 for padding tokens
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, src_len)
        tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, tgt_len)

        seq_length = tgt.size(1)
        nopeak_mask = torch.tril(torch.ones((1, 1, seq_length, seq_length), device=tgt.device)).bool()  # (1, 1, tgt_len, tgt_len)

        tgt_mask = tgt_padding_mask & nopeak_mask  # (B, 1, tgt_len, tgt_len)
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output


Let's test the flow - model is not trained here

The sample data provided and the testing script below is just to find bugs if any.

In [11]:
vocab = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "I": 3,
    "am": 4,
    "a" : 5,
    "cupcake": 6,
    "hello": 7,
    "world": 8
}
vocab_size = len(vocab)

dim_model = 16
num_heads = 2
num_layers = 1
dim_ff = 64
max_seq_length = 10
dropout = 0.1

model = Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    dim_model=dim_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=dim_ff,
    max_seq_length=max_seq_length,
    dropout=dropout
)

src = torch.tensor([[3, 4, 5, 6, 2, 0]])  # shape: (1, 6)
tgt = torch.tensor([[1, 7, 8, 2, 0, 0]])  # shape: (1, 6)

with torch.no_grad():  # inference mode
    output = model(src, tgt)

print("Output logits shape:", output.shape)
predicted_ids = torch.argmax(output, dim=-1)
print("Predicted token IDs:\n", predicted_ids)

inv_vocab = {v: k for k, v in vocab.items()}
decoded_output = [[inv_vocab[token.item()] for token in sent] for sent in predicted_ids]
print("Predicted tokens:\n", decoded_output)


Output logits shape: torch.Size([1, 6, 9])
Predicted token IDs:
 tensor([[6, 3, 3, 8, 8, 6]])
Predicted tokens:
 [['cupcake', 'I', 'I', 'world', 'world', 'cupcake']]
