# Neural Networks: Linear Layers & Matrix Multiplication

**Inference Engineering Series - Notebook 1**

---

Welcome to the first notebook in the Inference Engineering series. Before we can understand how large language models (LLMs) generate text, process tokens, or manage memory during inference, we need to understand the most fundamental operation in neural networks: **matrix multiplication**.

Every forward pass through a neural network is, at its core, a sequence of matrix multiplications interspersed with non-linear activation functions. Understanding this deeply will help you reason about compute costs, memory bandwidth, and optimization opportunities later in this series.

## What You'll Learn

1. **How matrix multiplication works** - step by step, from first principles
2. **What a linear layer does** - the transformation `y = Wx + b`
3. **How to build a neural network from scratch** - using only NumPy
4. **How PyTorch implements the same operations** - and why it matches
5. **How data flows through multiple layers** - visualizing hidden states
6. **Why this matters for inference** - compute costs and memory access patterns

## Part 1: Matrix Multiplication from First Principles

Let's start with the most basic building block. When we multiply two matrices $A$ and $B$, each element of the result $C$ is computed as a **dot product** between a row of $A$ and a column of $B$.

$$C_{ij} = \sum_{k} A_{ik} \cdot B_{kj}$$

For this to work, the number of columns in $A$ must equal the number of rows in $B$.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Let's start with two small matrices
A = np.array([[1, 2, 3],
              [4, 5, 6]])  # Shape: (2, 3)

B = np.array([[7, 8],
              [9, 10],
              [11, 12]])  # Shape: (3, 2)

print(f"Matrix A shape: {A.shape}")
print(f"Matrix B shape: {B.shape}")
print(f"Result shape will be: ({A.shape[0]}, {B.shape[1]})")
print()
print("Matrix A:")
print(A)
print("\nMatrix B:")
print(B)

### Manual Matrix Multiplication (Triple Loop)

Let's implement matmul the naive way - three nested loops. This is exactly what the hardware does, just much slower than optimized implementations.

In [None]:
def manual_matmul(A, B):
    """Matrix multiplication using three nested loops."""
    rows_A, cols_A = A.shape
    rows_B, cols_B = B.shape
    
    assert cols_A == rows_B, f"Incompatible shapes: {A.shape} x {B.shape}"
    
    # Initialize result matrix with zeros
    C = np.zeros((rows_A, cols_B))
    
    for i in range(rows_A):        # For each row in A
        for j in range(cols_B):    # For each column in B
            for k in range(cols_A): # Dot product
                C[i, j] += A[i, k] * B[k, j]
    
    return C

C_manual = manual_matmul(A, B)
C_numpy = A @ B  # NumPy's optimized matmul

print("Manual matmul result:")
print(C_manual)
print("\nNumPy matmul result:")
print(C_numpy)
print(f"\nResults match: {np.allclose(C_manual, C_numpy)}")

### Visualizing Matrix Multiplication Step by Step

Let's visualize exactly how each element of the output matrix is computed. Each output element is the dot product of one row from A and one column from B.

In [None]:
def visualize_matmul_step(A, B, target_i, target_j):
    """Visualize how C[i,j] is computed from A's row i and B's column j."""
    C = A @ B
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Matrix A - highlight row i
    ax = axes[0]
    ax.set_title(f"Matrix A\n(highlight row {target_i})", fontsize=12)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            color = '#FF6B6B' if i == target_i else '#E8E8E8'
            ax.add_patch(patches.Rectangle((j, A.shape[0]-1-i), 1, 1, 
                         facecolor=color, edgecolor='black', linewidth=2))
            ax.text(j+0.5, A.shape[0]-0.5-i, str(A[i,j]), 
                   ha='center', va='center', fontsize=14, fontweight='bold')
    ax.set_xlim(0, A.shape[1])
    ax.set_ylim(0, A.shape[0])
    ax.set_aspect('equal')
    ax.axis('off')
    
    # Multiplication symbol
    axes[1].text(0.5, 0.5, 'Ã—', fontsize=40, ha='center', va='center')
    axes[1].axis('off')
    
    # Matrix B - highlight column j
    ax = axes[2]
    ax.set_title(f"Matrix B\n(highlight col {target_j})", fontsize=12)
    for i in range(B.shape[0]):
        for j in range(B.shape[1]):
            color = '#4ECDC4' if j == target_j else '#E8E8E8'
            ax.add_patch(patches.Rectangle((j, B.shape[0]-1-i), 1, 1, 
                         facecolor=color, edgecolor='black', linewidth=2))
            ax.text(j+0.5, B.shape[0]-0.5-i, str(B[i,j]), 
                   ha='center', va='center', fontsize=14, fontweight='bold')
    ax.set_xlim(0, B.shape[1])
    ax.set_ylim(0, B.shape[0])
    ax.set_aspect('equal')
    ax.axis('off')
    
    # Result - highlight C[i,j]
    ax = axes[3]
    dot_product_terms = [f"{A[target_i,k]}*{B[k,target_j]}" for k in range(A.shape[1])]
    ax.set_title(f"Result C[{target_i},{target_j}]\n{' + '.join(dot_product_terms)} = {int(C[target_i,target_j])}", fontsize=11)
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            color = '#FFD93D' if (i == target_i and j == target_j) else '#E8E8E8'
            ax.add_patch(patches.Rectangle((j, C.shape[0]-1-i), 1, 1, 
                         facecolor=color, edgecolor='black', linewidth=2))
            ax.text(j+0.5, C.shape[0]-0.5-i, str(int(C[i,j])), 
                   ha='center', va='center', fontsize=14, fontweight='bold')
    ax.set_xlim(0, C.shape[1])
    ax.set_ylim(0, C.shape[0])
    ax.set_aspect('equal')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Show how each element is computed
