<a href="https://colab.research.google.com/github/KempnerInstitute/transformer-workshop/blob/main/transformer_instructor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
import requests

# Tokenization

First, we'll load the training data (`tiny_wikipedia.txt`), which you can find [here](https://github.com/KempnerInstitute/transformer-workshop/blob/main/tiny_wikipedia.txt), and set up tokenization. To do this, we identify all unique characters in the dataset and assign each one a unique integer ID.

We then define two functions:

`encode(text)`: converts text into integer token IDs

`decode(token_ids)`: converts integer token IDs back into text


In [None]:
# Load in all training data

url = "https://raw.githubusercontent.com/KempnerInstitute/transformer-workshop/main/tiny_wikipedia.txt"
response = requests.get(url)
with open("tiny_wikipedia.txt", "wb") as f:
    f.write(response.content)

with open('tiny_wikipedia.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
# Get all unique characters in the training data
chars = sorted(list(set(text)))

# Create mappings from characters to integers and back
str_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_str = {i: ch for i, ch in enumerate(chars)}

# Define encode (text → ids) and decode (ids → text) functions
def encode(text, str_to_int):
  """Convert a string into a list of integer token IDs."""
  ids = [str_to_int[c] for c in text]
  return ids

def decode(ids, int_to_str):
  """Convert a list of integer token IDs back into a string."""
  text_list = [int_to_str[id] for id in ids]
  return ''.join(text_list)

# Test the implementation
input_text = "My dog Leo is extremely cute."
ids = encode(input_text, str_to_int)
print(f"Input text: {input_text}")
print(f"Token IDs: {ids}")

decoded_text = decode(ids, int_to_str)
assert input_text == decoded_text, "Decoded text does not match input"
print("Decoded text matches the original")

## Tokenize input data and create splits

Next, we’ll convert our tokenized text into a PyTorch tensor and split it into training and validation sets. The training data will be used to learn model parameters, while the validation set lets us check how well the model generalizes to unseen text.

We'll also define a helper function, `get_batch`, which randomly samples small chunks of text from the dataset. Each row in the returned batch corresponds to a different training example — a sequence of `ctx_len` (context length) consecutive tokens. The input tensor `x` contains these sequences, and the target tensor `y` contains the same sequences shifted by one position.

In [None]:
# @markdown Execute to get helper function get_batch to generate batches of data
def get_batch(split, ctx_len, batch_size, device='cpu'):
    """
    Generate a small batch of input (x) and target (y) sequences.

    Args:
        split (str): 'train' or 'val'
        ctx_len (int): length of each sequence (context window)
        batch_size (int): number of sequences per batch
        device (str): device to move the tensors to ('cpu' or 'cuda')

    Returns:
        x, y (torch.Tensor): tensors of shape (batch_size, ctx_len)
    """
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - ctx_len, (batch_size,))

    # Each x is a sequence of ctx_len tokens; y is the same sequence shifted by one
    x = torch.stack([data_split[i:i+ctx_len] for i in ix])
    y = torch.stack([data_split[i+1:i+ctx_len+1] for i in ix])
    return x.to(device), y.to(device)

In [None]:
# Train and validation splits
data = torch.tensor(encode(text, str_to_int), dtype=torch.long)
n = int(0.9 * len(data)) # first 90% will be train, remaining 10% for validation
train_data = data[:n]
val_data = data[n:]

We’ll take a look at a small batch to see how the input (x) and target (y) sequences align (by decoding them back to text). Each row represents one data point — a short sequence of text that the model uses for next-token prediction.

In [None]:
# Let's grab a batch to look at the dimensions

x, y = get_batch(split='train', ctx_len=64, batch_size=8)
print(f"x shape: {x.shape}, y shape: {y.shape}")

# Show a few examples to see how x and y align
for i in range(3):
    print(f"\nData point {i}")
    print("x:", decode(x[i].tolist(), int_to_str))
    print("y:", decode(y[i].tolist(), int_to_str))


# Model configuration

We’ll store all the model’s key hyperparameters in a configuration class.
This makes it easy to keep track of settings (like model size, number of layers, and context length) and to modify them later without changing code in multiple places.

The Config class below defines these values and includes a helper method to update the vocabulary size once we’ve built the tokenizer.

