# Playground for FeedForward Layer

## GELU Activation
* Use Gaussian Error Linear Unit (GeLU) instead of ReLU
    * More Smooth than ReLU für better performance

In [22]:
import torch
import torch.nn as nn

torch.manual_seed(42)

class GELU(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x,3))
        ))


In [23]:
import matplotlib.pyplot as plt

def plot_GELU_and_RELU():

    x = torch.linspace(-3, 3, 100) # sample data

    gelu = GELU()
    relu = nn.ReLU()

    y_gelu = gelu(x)
    y_relu = relu(x)

    plt.figure(figsize=(8, 3))
    for i, (y, label) in enumerate(zip([y_gelu, y_relu], ["GELU", "ReLU"]), 1):
        plt.subplot(1, 2, i)
        plt.plot(x, y)
        plt.title(f"{label} activation function")
        plt.xlabel("x")
        plt.ylabel(f"{label}(x)")
        plt.grid(True)

    plt.tight_layout()
    plt.show()

# _test_plot = plot_GELU_and_RELU()

## FeedForward Layer

In [24]:
class FeedForward(nn.Module):

    def __init__(self, emb_dim, verbose=False):
        super().__init__()
        
        hidden_dim = 4 * emb_dim # some common convention

        self.layers = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            GELU(),
            nn.Linear(hidden_dim, emb_dim),
        )

        if verbose:
            print(f"\n=== FeedForward Initialization ===")
            print(f"    Input and output dimensions = ", emb_dim)
            print(f"    Hidden dimension = ",hidden_dim)        
            print(f"=== End FeedForward Initialization ===\n")

    def forward(self, x, verbose=False):
        return self.layers(x)

## Test Run

In [25]:
def test_feedForward(verbose = False):

    embbed_dim = 6
    print(f'Embbed_dim: ', embbed_dim)

    ffn = FeedForward(embbed_dim, verbose=verbose)

    x = torch.rand(2, 3, embbed_dim) # 2 batches, 3 context_length, embed_dim
    print("Sample data: ", x)

    out = ffn(x, verbose=verbose)

    print("\nOutput shape ", out.shape)
    print("Output data ", out )


# _test_run = test_feedForward(True)