for i in range(2):
    for j in range(2):
        visualize_matmul_step(A, B, i, j)

### Counting Operations (FLOPs)

For a matrix multiplication of shapes `(M, K) @ (K, N)`, how many floating-point operations do we need?

- Each output element requires `K` multiplications and `K-1` additions
- There are `M * N` output elements
- Total: approximately `2 * M * N * K` FLOPs (counting multiply and add separately)

This is crucial for understanding inference compute costs later.

In [None]:
def count_matmul_flops(M, K, N):
    """Count FLOPs for (M, K) @ (K, N) matmul."""
    return 2 * M * N * K

# Example: a typical hidden layer in a transformer
batch_size = 1
seq_len = 512
hidden_dim = 4096
ffn_dim = 11008  # Typical for Llama-7B

flops = count_matmul_flops(batch_size * seq_len, hidden_dim, ffn_dim)
print(f"Matrix multiply: ({batch_size * seq_len}, {hidden_dim}) @ ({hidden_dim}, {ffn_dim})")
print(f"FLOPs: {flops:,}")
print(f"GFLOPs: {flops / 1e9:.2f}")
print()
print("For context:")
print(f"  A100 GPU peak: ~312 TFLOPS (FP16)")
print(f"  Time for this matmul (theoretical): {flops / 312e12 * 1000:.4f} ms")

## Part 2: The Linear Layer

A **linear layer** (also called a fully connected layer or dense layer) is the most fundamental building block in neural networks. It computes:

$$y = W \cdot x + b$$

Where:
- $x$ is the input vector (or batch of vectors)
- $W$ is the **weight matrix** (learned parameters)
- $b$ is the **bias vector** (learned parameters)
- $y$ is the output

This is just a matrix multiplication followed by a vector addition. That's it. This simple operation, repeated billions of times, is what powers modern AI.

In [None]:
class LinearLayerFromScratch:
    """A linear layer implemented from scratch using NumPy."""
    
    def __init__(self, input_dim, output_dim):
        # Initialize weights with small random values (Xavier initialization)
        scale = np.sqrt(2.0 / (input_dim + output_dim))
        self.W = np.random.randn(output_dim, input_dim) * scale
        self.b = np.zeros(output_dim)
        
        print(f"Linear layer: {input_dim} -> {output_dim}")
        print(f"  Weight matrix shape: {self.W.shape}")
        print(f"  Bias vector shape:   {self.b.shape}")
        print(f"  Total parameters:    {self.W.size + self.b.size:,}")
    
    def forward(self, x):
        """Compute y = Wx + b"""
        return self.W @ x + self.b

# Create a linear layer: 4 inputs -> 3 outputs
layer = LinearLayerFromScratch(4, 3)

# Create an input vector
x = np.array([1.0, 2.0, 3.0, 4.0])
print(f"\nInput x: {x}")
print(f"Input shape: {x.shape}")

# Forward pass
y = layer.forward(x)
print(f"\nOutput y: {y}")
print(f"Output shape: {y.shape}")

### Visualizing the Linear Transformation

Let's see how the weight matrix transforms input vectors. Each column of $W$ acts as a "filter" that determines how much each input dimension contributes to each output dimension.

