In [1]:
import torch
import torch.nn as nn
import math

class InputEmbedding(nn.Module):

    def __init__(self,d_model:int, vocab_size:int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self,x):
        return self.embedding(x) * math.sqrt(self.d_model)
    

In [2]:
input_embedding = InputEmbedding(d_model=512, vocab_size=10000)
input_embedding

InputEmbedding(
  (embedding): Embedding(10000, 512)
)

In [3]:
x = torch.randint(0, 10000, (1, 20))  # batch size of 32, sequence length of 20
x[0:10]

tensor([[5761, 5436, 2743, 5380,  677, 1109, 6021, 7436, 8929, 3170, 1531, 7865,
         2056, 1050, 7356,  341, 4089, 8091, 2500, 1220]])

In [4]:
input_embedding(x).shape

torch.Size([1, 20, 512])

In [5]:
position = torch.arange(0, 5, dtype=torch.float)
print(position)
print(position.shape)

tensor([0., 1., 2., 3., 4.])
torch.Size([5])


In [6]:
position = torch.arange(0, 5, dtype=torch.float).unsqueeze(1) 
print(position)
print(position.shape)

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.]])
torch.Size([5, 1])


In [7]:
torch.arange(0,5,2).float()

tensor([0., 2., 4.])

In [8]:
(-math.log(10000.0) / 5)

-1.8420680743952367

In [9]:
torch.arange(0,5,2).float() * (-math.log(10000.0) / 5)

tensor([-0.0000, -3.6841, -7.3683])

In [10]:
torch.exp(torch.arange(0,5,2).float() * (-math.log(10000.0) / 5))

tensor([1.0000e+00, 2.5119e-02, 6.3096e-04])

In [11]:
div_term = torch.exp(torch.arange(0,5,2).float() * (-math.log(10000.0) / 5))
print(div_term)

tensor([1.0000e+00, 2.5119e-02, 6.3096e-04])


In [12]:
seq_len = 6
d_model = 5

pe = torch.zeros(seq_len, d_model)
        
## create position tensor of shape (seq_len, 1)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) 

div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0) / d_model))

## apply sin to even indices in the array; 2i
pe[:, 0::2] = torch.sin(position * div_term)
print("before apply cos",pe)

# pe[:, 1::2] = torch.cos(position * div_term)
## cos for odd indices (IMPORTANT FIX)
pe[:, 1::2] = torch.cos(position * div_term[:pe[:, 1::2].shape[1]])
print("after apply cos",pe)

before apply cos tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 8.4147e-01,  0.0000e+00,  2.5116e-02,  0.0000e+00,  6.3096e-04],
        [ 9.0930e-01,  0.0000e+00,  5.0217e-02,  0.0000e+00,  1.2619e-03],
        [ 1.4112e-01,  0.0000e+00,  7.5285e-02,  0.0000e+00,  1.8929e-03],
        [-7.5680e-01,  0.0000e+00,  1.0031e-01,  0.0000e+00,  2.5238e-03],
        [-9.5892e-01,  0.0000e+00,  1.2526e-01,  0.0000e+00,  3.1548e-03]])
after apply cos tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  2.5116e-02,  9.9968e-01,  6.3096e-04],
        [ 9.0930e-01, -4.1615e-01,  5.0217e-02,  9.9874e-01,  1.2619e-03],
        [ 1.4112e-01, -9.8999e-01,  7.5285e-02,  9.9716e-01,  1.8929e-03],
        [-7.5680e-01, -6.5364e-01,  1.0031e-01,  9.9496e-01,  2.5238e-03],
        [-9.5892e-01,  2.8366e-01,  1.2526e-01,  9.9212e-01,  3.1548e-03]])


In [13]:
import plotly.graph_objects as go
import numpy as np

def visualize_ffn_plotly(d_model=8, d_ff=32):
    # X positions
    x_input, x_hidden, x_output = 0, 3, 6

    # Y positions (centered)
    y_input = np.linspace(0, 1, d_model)
    y_hidden = np.linspace(-0.5, 1.5, d_ff)
    y_output = np.linspace(0, 1, d_model)

    fig = go.Figure()

    # ---- Input layer ----
    fig.add_trace(go.Scatter(
        x=[x_input]*d_model,
        y=y_input,
        mode="markers",
        marker=dict(size=14),
        name="Input (d_model)"
    ))

    # ---- Hidden layer ----
    fig.add_trace(go.Scatter(
        x=[x_hidden]*d_ff,
        y=y_hidden,
        mode="markers",
        marker=dict(size=8),
        name="Hidden (d_ff)"
    ))

    # ---- Output layer ----
    fig.add_trace(go.Scatter(
        x=[x_output]*d_model,
        y=y_output,
        mode="markers",
        marker=dict(size=14),
        name="Output (d_model)"
    ))

    # ---- Connections (sampled to reduce clutter) ----
    for yi in y_input:
        for yh in y_hidden[::4]:
            fig.add_trace(go.Scatter(
                x=[x_input, x_hidden],
                y=[yi, yh],
                mode="lines",
                line=dict(width=1),
                opacity=0.15,
                showlegend=False
            ))

    for yh in y_hidden[::4]:
        for yo in y_output:
            fig.add_trace(go.Scatter(
                x=[x_hidden, x_output],
                y=[yh, yo],
                mode="lines",
                line=dict(width=1),
                opacity=0.15,
                showlegend=False
            ))

    # ---- Annotations ----
    fig.add_annotation(x=x_input, y=1.25, text="Input<br>(d_model)", showarrow=False)
    fig.add_annotation(x=x_hidden, y=1.75, text="Linear + ReLU + Dropout<br>(d_ff)", showarrow=False)
    fig.add_annotation(x=x_output, y=1.25, text="Output<br>(d_model)", showarrow=False)

    fig.add_annotation(x=1.5, y=1.55, text="Linear(d_model → d_ff)", showarrow=False)
    fig.add_annotation(x=4.5, y=1.55, text="Linear(d_ff → d_model)", showarrow=False)

    # ---- Layout ----
    fig.update_layout(
        title="Transformer Feed Forward Network (Position-wise FFN)",
        showlegend=False,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        height=600,
        width=900
    )

    fig.show()


# Run
visualize_ffn_plotly(d_model=8, d_ff=32)