In [None]:
@dataclass
class Config:
  """Configuration settings for the Transformer model."""
  d_model: int = 256 # hidden dimension (embedding size)
  n_heads: int = 4 # number of attention heads (width)
  ctx_len: int = 64 # context length
  batch_size: int = 8 # batch size
  n_layers: int = 12 # number of layers (depth)
  vocab_size: int = -1 # vocab size, to be determined once we have created a tokenizer
  device: str = 'cpu'

  def set_vocab_size(self, vocab_size):
    """Update the vocabulary size once tokenization is defined."""
    self.vocab_size = vocab_size

In [None]:
# Initialize the configuration and set the vocabulary size
config = Config()
config.set_vocab_size(vocab_size=len(chars))  # number of unique characters in the dataset

# Embeddings

## Exercise: Implement token embeddings

We’ll now implement a class that converts token IDs into their corresponding embedding vectors.
Given a batch of token IDs of shape (batch size by context length), the embedding layer should output token embeddings of shape batch size by (context length by embedding dimension)

Use nn.Embedding (docs [here](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)) to map each integer token ID to a learnable embedding vector.



```python
class TokenEmbeddingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        # TODO: Create the token embedding layer by specifying arguments to nn.Embedding
        self.wte = nn.Embedding(...)

    def forward(self, x):
        batch_size, seq_len = x.shape

        # TODO: Get forward pass of token embedding layer
        x_tok = ...

        return x_tok


# Testing your implementation
xb, yb = get_batch('train', config.ctx_len, config.batch_size, config.device)

token_embedding = TokenEmbeddingLayer(config)
x_tok = token_embedding(xb)

assert x_tok.shape == (config.batch_size, config.ctx_len, config.d_model), "Embedding dimensions are incorrect"
print("Token embedding layer output shape is correct!")

```

In [None]:
# Solution
class TokenEmbeddingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.wte = nn.Embedding(config.vocab_size, config.d_model)

    def forward(self, x):
        batch_size, seq_len = x.shape

        x_tok = self.wte(x)

        return x_tok


# Test your implementation
xb, yb = get_batch('train', config.ctx_len, config.batch_size, config.device)

token_embedding = TokenEmbeddingLayer(config)
x_tok = token_embedding(xb)

assert x_tok.shape == (config.batch_size, config.ctx_len, config.d_model), "Embedding dimensions are incorrect"
print("Token embedding layer output shape is correct!")

**Reflect & discuss**:

Take a moment to think about what your embedding layer is doing before we move on:

*   What does each row of the embedding matrix represent? How does the output shape of the embedding layer relate to the input shape?
*  How does the embedding layer learn during training?
* What kind of information do the embedding vectors capture as the model trains?


Hint **Reflect & discuss**:

Take a moment to think about what your embedding layer is doing before we move on:

*   What does each row of the embedding matrix represent? How does the output shape of the embedding layer relate to the input shape?

*Each row of the embedding matrix (self.wte.weight) corresponds to a token in the vocabulary.
That row is a learned vector representation of that token, which the model will adjust during training so that tokens used in similar contexts end up with similar vectors. The embedding matrix has shape (vocab_size, d_model).
When we feed in a batch of token IDs shaped (batch_size, seq_len), the embedding layer looks up the appropriate rows for each token, producing an output tensor of shape (batch_size, seq_len, d_model).*

*  How does the embedding layer learn during training?

*The embedding weights are trainable parameters.
During the forward pass, token IDs are used to look up embeddings.
During backpropagation, gradients flow through those lookups and update the corresponding rows in the embedding matrix.*

* What kind of information do the embedding vectors capture as the model trains?

*As the model learns, embeddings start to capture statistical and semantic relationships between tokens — tokens that appear in similar contexts (e.g., “dog” and “cat”) move closer together in the embedding space.*


## Advanced Exercise: Implement full embedding layer

Now we'll combine token and position embeddings into a single embedding layer. You can reuse your token embedding code from before. Think about how to represent position embeddings so the model knows *where* each token appears in the sequence. This is a little tricky so feel free to click on the hints below the exercise for help.

In [None]:
class EmbeddingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.device = config.device

        # TODO: Create the token embedding layer (same as before)
        self.wte = nn.Embedding(...)

        # TODO: implement position embedding
        self.wpe = ...


    def forward(self, x):
        batch_size, seq_len = x.shape

        # TODO: Compute token and position embeddings
        x_tok = ...
        x_pos = ...

        # Combine token and position embeddings
        x_embeddings = x_tok + x_pos

        return x_embeddings


# Test your implementation
xb, yb = get_batch('train', config.ctx_len, config.batch_size, config.device)

