# Implement Transformer from Scratch

In this coding homework, you will:

- Implement a simple transformer model from scratch to enhance your understanding of how it works.
- Create a hand-designed transformer model capable of solving a basic problem. This will help you comprehend the various operations that transformers can perform.
- Analyze the attention patterns of a trained network to gain insights into how learned models often utilize features that differ greatly from those employed by humans.

Please note that a GPU is not necessary for this task. If you're using Colab, you can select the "Runtime" -> "Change runtime type" menu and choose "None" as the hardware accelerator.

**Note:** The same variables will be defined in different ways in various subparts of the homework. If you encounter errors stating that a variable has the wrong shape or a function is missing an argument, ensure that you have re-run the cells in that particular problem subpart.

In [None]:
"""
Setup, Configuration, and Helper Functions
===========================================
"""

from __future__ import annotations

import time
import json
import inspect
import random
import math
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional, Tuple, List, Any

import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

# =============================================================================
# Type Aliases
# =============================================================================
NDArrayFloat = npt.NDArray[np.floating[Any]]
TensorFloat = torch.Tensor

# =============================================================================
# Configuration Classes
# =============================================================================
@dataclass
class PlotConfig:
    """Configuration for matplotlib plots."""
    # Figure settings
    figure_width: float = 20.0
    figure_height: float = 5.0
    dpi: int = 100
    
    # Colormap settings
    colormap: str = "Reds"
    vmin: Optional[float] = None
    vmax: Optional[float] = None
    
    # Text settings
    title_fontsize: int = 12
    label_fontsize: int = 10
    
    # Axis settings
    show_xticks: bool = False
    show_yticks: bool = False
    tight_layout: bool = True
    
    # Grid settings
    nrows: int = 1
    ncols: int = 8
    
    # Save settings
    save_dir: str = "plots"
    save_format: str = "png"  # "png", "pdf", "svg"
    save_dpi: int = 150
    auto_save: bool = False  # If True, automatically save all plots
    add_timestamp: bool = True  # Add timestamp to filenames


@dataclass
class TrainingConfig:
    """Configuration for training loop."""
    num_epochs: int = 10_001
    learning_rate: float = 3e-2
    log_interval: int = 1000
    optimizer: str = "sgd"  # "sgd" or "adam"
    momentum: float = 0.0
    weight_decay: float = 0.0
    loss_fn: str = "mse"  # "mse" or "cross_entropy"


# Default configurations
PLOT_CONFIG = PlotConfig()
TRAIN_CONFIG = TrainingConfig()
plt.rcParams['figure.figsize'] = [PLOT_CONFIG.figure_width, PLOT_CONFIG.figure_height]
plt.rcParams['figure.dpi'] = PLOT_CONFIG.dpi

# Constants
RELATIVE_TOLERANCE = 1e-3
TEST_ITERATIONS = 10
TEST_MIN_SEQ_LEN = 1
TEST_MAX_SEQ_LEN = 4
TEST_INPUT_DIM = 5

# =============================================================================
# Utility Functions
# =============================================================================
def set_random_seed(seed: int = 42) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

_set_seed = set_random_seed  # Backward compatibility

# =============================================================================
# Visualization Functions
# =============================================================================
def rescale_and_plot(
    arr: NDArrayFloat,
    title: str = '',
    ax: Optional[Axes] = None,
    x_lab: Optional[str] = None,
    y_lab: Optional[str] = None,
    colormap: str = PLOT_CONFIG.colormap,
    vmin: Optional[float] = PLOT_CONFIG.vmin,
    vmax: Optional[float] = PLOT_CONFIG.vmax,
    title_fontsize: int = PLOT_CONFIG.title_fontsize,
    label_fontsize: int = PLOT_CONFIG.label_fontsize,
    show_xticks: bool = PLOT_CONFIG.show_xticks,
    show_yticks: bool = PLOT_CONFIG.show_yticks,
) -> None:
    """Plot a matrix as a heatmap with automatic [0, 1] rescaling."""
    assert arr.ndim == 2, f"arr must be 2D, got shape {arr.shape}"
    
    # Rescale to [0, 1]
    arr = arr - arr.min()
    if arr.max() > 0:
        arr = arr / arr.max()
    
    ax.imshow(arr, cmap=colormap, vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=title_fontsize)
    
    if not show_xticks:
        ax.set_xticks([])
    if not show_yticks:
        ax.set_yticks([])
    if x_lab is not None:
        ax.set_xlabel(x_lab, fontsize=label_fontsize)
    if y_lab is not None:
        ax.set_ylabel(y_lab, fontsize=label_fontsize)


