# The Simplest Transformer: Everything in 3D

See how "to be or" becomes "not" through simple 3D transformations.

In [1]:
# Minimal setup
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Our tiny world: just 5 words
words = ['to', 'be', 'or', 'not', 'that']
input_text = "to be or"
input_words = input_text.split()

print(f"Input: '{input_text}'")
print(f"Goal: Predict the next word (spoiler: it's 'not')")

Input: 'to be or'
Goal: Predict the next word (spoiler: it's 'not')


## Step 1: Words as Points in 3D Space

Each word lives at a specific location in 3D space. Similar words are closer together.

In [2]:
# Each word has a 3D location (pre-defined for clarity)
word_locations = {
    'to':   np.array([-1.0,  0.5, -0.5]),
    'be':   np.array([-0.5, -1.0,  0.3]),
    'or':   np.array([ 1.0, -0.3, -0.8]),
    'not':  np.array([ 0.3,  1.0,  0.7]),
    'that': np.array([ 0.7,  0.2, -0.6])
}

# Visualize
fig = go.Figure()

# Add grid for reference
grid_range = [-1.5, 1.5]
grid_lines = np.linspace(grid_range[0], grid_range[1], 7)
for val in grid_lines:
    # X-Y plane grid
    fig.add_trace(go.Scatter3d(
        x=[grid_range[0], grid_range[1]], y=[val, val], z=[-1.5, -1.5],
        mode='lines', line=dict(color='lightgray', width=1),
        showlegend=False, hoverinfo='skip'
    ))
    fig.add_trace(go.Scatter3d(
        x=[val, val], y=[grid_range[0], grid_range[1]], z=[-1.5, -1.5],
        mode='lines', line=dict(color='lightgray', width=1),
        showlegend=False, hoverinfo='skip'
    ))

# Plot all words
all_positions = np.array([word_locations[w] for w in words])
fig.add_trace(go.Scatter3d(
    x=all_positions[:, 0],
    y=all_positions[:, 1],
    z=all_positions[:, 2],
    mode='markers+text',
    marker=dict(
        size=15,
        color=['red' if w in input_words else 'lightblue' for w in words],
        line=dict(width=2, color='black')
    ),
    text=words,
    textposition='top center',
    textfont=dict(size=14),
    name='Words'
))

# Add origin
fig.add_trace(go.Scatter3d(
    x=[0], y=[0], z=[0],
    mode='markers',
    marker=dict(size=5, color='black'),
    name='Origin',
    showlegend=False
))

fig.update_layout(
    title="Words as Points in 3D Space<br><sub>Red = input words, Blue = other words</sub>",
    scene=dict(
        xaxis=dict(title="X", range=[-1.5, 1.5]),
        yaxis=dict(title="Y", range=[-1.5, 1.5]),
        zaxis=dict(title="Z", range=[-1.5, 1.5]),
        aspectmode='cube'
    ),
    height=500,
    showlegend=False
)

fig.show()

print("üìç Each word is just a point in 3D space!")
for w in input_words:
    pos = word_locations[w]
    print(f"   '{w}' is at ({pos[0]:.1f}, {pos[1]:.1f}, {pos[2]:.1f})")

üìç Each word is just a point in 3D space!
   'to' is at (-1.0, 0.5, -0.5)
   'be' is at (-0.5, -1.0, 0.3)
   'or' is at (1.0, -0.3, -0.8)


## Step 2: Adding Position (Word Order Matters!)

"dog bites man" ‚â† "man bites dog" - position is crucial!

In [3]:
# Position vectors (where in the sentence)
position_vectors = [
    np.array([0.3, 0.1, 0.2]),   # First position
    np.array([0.1, 0.3, 0.1]),   # Second position
    np.array([-0.1, 0.2, 0.3])   # Third position
]

# Get initial positions
initial_positions = np.array([word_locations[w] for w in input_words])

# Add position to each word
final_positions = initial_positions + np.array(position_vectors)

# Visualize the addition
fig = go.Figure()

# Original word positions
fig.add_trace(go.Scatter3d(
    x=initial_positions[:, 0],
    y=initial_positions[:, 1],
    z=initial_positions[:, 2],
    mode='markers+text',
    marker=dict(size=12, color='blue', symbol='circle'),
    text=[f"{w} (word)" for w in input_words],
    textposition='bottom center',
    name='Word positions'
))

# Show position vectors as arrows from each word
for i, word in enumerate(input_words):
    start = initial_positions[i]
    end = final_positions[i]

    # Arrow showing position vector
    fig.add_trace(go.Scatter3d(
        x=[start[0], end[0]],
        y=[start[1], end[1]],
        z=[start[2], end[2]],
        mode='lines+markers',
        line=dict(color='green', width=4),
        marker=dict(size=[0, 8], color='green', symbol=['circle', 'diamond']),
        name=f'Position {i+1}',
        showlegend=(i == 0),
        hovertemplate=f'{word}: position {i+1}<extra></extra>'
    ))

# Final positions
fig.add_trace(go.Scatter3d(
    x=final_positions[:, 0],
    y=final_positions[:, 1],
    z=final_positions[:, 2],
    mode='markers+text',
    marker=dict(size=15, color='red', symbol='diamond'),
    text=[f"{w} (final)" for w in input_words],
    textposition='top center',
    name='With position'
))

fig.update_layout(
    title="Adding Position Information<br><sub>Blue = word alone, Green arrow = position vector, Red = word + position</sub>",
    scene=dict(
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        zaxis=dict(title="Z"),
        aspectmode='cube'
    ),
    height=500
)

fig.show()

print("‚ûï Simple vector addition:")
print("   Word location + Position vector = Final embedding")
print(f"   Example: 'to' at {initial_positions[0]} + position {position_vectors[0]} = {final_positions[0]}")

‚ûï Simple vector addition:
   Word location + Position vector = Final embedding
   Example: 'to' at [-1.   0.5 -0.5] + position [0.3 0.1 0.2] = [-0.7  0.6 -0.3]


## Step 3: Attention - Words Look at Each Other

The magic of transformers: each word asks "who should I pay attention to?"

In [4]:
# Simple attention weights (how much each word looks at others)
# Rows = looking word, Columns = looked-at word
attention_weights = np.array([
    [1.0, 0.0, 0.0],  # 'to' only looks at itself
    [0.3, 0.7, 0.0],  # 'be' mostly looks at itself, a bit at 'to'
    [0.2, 0.3, 0.5]   # 'or' looks at all previous words
])

# Apply attention: weighted average of positions
attended_positions = []
for i in range(len(input_words)):
    # Each word becomes a weighted average of what it looks at
    weights = attention_weights[i, :i+1]  # Can only look at previous words
    weights = weights / weights.sum()      # Normalize to sum to 1

    # Weighted average of positions
    new_position = np.zeros(3)
    for j in range(i+1):
        new_position += weights[j] * final_positions[j]
    attended_positions.append(new_position)

attended_positions = np.array(attended_positions)

# Visualize attention as connections
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Attention Weights', '3D Position Changes'),
    specs=[[{'type': 'heatmap'}, {'type': 'scatter3d'}]],
    horizontal_spacing=0.15
)