In [None]:
def visualize_linear_layer(W, b, x, y):
    """Visualize the weight matrix and how it transforms the input."""
    fig, axes = plt.subplots(1, 4, figsize=(18, 5))
    
    # Weight matrix heatmap
    ax = axes[0]
    im = ax.imshow(W, cmap='RdBu_r', aspect='auto')
    ax.set_title('Weight Matrix W', fontsize=13, fontweight='bold')
    ax.set_xlabel('Input dimensions')
    ax.set_ylabel('Output dimensions')
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            ax.text(j, i, f'{W[i,j]:.2f}', ha='center', va='center', fontsize=10)
    plt.colorbar(im, ax=ax, shrink=0.8)
    
    # Input vector
    ax = axes[1]
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    ax.barh(range(len(x)), x, color=colors[:len(x)])
    ax.set_title('Input Vector x', fontsize=13, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Dimension')
    ax.set_yticks(range(len(x)))
    ax.invert_yaxis()
    
    # Bias vector
    ax = axes[2]
    ax.barh(range(len(b)), b, color='#FFD93D')
    ax.set_title('Bias Vector b', fontsize=13, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Dimension')
    ax.set_yticks(range(len(b)))
    ax.invert_yaxis()
    
    # Output vector
    ax = axes[3]
    colors_out = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    ax.barh(range(len(y)), y, color=colors_out[:len(y)])
    ax.set_title('Output y = Wx + b', fontsize=13, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Dimension')
    ax.set_yticks(range(len(y)))
    ax.invert_yaxis()
    
    plt.tight_layout()
    plt.show()

visualize_linear_layer(layer.W, layer.b, x, y)

## Part 3: Building a Multi-Layer Network from Scratch

A single linear layer can only learn linear functions. To learn complex patterns, we stack multiple layers and add **activation functions** between them (covered in detail in the next notebook).

For now, let's use a simple ReLU activation: $\text{ReLU}(x) = \max(0, x)$

In [None]:
def relu(x):
    """ReLU activation: max(0, x)"""
    return np.maximum(0, x)

class SimpleNeuralNetwork:
    """A simple feedforward neural network built from scratch."""
    
    def __init__(self, layer_dims):
        """
        Args:
            layer_dims: list of dimensions, e.g. [784, 128, 64, 10]
        """
        self.layers = []
        self.layer_dims = layer_dims
        
        print("=" * 50)
        print("Building Neural Network")
        print("=" * 50)
        
        total_params = 0
        for i in range(len(layer_dims) - 1):
            layer = LinearLayerFromScratch(layer_dims[i], layer_dims[i+1])
            self.layers.append(layer)
            total_params += layer.W.size + layer.b.size
            print()
        
        print(f"Total parameters: {total_params:,}")
        print("=" * 50)
    
    def forward(self, x, return_intermediates=False):
        """Forward pass through all layers."""
        intermediates = [x.copy()]
        
        for i, layer in enumerate(self.layers):
            x = layer.forward(x)
            
            # Apply ReLU to all but the last layer
            if i < len(self.layers) - 1:
                x = relu(x)
            
            intermediates.append(x.copy())
        
        if return_intermediates:
            return x, intermediates
        return x

# Build a network: 8 -> 16 -> 8 -> 4
net = SimpleNeuralNetwork([8, 16, 8, 4])

In [None]:
# Run a forward pass and capture intermediate states
np.random.seed(42)
x_input = np.random.randn(8)
output, intermediates = net.forward(x_input, return_intermediates=True)

print("Input:", x_input.round(3))
print("\nAfter Layer 1 + ReLU:", intermediates[1].round(3))
print("\nAfter Layer 2 + ReLU:", intermediates[2].round(3))
print("\nFinal output:", intermediates[3].round(3))

### Visualizing Hidden States Flowing Through Layers

Let's visualize how the activation values change as data flows through each layer. This is exactly what happens during inference in any neural network.

In [None]:
def visualize_forward_pass(intermediates, layer_dims):
    """Visualize activations at each layer of the network."""
    n_layers = len(intermediates)
    
    fig, axes = plt.subplots(1, n_layers, figsize=(4 * n_layers, 6))
    
    layer_names = ['Input'] + [f'Layer {i+1}\n({"+ ReLU" if i < n_layers-2 else "output"})' 
                                for i in range(n_layers - 1)]
    
    vmin = min(h.min() for h in intermediates)
    vmax = max(h.max() for h in intermediates)
    abs_max = max(abs(vmin), abs(vmax))
    
    for idx, (hidden, name) in enumerate(zip(intermediates, layer_names)):
        ax = axes[idx]
        
        # Draw as a vertical bar chart
        colors = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in hidden]
        bars = ax.barh(range(len(hidden)), hidden, color=colors, edgecolor='black', linewidth=0.5)
        ax.set_title(name, fontsize=12, fontweight='bold')
        ax.set_xlim(-abs_max * 1.1, abs_max * 1.1)
        ax.axvline(x=0, color='black', linewidth=0.5)
        ax.set_yticks(range(len(hidden)))
        ax.set_ylabel(f'Dim (size={len(hidden)})')
        ax.invert_yaxis()
        
        # Add value labels
        for i, v in enumerate(hidden):
            ax.text(v + 0.05 * abs_max * np.sign(v), i, f'{v:.2f}', 
                   va='center', fontsize=8)
    
    plt.suptitle('Data Flow Through Neural Network', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

visualize_forward_pass(intermediates, net.layer_dims)

Notice how:
- The input has both positive and negative values
- After ReLU, all negative values become 0 (shown in layer 1 and 2)
- The final output layer has no ReLU, so it can have negative values
- The dimensionality changes at each layer (8 -> 16 -> 8 -> 4)

## Part 4: The Same Thing in PyTorch

Now let's implement the exact same network using PyTorch and verify the results match. PyTorch's `nn.Linear` does the same `y = xW^T + b` operation.

In [None]:
import torch
import torch.nn as nn

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Create a PyTorch linear layer
torch_layer = nn.Linear(in_features=4, out_features=3)

print("PyTorch Linear Layer:")
print(f"  Weight shape: {torch_layer.weight.shape}")
print(f"  Bias shape:   {torch_layer.bias.shape}")
print(f"  Weight:\n{torch_layer.weight.data}")
print(f"  Bias: {torch_layer.bias.data}")

In [None]:
# Let's verify PyTorch does the same thing as our manual implementation
# Copy our numpy weights into PyTorch
W_np = np.array([[0.1, 0.2, 0.3, 0.4],
                 [0.5, 0.6, 0.7, 0.8],
                 [0.9, 1.0, 1.1, 1.2]])
b_np = np.array([0.01, 0.02, 0.03])
x_np = np.array([1.0, 2.0, 3.0, 4.0])

# Manual computation
y_manual = W_np @ x_np + b_np
print(f"Manual result: {y_manual}")

# PyTorch computation
torch_layer_test = nn.Linear(4, 3)
with torch.no_grad():
    torch_layer_test.weight.copy_(torch.tensor(W_np, dtype=torch.float32))
    torch_layer_test.bias.copy_(torch.tensor(b_np, dtype=torch.float32))

x_torch = torch.tensor(x_np, dtype=torch.float32)
y_torch = torch_layer_test(x_torch)
print(f"PyTorch result: {y_torch.detach().numpy()}")
print(f"\nResults match: {np.allclose(y_manual, y_torch.detach().numpy())}")

### Building a Full Network in PyTorch

In [None]:
class PyTorchNetwork(nn.Module):
    def __init__(self, layer_dims):
        super().__init__()
        layers = []
        for i in range(len(layer_dims) - 1):
            layers.append(nn.Linear(layer_dims[i], layer_dims[i+1]))
            if i < len(layer_dims) - 2:  # ReLU on all but last
                layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

# Same architecture as our NumPy network
model = PyTorchNetwork([8, 16, 8, 4])
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Let's look at parameter counts per layer
print("Parameter breakdown:")
print("-" * 40)
for name, param in model.named_parameters():
    print(f"{name:30s} shape={str(param.shape):15s} params={param.numel():,}")

total = sum(p.numel() for p in model.parameters())
print("-" * 40)
print(f"{'Total':30s} {'':15s} params={total:,}")
print(f"\nMemory at FP32: {total * 4 / 1024:.2f} KB")
print(f"Memory at FP16: {total * 2 / 1024:.2f} KB")

## Part 5: Batched Operations - Processing Multiple Inputs

In practice, we often process multiple inputs at once (a **batch**). This is simply another dimension in the matrix multiplication. Instead of $y = Wx$, we compute $Y = XW^T$ where $X$ has shape `(batch_size, input_dim)`.

Batching is critical for inference efficiency because it lets us better utilize GPU parallelism.

In [None]:
# Single input vs batched input
single_input = torch.randn(8)           # Shape: (8,)
batched_input = torch.randn(32, 8)      # Shape: (32, 8) - 32 inputs at once

with torch.no_grad():
    single_output = model(single_input)
    batched_output = model(batched_input)

print(f"Single input shape:  {single_input.shape}  -> Output: {single_output.shape}")
print(f"Batched input shape: {batched_input.shape} -> Output: {batched_output.shape}")

In [None]:
# Benchmark: single vs batched
import time

large_model = PyTorchNetwork([512, 1024, 1024, 256])
large_model.eval()

# Process 1000 inputs one at a time
single_inputs = [torch.randn(512) for _ in range(1000)]

start = time.time()
with torch.no_grad():
    for inp in single_inputs:
        _ = large_model(inp)
single_time = time.time() - start

# Process 1000 inputs as a batch
batched_inputs = torch.randn(1000, 512)

start = time.time()
with torch.no_grad():
    _ = large_model(batched_inputs)
batched_time = time.time() - start

print(f"Processing 1000 inputs:")
print(f"  One at a time: {single_time*1000:.1f} ms")
print(f"  As a batch:    {batched_time*1000:.1f} ms")
print(f"  Speedup:       {single_time/batched_time:.1f}x")

## Part 6: Comparing Matmul Implementations

Let's compare the performance of different matrix multiplication implementations to understand why optimized libraries matter.

In [None]:
import time

def benchmark_matmul(sizes):
    """Benchmark different matmul implementations across sizes."""
    results = {'size': [], 'numpy': [], 'torch_cpu': []}
    
    for N in sizes:
        A_np = np.random.randn(N, N).astype(np.float32)
        B_np = np.random.randn(N, N).astype(np.float32)
        A_torch = torch.tensor(A_np)
        B_torch = torch.tensor(B_np)
        
        results['size'].append(N)
        
        # NumPy
        times = []
        for _ in range(5):
            start = time.time()
            _ = A_np @ B_np
            times.append(time.time() - start)
        results['numpy'].append(np.median(times) * 1000)
        
        # PyTorch CPU
        times = []
        for _ in range(5):
            start = time.time()
            _ = torch.matmul(A_torch, B_torch)
            times.append(time.time() - start)
        results['torch_cpu'].append(np.median(times) * 1000)
        
        print(f"N={N:5d}: NumPy={results['numpy'][-1]:8.2f}ms, PyTorch={results['torch_cpu'][-1]:8.2f}ms")
    
    return results

sizes = [64, 128, 256, 512, 1024, 2048]
results = benchmark_matmul(sizes)

In [None]:
# Visualize benchmark results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
ax1.plot(results['size'], results['numpy'], 'o-', label='NumPy', color='#FF6B6B', linewidth=2)
ax1.plot(results['size'], results['torch_cpu'], 's-', label='PyTorch (CPU)', color='#4ECDC4', linewidth=2)
ax1.set_xlabel('Matrix Size (N x N)', fontsize=12)
ax1.set_ylabel('Time (ms)', fontsize=12)
ax1.set_title('Matrix Multiplication Time', fontsize=13, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# GFLOPS achieved
for label, times, color in [('NumPy', results['numpy'], '#FF6B6B'), 
                              ('PyTorch', results['torch_cpu'], '#4ECDC4')]:
    gflops = [2 * N**3 / (t/1000) / 1e9 for N, t in zip(results['size'], times)]
    ax2.plot(results['size'], gflops, 'o-', label=label, color=color, linewidth=2)

ax2.set_xlabel('Matrix Size (N x N)', fontsize=12)
ax2.set_ylabel('GFLOPS', fontsize=12)
ax2.set_title('Achieved Compute Throughput', fontsize=13, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 7: Matrix Multiplication in Transformers

In a transformer-based LLM, matrix multiplications happen everywhere. Let's map out where they occur in a single transformer layer:

1. **QKV Projection**: `[batch, seq, hidden] @ [hidden, 3*hidden]` - projects input to queries, keys, values
2. **Attention Score**: `[batch, heads, seq, head_dim] @ [batch, heads, head_dim, seq]` - computes attention weights
3. **Attention Output**: `[batch, heads, seq, seq] @ [batch, heads, seq, head_dim]` - applies attention to values
4. **Output Projection**: `[batch, seq, hidden] @ [hidden, hidden]` - projects attention output back
5. **FFN Up**: `[batch, seq, hidden] @ [hidden, 4*hidden]` - feed-forward network expansion
6. **FFN Down**: `[batch, seq, 4*hidden] @ [4*hidden, hidden]` - feed-forward network contraction

In [None]:
def analyze_transformer_matmuls(hidden_dim=4096, num_heads=32, seq_len=2048, batch_size=1):
    """Analyze all matmuls in a single transformer layer."""
    head_dim = hidden_dim // num_heads
    ffn_dim = int(hidden_dim * 2.6875)  # Typical for Llama: 11008 for hidden=4096
    
    matmuls = {
        'QKV Projection': (batch_size * seq_len, hidden_dim, 3 * hidden_dim),
        'Attention Scores': (batch_size * num_heads * seq_len, head_dim, seq_len),
        'Attention x Values': (batch_size * num_heads * seq_len, seq_len, head_dim),
        'Output Projection': (batch_size * seq_len, hidden_dim, hidden_dim),
        'FFN Gate+Up': (batch_size * seq_len, hidden_dim, 2 * ffn_dim),
        'FFN Down': (batch_size * seq_len, ffn_dim, hidden_dim),
    }
    
    print(f"Transformer Layer Analysis (hidden={hidden_dim}, heads={num_heads}, seq={seq_len})")
    print("=" * 80)
    
    total_flops = 0
    names = []
    flops_list = []
    
    for name, (M, K, N) in matmuls.items():
        flops = 2 * M * K * N
        total_flops += flops
        names.append(name)
        flops_list.append(flops / 1e9)
        print(f"{name:25s}: ({M:6d}, {K:5d}) x ({K:5d}, {N:5d}) = {flops/1e9:8.2f} GFLOPs")
    
    print(f"{'TOTAL':25s}: {'':38s} = {total_flops/1e9:8.2f} GFLOPs")
    
    return names, flops_list

names, flops_list = analyze_transformer_matmuls()

In [None]:
# Visualize the FLOP distribution
fig, ax = plt.subplots(figsize=(10, 6))

colors = ['#FF6B6B', '#FFD93D', '#4ECDC4', '#45B7D1', '#96CEB4', '#DDA0DD']
bars = ax.barh(names, flops_list, color=colors)

# Add value labels
for bar, val in zip(bars, flops_list):
    ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2, 
           f'{val:.1f} GFLOPs', va='center', fontsize=11)

ax.set_xlabel('GFLOPs', fontsize=12)
ax.set_title('FLOPs per Matrix Multiplication in a Transformer Layer\n(Llama-7B scale, seq_len=2048)', 
            fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

print(f"\nFFN operations account for {sum(flops_list[-2:])/sum(flops_list)*100:.1f}% of compute")
print(f"Attention operations account for {sum(flops_list[:4])/sum(flops_list)*100:.1f}% of compute")

## Part 8: Memory vs Compute - The Arithmetic Intensity

A critical concept for inference engineering is **arithmetic intensity**: the ratio of compute operations to memory operations.

$$\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes accessed}}$$

For a matmul `(M, K) @ (K, N)`:
- FLOPs = $2 \times M \times K \times N$
- Bytes = $(M \times K + K \times N + M \times N) \times \text{bytes\_per\_element}$

When arithmetic intensity is low, the operation is **memory-bound** (waiting for data). When it's high, it's **compute-bound** (GPU cores are busy).

In [None]:
def arithmetic_intensity(M, K, N, dtype_bytes=2):
    """Calculate arithmetic intensity of a matmul."""
    flops = 2 * M * K * N
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes
    return flops / bytes_accessed

# Compare different scenarios
scenarios = [
    ("Single token decode (batch=1)", 1, 4096, 4096),
    ("Small batch decode (batch=8)", 8, 4096, 4096),
    ("Large batch decode (batch=64)", 64, 4096, 4096),
    ("Prefill (seq=512)", 512, 4096, 4096),
    ("Prefill (seq=2048)", 2048, 4096, 4096),
]

print(f"{'Scenario':<40s} {'M':>6s} {'K':>6s} {'N':>6s} {'AI':>8s} {'Bound':>12s}")
print("-" * 80)

# A100 has ~312 TFLOPS / ~2 TB/s = ~156 ops/byte
gpu_ridge_point = 156  # ops/byte for A100

ais = []
labels = []
for name, M, K, N in scenarios:
    ai = arithmetic_intensity(M, K, N)
    ais.append(ai)
    labels.append(name)
    bound = "COMPUTE" if ai > gpu_ridge_point else "MEMORY"
    print(f"{name:<40s} {M:>6d} {K:>6d} {N:>6d} {ai:>8.1f} {bound:>12s}")

print(f"\nA100 ridge point: ~{gpu_ridge_point} ops/byte")
print("Operations with AI < ridge point are memory-bound")

In [None]:
# Visualize arithmetic intensity across batch sizes
batch_sizes = np.arange(1, 129)
ai_values = [arithmetic_intensity(bs, 4096, 4096) for bs in batch_sizes]

fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(batch_sizes, ai_values, linewidth=2.5, color='#4ECDC4')
ax.axhline(y=gpu_ridge_point, color='#FF6B6B', linestyle='--', linewidth=2, label=f'A100 Ridge Point ({gpu_ridge_point} ops/byte)')
ax.fill_between(batch_sizes, ai_values, gpu_ridge_point, 
                where=[ai < gpu_ridge_point for ai in ai_values],
                color='#FF6B6B', alpha=0.15, label='Memory-bound region')
ax.fill_between(batch_sizes, ai_values, gpu_ridge_point, 
                where=[ai >= gpu_ridge_point for ai in ai_values],
                color='#4ECDC4', alpha=0.15, label='Compute-bound region')

ax.set_xlabel('Batch Size (M dimension)', fontsize=12)
ax.set_ylabel('Arithmetic Intensity (ops/byte)', fontsize=12)
ax.set_title('Arithmetic Intensity vs Batch Size for (M, 4096) @ (4096, 4096)\nFP16 on A100 GPU', 
            fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xlim(1, 128)

plt.tight_layout()
plt.show()

print("\nKey insight: During decode, batch_size=1 is heavily memory-bound.")
print("This is why LLM inference serving batches multiple requests together.")

## Part 9: Putting It All Together - A Real Weight Matrix

Let's load a real model and inspect its weight matrices to see these concepts in action.

In [None]:
!pip install transformers -q

In [None]:
from transformers import AutoModel, AutoConfig

# Load GPT-2 small - a real transformer model
model_name = "gpt2"
config = AutoConfig.from_pretrained(model_name)

print(f"Model: {model_name}")
print(f"Hidden size: {config.n_embd}")
print(f"Num layers: {config.n_layer}")
print(f"Num heads: {config.n_head}")
print(f"Vocab size: {config.vocab_size}")
print(f"Context length: {config.n_positions}")

In [None]:
model_real = AutoModel.from_pretrained(model_name)
model_real.eval()

# Count all parameters
total_params = sum(p.numel() for p in model_real.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"Memory at FP32: {total_params * 4 / 1e9:.2f} GB")
print(f"Memory at FP16: {total_params * 2 / 1e9:.2f} GB")

# Show first layer's weight matrices
print("\n--- First Transformer Layer ---")
for name, param in model_real.named_parameters():
    if 'h.0.' in name:  # First layer only
        print(f"{name:50s} shape={str(param.shape):20s} params={param.numel():>10,}")

In [None]:
# Visualize the weight distribution of a real model
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

weight_matrices = [
    ('h.0.attn.c_attn.weight', 'Layer 0: QKV Projection'),
    ('h.0.attn.c_proj.weight', 'Layer 0: Attention Output'),
    ('h.0.mlp.c_fc.weight', 'Layer 0: FFN Up'),
    ('h.0.mlp.c_proj.weight', 'Layer 0: FFN Down'),
]

for ax, (param_name, title) in zip(axes.flat, weight_matrices):
    for name, param in model_real.named_parameters():
        if name == param_name:
            weights = param.detach().numpy().flatten()
            ax.hist(weights, bins=100, color='#4ECDC4', alpha=0.7, edgecolor='black', linewidth=0.3)
            ax.set_title(f'{title}\nshape={list(param.shape)}, mean={weights.mean():.4f}, std={weights.std():.4f}', fontsize=11)
            ax.set_xlabel('Weight Value')
            ax.set_ylabel('Count')
            ax.axvline(x=0, color='red', linestyle='--', alpha=0.5)
            break

plt.suptitle('Weight Distributions in GPT-2 (Real Model)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Visualize the actual weight matrix as a heatmap
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for name, param in model_real.named_parameters():
    if name == 'h.0.attn.c_attn.weight':
        W = param.detach().numpy()
        
        # Full weight matrix (downsampled for visualization)
        ax = axes[0]
        step = max(1, W.shape[0] // 64)
        im = ax.imshow(W[::step, ::step], cmap='RdBu_r', aspect='auto', 
                      vmin=-0.3, vmax=0.3)
        ax.set_title(f'QKV Weight Matrix\n(downsampled from {W.shape})', fontsize=12)
        ax.set_xlabel('Output dimension')
        ax.set_ylabel('Input dimension')
        plt.colorbar(im, ax=ax, shrink=0.8)
        
        # Zoom into a small region
        ax = axes[1]
        small_W = W[:32, :32]
        im = ax.imshow(small_W, cmap='RdBu_r', aspect='auto',
                      vmin=-0.3, vmax=0.3)
        ax.set_title(f'Top-left 32x32 corner (zoomed)', fontsize=12)
        ax.set_xlabel('Output dimension')
        ax.set_ylabel('Input dimension')
        plt.colorbar(im, ax=ax, shrink=0.8)
        break

plt.suptitle('Real Weight Matrix from GPT-2', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 10: From Input to Output - A Complete Forward Pass

Let's trace a complete forward pass through GPT-2, showing how matrix multiplications transform the input at each step.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "The matrix multiplication is"
inputs = tokenizer(text, return_tensors='pt')

print(f"Input text: '{text}'")
print(f"Token IDs: {inputs['input_ids'].tolist()}")
print(f"Tokens: {[tokenizer.decode(t) for t in inputs['input_ids'][0]]}")
print(f"Input shape: {inputs['input_ids'].shape}")

In [None]:
# Run forward pass and capture hidden states at every layer
with torch.no_grad():
    outputs = model_real(**inputs, output_hidden_states=True)

hidden_states = outputs.hidden_states  # Tuple of (n_layers + 1) tensors

print(f"Number of hidden state snapshots: {len(hidden_states)} (embedding + {config.n_layer} layers)")
print(f"Each hidden state shape: {hidden_states[0].shape}")
print(f"  - Batch size: {hidden_states[0].shape[0]}")
print(f"  - Sequence length: {hidden_states[0].shape[1]}")
print(f"  - Hidden dimension: {hidden_states[0].shape[2]}")

In [None]:
# Visualize how the hidden state of the last token evolves through layers
last_token_states = [hs[0, -1, :].numpy() for hs in hidden_states]  # Last token, all layers

fig, axes = plt.subplots(3, 1, figsize=(14, 12))

# 1. Norm of hidden states across layers
norms = [np.linalg.norm(s) for s in last_token_states]
axes[0].plot(range(len(norms)), norms, 'o-', color='#FF6B6B', linewidth=2, markersize=6)
axes[0].set_xlabel('Layer', fontsize=12)
axes[0].set_ylabel('L2 Norm', fontsize=12)
axes[0].set_title('Hidden State Norm Across Layers (Last Token)', fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# 2. Heatmap of hidden states evolving
state_matrix = np.array(last_token_states)[:, :64]  # First 64 dims for visibility
im = axes[1].imshow(state_matrix, aspect='auto', cmap='RdBu_r')
axes[1].set_xlabel('Hidden Dimension (first 64)', fontsize=12)
axes[1].set_ylabel('Layer', fontsize=12)
axes[1].set_title('Hidden State Values Across Layers', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=axes[1], shrink=0.8)

# 3. Cosine similarity between consecutive layers
cos_sims = []
for i in range(1, len(last_token_states)):
    a, b = last_token_states[i-1], last_token_states[i]
    cos_sim = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    cos_sims.append(cos_sim)

axes[2].bar(range(1, len(cos_sims)+1), cos_sims, color='#4ECDC4', edgecolor='black', linewidth=0.5)
axes[2].set_xlabel('Layer Transition', fontsize=12)
axes[2].set_ylabel('Cosine Similarity', fontsize=12)
axes[2].set_title('Cosine Similarity Between Consecutive Layer Outputs', fontsize=13, fontweight='bold')
axes[2].grid(True, alpha=0.3)
axes[2].set_ylim(0, 1)

plt.tight_layout()
plt.show()

### What This Tells Us

- **Hidden state norms** grow through layers - the representations become "larger" as more information is accumulated
- **The heatmap** shows how different dimensions activate differently at different layers - the model gradually builds up its representation
- **High cosine similarity** between adjacent layers means residual connections are working well - each layer makes a small incremental update rather than completely changing the representation

---

## Key Takeaways

1. **Matrix multiplication is THE fundamental operation** in neural networks. Every linear layer, every attention computation, every projection is a matmul.

2. **A linear layer computes `y = Wx + b`** - it's just a matrix multiply followed by a bias addition. That's all a "neuron" really does.

3. **FLOPs scale as O(MNK)** for a `(M,K) @ (K,N)` matmul. For LLMs, this means compute scales with hidden dimension, sequence length, and batch size.

4. **Arithmetic intensity determines whether you're compute-bound or memory-bound.** Single-token decode (batch=1) is heavily memory-bound. Prefill with long sequences is compute-bound. This distinction is central to inference optimization.

5. **Batching matters enormously.** Processing inputs one at a time wastes hardware capability. Batching increases arithmetic intensity and GPU utilization.

6. **Weight matrices in real models** follow approximately Gaussian distributions centered near zero. Their shapes determine the model's parameter count and compute requirements.

7. **Residual connections** in transformers mean each layer makes small updates to the hidden state, rather than completely rewriting it.

---

**Next notebook:** We'll explore activation functions - the non-linear operations between these matrix multiplications that give neural networks their power.