def save_figure(
    fig: Figure,
    name: str,
    save_dir: str = PLOT_CONFIG.save_dir,
    save_format: str = PLOT_CONFIG.save_format,
    save_dpi: int = PLOT_CONFIG.save_dpi,
    add_timestamp: bool = PLOT_CONFIG.add_timestamp,
    run_name: Optional[str] = None,
) -> Path:
    """Save a matplotlib figure to the plots directory.
    
    Args:
        fig: Matplotlib figure to save.
        name: Base name for the file (without extension).
        save_dir: Directory to save plots (created if doesn't exist).
        save_format: File format ("png", "pdf", "svg").
        save_dpi: Resolution for rasterized formats.
        add_timestamp: If True, append timestamp to filename.
        run_name: Optional run identifier to organize plots by experiment.
    
    Returns:
        Path to the saved figure.
    """
    # Create directory structure
    save_path = Path(save_dir)
    if run_name:
        save_path = save_path / run_name
    save_path.mkdir(parents=True, exist_ok=True)
    
    # Build filename
    if add_timestamp:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{name}_{timestamp}.{save_format}"
    else:
        filename = f"{name}.{save_format}"
    
    filepath = save_path / filename
    fig.savefig(filepath, dpi=save_dpi, bbox_inches='tight', format=save_format)
    print(f"✓ Figure saved: {filepath}")
    return filepath


def get_current_figure() -> Figure:
    """Get the current matplotlib figure."""
    return plt.gcf()


