In [1]:
import torch
import torch.nn as nn

In [2]:
import torch
import torch.nn as nn

# Multi-Head Self-Attention Layer
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.dense = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        # Split the d_model into num_heads
        x = x.view(x.size(0), -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, q, k, v, mask=None):
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.depth ** 0.5)
        if mask is not None:
            attn_scores += mask
        
        attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, v)
        
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(attn_output.size(0), -1, self.d_model)
        attn_output = self.dense(attn_output)
        
        return attn_output, attn_probs

# Position-wise Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        # Multi-Head Self-Attention
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        
        # Position-wise Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)
        
        return x

# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size, max_sequence_length, dropout):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_encoding = self.positional_encoding(max_sequence_length, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
    
    def forward(self, x, mask):
        x = self.embedding(x) + self.pos_encoding
        x = self.dropout(x)
        
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        
        return x
    
    def positional_encoding(self, max_len, d_model):
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pos_encoding = torch.zeros(1, max_len, d_model)
        pos_encoding[:, :, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, :, 1::2] = torch.cos(position * div_term)
        return pos_encoding

# Example usage of the Transformer model
def main():
    # Hyperparameters
    num_layers = 6
    d_model = 512
    num_heads = 8
    d_ff = 2048
    input_vocab_size = 10000
    max_sequence_length = 100
    dropout = 0.1

    # Create a Transformer Encoder
    encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, input_vocab_size, max_sequence_length, dropout)

    # Dummy input data (batch_size=32, sequence_length=100)
    input_data = torch.randint(0, input_vocab_size, (32, 100))

    # Dummy mask (assuming all elements are valid)
    mask = torch.ones((32, 1, 1, 100))

    # Forward pass
    output = encoder(input_data, mask)

    print("Output shape:", output.shape)

if __name__ == "__main__":
    main()

Output shape: torch.Size([32, 100, 512])
