In [1]:
import math

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]

In [3]:
pe = PositionalEncoding(10)
src = torch.randint(0, 5, (2, 2, 10))  # Random token IDs as input
pe(src).shape

torch.Size([2, 2, 10])

In [4]:
class EmbeddingProjectionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EmbeddingProjectionModel, self).__init__()
        self.projection = nn.Linear(input_dim, output_dim)  # Project Z to D
    
    def forward(self, x):
        # Apply the linear transformation to the input tensor
        return self.projection(x)  # Output shape: (B, S, D)

In [5]:
class SimpleTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        nhead,
        num_encoder_layers,
        dim_feedforward,
        max_len=5000,
    ):
        super(SimpleTransformer, self).__init__()

        # Word Embeddings Layer
        self.embedding = EmbeddingProjectionModel(3, d_model)

        # Positional Encoding
        self.positional_encoding = PositionalEncoding(d_model, max_len)

        # Transformer Encoder Layer
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )

        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_encoder_layer, num_layers=num_encoder_layers
        )

        # Output Layer: For Language Modeling (next token prediction)
        self.output_layer = nn.Linear(d_model, 1)

    def forward(self, src, attention_mask):
        # src shape: (batch_size, seq_len)

        # Embed the input
        src = self.embedding(src)  # (batch_size, seq_len, d_model)

        # Add positional encoding
        src = self.positional_encoding(src)  # (batch_size, seq_len, d_model)

        # Transform the input using Transformer Encoder
        memory = self.transformer_encoder(
            src.transpose(0, 1), src_key_padding_mask=attention_mask.transpose(0, 1)
        )  # (seq_len, batch_size, d_model)

        # Final output layer
        output = self.output_layer(
            memory.transpose(0, 1)
        )  # (batch_size, seq_len, vocab_size)

        return output

In [None]:
# Example usage
vocab_size = 1_000  # Size of vocabulary
d_model = 4  # Embedding dimension (also used in Transformer)
nhead = 2  # Number of attention heads
num_encoder_layers = 6  # Number of layers in the Transformer Encoder
dim_feedforward = 2048  # Feedforward layer dimension
max_len = 100  # Max sequence length
N = 3

# Initialize the model
model = SimpleTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len
)

# Example input: a batch of tokenized sequences of shape (batch_size, seq_len)
batch_size = 2
seq_len = 2  # Length of each input sequence
src = torch.randint(0, vocab_size, (batch_size, seq_len, N)).to(
    torch.float32
)  # Random token IDs as input

print("src", src.shape)
attention_mask = torch.full((batch_size, seq_len), True)
# Forward pass through the model
output = model(src, attention_mask)
print(output)  # Expected shape: (batch_size, seq_len, vocab_size)

src torch.Size([2, 2, 3])
attention_mask torch.Size([2, 2])
tensor([[[0.8675],
         [0.9520]],

        [[0.9310],
         [0.8122]]], grad_fn=<AddBackward0>)