# =============================================================================
# Training Functions
# =============================================================================
def train_loop(
    make_batch: Callable[[], Tuple[TensorFloat, TensorFloat]],
    input_dim: int,
    qk_dim: int,
    v_dim: int,
    pos_dim: Optional[int] = None,
    max_seq_len: Optional[int] = None,
    remove_cls: bool = False,
    num_epochs: int = TRAIN_CONFIG.num_epochs,
    lr: float = TRAIN_CONFIG.learning_rate,
    log_interval: int = TRAIN_CONFIG.log_interval,
    seed: Optional[int] = None,
) -> Tuple['PytorchTransformer', float]:
    """Train a PytorchTransformer on a given task."""
    if seed is not None:
        set_random_seed(seed)
    
    # Dimension validation
    assert input_dim > 0, f"input_dim must be positive, got {input_dim}"
    assert qk_dim > 0, f"qk_dim must be positive, got {qk_dim}"
    assert v_dim > 0, f"v_dim must be positive, got {v_dim}"
    
    transformer = PytorchTransformer(input_dim, qk_dim, v_dim, pos_dim, max_seq_len)
    optimizer = torch.optim.SGD(transformer.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    
    final_loss = 0.0
    for epoch in range(num_epochs):
        seq, target = make_batch()
        
        assert seq.dim() == 2, f"Input must be 2D (seq_len, input_dim), got {seq.shape}"
        assert seq.shape[1] == input_dim, f"Input dim mismatch: expected {input_dim}, got {seq.shape[1]}"
        
        optimizer.zero_grad()
        out = transformer(seq)
        if remove_cls:
            out = out[1:]
        
        assert out.shape == target.shape, f"Output {out.shape} != target {target.shape}"
        
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        final_loss = loss.item()
        
        if epoch % log_interval == 0:
            print(f'Step {epoch}: loss {final_loss:.6f}')
    
    return transformer, final_loss

# =============================================================================
# Comparison and Testing Functions
# =============================================================================
def compare_transformers(
    hand_transformer: 'NumpyTransformer',
    learned_transformer: 'PytorchTransformer',
    seq: NDArrayFloat,
    plot: bool = True,
    save_plot: bool = False,
    run_name: Optional[str] = None,
) -> Tuple[NDArrayFloat, NDArrayFloat]:
    """Compare hand-designed and learned transformers visually.
    
    Args:
        hand_transformer: Hand-designed NumpyTransformer.
        learned_transformer: Trained PytorchTransformer.
        seq: Input sequence.
        plot: If True, display plots.
        save_plot: If True, save plots to the plots directory.
        run_name: Optional run identifier for organizing saved plots.
    """
    assert seq.ndim == 2, f"seq must be 2D (seq_len, input_dim), got shape {seq.shape}"
    
    separator = '=' * 40
    print(f'{separator} Hand Designed {separator}')
    out_hand = hand_transformer.forward(
        seq, verbose=False, plot=plot, 
        save_plot=save_plot, plot_name="hand_designed", run_name=run_name
    )
    
    assert out_hand.shape[0] == seq.shape[0], \
        f"Output seq_len {out_hand.shape[0]} != input seq_len {seq.shape[0]}"

    # Extract learned weights (transpose due to PyTorch Linear convention)
    py_Km = learned_transformer.Km.weight.T.detach().numpy()
    py_Qm = learned_transformer.Qm.weight.T.detach().numpy()
    py_Vm = learned_transformer.Vm.weight.T.detach().numpy()
    py_pos = None
    if learned_transformer.pos is not None:
        py_pos = learned_transformer.pos.weight.detach().numpy()

    print(f'{separator}    Learned    {separator}')
    np_learned = NumpyTransformer(py_Km, py_Qm, py_Vm, py_pos)
    out_learned = np_learned.forward(
        seq, verbose=False, plot=plot,
        save_plot=save_plot, plot_name="learned", run_name=run_name
    )
    
    assert out_learned.shape == out_hand.shape, \
        f"Shape mismatch: hand={out_hand.shape}, learned={out_learned.shape}"
    
    return out_hand, out_learned


def test(seed: int = 42) -> None:
    """Verify NumpyTransformer and PytorchTransformer produce identical outputs."""
    set_random_seed(seed)
    qk_dim = np.random.randint(1, 5)
    v_dim = np.random.randint(1, 5)
    
    for i in range(TEST_ITERATIONS):
        Km = np.random.randn(TEST_INPUT_DIM, qk_dim)
        Qm = np.random.randn(TEST_INPUT_DIM, qk_dim)
        Vm = np.random.randn(TEST_INPUT_DIM, v_dim)
        
        if i < TEST_ITERATIONS // 2:
            pos_dim = np.random.randint(2, 4)
            pos = np.random.randn(TEST_MAX_SEQ_LEN, pos_dim)
            seq_dim = TEST_INPUT_DIM - pos_dim
        else:
            pos_dim = None
            pos = None
            seq_dim = TEST_INPUT_DIM

        seq = np.random.randn(np.random.randint(TEST_MIN_SEQ_LEN, TEST_MAX_SEQ_LEN + 1), seq_dim)
        out_np = NumpyTransformer(Km, Qm, Vm, pos).forward(seq, verbose=False)
        
        transformer = PytorchTransformer(seq_dim, qk_dim, v_dim, pos_dim, TEST_MAX_SEQ_LEN)
        state_dict = transformer.state_dict()
        state_dict['Km.weight'] = torch.FloatTensor(Km.T)
        state_dict['Qm.weight'] = torch.FloatTensor(Qm.T)
        state_dict['Vm.weight'] = torch.FloatTensor(Vm.T)
        if pos is not None:
            state_dict['pos.weight'] = torch.FloatTensor(pos)
        transformer.load_state_dict(state_dict)
        out_py = transformer(torch.FloatTensor(seq)).detach().numpy()
        
        if not np.allclose(out_np, out_py, rtol=RELATIVE_TOLERANCE):
            print('ERROR: Implementation mismatch!')
            print(f'NumPy output: {out_np}')
            print(f'PyTorch output: {out_py}')
            raise ValueError('NumPy and PyTorch outputs do not match')
    
    print('✓ All equivalence tests passed!')
    
    # Save test results for grading
    set_random_seed(1998)
    test_transformer = PytorchTransformer(7, 4, 3, 2, 9)
    o = test_transformer(torch.randn(8, 7))
    TO_SAVE["torch_transformer_shape"] = list(o.shape)
    TO_SAVE["torch_transformer_value"] = o.view(-1).tolist()[2:7]
    TO_SAVE["torch_transformer_init"] = inspect.getsource(PytorchTransformer.__init__)
    TO_SAVE["torch_transformer_forward"] = inspect.getsource(PytorchTransformer.forward)

# =============================================================================
# Submission Data Storage
# =============================================================================
TO_SAVE: dict[str, Any] = {"time": time.time()}

## Implement a Simple Transformer

Below, you'll find a simple transformer implementation in Numpy that we have provided for you. It's important to note that this implementation is different from a Transformer in real applications. The differences include:

- Only a single layer with a single head is in the network.
- There are no residual connections.
- There is no normalization or dropout.
- We concatenate the positional encoding rather than adding it to the inputs.
- There are no activation functions or MLP layers.
- It does not support attention masking.
- The input is a single sequence instead of a batch. So there is no need to implement padding.

To ensure that you understand the transformer model fully, your task is to **implement a PyTorch equivalent model**. You don't need to include the printing and plotting code found in the Numpy version. **You should implement a vectorized version of the attention operation**, meaning that you should calculate all attention scores at once, rather than looping over keys. Once you have completed your implementation, make sure it passes the tests included in the cell below.

Implement the `PytorchTransformer` class. It should be identical to the forward pass of the `NumpyTransformer` class.

**Hint:** The attention operation should be implemented as:

$$\mathrm{softmax}(\dfrac{QK^T}{ \sqrt{d_k}}) \cdot V$$

where the softmax is applied to the last dimension, meaning that the softmax is applied independently to each query's scores.

In [None]:
"""
Transformer Implementations
===========================
NumPy (reference) and PyTorch (student implementation) versions of a
simplified single-head, single-layer Transformer.
"""


class NumpyTransformer:
    """Reference implementation of simplified Transformer in NumPy."""
    
    def __init__(
        self,
        Km: NDArrayFloat,
        Qm: NDArrayFloat,
        Vm: NDArrayFloat,
        pos: Optional[NDArrayFloat] = None,
    ) -> None:
        """Initialize the NumPy Transformer."""
        # Dimension validation
        assert Km.ndim == 2, f"Km must be 2D, got shape {Km.shape}"
        assert Qm.ndim == 2, f"Qm must be 2D, got shape {Qm.shape}"
        assert Vm.ndim == 2, f"Vm must be 2D, got shape {Vm.shape}"
        assert Km.shape[0] == Qm.shape[0] == Vm.shape[0], \
            f"Matrices must have same input_dim: Km={Km.shape[0]}, Qm={Qm.shape[0]}, Vm={Vm.shape[0]}"
        assert Km.shape[1] == Qm.shape[1], \
            f"Km and Qm must have same qk_dim: {Km.shape[1]} vs {Qm.shape[1]}"
        
        if pos is not None:
            assert pos.ndim == 2, f"pos must be 2D, got shape {pos.shape}"
        
        self.Km = Km
        self.Qm = Qm
        self.Vm = Vm
        self.pos = pos
        self.qk_dim = Qm.shape[1]

    def forward(
        self,
        seq: NDArrayFloat,
        verbose: bool = False,
        plot: bool = False,
        save_plot: bool = False,
        plot_name: str = "transformer_attention",
        run_name: Optional[str] = None,
    ) -> NDArrayFloat:
        """Compute forward pass: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) @ V
        
        Args:
            seq: Input sequence of shape (seq_len, input_dim).
            verbose: If True, print intermediate computation steps.
            plot: If True, visualize attention matrices.
            save_plot: If True, save the plot to the plots directory.
            plot_name: Base name for the saved plot file.
            run_name: Optional run identifier for organizing saved plots.
        """
        assert seq.ndim == 2, f"seq must be 2D (seq_len, input_dim), got shape {seq.shape}"
        
        seq_len, input_dim = seq.shape
        
        # Concatenate positional encodings if provided
        if self.pos is not None:
            assert seq_len <= self.pos.shape[0], \
                f"seq_len {seq_len} > max pos length {self.pos.shape[0]}"
            seq = np.concatenate([seq, self.pos[:seq_len]], axis=-1)
        
        # Project to Q, K, V spaces
        K = seq @ self.Km
        Q = seq @ self.Qm
        V = seq @ self.Vm
        
        assert K.shape == (seq_len, self.qk_dim), f"K shape: expected {(seq_len, self.qk_dim)}, got {K.shape}"
        assert Q.shape == (seq_len, self.qk_dim), f"Q shape: expected {(seq_len, self.qk_dim)}, got {Q.shape}"
        
        if verbose:
            print(f'K (Keys):\n{K.tolist()}')
            print(f'Q (Queries):\n{Q.tolist()}')
            print(f'V (Values):\n{V.tolist()}')
        
        if plot or save_plot:
            fig, axs = plt.subplots(nrows=1, ncols=8)
            fig.tight_layout()
            self._plot_all(axs, K, Q, V)
        
        # Compute attention (non-vectorized for educational clarity)
        outputs = []
        attn_weights = []
        
        for i, q in enumerate(Q):
            if verbose:
                print(f'Item {i}: Computing attention for query {q}')
            
            dot = K @ q / np.sqrt(self.qk_dim)
            if verbose:
                print(f'  Dot products (q · K): {dot}')
            
            exp_dot = np.exp(dot)
            softmax_dot = exp_dot / np.sum(exp_dot)
            
            if verbose:
                print(f'  Attention weights: {softmax_dot}')
            
            attn_weights.append(softmax_dot)
            out_i = softmax_dot @ V
            
            if verbose:
                print(f'  Output: {out_i}')
            
            outputs.append(out_i)
        
        if plot or save_plot:
            rescale_and_plot(np.array(attn_weights).T, 'Attn', axs[6], x_lab='Q', y_lab='K')
            rescale_and_plot(np.array(outputs).T, 'Out', axs[7], x_lab='seq', y_lab='d_v')
            if save_plot:
                save_figure(fig, plot_name, run_name=run_name)
            if plot:
                plt.show()
            else:
                plt.close(fig)
        
        output = np.array(outputs)
        assert output.shape == (seq_len, V.shape[1]), \
            f"Output shape: expected {(seq_len, V.shape[1])}, got {output.shape}"
        
        return output
    
    def _plot_all(self, axs, K: NDArrayFloat, Q: NDArrayFloat, V: NDArrayFloat) -> None:
        """Plot weight matrices and K, Q, V projections."""
        rescale_and_plot(self.Km.T, 'Km', axs[0], x_lab='d_in', y_lab='d_qk')
        rescale_and_plot(self.Qm.T, 'Qm', axs[1], x_lab='d_in', y_lab='d_qk')
        rescale_and_plot(self.Vm.T, 'Vm', axs[2], x_lab='d_in', y_lab='d_v')
        rescale_and_plot(K.T, 'K', axs[3], x_lab='seq', y_lab='d_qk')
        rescale_and_plot(Q.T, 'Q', axs[4], x_lab='seq', y_lab='d_qk')
        rescale_and_plot(V.T, 'V', axs[5], x_lab='seq', y_lab='d_v')

class PytorchTransformer(nn.Module):
    """PyTorch implementation of simplified single-head Transformer."""
    
    def __init__(
        self,
        input_dim: int,
        qk_dim: int,
        v_dim: int,
        pos_dim: Optional[int] = None,
        max_seq_len: int = 10,
    ) -> None:
        """Initialize the PyTorch Transformer."""
        super().__init__()
        
        # Dimension validation
        assert input_dim > 0, f"input_dim must be positive, got {input_dim}"
        assert qk_dim > 0, f"qk_dim must be positive, got {qk_dim}"
        assert v_dim > 0, f"v_dim must be positive, got {v_dim}"
        
        self._input_dim = input_dim
        self._max_seq_len = max_seq_len
        
        if pos_dim is not None:
            self.pos: Optional[nn.Embedding] = nn.Embedding(max_seq_len, pos_dim)
        else:
            self.pos = None
        
        total_input_dim = input_dim + (pos_dim if pos_dim is not None else 0)

        ########################################################################
        # TODO: Define query, key, value projection layers Qm, Km, Vm.
        #       Each of them is a linear projection without bias
        ########################################################################
        ########################################################################

        self.d_k = qk_dim

    def forward(self, seq) -> torch.Tensor:
        """
        Transformer forward pass

        Inputs: seq is a torch tensor of shape (seq_len, input_dim).
        Outputs: a torch tensor of shape (seq_len, v_dim), the output of the attention operation
        """
        out = None
        ################################################################################################
        # TODO: Implement the forward pass of the `PytorchTransformer` class.
        #       The forward pass should be identical to the forward pass of the
        #       `NumpyTransformer` class.
        #
        # Hint: The attention operation should be implemented as
        #       If `pos` exists, it should be concatenated to the input sequence.
        #################################################################################################
        ################################################################################################
        # END OF YOUR CODE
        ################################################################################################
        
        assert out.shape[0] == seq_len, f"Output seq_len {out.shape[0]} != input seq_len {seq_len}"
        return out


# Run equivalence tests to verify implementation
test()

## Self-Attention: Attention by Content

In this coding homework, we will explore how Transformers can attend to different tokens in a variable-length sequence based on their contents. We will do this by **implementing a Transformer that performs the *identity* operation on a sequence of one-hot vectors**. We will then compare the performance and weights of this hand-coded Transformer with those of a PyTorch model trained on the same task.

To hand-design the Transformer, we will **choose values for `Km`, `Qm`, and `Vm` that enable the model to attend to the content of each token in the input sequence**. We will then use this Transformer to process several example data points, and verify that the output matches the input.

Once your hand-written Transformer is working correctly, we will run the PyTorch training loop to train a model on the identity operation task. We will then compare the weights and intermediate outputs of this model with those of our hand-coded transformer, and comment on their similarities and differences. Note that when we generate plots, we will rescale the range of the weights and outputs to 0-1, so we can compare their relative values without comparing absolute values.

The test cases for our hand-coded transformer are as follows:

```
Input sequence -->   Output sequence
[A, B, C, C]   -->   [A, B, C, C]
[C, A, C]      -->   [C, A, C]
[B, B, C]      -->   [B, B, C]
```

We have provided some hints below, but to enhance your understanding of attention and the Transformer, we highly recommend attempting this problem to the best of your abilities before referring to the hints.

In [None]:
#@title Hints

# Hint 1: To attend to a specific element, ensure that its pre-softmax score is
#         significantly higher than that of the other elements.
softmax = lambda x: np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)
print('='*20, 'Hint 1', '='*20)
print('Selecting index 0', softmax(np.array([9, 0, 0])))
print('Selecting index 1', softmax(np.array([-3, 5, -5])))


