The projection layer converts the decoder's output embeddings into vocabulary probabilities to predict the next word

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

print("=" * 80)
print("COMPLETE TRANSFORMER: INPUT → ENCODER → DECODER → PROJECTION → OUTPUT")
print("=" * 80)
print()


COMPLETE TRANSFORMER: INPUT → ENCODER → DECODER → PROJECTION → OUTPUT



In [25]:
# =============================================
# SETUP: Vocabulary and Parameters
# =============================================
# English to French translation example
src_vocab = {
    "<PAD>": 0, "<SOS>": 1, "<EOS>": 2,
    "I": 3, "love": 4, "cats": 5, "dogs": 6
}

tgt_vocab = {
    "<PAD>": 0, "<SOS>": 1,"<EOS>": 2,
    "je": 3, "aime": 4, "les": 5, "chats": 6, "chiens": 7
}

src_vocab_size = len(src_vocab)  # 7
tgt_vocab_size = len(tgt_vocab)  # 8
d_model = 8  # Small for demo
seq_len = 4

print("VOCABULARY")
print("-" * 80)
print(f"Source vocab (English): {src_vocab}")
print(f"Target vocab (French): {tgt_vocab}")
print(f"Source vocab size: {src_vocab_size}")
print(f"Target vocab size: {tgt_vocab_size}")
print(f"d_model (embedding size): {d_model}")
print()

VOCABULARY
--------------------------------------------------------------------------------
Source vocab (English): {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, 'I': 3, 'love': 4, 'cats': 5, 'dogs': 6}
Target vocab (French): {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, 'je': 3, 'aime': 4, 'les': 5, 'chats': 6, 'chiens': 7}
Source vocab size: 7
Target vocab size: 8
d_model (embedding size): 8



In [26]:

# =============================================
# INPUT SENTENCES
# =============================================
print("STEP 1: INPUT SENTENCES")
print("-" * 80)

# Source: "I love cats"
src_sentence = ["I", "love", "cats", "<EOS>"]
src_ids = [src_vocab[w] for w in src_sentence]  # [3, 4, 5, 2]

# Target: "je aime les chats" (during training, we provide this)
tgt_sentence = ["<SOS>", "je", "aime", "les"]  # First 4 tokens
tgt_ids = [tgt_vocab[w] for w in tgt_sentence]  # [1, 3, 4, 5]

print(f"Source (English): {src_sentence}")
print(f"Source IDs: {src_ids}")
print()
print(f"Target (French): {tgt_sentence}")
print(f"Target IDs: {tgt_ids}")
print()

STEP 1: INPUT SENTENCES
--------------------------------------------------------------------------------
Source (English): ['I', 'love', 'cats', '<EOS>']
Source IDs: [3, 4, 5, 2]

Target (French): ['<SOS>', 'je', 'aime', 'les']
Target IDs: [1, 3, 4, 5]



In [27]:
# Convert to tensors (batch_size=1)
src_batch = torch.tensor([src_ids])  # (1, 4 words)
tgt_batch = torch.tensor([tgt_ids])  # (1, 4 words)

print(f"Source batch shape: {src_batch.shape}")
print(f"Target batch shape: {tgt_batch.shape}")
print()

Source batch shape: torch.Size([1, 4])
Target batch shape: torch.Size([1, 4])



In [28]:

# =============================================
# EMBEDDINGS
# =============================================
print("STEP 2: EMBEDDINGS")
print("-" * 80)

src_embedding = nn.Embedding(src_vocab_size, d_model)
tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

src_embedded = src_embedding(src_batch)  # (1, 4, 8)
tgt_embedded = tgt_embedding(tgt_batch)  # (1, 4, 8)

print(f"Source embedded shape: {src_embedded.shape}")
print(f"Target embedded shape: {tgt_embedded.shape}")
print()
print("Sample embedding for 'I' (src token 0):")
print(src_embedded[0, 0, :])
print()

STEP 2: EMBEDDINGS
--------------------------------------------------------------------------------
Source embedded shape: torch.Size([1, 4, 8])
Target embedded shape: torch.Size([1, 4, 8])

