In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_hidden_dim, dropout=0.1):
        """
        Position-wise feed-forward network.
        
        Args:
            embed_dim: Input and output dimension
            ff_hidden_dim: Hidden layer dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_hidden_dim)
        self.linear2 = nn.Linear(ff_hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Forward pass of feed-forward network.
        
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            
        Returns:
            output: Output tensor [batch_size, seq_len, embed_dim]
        """
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [7]:
# Example
batch_size = 4
seq_len = 10
embed_dim = 512
ff_hidden_dim = 2048

In [8]:
# random input tensor with shape [batch_size, seq_len, embed_dim]
x = torch.randn(batch_size, seq_len, embed_dim)

feed_forward = FeedForward(embed_dim=embed_dim, ff_hidden_dim=ff_hidden_dim, dropout=0.1)
output = feed_forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([4, 10, 512])
Output shape: torch.Size([4, 10, 512])