embedding = EmbeddingLayer(config)
x_embedding = embedding(xb)

assert x_embedding.shape == (config.batch_size, config.ctx_len, config.d_model), "Embedding dimensions are incorrect"
print("Embedding layer output shape is correct!")

In [None]:
# @markdown **Click to see hint #1**
"""
For position embeddings, you can also use nn.Embedding.
Instead of the first dimension being equal to vocab size,
it should be equal to the context length (so you learn an
embedding for each position in a sequence)
"""

```python
# @markdown **Click to see hint #2**

"""
The output of the token embeddings forward pass has shape
(batch size x context length x model dimension).

For the forward pass of the position embeddings, you only
need to create a matrix of shape (context length by model
dimension) because nothing depends on the actual data in
each batch. Broadcasting will ensure you can still add
this matrix to the token embeddings.
"""

```

In [None]:
# Solution
class EmbeddingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.wte = nn.Embedding(config.vocab_size, config.d_model)
        self.wpe = nn.Embedding(config.ctx_len, config.d_model)
        self.device = config.device

    def forward(self, x):
        batch_size, seq_len = x.shape

        x_tok = self.wte(x)
        # print(x_tok.shape) uncomment this if you want to see the shape of above tensor
        x_pos = self.wpe(torch.arange(seq_len, device=self.device))
        # print(x_pos.shape)

        # Combine token and position embeddings
        x_embeddings = x_tok + x_pos

        return x_embeddings


# Testing your implementation
xb, yb = get_batch('train', config.ctx_len, config.batch_size, config.device)

embedding = EmbeddingLayer(config)
x_embedding = embedding(xb)

assert x_embedding.shape == (config.batch_size, config.ctx_len, config.d_model), "Embedding dimensions are incorrect"
print("Embedding layer output shape is correct!")

**Reflect and Discuss**


*   Why does the model need position embeddings in addition to token embeddings?

*   Why are we adding the token and position embeddings together, instead of concatenating them?



Hint **Reflect and Discuss**


*   Why does the model need position embeddings in addition to token embeddings?

*Token embeddings tell the model what each token is, but not where it appears in the sequence. Without position information, the model would treat text as a “bag of words,” unable to distinguish between different word orders (e.g., “dog bites man” vs. “man bites dog”). So, position embeddings give the model a sense of order and structure, which is essential for understanding sequences.*

*   Why are we adding the token and position embeddings together, instead of concatenating them?

*Both token and position embeddings have the same dimensionality (d_model), meaning each represents information in the same feature space. By adding them, we combine what the token is (its identity) and where it is (its position) into a single vector of the same size. This keeps the total embedding dimension fixed — so the next layers of the model (attention, feedforward, etc.) can process the combined information without any change in shape.*

# Attention

## Exercise: Implementing single headed causal self attention

Self-attention is a core mechanism in transformers that allows each position in a sequence to attend to all previous positions. The "causal" part ensures each position can only attend to past positions - this is crucial for language modeling.

In this exercise, you'll fill out the `SingleHeadCausalAttention` module below.  

The `__init__` method should define the key, query, and value projection layers.  A causal mask (`self.cmask`) is already provided for you (it’s a lower-triangular matrix of 1s that enforces the causal constrain).

The `forward(self, x)` should implement the full attention computation:
$$\textrm{attention}(K, V, Q) = \textrm{softmax}\left( c \odot \frac{Q K^\top}{\sqrt{d_k}} \right) V $$
where $c \odot \dots$ denotes the application of the causal mask.  To do this,

1. Project the input `x` into the K, Q, V matrices using the `self.key`, `self.query`, and `self.values` projections.
2. Compute scaled attention scores

$$\frac{Q K^\top}{\sqrt{d_k}} $$

3. Apply the causal mask. You can use `torch.masked_fill(...)` to apply the mask.  This function takes three arguments: the input matrix you want to mask, where you want to mask it (a boolean condition), and the value you want to mask with.  Think about what value you should fill masked positions with so that after the softmax they have (essentially) zero probability. It may be helpful to recall the softmax formula; the $i$-th component of a vector $u$ after a softmax is: $$ \textrm{softmax}(x)_i =  \frac{e^{x_i}}{\sum_j e^{x_j}}.$$

4. Apply the softmax and compute the weighted sum with V.


Hints:
1. Keep track of the tensor dimensions after each step!
2. You can transpose tensors in Pytorch by calling `A.transpose(dim_1, dim_2)` where `dim_1`, `dim_2` refer to the dimensions you want to transpose.