Sample embedding for 'I' (src token 0):
tensor([ 0.6667,  0.8459, -1.0229,  0.3337,  0.5114,  0.4131,  0.4232,  0.6103],
       grad_fn=<SliceBackward0>)



In [29]:
# =============================================
# ENCODER (simplified)
# =============================================
print("STEP 3: ENCODER")
print("-" * 80)

# Simplified encoder (just a linear layer for demo)
encoder = nn.Linear(d_model, d_model)
encoder_output = encoder(src_embedded)  # (1, 4, 8)

print(f"Encoder output shape: {encoder_output.shape}")
print(f"Interpretation: (batch=1, src_seq_len=4, d_model=8)")
print()
print("Encoder output for token 0:")
print(encoder_output[0, 0, :])
print()

STEP 3: ENCODER
--------------------------------------------------------------------------------
Encoder output shape: torch.Size([1, 4, 8])
Interpretation: (batch=1, src_seq_len=4, d_model=8)

Encoder output for token 0:
tensor([ 0.1039,  0.3812, -0.3746, -0.1416,  0.4212,  0.8765, -0.1701, -0.5222],
       grad_fn=<SliceBackward0>)



In [30]:
# =============================================
# DECODER (simplified)
# =============================================
print("STEP 4: DECODER")
print("-" * 80)

# Simplified decoder (just a linear layer for demo)
decoder = nn.Linear(d_model, d_model)
decoder_output = decoder(tgt_embedded)  # (1, 4, 8)

print(f"Decoder output shape: {decoder_output.shape}")
print(f"Interpretation: (batch=1, tgt_seq_len=4, d_model=8)")
print()
print("Decoder output for token 0:")
print(decoder_output[0, 0, :])
print()

STEP 4: DECODER
--------------------------------------------------------------------------------
Decoder output shape: torch.Size([1, 4, 8])
Interpretation: (batch=1, tgt_seq_len=4, d_model=8)

Decoder output for token 0:
tensor([ 0.0964,  0.4175,  1.0548,  1.2416,  0.9222, -0.0118, -0.7893, -0.4573],
       grad_fn=<SliceBackward0>)



In [31]:
# =============================================
# PROJECTION LAYER (THE KEY PART!)
# =============================================
print("STEP 5: PROJECTION LAYER")
print("=" * 80)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        print(f"  Created projection: {d_model} → {vocab_size}")
    
    def forward(self, x):
        # x: (batch, seq_len, d_model) → (batch, seq_len, vocab_size)
        return self.proj(x)

projection = ProjectionLayer(d_model, tgt_vocab_size)
print()

STEP 5: PROJECTION LAYER
  Created projection: 8 → 8



In [32]:
# Apply projection
logits = projection(decoder_output)  # (1, 4, 8) → (1, 4, 8)
print(logits)
print(f"Logits shape: {logits.shape}")
print(f"Interpretation: (batch=1, seq_len=4, vocab_size={tgt_vocab_size})")
print()

print("What are logits?")
print("  Raw scores for EACH word in the vocabulary")
print("  Higher score = model thinks that word is more likely")
print()

print("Logits for position 0 (after <SOS>):")
print(logits[0, 0])
print()

tensor([[[ 0.1597,  0.4743, -0.6621, -0.0593, -0.4071,  0.7157,  0.3807,
          -0.0580],
         [ 0.4536,  0.2123, -0.1139, -0.1537, -0.1311,  0.2143,  0.1252,
          -0.2146],
         [ 0.5522,  0.0049,  0.1381,  0.0944,  0.2438,  0.1069,  0.2276,
           0.0667],
         [ 0.2487,  0.2507, -0.6990, -0.2207, -0.4781,  0.3978,  0.1632,
          -0.2684]]], grad_fn=<ViewBackward0>)
Logits shape: torch.Size([1, 4, 8])
Interpretation: (batch=1, seq_len=4, vocab_size=8)

What are logits?
  Raw scores for EACH word in the vocabulary
  Higher score = model thinks that word is more likely

Logits for position 0 (after <SOS>):
tensor([ 0.1597,  0.4743, -0.6621, -0.0593, -0.4071,  0.7157,  0.3807, -0.0580],
       grad_fn=<SelectBackward0>)



