In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

# Input/Output Embedding

<img src="img/embedding_softmax.png" height="150px">

## Smoke Test

In [5]:
tokens_dictionary_size = 1000
d_model = 512
embedding = nn.Embedding(num_embeddings=tokens_dictionary_size, embedding_dim=d_model)

batch_size = 128
sequence_length = 20

input = torch.randint(low=0, high=tokens_dictionary_size, size=(batch_size, sequence_length))

print(f"Batch size: {batch_size}")
print(f"Sequence length: {sequence_length}")
print(f"Embedding dimension: {d_model}")

print(f"Input shape: {input.shape}")

print(f"Embedding output shape: {embedding(input).shape}")

Batch size: 128
Sequence length: 20
Embedding dimension: 512
Input shape: torch.Size([128, 20])
Embedding output shape: torch.Size([128, 20, 512])


# Positional Encoding

<img src="img/positional_encoding.png" height="150px">

$pos \in [0,seq_length - 1]$  
$i \in [0, d_{model - 1}]$ (i is the dimensions of the embedding)  
Note that my implementation uses logarithmic for numerical stability purposes.

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_length: int, d_model: int):
        super().__init__()
        self.positional_encoding_matrix = self.compute_positional_encoding_matrix(max_seq_length=max_seq_length, d_model=d_model)
    
    def compute_positional_encoding_matrix(self, max_seq_length: int, d_model: int) -> torch.Tensor:
        """Computes the positional encoding matrix of shape (max_seq_length, d_model).

        Args:
            max_seq_length (int): Maximum sequence length
            d_model (int): Embedding dimension

        Returns:
            torch.Tensor: Positional encoding matrix of shape (1, max_seq_length, d_model)
        """
        positions = torch.arange(start=0, end=max_seq_length).unsqueeze(1)
        
        even_dimensions = torch.arange(start=0, end=d_model, step=2)
        division_term = torch.exp(even_dimensions * (-math.log(10000.0) / d_model))
        
        positional_encoding_matrix = torch.zeros(1, max_seq_length, d_model)
        positional_encoding_matrix[0, :, 0::2] = torch.sin(positions * division_term)
        positional_encoding_matrix[0, :, 1::2] = torch.cos(positions * division_term)
        
        return positional_encoding_matrix
    
    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        """Forward pass of the PositionalEncoding module.

        Args:
            embedding (torch.Tensor): Embedding tensor of shape (batch_size, sequence_length, d_model)

        Returns:
            torch.Tensor: Positional encoding tensor of shape (batch_size, sequence_length, d_model)
        """
        return embedding + self.positional_encoding_matrix[:, :embedding.shape[1], :]

## Smoke Test

In [8]:
sequence_length = 50

embedding = torch.randn(size=(batch_size, sequence_length, d_model))
print(f"embedding shape: {embedding.shape}")

max_seq_length = 1000
d_model = 512

positional_encoding = PositionalEncoding(max_seq_length=max_seq_length, d_model=d_model)
print(f"positional_encoding.shape: {positional_encoding.forward(embedding).shape}")

embedding shape: torch.Size([128, 50, 512])
positional_encoding.shape: torch.Size([128, 50, 512])


# Scaled Dot-Product Attention

<img src="img/scaled_dot_product_attention.png" height="300px">

