Model Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass


In [2]:
#config dataclass for model parameters
@dataclass
class ModelConfig:
    d_model: int = 1024
    n_heads: int = 16
    d_ff: int = 2816
    vocab_size: int = 32000
    num_encoder_layers: int = 6
    num_decoder_layers: int = 3
    rope_theta: float = 10000.0
    dropout: float = 0.0


In [3]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        x_norm = x * rms
        return self.weight * x_norm


In [4]:
cfg = ModelConfig()
norm = RMSNorm(cfg.d_model)

x = torch.randn(2, 5, cfg.d_model)
y = norm(x)

print("Input shape :", x.shape)
print("Output shape:", y.shape)


Input shape : torch.Size([2, 5, 1024])
Output shape: torch.Size([2, 5, 1024])
