<a href="https://colab.research.google.com/github/Papa-Panda/llm/blob/main/Activation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://gemini.google.com/app/90bd61d394533684

# 为什么现在的大模型要高精度跑GeLU或SwiGLU，而不是改回ReLU跑低精度？ - 伊斯特伍德的回答 - 知乎
# https://www.zhihu.com/question/15527003900/answer/129423629478

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        # This linear layer effectively creates two sets of outputs:
        # one for the main branch and one for the gating branch.
        self.linear = nn.Linear(input_dim, 2 * output_dim)
        self.output_dim = output_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Step 1: Apply the combined linear transformation
        # The output tensor 'gate_and_main' will have features for both
        # the gating and the main branches.
        # Shape: (batch_size, sequence_length, 2 * output_dim)
        gate_and_main = self.linear(x)

        # Step 2: Split the tensor into two halves along the last dimension
        # (batch_size, sequence_length, output_dim) for each part
        gate, main = torch.split(gate_and_main, self.output_dim, dim=-1)

        # Step 3: Apply Swish (SiLU) activation to the gate branch
        swished_gate = F.silu(gate) # F.silu is PyTorch's Swish/SiLU

        # Step 4: Element-wise multiply the swished gate with the main branch
        output = swished_gate * main

        return output

# --- Example Usage (PyTorch) ---
if __name__ == "__main__":
    # Define input dimensions
    input_dim = 768  # Common embedding dimension in LLMs
    output_dim = 2048 # Common hidden dimension size

    # Create an instance of the SwiGLU layer
    swiglu_layer = SwiGLU(input_dim, output_dim)
    print(f"PyTorch SwiGLU layer: {swiglu_layer}")

    # Create a dummy input tensor
    # batch_size = 2, sequence_length = 10, input_dim = 768
    dummy_input = torch.randn(2, 10, input_dim)
    print(f"PyTorch Dummy input shape: {dummy_input.shape}")

    # Pass the input through the SwiGLU layer
    output = swiglu_layer(dummy_input)
    print(f"PyTorch Output shape: {output.shape}")

    # Verify output dimensions
    assert output.shape == (2, 10, output_dim), "PyTorch Output shape mismatch!"
    print("PyTorch SwiGLU implementation test passed!")

    # --- Another common pattern in LLMs: applying a final linear layer after SwiGLU ---
    # This is often done in the feed-forward network (FFN) block.
    # The output of SwiGLU (e.g., 2048 features) is then projected back
    # to the original embedding dimension (e.g., 768 features).
    final_projection_layer = nn.Linear(output_dim, input_dim)
    final_output = final_projection_layer(output)
    print(f"PyTorch Final output after projection shape: {final_output.shape}")
    assert final_output.shape == (2, 10, input_dim), "PyTorch Final output shape mismatch!"
    print("PyTorch Full FFN-like block (SwiGLU + projection) test passed!")

PyTorch SwiGLU layer: SwiGLU(
  (linear): Linear(in_features=768, out_features=4096, bias=True)
)
PyTorch Dummy input shape: torch.Size([2, 10, 768])
PyTorch Output shape: torch.Size([2, 10, 2048])
PyTorch SwiGLU implementation test passed!
PyTorch Final output after projection shape: torch.Size([2, 10, 768])
PyTorch Full FFN-like block (SwiGLU + projection) test passed!