# Attention matrix
fig.add_trace(
    go.Heatmap(
        z=attention_weights,
        x=input_words,
        y=input_words,
        text=np.round(attention_weights, 2),
        texttemplate='%{text}',
        textfont=dict(size=16),
        colorscale='Blues',
        showscale=False
    ),
    row=1, col=1
)

# 3D visualization
# Before attention
fig.add_trace(
    go.Scatter3d(
        x=final_positions[:, 0],
        y=final_positions[:, 1],
        z=final_positions[:, 2],
        mode='markers+text',
        marker=dict(size=10, color='blue'),
        text=[f"{w} (before)" for w in input_words],
        textposition='bottom center',
        name='Before attention',
        showlegend=True
    ),
    row=1, col=2
)

# After attention
fig.add_trace(
    go.Scatter3d(
        x=attended_positions[:, 0],
        y=attended_positions[:, 1],
        z=attended_positions[:, 2],
        mode='markers+text',
        marker=dict(size=15, color='red', symbol='diamond'),
        text=[f"{w} (after)" for w in input_words],
        textposition='top center',
        name='After attention',
        showlegend=True
    ),
    row=1, col=2
)

# Movement arrows
for i in range(len(input_words)):
    fig.add_trace(
        go.Scatter3d(
            x=[final_positions[i, 0], attended_positions[i, 0]],
            y=[final_positions[i, 1], attended_positions[i, 1]],
            z=[final_positions[i, 2], attended_positions[i, 2]],
            mode='lines',
            line=dict(color='gray', width=2, dash='dash'),
            showlegend=False,
            hoverinfo='skip'
        ),
        row=1, col=2
    )