```python
class SingleHeadCausalAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Calculate the dimension for each attention head
        self.head_dim = config.d_model // config.n_heads

        # TODO: Initialize the Key, Query, and Value projections
        # Each projects from d_model to head_dim with no bias
        self.key = nn.Linear(..., ...,  bias=False)
        self.query = nn.Linear(..., ...,  bias=False)
        self.value = nn.Linear(..., ...,  bias=False)

        # Create causal mask (lower triangular matrix), you an refer to it by `self.cmask`
        self.register_buffer("cmask", torch.tril(torch.ones([config.ctx_len, config.ctx_len])))

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # TODO Step 1: Compute K, Q, V projections
        K = ...
        Q = ...
        V = ...

        # TODO Step 2: Compute attention scores
        attention_scores = ...

        # TODO Step 3: Apply the causal mask (you can use `torch.masked_fill(...)` here)
        ...

        # TODO Step 4: Applyl the softmax and compute the weighted sum with V
        ...

        return # Final output


# Test your implementation
config = Config(d_model=256, n_heads=8, ctx_len=16)
attention = SingleHeadCausalAttention(config)
x = torch.randn(2, 10, 256)  # (batch_size, seq_len, d_model)
output = attention(x)

assert output.shape == (2, 10, 32)  # head_dim = 256/8 = 32
print("Single-head causal attention output shape is correct!")

```

In [None]:
# Solution
class SingleHeadCausalAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.head_dim = config.d_model // config.n_heads
        self.key = nn.Linear(config.d_model, self.head_dim, bias=False)
        self.query = nn.Linear(config.d_model, self.head_dim, bias=False)
        self.value = nn.Linear(config.d_model, self.head_dim, bias=False)

        self.register_buffer("cmask", torch.tril(torch.ones([config.ctx_len, config.ctx_len])))


    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # Step 1: Compute K, Q, V projections
        K = self.key(x) # (batch_size, seq_len, head_dim)
        Q = self.query(x) # (batch_size, seq_len, head_dim)
        V = self.value(x) # (batch_size, seq_len, head_dim)

        # Step 2: Compute scaled attention scores
        attention_scores = Q @ K.transpose(-2, -1) * self.head_dim**-0.5  # (batch_size, seq_len, seq_len)

        masked_scores = torch.masked_fill(attention_scores, self.cmask[:seq_len, :seq_len]==0, float('-inf'))
        attention_weights = F.softmax(masked_scores, dim=-1)
        outputs = attention_weights @ V
        return outputs


# Test your implementation
config = Config(d_model=256, n_heads=8, ctx_len=16)
attention = SingleHeadCausalAttention(config)
x = torch.randn(2, 10, 256)  # (batch_size, seq_len, d_model)
output = attention(x)
assert output.shape == (2, 10, 32)  # head_dim = 256/8 = 32
print("Single-head causal attention output shape is correct!")

**Reflect & Discuss**

*   What does the causal mask accomplish? What would happen if we removed it?

*   How do the key, query, and value projections differ conceptually?
*  What do the attention weights represent?   


Hint **Reflect & Discuss**

*   What does the causal mask accomplish? What would happen if we removed it?

*The causal mask ensures that each token can only attend to itself and to earlier tokens in the sequence. This enforces the left-to-right flow required for autoregressive language modeling, where the model predicts the next token based only on past context. If the mask were removed, tokens could attend to future positions, effectively letting the model “see the answer” during training and making it unusable for text generation, where future tokens aren’t yet known.*

*   How do the key, query, and value projections differ conceptually?

*The query represents what the current token is trying to find out — the kind of information it’s seeking from the rest of the sequence. The keys represent what information each token has to offer, and the values carry that actual information to be shared. The attention mechanism compares queries to keys to determine which values are most relevant, combining them into a context-aware representation for each position.*

*  What do the attention weights represent?   

*The attention weights represent how much importance the model assigns to each token when computing a new representation for the current position. After applying the softmax, they form a probability distribution over all tokens in the sequence, where higher weights indicate tokens the model finds more relevant or informative. In effect, the attention weights show where the model is “looking” in the context — which past words it considers most useful for understanding or predicting the current one.*


# Multi-head self attention

In transformers, multi-head attention allows the model to attend to information from different representation subspaces simultaneously. Each head runs self-attention independently, and the outputs are concatenated and linearly projected back into the model’s hidden dimension.