# Hint 2: Attending to a particular element is more manageable if the keys are
#         orthogonal.
print('='*20, 'Hint 2', '='*20)
keys = np.array([[2, 0], [0, 1]])  # Orthogonal
q = np.array([5, 0])
print('Selecting index 0', softmax(q @ keys))
q = np.array([0, 5])
print('Selecting index 1', softmax(q @ keys))


# Hint 3: You can use the following helper functions to test the keys, queries,
#         and values produced by your matrix for each valid sequence element.
#  Km, Qm, Vm, and are the matrices you will define below.
all_token_seq = np.eye(3)  # Each row is a sequence element. The identity corresponds to [A, B, C].
get_K = lambda: all_token_seq @ Km  # Each row of the output is a key
get_Q = lambda: all_token_seq @ Qm  # Each row of the output is a query
get_V = lambda: all_token_seq @ Vm  # Each row of the output is a value


# Hint 4: To test different attention weights, use the softmax function defined
#         above.


# Hint 5: When there are repeated elements in a sequence with the same content,
#         attending to all of them rather than a single one will be simpler.
#         Since they have the same content, taking a "weighted average" over
#         values weighted by attention scores will produce the same output as
#         attending to a single one.

In [None]:
# Token definitions (one-hot encodings)
A = np.array([1, 0, 0])
B = np.array([0, 1, 0])
C = np.array([0, 0, 1])
tokens = [A, B, C]