In [33]:
# Map indices to words
reverse_tgt_vocab = {v: k for k, v in tgt_vocab.items()}
print("Interpretation:")
for idx, score in enumerate(logits[0, 0]):
    word = reverse_tgt_vocab[idx]
    print(f"  {word:>10s}: {score.item():>8.3f}")
print()

Interpretation:
       <PAD>:    0.160
       <SOS>:    0.474
       <EOS>:   -0.662
          je:   -0.059
        aime:   -0.407
         les:    0.716
       chats:    0.381
      chiens:   -0.058



In [34]:
# =============================================
# SOFTMAX: Convert to probabilities
# =============================================
print("STEP 6: SOFTMAX (Convert logits → probabilities)")
print("-" * 80)

probabilities = F.softmax(logits, dim=-1)  # (1, 4, 8)

print(f"Probabilities shape: {probabilities.shape}")
print()

print("Probabilities for position 0 (predicting next word after <SOS>):")
print(probabilities[0, 0])
print()

STEP 6: SOFTMAX (Convert logits → probabilities)
--------------------------------------------------------------------------------
Probabilities shape: torch.Size([1, 4, 8])

Probabilities for position 0 (predicting next word after <SOS>):
tensor([0.1254, 0.1717, 0.0551, 0.1007, 0.0711, 0.2186, 0.1564, 0.1009],
       grad_fn=<SelectBackward0>)



In [35]:
print("Word probabilities:")
for idx, prob in enumerate(probabilities[0, 0]):
    word = reverse_tgt_vocab[idx]
    print(f"  {word:>10s}: {prob.item()*100:>6.2f}%")
print()

# Sum should be 1.0
print(f"Sum of probabilities: {probabilities[0, 0].sum().item():.6f} (should be 1.0)")
print()

Word probabilities:
       <PAD>:  12.54%
       <SOS>:  17.17%
       <EOS>:   5.51%
          je:  10.07%
        aime:   7.11%
         les:  21.86%
       chats:  15.64%
      chiens:  10.09%

Sum of probabilities: 1.000000 (should be 1.0)



In [36]:
# =============================================
# PREDICTION: Select word with highest probability
# =============================================
print("STEP 7: PREDICTION")
print("-" * 80)

predicted_ids = torch.argmax(probabilities, dim=-1)  # (1, 4)

print(f"Predicted IDs: {predicted_ids}")
print()

print("Predicted words for each position:")
for pos in range(seq_len):
    predicted_id = predicted_ids[0, pos].item()
    predicted_word = reverse_tgt_vocab[predicted_id]
    prob = probabilities[0, pos, predicted_id].item()
    print(f"  Position {pos}: '{predicted_word}' (ID={predicted_id}, prob={prob*100:.2f}%)")
print()

STEP 7: PREDICTION
--------------------------------------------------------------------------------
Predicted IDs: tensor([[5, 0, 0, 5]])

Predicted words for each position:
  Position 0: 'les' (ID=5, prob=21.86%)
  Position 1: '<PAD>' (ID=0, prob=18.27%)
  Position 2: '<PAD>' (ID=0, prob=17.90%)
  Position 3: 'les' (ID=5, prob=18.80%)



In [37]:

# =============================================
# DETAILED BREAKDOWN: How Projection Works
# =============================================
print("=" * 80)
print("HOW PROJECTION LAYER WORKS INTERNALLY")
print("=" * 80)
print()

print("Input to projection layer (one token):")
single_token = decoder_output[0, 1, :]  # Token 1
print(f"  Shape: {single_token.shape}")  # (8,)
print(f"  Values: {single_token}")
print()

print("Projection layer is nn.Linear(d_model=8, vocab_size=8)")
print(f"  Weight matrix shape: {projection.proj.weight.shape}")  # (8, 8)
print(f"  Bias shape: {projection.proj.bias.shape}")  # (8,)
print()

print("Matrix multiplication:")
print("  decoder_output: (batch, seq, d_model) = (1, 4, 8)")
print("  weight:         (vocab_size, d_model) = (8, 8)")
print("  result:         (batch, seq, vocab_size) = (1, 4, 8)")
print()