fig.update_xaxes(title="Looked at ‚Üí", row=1, col=1)
fig.update_yaxes(title="Looking ‚Üì", row=1, col=1)

fig.update_layout(
    title="Attention: Words Mix Their Positions",
    height=400,
    showlegend=True
)

fig.show()

print("üëÄ Attention as weighted averaging:")
print("   'to' (100% self): stays at same position")
print("   'be' (70% self, 30% 'to'): moves slightly toward 'to'")
print("   'or' (50% self, 30% 'be', 20% 'to'): blends all three positions")

üëÄ Attention as weighted averaging:
   'to' (100% self): stays at same position
   'be' (70% self, 30% 'to'): moves slightly toward 'to'
   'or' (50% self, 30% 'be', 20% 'to'): blends all three positions


## Step 4: Transformation Layers - Refining Understanding

Each layer transforms the 3D positions to better capture meaning.

In [5]:
# Simple transformation: rotate and scale
def transform_layer(positions, layer_name):
    """Apply a simple transformation to simulate a layer."""
    if layer_name == "Layer 1":
        # Rotate around Z axis and scale
        angle = np.pi / 6  # 30 degrees
        rotation = np.array([
            [np.cos(angle), -np.sin(angle), 0],
            [np.sin(angle), np.cos(angle), 0],
            [0, 0, 1]
        ])
        return positions @ rotation * 1.1
    else:  # Layer 2
        # Different rotation and slight contraction
        angle = -np.pi / 8
        rotation = np.array([
            [1, 0, 0],
            [0, np.cos(angle), -np.sin(angle)],
            [0, np.sin(angle), np.cos(angle)]
        ])
        return positions @ rotation * 0.9

# Apply two layers
layer1_positions = transform_layer(attended_positions, "Layer 1")
layer2_positions = transform_layer(layer1_positions, "Layer 2")

# Create path visualization
fig = go.Figure()

# Plot the journey for each word
colors = ['red', 'green', 'blue']
stages = [
    ("Start", attended_positions),
    ("Layer 1", layer1_positions),
    ("Layer 2", layer2_positions)
]

for word_idx, (word, color) in enumerate(zip(input_words, colors)):
    # Get path for this word
    path = np.array([stage[1][word_idx] for stage in stages])

    # Plot path
    fig.add_trace(go.Scatter3d(
        x=path[:, 0],
        y=path[:, 1],
        z=path[:, 2],
        mode='lines+markers+text',
        line=dict(color=color, width=4),
        marker=dict(size=[8, 10, 15], color=color),
        text=["", "", word],  # Only label the final position
        textposition='top center',
        textfont=dict(size=14),
        name=f"'{word}'",
        hovertemplate='%{text}<br>Stage: %{customdata}<br>Position: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>',
        customdata=[s[0] for s in stages]
    ))

# Add stage markers
for stage_idx, (stage_name, positions) in enumerate(stages):
    if stage_idx == 0:  # Mark starting positions
        fig.add_trace(go.Scatter3d(
            x=positions[:, 0],
            y=positions[:, 1],
            z=positions[:, 2],
            mode='markers',
            marker=dict(size=5, color='black', symbol='cross'),
            name='Start',
            showlegend=False,
            hoverinfo='skip'
        ))

fig.update_layout(
    title="Journey Through Transformer Layers<br><sub>Each word follows its own path through 3D space</sub>",
    scene=dict(
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        zaxis=dict(title="Z"),
        aspectmode='cube',
        camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
    ),
    height=500,
    showlegend=True
)

fig.show()