################################################################################
# TODO: Write Numpy arrays for `Km`, `Qm`, and `Vm`.
#       The dimensions should be (input_dim, qk_dim), (input_dim, qk_dim), and
#       (input_dim, v_dim), respectively.
#       In this case, input_dim = 3, and v_dim = 3. qk_dim can be any value you
#       choose, but 3 is a reasonable choice.
################################################################################
############################################ END OF YOUR CODE ##################

# Dimension validation
assert Km.shape[0] == 3 and Qm.shape[0] == 3 and Vm.shape == (3, 3), \
    f"Shape error: Km={Km.shape}, Qm={Qm.shape}, Vm={Vm.shape}"
print(f"✓ Matrices: Km {Km.shape}, Qm {Qm.shape}, Vm {Vm.shape}")


def generate_test_cases_identity(
    tokens: List[NDArrayFloat],
    max_len: int = 7,
) -> Tuple[NDArrayFloat, NDArrayFloat]:
    """Generate random test cases for identity task."""
    seq_len = np.random.randint(1, max_len)
    input_arr = np.stack(random.choices(tokens, k=seq_len))
    return input_arr, input_arr  # Identity: output = input


# Test implementation
print("Running identity task tests...")
for i in range(10):
    seq, expected_out = generate_test_cases_identity(tokens)
    np_transformer = NumpyTransformer(Km, Qm, Vm)
    out = np_transformer.forward(seq)
    assert np.allclose(out, expected_out, rtol=RELATIVE_TOLERANCE), \
        f"Test {i} failed: out shape {out.shape}, max error {np.abs(out - expected_out).max():.6f}"