## Exercise: implementing multi-head attention


In this exercise, you’ll implement the MultiHeadCausalAttention module using your SingleHeadCausalAttention from the previous exercise. You should not need to write more than a few lines of code here.

1. Define the attention heads. Use `nn.ModuleList(...)` to create a list of attention heads (instances of `SingleHeadCausalAttention`) that will act in parallel on the input.  
2. Define the output projection, using a linear layer, `self.linear`.
3. Implement the forward pass.  The input `x` (which is shape (batch_size, seq_len, d_model)) should be passed through each head. You then concatenate the outputs of each head (you can use `torch.cat(...)`) and pass it through the linear layer.

```python
class MultiHeadCausalAttention(nn.Module):

    def __init__(self, config):
        super().__init__()

        # TODO: Create multiple attention heads
        self.heads = nn.ModuleList([...])

        # TODO: Final linear projection (d_model → d_model)
        self.linear = ...

    def forward(self, x):
        # TODO: Pass input through all heads and concatenate the outputs
        # Hint: use torch.cat(...) to combine head outputs
        # Then apply the final linear layer
        ...


# Test your implementation
config = Config(d_model=256, n_heads=8, ctx_len=16)
mha = MultiHeadCausalAttention(config)

x = torch.randn(2, 10, 256)  # (batch_size=2, seq_len=10, d_model=256)
out = mha(x)

assert out.shape == (2, 10, 256)
print("Multi-head causal attention output shape is correct!")

```

In [None]:
# Solution
class MultiHeadCausalAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList([SingleHeadCausalAttention(config) for _ in range(config.n_heads)])

        self.linear = nn.Linear(config.d_model, config.d_model)


    def forward(self, x):
        y = torch.cat([h(x) for h in self.heads], dim=-1)
        y = self.linear(y)
        return y


# Testing your implementation
config = Config(d_model=256, n_heads=8, ctx_len=16)
mha = MultiHeadCausalAttention(config)

# Test with small batch
x = torch.randn(2, 10, 256)  # (batch_size=2, seq_len=10, d_model=256)
out = mha(x)
assert out.shape == (2, 10, 256)

# Define the full decoder block


Below, we’ll first define the feedforward network. This module applies two linear transformations with a ReLU activation in between.
The first layer expands the dimensionality by a factor of 4, and the second projects it back to the model dimension

In [None]:
class FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.l1 = nn.Linear(config.d_model, 4*config.d_model)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(4*config.d_model, config.d_model)

    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        return x


## Exercise: Decoder Block

Now you’ll implement a single decoder block, which is a core component that combines multi-head self-attention, the feed-forward network, and layer normalization.

Each sublayer (MultiHeadAttention and the feedforward network) uses a residual connection following the pattern:

x = x + sublayer(layer_norm(x))

This order (LayerNorm → sublayer → residual) is used in GPT-style decoders because it improves stability during training.

```python
class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.mha = MultiHeadCausalAttention(config)

        # TODO: Initialize layer normalization layers
        # Hint: use nn.LayerNorm
        self.ln1 = ...
        self.ffn = FFN(config)
        self.ln2 = ...

    def forward(self, x):
        # TODO: residual around attention
        x = ...

        # TODO: residual around FFN
        x = ...

        return x


# Test your implementation
config = Config(d_model=256)
ffn = FFN(config)
decoder = DecoderBlock(config)

# Test with random input
x = torch.randn(2, 10, 256)  # (batch_size, sequence_length, d_model)
output = decoder(x)

assert output.shape == x.shape
print("Decoder block output shape is correct!")

```