print("üîÑ Each layer rotates and scales the positions")
print("   Layer 1: Rotate 30¬∞ around Z, scale up 10%")
print("   Layer 2: Rotate -22.5¬∞ around X, scale down 10%")
print("   Result: Words end up in better positions for prediction")

üîÑ Each layer rotates and scales the positions
   Layer 1: Rotate 30¬∞ around Z, scale up 10%
   Layer 2: Rotate -22.5¬∞ around X, scale down 10%
   Result: Words end up in better positions for prediction


## Step 5: Final Prediction - From 3D Position to Next Word

The last word's final position determines what comes next.

In [6]:
# Final position of 'or' (the last word)
final_position = layer2_positions[-1]

# Each word in vocabulary has a "target direction" in 3D
word_directions = {
    'to':   np.array([-1.0,  0.0,  0.0]),
    'be':   np.array([ 0.0, -1.0,  0.0]),
    'or':   np.array([ 0.0,  0.0, -1.0]),
    'not':  np.array([ 0.7,  0.7,  0.0]),  # Points similar to our final position!
    'that': np.array([-0.7,  0.7,  0.0])
}

# Normalize directions
for word in word_directions:
    word_directions[word] = word_directions[word] / np.linalg.norm(word_directions[word])

# Calculate similarity (dot product) with each word
final_normalized = final_position / np.linalg.norm(final_position)
similarities = {}
for word, direction in word_directions.items():
    similarities[word] = np.dot(final_normalized, direction)

# Convert to probabilities (softmax)
scores = np.array(list(similarities.values()))
exp_scores = np.exp(scores * 5)  # Scale up for clearer differences
probabilities = exp_scores / exp_scores.sum()

# Visualization
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Final Vector vs Word Directions', 'Prediction Probabilities'),
    specs=[[{'type': 'scatter3d'}, {'type': 'bar'}]],
    horizontal_spacing=0.2
)

# 3D: Final position and word directions
# Final position vector
fig.add_trace(
    go.Scatter3d(
        x=[0, final_normalized[0]],
        y=[0, final_normalized[1]],
        z=[0, final_normalized[2]],
        mode='lines+markers+text',
        line=dict(color='red', width=8),
        marker=dict(size=[5, 15], color='red'),
        text=['', "'or' final"],
        textposition='top center',
        name='Final vector',
        showlegend=True
    ),
    row=1, col=1
)

# Word direction vectors
for word, direction in word_directions.items():
    color = 'darkgreen' if word == 'not' else 'lightblue'
    width = 6 if word == 'not' else 3

    fig.add_trace(
        go.Scatter3d(
            x=[0, direction[0]],
            y=[0, direction[1]],
            z=[0, direction[2]],
            mode='lines+markers+text',
            line=dict(color=color, width=width),
            marker=dict(size=[0, 10], color=color),
            text=['', word],
            textposition='top center',
            name=word,
            showlegend=(word == 'not')
        ),
        row=1, col=1
    )

# Probability bar chart
fig.add_trace(
    go.Bar(
        x=words,
        y=probabilities,
        marker_color=['darkgreen' if w == 'not' else 'lightblue' for w in words],
        text=[f'{p:.1%}' for p in probabilities],
        textposition='outside',
        showlegend=False
    ),
    row=1, col=2
)

fig.update_xaxes(title="Words", row=1, col=2)
fig.update_yaxes(title="Probability", tickformat='.0%', row=1, col=2)

fig.update_layout(
    title="Final Prediction: Which Word Direction Best Matches?",
    height=400
)

fig.show()

predicted_word = words[np.argmax(probabilities)]
print(f"üéØ Prediction: '{predicted_word}' with {max(probabilities):.1%} confidence")
print(f"   Complete: '{input_text} {predicted_word}'")
print("\nüìê How it works:")
print("   1. Final position points in a direction")
print("   2. Each word has its own direction")
print("   3. Most similar direction wins!")

üéØ Prediction: 'or' with 88.6% confidence
   Complete: 'to be or or'

üìê How it works:
   1. Final position points in a direction
   2. Each word has its own direction
   3. Most similar direction wins!


## The Complete Picture

Let's see the entire journey in one visualization!

