In [3]:
# GLU Variants FFN Output Statistics
# Implements six FFN variants and prints output stats using PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

# Sample input tensor
x = torch.randn(2, 768)  # batch_size=2, d_model=768

# Utility function to compute output statistics
def output_stats(model, x):
    with torch.no_grad():
        out = model(x)
    return {
        "mean": out.mean().item(),
        "std": out.std().item(),
        "min": out.min().item(),
        "max": out.max().item()
    }


In [4]:
# 1. Baseline FFN with ReLU
class FFNReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(768, 3072)
        self.w2 = nn.Linear(3072, 768)

    def forward(self, x):
        return self.w2(F.relu(self.w1(x)))


In [5]:
# 2. FFN with GELU
class FFNGELU(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(768, 3072)
        self.w2 = nn.Linear(3072, 768)

    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)))


In [6]:
# 3. FFN with GLU
class FFNGLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Linear(768, 2048)
        self.v = nn.Linear(768, 2048)
        self.w2 = nn.Linear(2048, 768)

    def forward(self, x):
        gate = self.w(x)
        value = self.v(x)
        return self.w2(gate * value)


In [7]:
# 4. FFN with Bilinear
class FFNBilinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Linear(768, 2048)
        self.v = nn.Linear(768, 2048)
        self.w2 = nn.Linear(2048, 768)

    def forward(self, x):
        return self.w2(self.w(x) * self.v(x))

In [8]:
# 5. FFN with GEGLU
class FFNGEGLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Linear(768, 2048)
        self.v = nn.Linear(768, 2048)
        self.w2 = nn.Linear(2048, 768)

    def forward(self, x):
        gate = F.gelu(self.w(x))
        value = self.v(x)
        return self.w2(gate * value)


In [9]:
# 6. FFN with SwiGLU
class FFNSwiGLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Linear(768, 2048)
        self.v = nn.Linear(768, 2048)
        self.w2 = nn.Linear(2048, 768)

    def swish(self, x):
        return x * torch.sigmoid(x)

    def forward(self, x):
        gate = self.swish(self.w(x))
        value = self.v(x)
        return self.w2(gate * value)


In [12]:
# Instantiate models
models = {
    "FFNReLU": FFNReLU(),
    "FFNGELU": FFNGELU(),
    "FFNGLU": FFNGLU(),
    "FFNBilinear": FFNBilinear(),
    "FFNGEGLU": FFNGEGLU(),
    "FFNSwiGLU": FFNSwiGLU()
}
# Print output statistics
for name, model in models.items():
    stats = output_stats(model, x)
    print(f"{name} Output Statistics:")
    for k, v in stats.items():
        print(f"  {k}: {v:.4f}")
    print()


FFNReLU Output Statistics:
  mean: 0.0098
  std: 0.2458
  min: -0.7638
  max: 0.7390

FFNGELU Output Statistics:
  mean: -0.0058
  std: 0.2092
  min: -0.6795
  max: 0.5599

FFNGLU Output Statistics:
  mean: -0.0003
  std: 0.2020
  min: -0.7058
  max: 0.8357

FFNBilinear Output Statistics:
  mean: 0.0012
  std: 0.1969
  min: -0.6839
  max: 0.6964

FFNGEGLU Output Statistics:
  mean: 0.0018
  std: 0.1248
  min: -0.4356
  max: 0.3517

FFNSwiGLU Output Statistics:
  mean: 0.0002
  std: 0.1018
  min: -0.3763
  max: 0.3306

