<a href="https://colab.research.google.com/github/a01110946/transformer/blob/main/src/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention is all you need

In [87]:
# Import all the required libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import Counter
import math
import numpy as np
import re

In [88]:
torch.manual_seed(23)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [89]:
# Define your constants

MAX_SEQ_LEN = 30

In [90]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_len=MAX_SEQ_LEN):
    super().__init__()
    self.pos_embed_matrix = torch.zeros(max_seq_len, d_model, device=device)
    token_pos = torch.arange(0, max_seq_len, dtype = torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float()
                         * (-math.log(10000.0)/d_model))
    self.pos_embed_matrix[:, 0::2] = torch.sin(token_pos * div_term)
    self.pos_embed_matrix[:, 1::2] = torch.cos(token_pos * div_term)
    self.pos_embed_matrix = self.pos_embed_matrix.unsqueeze(0).transpose(0, 1)

  def forward(self, x):
    return x + self.pos_embed_matrix[:x.size(0), :]

In [91]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model = 512, num_heads = 8):
    super().__init__()
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

    self.d_v = d_model // num_heads
    self.d_k = self.d_v
    self.num_heads = num_heads

    self.W_q = nn.Linear(d_model, d_model) # Query matrix
    self.W_k = nn.Linear(d_model, d_model) # Key matrix
    self.W_v = nn.Linear(d_model, d_model) # Value matrix
    self.W_o = nn.Linear(d_model, d_model) # Output matrix

  def forward(self, Q, K, V, mask=None):
    batch_size = Q.size(0)
    '''
    Q, K, V -> [batch_size, seq_len, num_heads*d_k]
    After transpose, Q, K, V -> [batch_size, num_heads, seq_len, d_k]
    '''
    Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    weighted_values, attention = self.scale_dot_product_attention(Q, K, V, mask=None)
    weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
    weighted_values = self.W_o(weighted_values)
    return weighted_values, attention

  def scale_dot_product_attention(self, Q, K, V, mask=None):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)
    attention = F.softmax(scores, dim=-1)
    weighted_values = torch.matmul(attention, V)
    return weighted_values, attention


In [92]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model = 512, d_ff = 2048):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
  def forward(self, x):
    x = F.relu(self.linear1(x))
    x = self.linear2(x)
    return x

In [93]:
class EncoderSublayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    attention_score, _ = self.self_attn(x, x, x, mask)
    x = x + self.dropout1(attention_score)
    x = self.norm1(x)
    x = x + self.dropout2(self.feed_forward(x))
    x = self.norm2(x)
    return x

In [94]:
class Encoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList([EncoderSublayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.norm = nn.LayerNorm(d_model)

  def forward(self, x, mask=None):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

In [95]:
class DecoderSublayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

  def forward(self, x, encoder_output, target_mask=None, encoder_mask=None):
    attention_score, _ = self.self_attn(x, x, x, target_mask)
    x = x + self.dropout1(attention_score)
    x = self.norm1(x)
    encoder_attn, _ = self.cross_attn(x, encoder_output, encoder_output, encoder_mask)
    x = x + self.dropout2(encoder_attn)
    x = self.norm2(x)
    ff_output = self.feed_forward(x)
    x = x + self.dropout3(ff_output)
    x = self.norm3(x)
    return x

In [96]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderSublayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, target_mask, encoder_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, target_mask, encoder_mask)
        return self.norm(x)

In [97]:
class Transformer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, num_layers,
               input_vocab_size, target_vocab_size,
               max_seq_len=MAX_SEQ_LEN, dropout=0.1):
    super().__init__()
    self.encoder_embedding = nn.Embedding(input_vocab_size, d_model)
    self.decoder_embedding = nn.Embedding(target_vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
    self.dropout = nn.Dropout(dropout)
    self.encoder = Encoder(d_model, num_heads, d_ff, num_layers)
    self.decoder = Decoder(d_model, num_heads, d_ff, num_layers)
    self.output_layer = nn.Linear(d_model, target_vocab_size)

  def mask(self, source, target):
        source_mask = (source != 0).unsqueeze(1).unsqueeze(2)
        target_mask = (target != 0).unsqueeze(1).unsqueeze(3)
        seq_length = target.size(1)
        nopeak_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device)).bool()
        target_mask = target_mask & nopeak_mask

        print(f"Source mask shape: {source_mask.shape}")
        print(f"Target mask shape: {target_mask.shape}")

        return source_mask, target_mask

  def forward(self, source, target):
        source_mask, target_mask = self.mask(source, target)
        source = self.encoder_embedding(source) * math.sqrt(self.encoder_embedding.embedding_dim)
        source = self.positional_encoding(source)
        source = self.dropout(source)

        print(f"Source shape after embedding: {source.shape}")

        encoder_output = self.encoder(source, source_mask)

        print(f"Encoder output shape: {encoder_output.shape}")

        target = self.decoder_embedding(target) * math.sqrt(self.decoder_embedding.embedding_dim)
        target = self.positional_encoding(target)
        target = self.dropout(target)

        print(f"Target shape after embedding: {target.shape}")

        output = self.decoder(target, encoder_output, target_mask, source_mask)
        output = self.output_layer(output)
        return output

## Simple Test

In [98]:
seq_len_source = 10
seq_len_target = 10
batch_size = 4
input_vocab_size = 50
target_vocab_size = 5

source = torch.randint(1, input_vocab_size, (batch_size, seq_len_source))
target = torch.randint(1, target_vocab_size, (batch_size, seq_len_target))
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6

model = Transformer(d_model, num_heads, d_ff, num_layers,
               input_vocab_size, target_vocab_size,
               max_seq_len=MAX_SEQ_LEN, dropout=0.1)

model = model.to(device)
source = source.to(device)
target = target.to(device)

output = model(source, target)
print(output.shape)

Source mask shape: torch.Size([4, 1, 1, 10])
Target mask shape: torch.Size([4, 1, 10, 10])
Source shape after embedding: torch.Size([4, 10, 512])
Encoder output shape: torch.Size([4, 10, 512])
Target shape after embedding: torch.Size([4, 10, 512])
torch.Size([4, 10, 5])