In [None]:
# Solution
class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mha = MultiHeadCausalAttention(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ffn = FFN(config)
        self.ln2 = nn.LayerNorm(config.d_model)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


# Test your implementation
config = Config(d_model=256)
decoder = DecoderBlock(config)

# Test with random input
x = torch.randn(2, 10, 256)  # (batch_size, sequence_length, d_model)
output = decoder(x)
assert output.shape == x.shape
print("Decoder block output shape is correct!")

# Define the transformer!

We're now ready to put the components together into our final decoder module that can actually generate text! This is the top-level module that:

* Embeds input tokens and adds positional information
* Processes them through multiple transformer layers
* Outputs predictions for the next token through the `forward(...)` function
* Can generate new sequences autoregressively through the `generate(...)` function

Your task to implement the token generation method of the Decoder class below. In training, the model sees the entire sequence and learns to predict each token from the previous ones. At inference time, however, we don’t have the next token — we must generate it one at a time, feeding each newly generated token back into the model. This iterative process is called autoregressive generation. At every step, the model looks at the most recent context (up to ctx_len tokens), predicts a probability distribution over the vocabulary, samples one token, appends it, and repeats until it has produced max_len new tokens.


In [None]:
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Stack of decoder blocks
        self.blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.n_layers)])

        # Final layer norm (normalize across d_model dimension)
        self.ln = nn.LayerNorm(config.d_model)

        # Linear projection from d_model to vocab_size
        self.lin = nn.Linear(config.d_model, config.vocab_size)

        # Embeddings
        self.emb = EmbeddingLayer(config)

        # Loss function for training
        self.L = nn.CrossEntropyLoss()
        self.ctx_len = config.ctx_len

        self.device = config.device # don't change this (for training model on right device)

    def forward(self, x, targets=None):
        """
        Args:
            x: Input tokens (B, T)
            targets: Optional target tokens (B, T)
        Returns:
            logits: Predictions (B, T, vocab_size)
            loss: Optional cross-entropy loss
        """
        batch_size, seq_len = x.shape

        # Embed tokens (token + positional embeddings)
        x = self.emb(x)

        # Process through the stack of transformer blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln(x)

        # Project from hidden dimension to vocabulary size
        logits = self.lin(x)

        # Compute loss if targets are provided
        if targets is None:
            loss = None
        else:
            # Reshape logits and targets for loss computation
            batch_size, seq_len, vocab_size = logits.shape
            logits = logits.view(batch_size*seq_len, vocab_size)
            targets = targets.view(batch_size*seq_len)

            # Compute loss
            loss = self.L(logits, targets)

        return logits, loss

    def generate(self, token_ids, max_len=256):
        """
        Generate new tokens given initial sequence of token IDs.

        Args:
            token_ids (torch.Tensor):
                The starting sequence of token IDs, shape (batch_size, seq_len).
            max_len (int, optional):
                Maximum number of new tokens to generate.
                Defaults to 256.

        Returns:
            torch.Tensor:
                The complete sequence of generated token IDs, including both the
                original input and the newly generated tokens.
                Shape: (batch_size, seq_len + max_len).
        """

        for _ in range(max_len):

            # TODO: Grab the last ctx_len tokens
            token_window = ...

            # TODO: Get model predictions
            logits, _ = ...

            # Only keep predictions for the last token
            logits = logits[:, -1, :]

            # Convert logits to probabilities
            probs = F.softmax(logits, dim=-1)

            # TODO: Sample the next token (hint, can use torch.multinomial)
            next_token = ...

            # TODO: Append next token to the sequence
            token_ids = ...

        return token_ids


# Testing your implementation
config = Config(
    vocab_size=100,
    d_model=256,
    ctx_len=64,
    n_layers=4
)
decoder = Decoder(config)

x = torch.randint(0, 100, (1, 10))
logits, loss = decoder(x, x)

out = decoder.generate(torch.tensor([[1, 2, 3]]), max_len=5)
assert out.shape == (1, 8)
print("Decoder generation output shape is correct!")

Let’s inspect the architecture of our decoder to make sure it matches what we expect.

```python
print(decoder)

```

