In [None]:
import os
import sys
import math
import torch
import torch.nn as nn

## SwiGLU

In [None]:
class SwiGLU(nn.Module):
    def __init__(self, name: str = "SwiGLU"):
        super(SwiGLU, self).__init__()

        self.name = name
        self.constant = 0.044715

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")

        swish = x * torch.sigmoid(x)
        gelu = 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / math.pi)) * (x + self.constant * torch.pow(x, 3))))
        return swish * gelu
    
if __name__ == "__main__":
    activation_func = SwiGLU()
    
    texts = torch.randn((64, 128, 512))

    assert (activation_func(texts).size()) == (64, 128, 512), "SwiGLU activation function is not working properly".capitalize()

## RMSNorm

In [None]:
class RMSNormalization(nn.Module):
    def __init__(self, dimension: int = 512, eps: float = 1e-4):
        super(RMSNormalization, self).__init__()

        self.dimension = dimension
        self.eps = eps

        self.gamma = nn.Parameter(
            data=torch.ones(
                (
                    self.dimension // self.dimension,
                    self.dimension // self.dimension,
                    self.dimension,
                )
            ),
            requires_grad=True,
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())

        RMS = torch.sqrt(torch.mean(x**2, dim=-1) + self.eps)
        RMS = RMS.unsqueeze(dim=-1)

        RMSNorm = x / RMS

        return torch.mul(RMSNorm, self.gamma)


if __name__ == "__main__":
    norm = RMSNormalization(dimension=512)

    assert (norm(torch.randn(64, 128, 512)).size()) == (
        64,
        128,
        512,
    ), "RMSNormalization is not working properly".capitalize()

## RoPE - Rotary Positional Encoding

In [None]:
class RoPE(nn.Module):
    def __init__(self, dimension: int = 512, sequence_length: int = 128, base: int = 10000):
        super(RoPE, self).__init__()
        
        self.dimension = dimension//2
        self.sequence_length = sequence_length
        self.base = base
        
        self.sin_values = torch.zeros((self.sequence_length, self.dimension))
        self.cos_values = torch.zeros((self.sequence_length, self.dimension))
        
        for position in range(self.sequence_length):
            for i in range(self.dimension):
                inverse_frequncy = 1.0 / (self.base ** (2 * (i // 2) / self.dimension))
                
                theta = position * inverse_frequncy
                
                self.sin_values[position, i] = math.sin(theta)
                self.cos_values[position, i] = math.cos(theta)
                
        self.register_buffer("sin", self.sin_values.unsqueeze(dim=0))
        self.register_buffer("cos", self.cos_values.unsqueeze(dim=0))
        
    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())
        
        x1 = x[..., 0::2]
        x2 = x[..., 1::2]
        
        sin = self.sin[:, :x.size(1), :]
        cos = self.cos[:, :x.size(1), :]

        rotated_even = x1 * cos - x2 * sin
        rotated_odd  = x1 * sin + x2 * cos

        
        output = torch.stack((rotated_even, rotated_odd), dim=-1)
        output = output.view(output.size(0), output.size(1), -1)
        
        return output
        
        
if __name__ == "__main__":
    encoding = RoPE(dimension=512, sequence_length=128)
    
    texts = torch.randn((64, 128, 512))
    
    assert (encoding(texts).size()) == (64, 128, 512), "RoPE is not working properly".capitalize()

## Grouped Query Attention Layer - GQA

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, dimension: int = 512, query_heads: int = 8, kv_heads: int = 4):
        super(GroupedQueryAttention, self).__init__()

        self.dimension = dimension
        self.query_heads = query_heads
        self.kv_heads = kv_heads

        assert (
            self.dimension % self.query_heads == 0
        ), "Dimension must be divisible by query heads".capitalize()
        assert (
            self.dimension % self.kv_heads == 0
        ), "Dimension must be divisible by kv heads".capitalize()

        self.head_dim = self.dimension // self.query_heads
        self.num_of_repeatation = self.query_heads // self.kv_heads

        self.query = nn.Linear(
            in_features=self.dimension,
            out_features=self.query_heads * self.head_dim,
            bias=False,
        )
        self.key = nn.Linear(
            in_features=self.dimension,
            out_features=self.kv_heads * self.head_dim,
            bias=False,
        )
        self.value = nn.Linear(
            in_features=self.dimension,
            out_features=self.kv_heads * self.head_dim,
            bias=False,
        )
        self.output = nn.Linear(
            in_features=self.dimension, out_features=self.dimension, bias=False
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        assert (
            key.size() == value.size()
        ), "Key and value must have the same size".capitalize()

        query = query.view(
            query.size(0),
            query.size(1),
            self.query_heads,
            query.size(-1) // self.query_heads,
        )
        key = key.view(
            key.size(0), key.size(1), self.kv_heads, key.size(-1) // self.kv_heads
        )
        value = value.view(
            value.size(0), value.size(1), self.kv_heads, value.size(-1) // self.kv_heads
        )

        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 1, 3)
        value = value.permute(0, 2, 1, 3)

        key = torch.repeat_interleave(input=key, repeats=self.num_of_repeatation, dim=1)
        value = torch.repeat_interleave(
            input=value, repeats=self.num_of_repeatation, dim=1
        )

        attention = torch.matmul(
            query, torch.transpose(input=key, dim0=-1, dim1=-2)
        ) / torch.sqrt(torch.tensor(self.head_dim))
        attention = torch.softmax(input=attention, dim=-1)

        attention = torch.matmul(input=attention, other=value)
        attention = torch.permute(input=attention, dims=(0, 2, 1, 3))

        attention = attention.reshape(
            attention.size(0), attention.size(1), attention.size(2) * attention.size(3)
        )

        attention = self.output(attention)

        return attention


if __name__ == "__main__":
    attention = GroupedQueryAttention()
    texts = torch.randn((64, 128, 512))

    output = attention(texts)

    assert (
        output.size() == (64, 128, 512)
    ), "GroupedQueryAttention is not working properly".capitalize()

## MLP - Feed Forward Neural Network

In [None]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(
        self,
        hidden_dimension: int = 4096,
        output_dimension: int = 14336,
        bias: bool = True,
    ):
        super(FeedForwardNeuralNetwork, self).__init__()

        self.hidden_dimension = hidden_dimension
        self.output_dimension = output_dimension
        self.bias = bias

        self.gate_projection = nn.Linear(
            in_features=self.hidden_dimension,
            out_features=self.output_dimension,
            bias=self.bias,
        )
        self.up_projection = nn.Linear(
            in_features=self.hidden_dimension,
            out_features=self.output_dimension,
            bias=self.bias,
        )
        self.down_projection = nn.Linear(
            in_features=self.output_dimension,
            out_features=self.hidden_dimension,
            bias=self.bias,
        )

        self.swish = nn.SiLU(inplace=True)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")

        gate_output = self.gate_projection(x)
        up_output = self.up_projection(x)
        up_output = self.swish(up_output)

        activation = torch.mul(input=gate_output, other=up_output)

        return self.down_projection(activation)


if __name__ == "__main__":
    network = FeedForwardNeuralNetwork(hidden_dimension=512, output_dimension=4 * 512)
    texts = torch.randn((64, 128, 512))
    print(network(texts).size())

    assert (
        network(texts).size() == (64, 128, 512)
    ), "FeedForwardNeuralNetwork is not working properly".capitalize()