print("✓ All 10 identity tests passed!")

# Save results for grading
set_random_seed(1997)
seq, _ = generate_test_cases_identity(tokens)
np_transformer = NumpyTransformer(Km, Qm, Vm)
out = np_transformer.forward(seq, verbose=False)
TO_SAVE["attention_by_content"] = out.reshape(-1).tolist()
TO_SAVE["attention_by_content_Q"] = Qm.reshape(-1).tolist()
TO_SAVE["attention_by_content_K"] = Km.reshape(-1).tolist()
TO_SAVE["attention_by_content_V"] = Vm.reshape(-1).tolist()

In [None]:
# Compare hand-designed and trained transformers
def make_batch_identity(tokens: List[NDArrayFloat] = tokens, max_len: int = 7):
    """Create a training batch for the identity task."""
    seq, target = generate_test_cases_identity(tokens, max_len=max_len)
    return torch.FloatTensor(seq), torch.FloatTensor(target)

set_random_seed(227)

A = np.array([1, 0, 0])
B = np.array([0, 1, 0])
C = np.array([0, 0, 1])
transformer_py, loss = train_loop(make_batch_identity, input_dim=len(A), qk_dim=Km.shape[1], v_dim=Vm.shape[1])
seq = np.stack([A, B, B, C, C])
print("seq:", seq)
compare_transformers(np_transformer, transformer_py, seq)

### Question

In the figure provided, compare the variables of your hand-designed Transformer with those of the learned Transformer. **Identify the similarities and differences between the two sets of variables and provide a brief explanation for each difference**. Please include your answers in your written submission for this assignment.

## Self-Attention: Attention by Position

In Transformers, tokens can decide what other tokens to attend to by looking at their positions. In this section, we'll explore how this works by **hand-designing a Transformer for the task of copying the first token of a sequence across the entire sequence.**

To accomplish this, we'll add a positional encoding to the input sequence. Transformers typically use a sinusoidal positional encoding or a learned positional encoding, but we'll **set the weight by hand to any value we choose**. These positional encodings will get concatenated to the input sequence inside the Transformer. For simplicity, we'll *concatenate* the positional encoding to the input embeddings instead of adding it.

Here are the example data points (where `A`, `B`, and `C` are vectors and `A:pos_0` represents the concatenation between vectors `A` and `pos_0`):

```
Input sequence --> Input sequence with positional encoding --> Output sequence
[A, B, C, C]   --> [A:pos_0, B:pos_1, C:pos_2, C:pos_3]    --> [A, A, A, A]
[C, A, C]      --> [C:pos_0, A:pos_1, C:pos_2]             --> [C, C, C]
[B, B, C]      --> [B:pos_0, B:pos_1, C:pos_2]             --> [B, B, B]
```

Once you've passed the test cases, run the training loop below to train the PyTorch model.

We have provided some hints below, but to enhance your understanding of attention and the Transformer, we highly recommend attempting this problem to the best of your abilities before referring to the hints.

In [None]:
#@title Hints

# Hint 1: All hints from the previous part still apply.


# Hint 2: If you only want to use part of the information in a sequence element,
#         choose key/query/value matrices which remove the unwanted information.
seq = np.array([[1, 2, 3]])  # A sequence of length 1 with a 3-d element
Qm = np.array([[1, 0], [0, 0], [0, 1]])
print('Selecting only the first and last vector elements', seq @ Qm)


# Hint 3: You can use the following helper functions to test what keys, queries,
#         and values would be produced by your matrix.
# You will need to provide a sequence (e.g. np.stack([A, B, C])). Km, Qm, Vm, and pos are the matrices you will define below.
get_K = lambda seq: np.concatenate([seq, pos[:seq.shape[0]]], axis=1) @ Km # Each row of the output is a key
get_Q = lambda seq: np.concatenate([seq, pos[:seq.shape[0]]], axis=1) @ Qm # Each row of the output is a query
get_V = lambda seq: np.concatenate([seq, pos[:seq.shape[0]]], axis=1) @ Vm # Each row of the output is a value


In [None]:
A = np.array([1, 0, 0])
B = np.array([0, 1, 0])
C = np.array([0, 0, 1])
tokens = [A, B, C]