In [None]:
# Solution
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Stack of decoder blocks
        self.blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.n_layers)])

        # Final layer norm (normalize across d_model dimension)
        self.ln = nn.LayerNorm(config.d_model)

        # Linear projection from d_model to vocab_size
        self.lin = nn.Linear(config.d_model, config.vocab_size)

        # Embeddings
        self.emb = EmbeddingLayer(config)

        # Loss function for training
        self.L = nn.CrossEntropyLoss()
        self.ctx_len = config.ctx_len

        self.device = config.device # don't change this (for training model on right device)

    def forward(self, x, targets=None):
        """
        Args:
            x: Input tokens (B, T)
            targets: Optional target tokens (B, T)
        Returns:
            logits: Predictions (B, T, vocab_size)
            loss: Optional cross-entropy loss
        """
        batch_size, seq_len = x.shape

        # Embed tokens (token + positional embeddings)
        x = self.emb(x)

        # Process through the stack of transformer blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln(x)

        # Project from hidden dimension to vocabulary size
        logits = self.lin(x)

        # Compute loss if targets are provided
        if targets is None:
            loss = None
        else:
            # Reshape logits and targets for loss computation
            batch_size, seq_len, vocab_size = logits.shape
            logits = logits.view(batch_size*seq_len, vocab_size)
            targets = targets.view(batch_size*seq_len)

            # Compute loss
            loss = self.L(logits, targets)

        return logits, loss

    def generate(self, token_ids, max_len=256):
        """
        Generate new tokens given initial sequence of token IDs.

        Args:
            token_ids (torch.Tensor):
                The starting sequence of token IDs, shape (batch_size, seq_len).
            max_len (int, optional):
                Maximum number of new tokens to generate.
                Defaults to 256.

        Returns:
            torch.Tensor:
                The complete sequence of generated token IDs, including both the
                original input and the newly generated tokens.
                Shape: (batch_size, seq_len + max_len).
        """

        for _ in range(max_len):

            # Grab the last ctx_len tokens
            token_window = token_ids[:, -self.ctx_len:]

            # Get model predictions
            logits, _ = self(token_window)

            # Only keep predictions for the last token
            logits = logits[:, -1, :]

            # Convert logits to probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample the next token (hint, can use torch.multinomial)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append next token to the sequence
            token_ids = torch.cat((token_ids, next_token), dim=1)

        return token_ids


# Testing your implementation
config = Config(
    vocab_size=100,
    d_model=256,
    ctx_len=64,
    n_layers=4
)
decoder = Decoder(config)

x = torch.randint(0, 100, (1, 10))
logits, loss = decoder(x, x)

out = decoder.generate(torch.tensor([[1, 2, 3]]), max_len=5)
out = decoder.generate(torch.tensor([[1, 2, 3]]), max_len=5)
assert out.shape == (1, 8)
print("Decoder output shape is correct!")


**Reflect & Discuss**

1.   What does `ctx_len` control, and what might happen if it’s too short or too long?
2. How does the training objective relate to generation quality? If the model’s loss during training is low, does that always mean generation will sound good?
3. During generation, what would happen if we always picked the most likely token instead of sampling?

Hint **Reflect & Discuss**

1.   What does `ctx_len` control, and what might happen if it’s too short or too long?

*ctx_len determines how many of the most recent tokens the model can “see” when predicting the next one—it’s the size of the model’s working memory. If it’s too short, the model quickly forgets earlier parts of the sequence and can lose coherence, producing text that drifts off topic or repeats. If it’s very long, the model retains more context but becomes slower and more memory-hungry, since attention scales quadratically with sequence length. In practice, ctx_len is a trade-off between contextual understanding and computational efficiency.*

2. How does the training objective relate to generation quality? If the model’s loss during training is low, does that always mean generation will sound good?

*The training objective teaches the model to predict the next token given its context, minimizing cross-entropy loss. A low loss means the model is good at local next-token prediction on the training data. But fluent generation depends on repeatedly applying those predictions in sequence. Small local mistakes can compound over many steps, leading to incoherence, repetition, or drift. So low training loss is necessary but not sufficient for natural, high-quality text—generation quality also depends on how well the model generalizes and how sampling is performed during inference.*

3. During generation, what would happen if we always picked the most likely token instead of sampling?

*Always choosing the highest-probability token (taking the argmax) makes the process deterministic. The model would produce the same output every time for a given prompt, but the text often becomes repetitive or formulaic, because small biases toward common words get reinforced at each step. Sampling from the probability distribution adds randomness and variety, allowing the model to explore alternative word choices and generate more natural, creative language, even though it introduces some unpredictability.*

# Train your model

In [None]:
import torch, gc

# Clear cached variables and free up GPU memory
gc.collect()
torch.cuda.empty_cache()

Now that our model architecture is complete, we can train it to predict and generate text.
Below, we define our model configuration, set up the optimizer, and run a simple training loop.
Every few hundred steps, we’ll evaluate performance on the validation set and print a short sample of generated text to see how learning is progressing.
By the end of training, we’ll use the fully trained model to generate a longer passage from scratch.