print("For each position, it computes:")
print("  logits[pos] = decoder_output[pos] @ weight.T + bias")
print("               └─(8,)─┘              └─(8, 8)─┘   └─(8,)─┘")
print("               = (8,) result for each vocab word")
print()

# =============================================
# VISUAL SUMMARY
# =============================================
print("=" * 80)
print("VISUAL SUMMARY")
print("=" * 80)
print()

print("FULL PIPELINE:")
print()
print("  Input: 'I love cats'")
print("    ↓")
print("  Token IDs: [3, 4, 5, 2]")
print("    ↓")
print("  Embeddings: (1, 4, 8)  ← Each token → 8-dim vector")
print("    ↓")
print("  ENCODER: (1, 4, 8)  ← Contextualized representations")
print("    ↓")
print("  Target: '<SOS> je aime les'")
print("    ↓")
print("  Token IDs: [1, 3, 4, 5]")
print("    ↓")
print("  Embeddings: (1, 4, 8)")
print("    ↓")
print("  DECODER: (1, 4, 8)  ← Contextualized with encoder output")
print("    ↓")
print("  PROJECTION: (1, 4, 8) → (1, 4, 8)")
print("              └─d_model─┘   └─vocab─┘")
print("    ↓")
print("  Logits: Raw scores for each vocab word")
print("    ↓")
print("  SOFTMAX: Convert to probabilities")
print("    ↓")
print("  ARGMAX: Select word with highest probability")
print("    ↓")
print("  Output: Predicted word at each position")
print()

# =============================================
# WHY IS PROJECTION NEEDED?
# =============================================
print("=" * 80)
print("WHY IS PROJECTION LAYER NEEDED?")
print("=" * 80)
print()

print("Problem:")
print(f"  Decoder outputs vectors of size d_model = {d_model}")
print(f"  But we need to predict one of {tgt_vocab_size} words!")
print()

print("Solution:")
print(f"  Projection layer: Linear({d_model}, {tgt_vocab_size})")
print("  Maps each d_model vector → vocab_size scores")
print("  Each score = 'how likely is this word?'")
print()

print("Example (position 0):")
print("  Decoder output: [0.12, -0.34, 0.56, ...]  (8 numbers)")
print("  Projection →    [1.2, 0.3, -0.5, 2.1, ...] (8 numbers)")
print("                   ↑    ↑     ↑    ↑")
print("                  <PAD> <SOS> <EOS> je  ... (one per vocab word)")
print()

print("After softmax:")
print("  [0.05, 0.02, 0.01, 0.12, ...]  (probabilities sum to 1)")
print()

print("Pick highest: 'je' (position 3) with prob 0.12")
print()

print("✅ Projection layer converts abstract embeddings → concrete word predictions!")


HOW PROJECTION LAYER WORKS INTERNALLY

Input to projection layer (one token):
  Shape: torch.Size([8])
  Values: tensor([-0.0262, -0.4131, -0.4404,  0.1978,  0.4935,  0.0317,  0.0225,  1.0157],
       grad_fn=<SliceBackward0>)

Projection layer is nn.Linear(d_model=8, vocab_size=8)
  Weight matrix shape: torch.Size([8, 8])
  Bias shape: torch.Size([8])

Matrix multiplication:
  decoder_output: (batch, seq, d_model) = (1, 4, 8)
  weight:         (vocab_size, d_model) = (8, 8)
  result:         (batch, seq, vocab_size) = (1, 4, 8)

For each position, it computes:
  logits[pos] = decoder_output[pos] @ weight.T + bias
               └─(8,)─┘              └─(8, 8)─┘   └─(8,)─┘
               = (8,) result for each vocab word

VISUAL SUMMARY

FULL PIPELINE:

  Input: 'I love cats'
    ↓
  Token IDs: [3, 4, 5, 2]
    ↓
  Embeddings: (1, 4, 8)  ← Each token → 8-dim vector
    ↓
  ENCODER: (1, 4, 8)  ← Contextualized representations
    ↓
  Target: '<SOS> je aime les'
    ↓
  Token IDs: [1, 3, 