In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F


### The Positional Encoding

Contrary to RNNs, all words in a sequence are input in parallel. This means that the model needs to know the position *where* the words are in the input sequence.

In [21]:
class PositionalEncoding(nn.Module):
    """
        - max_len: Maximum length of the sequence.
        - embed_size: Dimensionality of the positional encoding vectors.
    """
    
    def __init__(self, max_len, embed_size):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        # A tensor representing the exponential term used in the positional encoding formula
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()


In [22]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size)
        )

    def forward(self, x, mask):
        attention_output, _ = self.attention(x, x, x, mask)
        x = self.norm1(x + attention_output)
        feed_forward_output = self.feed_forward(x)
        out = self.norm2(x + feed_forward_output)
        return out


In [23]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_size, heads, num_layers):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(max_len=1000, embed_size=embed_size)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_size, heads) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        
        mask = self.generate_square_subsequent_mask(x.size(1)).to(x.device)
        
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, mask)

        x = self.fc_out(x)
        return x

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


In [25]:
# Example usage:
vocab_size = 1000  # adjust based on your vocabulary size
embed_size = 512
heads = 8
num_layers = 6

model = Transformer(vocab_size, embed_size, heads, num_layers)

# Create a random input tensor for testing
input_tensor = torch.randint(0, vocab_size, (10, 10))  # Batch size of 32, sequence length of 10

# Forward pass
output = model(input_tensor)
print(output.shape)


torch.Size([10, 10, 1000])