In [None]:
# Set up model
config = Config(d_model=252, n_heads=12, ctx_len=128, batch_size = 64, n_layers = 12, device='cuda')
config.set_vocab_size(vocab_size=len(chars))
model = Decoder(config).to(config.device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {n_params}")

# Define hyperparameters
learning_rate = 3e-4
max_iters = 6000
eval_interval = 200  # How often to evaluate
eval_iters = 100     # How many batches to use for evaluation

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for iter in range(max_iters):
    xb, yb = get_batch('train', config.ctx_len, config.batch_size, config.device)

    # Forward pass
    logits, loss = model(xb, yb)

    # Backward pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Evaluate periodically
    if iter % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            val_losses = []
            for _ in range(eval_iters):
                xb, yb = get_batch('val', config.ctx_len, config.batch_size, config.device)
                _, val_loss = model(xb, yb)
                val_losses.append(val_loss.item())
            avg_val_loss = sum(val_losses) / len(val_losses)

            print(f"step {iter}: train loss {loss.item():.4f}, val loss {avg_val_loss:.4f}")

            # Generate text sample to monitor learning progress
            context = torch.zeros((1, 1), dtype=torch.long, device=config.device)
            print(decode(model.generate(context, max_len=100)[0].tolist(), int_to_str))
            print('='*50)

        model.train()

# Final text generation
model.eval()
context = torch.zeros((1, 1), dtype=torch.long, device=config.device)
print("\nFinal generated text:")
print(decode(model.generate(context, max_len=500)[0].tolist(), int_to_str))

In [None]:
# Save our model
import os
if not os.path.exists('model.pth'):
   torch.save(model.state_dict(), 'model.pth')
else:
   print('Model file (model.pth) already exists!  Saving under a different name model_other.pth.')
   torch.save(model.state_dict(), 'model_other.pth')

# Evaluate the trained transformer

We provide a trained model with slightly larger dimensions and context size, which you are welcome to use! You can also comment out the below cell and use the model you trained above for this section if you were able to train it fully.

In [None]:
# to load a saved model, uncomment the below code

# url = "https://osf.io/dt6h4/download"   # replace with your actual OSF file link
# r = requests.get(url)
# with open("model_transformer.pth", "wb") as f:
#     f.write(r.content)

# config = Config(d_model=768, n_heads=12, ctx_len=512, batch_size = 64, n_layers = 12, device='cuda:0')
# config.set_vocab_size(vocab_size=len(chars))
# model = Decoder(config).to(config.device)
# model.load_state_dict(torch.load("transformer_model.pth"))

Let's see what your model generates! Try changing the **prompt** and **max_len** values below and observe how the text changes.


In [None]:
model.eval()
prompt_ids = torch.tensor([encode("Neuroscience is", str_to_int)], dtype=torch.long, device=config.device)
max_len = 512
print(decode(model.generate(prompt_ids, max_len=max_len)[0].tolist(), int_to_str))


Here you can generate text without sampling (by just taking the argmax of the logits at every point).

In [None]:
# Deterministic (argmax) generation
def generate_argmax(model, idx, max_len=200):
    for _ in range(max_len):
        idx_window = idx[:, -model.ctx_len:]
        logits, _ = model(idx_window)
        logits = logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        idx = torch.cat((idx, next_token), dim=1)
    return idx

prompt = "Neuroscience is"
prompt_ids = torch.tensor([[str_to_int[c] for c in prompt]], device=config.device)
sample = generate_argmax(model, prompt_ids.clone(), max_len=200)
print(decode(sample[0].tolist(), int_to_str))

**Reflect and Discuss**

* How does the model’s behavior differ when using sampling versus argmax?

* With smapling, how coherent or structured does the text sound? Does the model stay on topic, or drift?

* What happens if you start with a completely different prompt or punctuation?

## Extension: Add Temperature to Generation

You might have noticed that your model sometimes repeats itself or produces very predictable text.  
Let’s make generation more flexible by adding a **temperature** parameter that controls how random or “creative” the model is.

In [None]:
# Extended version of generate() with temperature control
def generate_with_temperature(self, idx, max_len=256, temperature=1.0):
    for _ in range(max_len):
        idx_window = idx[:, -self.ctx_len:]
        logits, _ = self(idx_window)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_token), dim=1)
    return idx

# Attach the new method to the existing model
from types import MethodType
model.generate_with_temperature = MethodType(generate_with_temperature, model)

In [None]:
model.eval()
prompt_ids = torch.tensor([encode("Neuroscience is", str_to_int)], dtype=torch.long, device=config.device)
max_len = 512
for temp in [0.3, 0.7, 1.0, 1.5, 2.0]:
    print(f"\nTemperature = {temp}")
    sample = model.generate_with_temperature(
        prompt_ids.clone(), max_len=200, temperature=temp
    )
    print(decode(sample[0].tolist(), int_to_str))
    print("=" * 60)