In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

# Transformer Encoder Definition
class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # Embedding layer to convert input tokens into numerical vectors
        self.embedding = nn.Embedding(input_dim, embed_dim)
        
        # Define a single Transformer Encoder layer with batch_first=True for better performance
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,  # Dimension of embedding vector
            nhead=num_heads,  # Number of attention heads (multi-head attention)
            dim_feedforward=ff_dim,  # Hidden layer size in Feedforward Network (FFN)
            dropout=dropout,  # Dropout to prevent overfitting
            batch_first=True  # Ensure batch is the first dimension for efficient processing
        )
        
        # Transformer Encoder with multiple layers
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Linear layer to process the final encoder output
        self.fc = nn.Linear(embed_dim, input_dim)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # Convert input tokens into embeddings
        src = self.embedding(src)
        
        # Pass through the Transformer Encoder
        src = self.transformer_encoder(src)
        
        # Pass through fully connected layer to produce final output
        output = self.fc(src)
        
        return output

# Hyperparameters
input_dim = 1000  # Vocabulary size
embed_dim = 64    # Embedding dimension
num_heads = 8     # Number of attention heads in multi-head attention
ff_dim = 256      # Hidden layer size in Feedforward Network (FFN)
num_layers = 3    # Number of encoder layers

# Initialize model
model = TransformerEncoder(input_dim, embed_dim, num_heads, ff_dim, num_layers)

# Sample input with batch_first=True (batch_size=1, sequence_length=5)
sample_input = torch.randint(0, input_dim, (1, 5))  # (batch_size=1, sequence_length=5)
output = model(sample_input)  # Perform forward pass

# Display results
print("Input:", sample_input)
print("Output shape:", output.shape)  # Expected shape: (batch_size, sequence_length, input_dim)


Input: tensor([[263, 395, 755, 384,  59]])
Output shape: torch.Size([1, 5, 1000])