################################################################################
# TODO: Implement numpy arrays for Km, Qm, and Vm and pos.
#       The shape of Km, and Qm are [input_dim + pos_dim, qk_dim].
#       The shape of Vm is [input_dim + pos_dim, v_dim].
#       The shape of pos is [max_len, pos_dim].
#       In this case, input_dim = 3, and v_dim = 3. qk_dim can be any value you
#       choose, but 1 is a reasonable choice. max_len is the maximum sequence
#       length you will encounter, 4 in this case.
#       pos_dim can be any value you choose, but 4 is a resonable choice.
################################################################################
############################################ END OF YOUR CODE ##################

# Dimension validation
pos_dim = pos.shape[1]
total_dim = 3 + pos_dim
assert Km.shape[0] == total_dim and Qm.shape[0] == total_dim and Vm.shape == (total_dim, 3), \
    f"Shape error: Km={Km.shape}, Qm={Qm.shape}, Vm={Vm.shape}, expected input_dim={total_dim}"
print(f"✓ Matrices: pos {pos.shape}, Km {Km.shape}, Qm {Qm.shape}, Vm {Vm.shape}")


def generate_test_cases_first(
    tokens: List[NDArrayFloat],
    max_len: int = 5,
) -> Tuple[NDArrayFloat, NDArrayFloat]:
    """Generate test cases for copy-first-token task."""
    seq_len = np.random.randint(1, max_len)
    input_arr = np.stack(random.choices(tokens, k=seq_len))
    expected_out = np.stack([input_arr[0]] * seq_len)
    return input_arr, expected_out


# Test implementation
print("Running copy-first-token task tests...")
for i in range(10):
    seq, expected_out = generate_test_cases_first(tokens)
    np_transformer = NumpyTransformer(Km, Qm, Vm, pos=pos)
    out = np_transformer.forward(seq)
    assert np.allclose(out, expected_out, rtol=RELATIVE_TOLERANCE), \
        f"Test {i} failed: max error {np.abs(out - expected_out).max():.6f}"
print("✓ All 10 copy-first-token tests passed!")

# Save results for grading
set_random_seed(2017)
seq, _ = generate_test_cases_first(tokens)
np_transformer = NumpyTransformer(Km, Qm, Vm, pos=pos)
out = np_transformer.forward(seq)
TO_SAVE["attention_by_position"] = out.reshape(-1).tolist()
TO_SAVE["attention_by_position_pos"] = pos.reshape(-1).tolist()
TO_SAVE["attention_by_position_Q"] = Qm.reshape(-1).tolist()
TO_SAVE["attention_by_position_K"] = Km.reshape(-1).tolist()
TO_SAVE["attention_by_position_V"] = Vm.reshape(-1).tolist()

In [None]:
# Compare hand-designed and trained transformers
def make_batch_first(tokens: List[NDArrayFloat] = tokens, max_len: int = 5):
    """Create a training batch for the copy-first-token task."""
    seq, target = generate_test_cases_first(tokens, max_len=max_len)
    return torch.FloatTensor(seq), torch.FloatTensor(target)

pos_dim = pos.shape[1]
transformer_py, loss = train_loop(make_batch_first, input_dim=len(A), qk_dim=Km.shape[1], v_dim=Vm.shape[1], pos_dim=pos_dim, max_seq_len=pos.shape[0])
seq = np.stack([A, B, B])
out_np, out_py = compare_transformers(np_transformer, transformer_py, seq)
print("seq:", seq)
print(f'Out (Hand designed) \n {np.round(out_np, 2)}')
print(f'Out (Learned) \n {np.round(out_py, 2)}')

### Question

In the figure provided, compare the variables of your hand-designed Transformer with those of the learned Transformer. **Identify the similarities and differences between the two sets of variables and provide a brief explanation for each.** Please include your findings in your written submission for this assignment.

## Generate the Submission Log

Please download `submission_log.json` and submit it to Gradescope.

In [None]:
with open("submission_log.json", "w", encoding="utf-8") as f:
    json.dump(TO_SAVE, f)

## (Optional) Self-Attention: Attention by Content and Positoin

Finally, we'll explore how transformers can attend to tokens by looking at both their position and their content. In this section, we'll design a transformer for the following task: given a sequence of tokens, output a positive number for every unique token and a negative number for every repeated token.

To make implementing this easier, we'll add a CLS token to the beginning of the sequence. We will ignore the output of the CLS token index, which means we can use the CLS token to represent whatever we want. (In practice, the CLS token is often thought of as a representation of the entire sequence, but you can use it however is useful.)


\# Example data points (in each case, A, B, and C are vectors. A:pos_0 represents concatenation between vectors A and pos_0. The target outputs shown are +/-1, but any number with the right sign is fine. "Ignore" means that the output can be anything and will not be used to compute the loss.): \
Input sequence --> Input sequence with CLS and pos encoding --> Output sequence \
[A, B, C, C] --> [CLS: pos_0, A:pos_1, B:pos_2, C:pos_3, C:pos_4] --> [Ignore, 1, 1, -1, -1] \
[C, A, C] --> [CLS: pos_0, C:pos_1, A:pos_2, C:pos_3] --> [Ignore, -1, 1, 1] \
[B, B, C] --> [CLS: pos_0, B:pos_1, B:pos_2, C:pos_3] --> [Ignore, -1, -1, 1]