In [24]:
class ScaledDotProductAttention(nn.Module):
    def forward(self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        """Forward pass of the scaled dot product attention.

        Args:
            queries (torch.Tensor): Queries of shape (batch_size, sequence_length, d_k)
            keys (torch.Tensor): Keys of shape (batch_size, sequence_length, d_k)
            values (torch.Tensor): Values of shape (batch_size, sequence_length, d_v)

        Returns:
            torch.Tensor: Outputs of the scaled dot product attention of shape (batch_size, sequence_length, d_v)
        """
        d_k = queries.shape[-1]
        dot_product = torch.bmm(queries, keys.transpose(1, 2))
        scaled_dot_product = dot_product / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        queries_keys_similarity_probabilities = F.softmax(scaled_dot_product, dim=1)
        return torch.bmm(queries_keys_similarity_probabilities, values)

## Smoke Test

In [26]:
scaled_dot_product_attention = ScaledDotProductAttention()

batch_size = 2

sequence_length = 20

d_k = 4

d_v = 5

queries = torch.randn(batch_size, sequence_length, d_k)

keys = torch.randn(batch_size, sequence_length, d_k)

values = torch.randn(batch_size, sequence_length, d_v)

print(f"queries.shape: {queries.shape}")
print(f"keys.shape: {keys.shape}")
print(f"values.shape: {values.shape}")

print(f"scaled_dot_product_attention shape : {scaled_dot_product_attention(queries, keys, values).shape}")

queries.shape: torch.Size([2, 20, 4])
keys.shape: torch.Size([2, 20, 4])
values.shape: torch.Size([2, 20, 5])
scaled_dot_product_attention shape : torch.Size([2, 20, 5])


# Multi-Head Attention

<img src="img/multi_head_attention.png" height="300px">

In [35]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, d_k: int, d_v: int, d_model: int):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        assert d_k == d_model // num_heads, "d_k must be equal to d_model // num_heads"
        assert d_v == d_model // num_heads, "d_v must be equal to d_model // num_heads"
        
        self.queries_projections = nn.ModuleList([nn.Linear(in_features=d_model, out_features=d_k, bias=False) for _ in range(num_heads)])
        self.keys_projections = nn.ModuleList([nn.Linear(in_features=d_model, out_features=d_k, bias=False) for _ in range(num_heads)])
        self.values_projections = nn.ModuleList([nn.Linear(in_features=d_model, out_features=d_v, bias=False) for _ in range(num_heads)])
        self.attentions = nn.ModuleList([ScaledDotProductAttention() for _ in range(num_heads)])
        self.multi_head_linear = nn.Linear(in_features=num_heads*d_v, out_features=d_model)
    
    def forward(self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        """Forward pass of the multi-head attention.

        Args:
            queries (torch.Tensor): Queries of shape (batch_size, sequence_length, d_model)
            keys (torch.Tensor): Keys of shape (batch_size, sequence_length, d_model)
            values (torch.Tensor): Values of shape (batch_size, sequence_length, d_model)

        Returns:
            torch.Tensor: Outputs of the multi-head attention of shape (batch_size, sequence_length, d_model)
        """

        heads = [attention(self.queries_projections[i](queries), self.keys_projections[i](keys), self.values_projections[i](values)) for i, attention in enumerate(self.attentions)]
        
        heads_concatenated = torch.cat(heads, dim=2)
        
        return self.multi_head_linear(heads_concatenated)

## Smoke Test

In [38]:
batch_size = 2
num_heads = 8
sequence_length = 20
d_k = 64
d_v = 64
d_model = 512

multi_head_attention = MultiHeadAttention(num_heads=num_heads, d_k=d_k, d_v=d_v, d_model=d_model)

queries = torch.randn(batch_size, sequence_length, d_model)

keys = torch.randn(batch_size, sequence_length, d_model)

values = torch.randn(batch_size, sequence_length, d_model)

print(f"multi head attention output shape: {multi_head_attention(queries, keys, values).shape}")

multi head attention output shape: torch.Size([2, 20, 512])


# Feed Forward

<img src="img/feed_forward.png" height="250px">

In [39]:
class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(self, d_model: int, d_ff: int = 2048):
        super().__init__()
        self.linear_1 = nn.Linear(in_features=d_model, out_features=d_ff, bias=True)
        self.linear_2 = nn.Linear(in_features=d_ff, out_features=d_model, bias=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the position-wise feed forward network.

        Args:
            x (torch.Tensor): Input of shape (batch_size, sequence_length, d_model)

        Returns:
            torch.Tensor: Output of shape (batch_size, sequence_length, d_model)
        """
        return self.linear_2(F.relu(self.linear_1(x)))

## Smoke Test

In [40]:
d_model = 512

position_wise_feed_forward_network = PositionWiseFeedForwardNetwork(d_model=d_model, d_ff=2048)

batch_size = 2
sequence_length = 3

x = torch.randn(batch_size, sequence_length, d_model)

print(f"x.shape: {x.shape}")
print(f"position wise feed forward network output shape: {position_wise_feed_forward_network(x).shape}")

x.shape: torch.Size([2, 3, 512])
position wise feed forward network output shape: torch.Size([2, 3, 512])


# Transformer Network

<img src="img/transformer.png" height="500px">

In [41]:
class EncoderLayer(nn.Module):
    def __init__(self, num_heads: int, d_k: int, d_v: int, d_model: int):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(num_heads=num_heads, d_k=d_k, d_v=d_v, d_model=d_model)
        self.layer_normalization_1 = nn.LayerNorm(normalized_shape=d_model)
        self.position_wise_feed_forward_network = PositionWiseFeedForwardNetwork(d_model=d_model, d_ff=2048)
        self.layer_normalization_2 = nn.LayerNorm(normalized_shape=d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the encoder layer.

        Args:
            x (torch.Tensor): Input of shape (batch_size, sequence_length, d_model)

        Returns:
            torch.Tensor: Output of shape (batch_size, sequence_length, d_model)
        """
        x = x + self.multi_head_attention(x, x, x)
        x = self.layer_normalization_1(x)
        x = x + self.position_wise_feed_forward_network(x)
        x = self.layer_normalization_2(x)
        return x

In [None]:
# TODO
class DecoderLayer(nn.Module):
    pass

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        tokens_dictionary_size: int,
        max_sequence_length: int,
        num_heads: int,
        sequence_length: int,
        d_model: int = 512,
        num_layers: int = 6
    ):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=tokens_dictionary_size, embedding_dim=d_model)
        self.positional_encoding = PositionalEncoding(max_seq_length=max_sequence_length, d_model=d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(num_heads=num_heads, d_k=sequence_length, d_v=sequence_length, d_model=d_model) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(h=num_heads, d_k=sequence_length, d_v=sequence_length, d_model=d_model) for _ in range(num_layers)])
        self.output_linear = nn.Linear(in_features=d_model, out_features=tokens_dictionary_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the transformer.

        Args:
            x (torch.Tensor): Input of shape (batch_size, sequence_length)

        Returns:
            torch.Tensor: Output of shape TODO
        """
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
        # TODO : Decoder part
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(x)
        
        x = self.output_linear(x)
        
        x = F.softmax(x, dim=-1)
        
        return x