In [1]:
import torch
from torch.nn import Module, Linear, Parameter, ReLU, GELU, SiLU

## Standard MLP Block

- basic up, down, and activation func
- x is in shape (bs, seqlen, hidden_size)
- use standard layer norm `((x-mean(x))/(var(x)+eps))*W +B`
- standard mlp applies layer norm after mlp
- skip connection before norm 

In [2]:

class LayerNorm(Module):
    def __init__(self, hidden_size, eps=1e-5):
        super().__init__()
        self.weight = Parameter(torch.rand(hidden_size))
        self.bias = Parameter(torch.rand(hidden_size))
        self.eps = eps
    
    def forward(self, x):
        x_mean = torch.mean(x, dim=1).unsqueeze(1)
        x_variance = torch.var(x + self.eps, dim=1).unsqueeze(1)
        norm_x = (x - x_mean) / (x_variance + self.eps)
        return self.weight * norm_x + self.bias
        

class StandardMLP(Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.up_proj = Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = Linear(intermediate_size, hidden_size, bias=False)
        self.act_fn = ReLU()
        self.ln = LayerNorm(hidden_size)
    
    def forward(self, x):
        up = self.act_fn(self.up_proj(x))
        down = self.down_proj(up)
        return self.ln(down + x) # residual connection
        

In [3]:
x = torch.rand(4, 32, 128)
mlp = StandardMLP(128, 256)

In [4]:
mlp(x)

tensor([[[ 0.3871,  3.4959, -1.2076,  ..., -1.6955,  0.8514,  0.0924],
         [ 0.1074,  1.6093, -0.3789,  ...,  3.3887,  0.7555,  0.5051],
         [ 0.6063, -3.2555,  0.6370,  ..., -0.2145,  0.7132,  0.4433],
         ...,
         [ 0.5588, -1.0086, -0.5843,  ...,  2.7614,  0.6521,  0.6618],
         [ 0.2326, -1.6392,  1.5063,  ..., -1.8309,  0.7629, -0.0828],
         [ 0.2114,  2.4136, -1.5328,  ...,  0.3696,  0.7693,  0.3181]],

        [[ 0.4951, -0.0211, -2.1303,  ...,  2.4653,  0.6135,  0.1638],
         [ 0.6043, -1.3277,  0.0066,  ...,  4.1590,  0.6071, -0.1181],
         [ 0.4841,  2.2919,  0.7837,  ..., -0.1330,  0.6588,  0.4361],
         ...,
         [ 0.3120, -3.3606,  2.1393,  ...,  0.6211,  0.5291, -0.0717],
         [ 0.3507,  2.5743,  1.8553,  ...,  1.9771,  0.5814,  0.5648],
         [ 0.5014,  0.6742,  1.2610,  ...,  4.3965,  0.6582,  0.0211]],

        [[ 0.2781,  2.0835, -1.7272,  ...,  3.2816,  0.6353,  0.2760],
         [ 0.1631,  2.9376, -1.9473,  ..., -1

## GPT MLP Block

- intermediate size is fixed (4 * hidden_size)
- use GELU as act_fn
- standard mlp applies layer norm before mlp

In [5]:
class GPTMLP(StandardMLP):
    def __init__(self, hidden_size):
        super().__init__(hidden_size, 4 * hidden_size)
        self.act_fn = GELU()
    
    def forward(self, x):
        up = self.act_fn(self.up_proj(self.ln(x)))
        down = self.down_proj(up)
        return x + down
        


In [6]:
x = torch.rand(4, 32, 128)
mlp = GPTMLP(128)

In [7]:
mlp(x).shape

torch.Size([4, 32, 128])

## LLaMa MLP Block

- use RMSNorm (`x/(mean(x^2)+eps) * W`) instead of LN
- use SiLU instead of GELU
- smaller intermediate size than 4h
- norm before mlp

In [8]:
class RMSNorm(Module):
    def __init__(self, hidden_size, eps=1e-5):
        super().__init__()
        self.weight = Parameter(torch.rand(hidden_size))
        self.eps = eps
    
    def forward(self, x):
        rms_x = torch.mean(x ** 2, dim=1).unsqueeze(1)
        return self.weight * x/(rms_x+self.eps)

In [9]:
class LLaMaMLP(Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.up_proj = Linear(hidden_size, intermediate_size, bias=False)
        self.gate_proj = Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = Linear(intermediate_size, hidden_size, bias=False)
        self.rms_norm = RMSNorm(hidden_size)
        self.act_fn = SiLU()
    
    def forward(self, x):
        gated = self.act_fn(self.gate_proj(self.rms_norm(x)))
        up = gated * self.up_proj(x)
        down = self.down_proj(up)
        return down + x
        

In [10]:
x = torch.rand(4, 32, 128)
mlp = LLaMaMLP(128, 342)

In [11]:
mlp(x)

tensor([[[ 0.7474,  0.8281,  0.5100,  ...,  0.2617,  0.9431,  0.0669],
         [ 0.7033,  0.3637,  0.8419,  ...,  0.2813,  0.1584,  0.3039],
         [ 0.4184,  0.7633,  0.2466,  ...,  0.7423,  0.4674,  0.4228],
         ...,
         [ 0.5765,  0.8865,  0.5871,  ...,  0.6940,  0.5202,  0.9704],
         [ 0.5382,  0.4332,  0.1566,  ...,  0.6802,  0.1593,  0.2337],
         [ 0.6232,  0.0181,  0.4253,  ...,  0.6246,  0.7077,  0.6082]],

        [[ 0.8238,  0.1485,  0.0760,  ...,  0.6720,  0.4226,  0.8507],
         [-0.0671,  0.6248,  0.1084,  ...,  0.1321,  0.1907,  0.3431],
         [ 0.1610,  0.3132,  0.4587,  ...,  0.6930,  0.8047,  0.2974],
         ...,
         [ 0.4569,  0.9068,  0.2996,  ...,  0.0283,  0.4840,  0.4164],
         [ 0.4503,  0.0296,  0.7272,  ...,  0.6038,  0.6601,  0.7697],
         [ 0.7394,  0.3530,  0.9824,  ...,  0.3131,  0.2823,  0.4291]],

        [[ 0.5240,  0.7660,  0.7887,  ...,  0.7162,  0.0810,  0.3355],
         [ 0.6720,  0.2293,  0.8425,  ...,  0