Once the test cases pass, run the training loop below a few times to train the PyTorch model. Comment on the similarities and differences between the weights and intermediate outputs of the learned and hand-coded model.

In [None]:
A = np.array([1, 0, 0, 0])
B = np.array([0, 1, 0, 0])
C = np.array([0, 0, 1, 0])
CLS = np.array([0, 0, 0, 1])
tokens = [A, B, C]

In [None]:
# Hints (feel free to ignore this block if it's not useful)

# Hint 1: All hints from the previous part still apply.

# Hint 2: To check if an array is unique, use what you discovered in the "select by content" part to find rows with the same value and
# what you learned in the "select by position" part to NOT select the key which comes from the same position as the query.

# Hint 3: If you need an offset value, consider using the CLS token The CLS token is the first token in a sequence, and it is orthogonal
# to all other tokens. This means you can create a query or value which selects it but not any othe token (e.g. by putting 0s in all
# indexes except the index where only CLS has a 1).

# Hint 4: You can use the following helper functions to test what keys, queries, and values would be produced by your matrix.
# You will need to provide a sequence (e.g. np.stack([A, B, C])). Km, Qm, Vm, and pos are the matrices you will define below.
get_K = lambda seq: np.concatenate([np.stack([CLS] + list(seq)), pos[:seq.shape[0]+1]], axis=1) @ Km # Each row of the output is a key
get_Q = lambda seq: np.concatenate([np.stack([CLS] + list(seq)), pos[:seq.shape[0]+1]], axis=1) @ Qm # Each row of the output is a query
get_V = lambda seq: np.concatenate([np.stack([CLS] + list(seq)), pos[:seq.shape[0]+1]], axis=1) @ Vm # Each row of the output is a value


In [None]:
################################################################################################
# TODO: Implement numpy arrays for Km, Qm, and Vm and pos.
#      The dimensions of Km, and Qm are (input_dim + pos_dim, qk_dim).
#      The dimensions of Vm are (input_dim + pos_dim, v_dim).
#      The dimensions of pos are (max_len + 1, pos_dim). (Each row is a position vector.)
#      In this case, input_dim = 4, and v_dim = 1. qk_dim can be any value you choose, but 8 is
#      a reasonable choice. max_len is the maximum sequence length you will encounter (before CLS is added),
#      4 in this case.  pos_dim can be any value you choose, but 4 is a reasonable choice.
#################################################################################################
############################################ END OF YOUR CODE ####################################

def generate_test_cases_unique(tokens: List[NDArrayFloat], max_len: int = 5):
    """Generate test cases for unique token detection."""
    seq_len = np.random.randint(1, max_len)
    input_arr = np.stack(random.choices(tokens, k=seq_len))
    expected_out = np.stack([1 if np.sum(np.min(input_arr == x, axis=1)) == 1 else -1 for x in input_arr]).reshape(-1, 1)
    input_arr = np.stack([CLS] + list(input_arr))
    return input_arr, expected_out


# Test implementation
print("Running unique token detection tests...")
for i in range(10):
    seq, expected_out = generate_test_cases_unique([A, B, C])
    np_transformer = NumpyTransformer(Km, Qm, Vm, pos)
    out = np_transformer.forward(seq)
    assert np.allclose(np.sign(out[1:]), expected_out, rtol=RELATIVE_TOLERANCE), \
        f"Test {i} failed: got {np.sign(out[1:]).flatten()}, expected {expected_out.flatten()}"
print("✓ All 10 unique token detection tests passed!")


In [None]:
# Compare hand-designed and trained transformers
# Note: The PyTorch model must output exactly +/-1, not just the sign.
def make_batch_unique(tokens: List[NDArrayFloat] = tokens, max_len: int = 5):
    seq, target = generate_test_cases_unique(tokens, max_len=max_len)
    return torch.FloatTensor(seq), torch.FloatTensor(target)

pos_dim = pos.shape[1]
transformer_py, loss = train_loop(make_batch_unique, input_dim=len(A), qk_dim=Km.shape[1], v_dim=Vm.shape[1], pos_dim=pos_dim, max_seq_len=pos.shape[0], remove_cls=True)
seq = np.stack([CLS, A, B, C, C])
expected_out = np.stack([1, 1, -1, -1]).reshape(-1, 1)
out_npy, out_pyt = compare_transformers(np_transformer, transformer_py, seq)
out_npy = np.sign(out_npy[1:])
out_pyt = np.sign(out_pyt[1:])

# Visualize comparison (CLS token excluded)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(out_npy.T, vmin=-1, vmax=1)
plt.title('Hand-Designed Transformer')
plt.xticks([])
plt.yticks([])
plt.xlabel('Sequence')
plt.ylabel('Output')
plt.subplot(1, 3, 2)
plt.imshow(out_pyt.T, vmin=-1, vmax=1)
plt.title('Trained Transformer')
plt.xticks([])
plt.yticks([])
plt.xlabel('Sequence')
plt.ylabel('Output')
plt.subplot(1, 3, 3)
plt.imshow(expected_out.T, vmin=-1, vmax=1)
plt.title('Expected Output')
plt.xticks([])
plt.yticks([])
plt.xlabel('Sequence')
plt.ylabel('Output')
plt.show()
