In [1]:
import os
import sys
import math
import torch
import torch.nn as nn

## SwiGLU

In [2]:
class SwiGLU(nn.Module):
    def __init__(self, name: str = "SwiGLU"):
        super(SwiGLU, self).__init__()

        self.name = name
        self.constant = 0.044715

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")

        swish = x * torch.sigmoid(x)
        gelu = 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / math.pi)) * (x + self.constant * torch.pow(x, 3))))
        return swish * gelu
    
if __name__ == "__main__":
    activation_func = SwiGLU()
    
    texts = torch.randn((64, 128, 512))

    assert (activation_func(texts).size()) == (64, 128, 512), "SwiGLU activation function is not working properly".capitalize()

## RMSNorm

In [None]:
class RMSNormalization(nn.Module):
    def __init__(self, dimension: int = 512, eps: float = 1e-4):
        super(RMSNormalization, self).__init__()

        self.dimension = dimension
        self.eps = eps

        self.gamma = nn.Parameter(data=torch.ones((self.dimension // self.dimension,
                                  self.dimension // self.dimension, self.dimension)), requires_grad=True)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())

        RMS = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps)
        RMS = RMS.unsqueeze(dim=-1)
        
        RMSNorm = x / RMS
        
        return torch.mul(RMSNorm, self.gamma)
        


if __name__ == "__main__":
    norm = RMSNormalization(dimension=512)
    
    assert (norm(torch.randn(64, 128, 512)).size()) == (64, 128, 512), "RMSNormalization is not working properly".capitalize()
