In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

@dataclass
class ModelConfig:
    d_model: int = 1024
    n_heads: int = 16
    d_ff: int = 2816
    vocab_size: int = 32000
    num_encoder_layers: int = 6
    num_decoder_layers: int = 3
    rope_theta: float = 10000.0
    dropout: float = 0.0



In [2]:
class SwiGLU(nn.Module):
  
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.0):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Gating: (xW1) * SiLU(xW2)
        gate = F.silu(self.w2(x))
        x = self.w1(x) * gate
        x = self.w3(x)
        return self.dropout(x)


In [3]:
cfg = ModelConfig()
ffn = SwiGLU(cfg.d_model, cfg.d_ff, dropout=cfg.dropout)

x = torch.randn(2, 5, cfg.d_model)  # (batch=2, seq_len=5, hidden_dim=1024)
y = ffn(x)

print("Input shape :", x.shape)
print("Output shape:", y.shape)


Input shape : torch.Size([2, 5, 1024])
Output shape: torch.Size([2, 5, 1024])