In [7]:
# Create comprehensive visualization
fig = go.Figure()

# All stages for the last word ('or')
or_journey = [
    ("1. Word", word_locations['or']),
    ("2. +Position", word_locations['or'] + position_vectors[2]),
    ("3. Attention", attended_positions[2]),
    ("4. Layer 1", layer1_positions[2]),
    ("5. Layer 2", layer2_positions[2])
]

# Plot the complete path
path = np.array([stage[1] for stage in or_journey])

fig.add_trace(go.Scatter3d(
    x=path[:, 0],
    y=path[:, 1],
    z=path[:, 2],
    mode='lines+markers+text',
    line=dict(color='red', width=6),
    marker=dict(
        size=[10, 12, 14, 16, 20],
        color=['lightblue', 'blue', 'purple', 'orange', 'red'],
        line=dict(width=2, color='black')
    ),
    text=[stage[0] for stage in or_journey],
    textposition='top center',
    textfont=dict(size=12),
    name="'or' journey",
    hovertemplate='Stage: %{text}<br>Position: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>'
))

# Add the prediction arrow
final_pos = layer2_positions[2]
final_norm = final_pos / np.linalg.norm(final_pos)
not_direction = word_directions['not']

fig.add_trace(go.Scatter3d(
    x=[final_pos[0], final_pos[0] + not_direction[0] * 0.5],
    y=[final_pos[1], final_pos[1] + not_direction[1] * 0.5],
    z=[final_pos[2], final_pos[2] + not_direction[2] * 0.5],
    mode='lines+text',
    line=dict(color='green', width=8),
    text=['', "‚Üí 'not'"],
    textposition='top center',
    textfont=dict(size=16, color='green'),
    name='Prediction',
    showlegend=False
))

# Add grid at z=0 for reference
xx, yy = np.meshgrid(np.linspace(-2, 2, 5), np.linspace(-2, 2, 5))
zz = np.zeros_like(xx)
fig.add_trace(go.Surface(
    x=xx, y=yy, z=zz,
    opacity=0.1,
    colorscale='Greys',
    showscale=False,
    name='Ground plane'
))

fig.update_layout(
    title="Complete Transformer Journey: 'or' ‚Üí 'not'<br><sub>From word embedding to final prediction</sub>",
    scene=dict(
        xaxis=dict(title="X", range=[-2, 2]),
        yaxis=dict(title="Y", range=[-2, 2]),
        zaxis=dict(title="Z", range=[-2, 2]),
        aspectmode='cube',
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5),
            center=dict(x=0, y=0, z=0)
        )
    ),
    height=600,
    showlegend=False
)

# Add annotations for each stage
fig.add_trace(go.Scatter3d(
    x=[-1.5],
    y=[1.5],
    z=[1.5],
    mode='text',
    text=["Steps:<br>1. Start with word<br>2. Add position<br>3. Apply attention<br>4. Transform (Layer 1)<br>5. Transform (Layer 2)<br>6. Predict next word"],
    textfont=dict(size=10),
    showlegend=False
))

fig.show()

print("‚ú® THE COMPLETE TRANSFORMER:")
print("="*40)
print("1. Words start as points in 3D space")
print("2. Add position vectors (word order)")
print("3. Attention mixes positions (context)")
print("4. Layers transform the space (understanding)")
print("5. Final position determines prediction")
print("\nIt's all just geometry in 3D space!")

‚ú® THE COMPLETE TRANSFORMER:
1. Words start as points in 3D space
2. Add position vectors (word order)
3. Attention mixes positions (context)
4. Layers transform the space (understanding)
5. Final position determines prediction

It's all just geometry in 3D space!


## What We Learned

### Transformers Are Just 3D Geometry!

1. **Words = Points in space**
2. **Position = Vector addition**
3. **Attention = Weighted averaging**
4. **Layers = Geometric transformations**
5. **Prediction = Direction matching**

### Real Transformers:
- Use 512+ dimensions (not 3)
- Have 12-96 layers (not 2)
- Process 1000s of words at once

But the principles are **exactly the same** - just harder to visualize!

This is how ChatGPT, GPT-4, and all language models work at their core.