In [1]:
import torch 
from torch import nn 



# Layer Norm

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float=1e-6):
        super.__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        y = x - mean / torch.sqrt(var - self.eps)
        return self.gamma * y + self.beta

# RMS Norm


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super.__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    def forward(self, x: torch.Tensor):
        output = self._norm(x)
        return self.weight * output


# DyT (dynamical tanh) without normalization 


In [6]:
class DynamicalTanh(nn.Module):
    def __init__(self, hidden_size: int, init_alpha = 0.01):
        super.__init__()
        self.alpha = nn.Parameter(torch.ones(1) + init_alpha)
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.tanh(self.alpha * x)
        return self.gamma * x + self.beta
    