In [2]:
from einops import rearrange
import torch
from torch import einsum

In [3]:
einsum?

[0;31mSignature:[0m [0meinsum[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m:[0m [0mAny[0m[0;34m)[0m [0;34m->[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
einsum(equation, *operands) -> Tensor

Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
based on the Einstein summation convention.

Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
with some subscript and define which subscripts are part of the output. The output is then computed by summing
the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
output. For example, matrix multiplication can 

In [4]:
rearrange?

[0;31mSignature:[0m
[0mrearrange[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0;34m~[0m[0mTensor[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0;34m~[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpattern[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0maxes_lengths[0m[0;34m:[0m [0mAny[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m~[0m[0mTensor[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
stack, concatenate and other operations.

Examples:

```python
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]

# stack along first (batch) axis, output is

# Einstein Operations: Making Tensor Math Readable

Tensor operations are everywhere in deep learning. But writing them with standard PyTorch/NumPy can be confusing:

```python
# What does this do?
result = torch.bmm(x.transpose(1, 2), y).sum(dim=-1)
```

This notebook builds intuition for **einsum** and **einops** — two tools that make tensor operations readable and less error-prone. We'll start from explicit loops and see how these tools emerge naturally.


In [5]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (10, 4)
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'


## Our Toy Example

Let's create small matrices we can trace through by hand. This makes it easy to verify that our operations are doing what we think.


In [6]:
# A simple 2x3 matrix
A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
], dtype=torch.float32)

# A 3x2 matrix  
B = torch.tensor([
    [1, 2],
    [3, 4],
    [5, 6]
], dtype=torch.float32)

print("Matrix A (2×3):")
print(A)
print("\nMatrix B (3×2):")
print(B)


Matrix A (2×3):
tensor([[1., 2., 3.],
        [4., 5., 6.]])

Matrix B (3×2):
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])


## The Problem: Loops Are Verbose and Error-Prone

Let's implement some common operations with explicit loops to see the pattern.


In [18]:
# Matrix multiplication with explicit loops
# C[i,k] = sum over j of A[i,j] * B[j,k]

def matmul_loops(A, B):
    """Matrix multiplication using explicit loops"""
    I, J = A.shape
    J2, K = B.shape
    assert J == J2, "Inner dimensions must match"
    
    print(f"A.shape (I, J): {A.shape}")
    print(f"B.shape (J2, K): {B.shape}")
    
    C = torch.zeros(I, K)
    for i in range(I):
        # import pdb; pdb.set_trace()
        print(f"i: {i}")
        for k in range(K):
            print(f"  k: {k}")
            for j in range(J):
                print(f"    j: {j}", "A[i, j]:", A[i, j], "B[j, k]:", B[j, k])
                C[i, k] += A[i, j] * B[j, k]
    return C

result_loops = matmul_loops(A, B)
print("A @ B with loops:")
print(result_loops)


A.shape (I, J): torch.Size([2, 3])
B.shape (J2, K): torch.Size([3, 2])
i: 0
  k: 0
    j: 0 A[i, j]: tensor(1.) B[j, k]: tensor(1.)
    j: 1 A[i, j]: tensor(2.) B[j, k]: tensor(3.)
    j: 2 A[i, j]: tensor(3.) B[j, k]: tensor(5.)
  k: 1
    j: 0 A[i, j]: tensor(1.) B[j, k]: tensor(2.)
    j: 1 A[i, j]: tensor(2.) B[j, k]: tensor(4.)
    j: 2 A[i, j]: tensor(3.) B[j, k]: tensor(6.)
i: 1
  k: 0
    j: 0 A[i, j]: tensor(4.) B[j, k]: tensor(1.)
    j: 1 A[i, j]: tensor(5.) B[j, k]: tensor(3.)
    j: 2 A[i, j]: tensor(6.) B[j, k]: tensor(5.)
  k: 1
    j: 0 A[i, j]: tensor(4.) B[j, k]: tensor(2.)
    j: 1 A[i, j]: tensor(5.) B[j, k]: tensor(4.)
    j: 2 A[i, j]: tensor(6.) B[j, k]: tensor(6.)
A @ B with loops:
tensor([[22., 28.],
        [49., 64.]])


In [12]:
# Verify: this matches torch's built-in
result_torch = A @ B
print("A @ B with torch:")
print(result_torch)
print(f"\nMatch: {torch.allclose(result_loops, result_torch)}")


A @ B with torch:
tensor([[22., 28.],
        [49., 64.]])

Match: True


In [13]:
# Let's trace through one element to understand
# C[0, 0] = A[0,0]*B[0,0] + A[0,1]*B[1,0] + A[0,2]*B[2,0]
#         = 1*1 + 2*3 + 3*5 = 1 + 6 + 15 = 22

print("Tracing C[0,0]:")
print(f"  A[0,:] = {A[0,:].tolist()}")
print(f"  B[:,0] = {B[:,0].tolist()}")
print(f"  A[0,0]*B[0,0] + A[0,1]*B[1,0] + A[0,2]*B[2,0]")
print(f"  = {A[0,0]}*{B[0,0]} + {A[0,1]}*{B[1,0]} + {A[0,2]}*{B[2,0]}")
print(f"  = {A[0,0]*B[0,0]} + {A[0,1]*B[1,0]} + {A[0,2]*B[2,0]}")
print(f"  = {A[0,0]*B[0,0] + A[0,1]*B[1,0] + A[0,2]*B[2,0]}")
print(f"\nC[0,0] in result: {result_torch[0,0]}")


Tracing C[0,0]:
  A[0,:] = [1.0, 2.0, 3.0]
  B[:,0] = [1.0, 3.0, 5.0]
  A[0,0]*B[0,0] + A[0,1]*B[1,0] + A[0,2]*B[2,0]
  = 1.0*1.0 + 2.0*3.0 + 3.0*5.0
  = 1.0 + 6.0 + 15.0
  = 22.0

C[0,0] in result: 22.0


## The Pattern: Indices and Summation

Look at what we're doing in matrix multiplication:

```
C[i, k] = Σⱼ A[i, j] × B[j, k]
```

The pattern is:
1. **Label each dimension** with a letter (i, j, k)
2. **Multiply** corresponding elements
3. **Sum** over indices that don't appear in the output (j disappears)

This is exactly what **einsum** notation captures!


In [19]:
# einsum: express the SAME operation in one line
# "ij,jk->ik" means:
#   - A has dimensions [i, j]
#   - B has dimensions [j, k]  
#   - Output has dimensions [i, k]
#   - j appears in both inputs but not output → sum over j

result_einsum = einsum('ij,jk->ik', A, B)
print("A @ B with einsum('ij,jk->ik', A, B):")
print(result_einsum)
print(f"\nMatch: {torch.allclose(result_einsum, result_torch)}")


A @ B with einsum('ij,jk->ik', A, B):
tensor([[22., 28.],
        [49., 64.]])

Match: True


**Reading einsum notation**: `'ij,jk->ik'`

| Part | Meaning |
|------|---------|
| `ij` | First tensor has dimensions i (rows) and j (cols) |
| `,` | Separator between tensors |
| `jk` | Second tensor has dimensions j (rows) and k (cols) |
| `->` | "produces" |
| `ik` | Output has dimensions i and k |
| (j missing from output) | Sum over j |

The beauty: **the notation directly describes the math!**


## Building Up: Simple Operations First

Let's see how simpler operations work in einsum, then build up.


In [20]:
# 1. Sum all elements
# Loop version:
total_loop = 0
for i in range(A.shape[0]):
    for j in range(A.shape[1]):
        total_loop += A[i, j]

print("A:")
print(A)
print(f"\nSum all elements:")
print(f"  Loop: {total_loop}")
print(f"  torch.sum: {A.sum()}")
print(f"  einsum('ij->', A): {einsum('ij->', A)}")  # no output indices = sum everything


A:
tensor([[1., 2., 3.],
        [4., 5., 6.]])

Sum all elements:
  Loop: 21.0
  torch.sum: 21.0
  einsum('ij->', A): 21.0


In [24]:
# 2. Sum along rows (keep columns)
# For each column j, sum over all rows i
row_sums_loop = torch.zeros(A.shape[1])
for j in range(A.shape[1]):
    for i in range(A.shape[0]):
        row_sums_loop[j] += A[i, j]

col_sums_loop = torch.zeros(A.shape[0])
for i in range(A.shape[0]):
    for j in range(A.shape[1]):
        row_sums_loop[i] += A[i, j]

print("A:")
print(A)
print(f"\nSum along rows (for each column):")
print(f"  Loop: {col_sums_loop}")
print(f"  torch.sum(dim=0): {A.sum(dim=0)}")
print(f"  einsum('ij->j', A): {einsum('ij->j', A)}")  # i not in output = sum over i
print()
print(f"Sum along columns (for each row):")
print(f"  Loop: {col_sums_loop}")
print(f"  torch.sum(dim=1): {A.sum(dim=1)}")
print(f"  einsum('ij->i', A): {einsum('ij->i', A)}")  # j not in output = sum over j


A:
tensor([[1., 2., 3.],
        [4., 5., 6.]])

Sum along rows (for each column):
  Loop: tensor([0., 0.])
  torch.sum(dim=0): tensor([5., 7., 9.])
  einsum('ij->j', A): tensor([5., 7., 9.])

Sum along columns (for each row):
  Loop: tensor([0., 0.])
  torch.sum(dim=1): tensor([ 6., 15.])
  einsum('ij->i', A): tensor([ 6., 15.])


In [33]:
# 3. Transpose
# Just swap the order of indices in the output!

print("A:")
print(A)
print(f"A.shape: {A.shape}")

# Loop version
A_T_loop = torch.zeros(A.shape[1], A.shape[0], dtype=A.dtype)
print(f"A_T_loop.shape: {A_T_loop.shape}")
print()
for i in range(A.shape[0]):
    print(f"i: {i}")
    for j in range(A.shape[1]):
        print(f"  j: {j}", "A[i, j]:", A[i, j])
        A_T_loop[j, i] = A[i, j]

print(f"\nTranspose (swap i and j):")
print(f"  Loop:\n{A_T_loop}")
print(f"  A.T:\n{A.T}")
print()
print(f"  einsum('ij->ji', A):\n{einsum('ij->ji', A)}")  # swap output order


A:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
A.shape: torch.Size([2, 3])
A_T_loop.shape: torch.Size([3, 2])

i: 0
  j: 0 A[i, j]: tensor(1.)
  j: 1 A[i, j]: tensor(2.)
  j: 2 A[i, j]: tensor(3.)
i: 1
  j: 0 A[i, j]: tensor(4.)
  j: 1 A[i, j]: tensor(5.)
  j: 2 A[i, j]: tensor(6.)

Transpose (swap i and j):
  Loop:
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
  A.T:
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])

  einsum('ij->ji', A):
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])


In [35]:
# 4. Dot product (inner product) of two vectors
v1 = torch.tensor([1., 2., 3.])
v2 = torch.tensor([4., 5., 6.])

# Loop version
dot_loop = 0
for i in range(len(v1)):
    dot_loop += v1[i] * v2[i]

print(f"v1: {v1.tolist()}")
print(f"v2: {v2.tolist()}")
print(f"\nDot product (v1 · v2):")
print(f"  Loop: {dot_loop}")
print(f"  torch.dot: {torch.dot(v1, v2)}")
print(f"  einsum('i,i->', v1, v2): {einsum('i,i->', v1, v2)}")  # same index, no output = sum products


v1: [1.0, 2.0, 3.0]
v2: [4.0, 5.0, 6.0]

Dot product (v1 · v2):
  Loop: 32.0
  torch.dot: 32.0
  einsum('i,i->', v1, v2): 32.0


In [41]:
# 5. Outer product: every element of v1 times every element of v2
# Result[i,j] = v1[i] * v2[j]

outer_loop = torch.zeros(len(v1), len(v2))
for i in range(len(v1)):
    for j in range(len(v2)):
        outer_loop[i, j] = v1[i] * v2[j]

print(f"v1: {v1.tolist()}")
print(f"v2: {v2.tolist()}")
print(f"\nOuter product (v1 ⊗ v2):")
print(f"  Loop:\n{outer_loop}")
print()
print(f"  torch.outer:\n{torch.outer(v1, v2)}")
print()
print(f"  einsum('i,j->ij', v1, v2):\n{einsum('i,j->ij', v1, v2)}")  # different indices = keep both


v1: [1.0, 2.0, 3.0]
v2: [4.0, 5.0, 6.0]

Outer product (v1 ⊗ v2):
  Loop:
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]])

  torch.outer:
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]])

  einsum('i,j->ij', v1, v2):
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]])


In [42]:
# Harder outer product example: 3 vectors, result is a rank-3 tensor

u = torch.tensor([1., 2.])
v = torch.tensor([3., 4., 5.])
w = torch.tensor([6., 7.])

# Loop version for 3-way outer product: Result[i,j,k] = u[i] * v[j] * w[k]
outer3_loop = torch.zeros(len(u), len(v), len(w))
for i in range(len(u)):
    for j in range(len(v)):
        for k in range(len(w)):
            outer3_loop[i, j, k] = u[i] * v[j] * w[k]

print(f"u: {u.tolist()}")
print(f"v: {v.tolist()}")
print(f"w: {w.tolist()}")
print("\n3-way Outer product (u ⊗ v ⊗ w):")
print(f"  Loop:\n{outer3_loop}")
print()
print("  Using einsum('i,j,k->ijk', u, v, w):")
print(einsum('i,j,k->ijk', u, v, w))  # three different indices


u: [1.0, 2.0]
v: [3.0, 4.0, 5.0]
w: [6.0, 7.0]

3-way Outer product (u ⊗ v ⊗ w):
  Loop:
tensor([[[18., 21.],
         [24., 28.],
         [30., 35.]],

        [[36., 42.],
         [48., 56.],
         [60., 70.]]])

  Using einsum('i,j,k->ijk', u, v, w):
tensor([[[18., 21.],
         [24., 28.],
         [30., 35.]],

        [[36., 42.],
         [48., 56.],
         [60., 70.]]])


In [48]:
import torch

# Challenging heterogeneous outer product example:
# Make one input a matrix (rank-2), and ignore one vector in the output (trace it out / sum over it).

# a: shape (2,), b: shape (2,), c: shape (3, 2) [matrix!], d: shape (2,)
a = torch.tensor([1., 2.])
b = torch.tensor([3., 4.])
c = torch.tensor([[2., 1.], [3., 2.], [4., 3.]])  # 3x2 matrix
d = torch.tensor([1., 2.])

print(f"a: {a.tolist()}")
print(f"b: {b.tolist()}  # This will be summed out in final result")
print(f"c:\n{c}")
print(f"d: {d.tolist()}")
print()

# Compute: Outer over a (i), b (j), c (k,l), d (m)
# But sum out b -- i.e., Result[i, k, l, m] = sum_j a[i] * b[j] * c[k, l] * d[m]

# ---------------------------------------------------------------------
# Step 0: Full explicit nested loops (reference implementation)
# ---------------------------------------------------------------------
outer_step0 = torch.zeros(len(a), c.shape[0], c.shape[1], len(d))

for i in range(len(a)):
    for k in range(c.shape[0]):
        for l in range(c.shape[1]):
            for m in range(len(d)):
                s = 0.
                for j in range(len(b)):
                    s += a[i] * b[j] * c[k, l] * d[m]
                outer_step0[i, k, l, m] = s

print("Step 0: Full 5-loop implementation (i, k, l, m, j)")
print("  Result.shape =", tuple(outer_step0.shape))
print(outer_step0)
print()

# ---------------------------------------------------------------------
# Step 1: Remove the innermost 'j' loop by summing b explicitly
#   sum_j a[i]*b[j]*c[k,l]*d[m] = a[i] * (sum_j b[j]) * c[k,l] * d[m]
# ---------------------------------------------------------------------
sum_b = b.sum()
outer_step1 = torch.zeros_like(outer_step0)

for i in range(len(a)):
    for k in range(c.shape[0]):
        for l in range(c.shape[1]):
            for m in range(len(d)):
                outer_step1[i, k, l, m] = a[i] * sum_b * c[k, l] * d[m]

print("Step 1: Removed the 'j' loop (now 4 loops: i, k, l, m)")
print("  Max difference vs Step 0:", (outer_step1 - outer_step0).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 2: Remove the 'm' loop by vectorizing over d
#   For fixed i,k,l, the whole d-dimension is just scaled by a[i]*sum_b*c[k,l]
# ---------------------------------------------------------------------
outer_step2 = torch.zeros_like(outer_step0)

for i in range(len(a)):
    for k in range(c.shape[0]):
        for l in range(c.shape[1]):
            # This fills the entire m-axis at once
            outer_step2[i, k, l, :] = a[i] * sum_b * c[k, l] * d

print("Step 2: Removed the 'm' loop (now 3 loops: i, k, l)")
print("  Max difference vs Step 0:", (outer_step2 - outer_step0).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 3: Remove the 'l' loop by vectorizing over c's second dimension
#   For fixed i,k, we want a[i]*sum_b * c[k, :] (length-2) outer d (length-2)
#   That gives a 2x2 block over (l, m).
# ---------------------------------------------------------------------
outer_step3 = torch.zeros_like(outer_step0)

for i in range(len(a)):
    for k in range(c.shape[0]):
        # c[k] has shape (2,), we make it (2,1) so it can broadcast with d (2,)
        # Result is shape (2,2) filling (l,m).
        outer_step3[i, k, :, :] = a[i] * sum_b * c[k].unsqueeze(-1) * d

print("Step 3: Removed the 'l' loop (now 2 loops: i, k)")
print("  Max difference vs Step 0:", (outer_step3 - outer_step0).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 4: Remove the 'k' loop by vectorizing over the whole matrix c
#   For fixed i, we want a[i]*sum_b * c (3x2), and then an extra outer with d (length-2)
#   c has shape (3,2). We lift it to (3,2,1) and multiply with d (2,) -> (3,2,2)
# ---------------------------------------------------------------------
outer_step4 = torch.zeros_like(outer_step0)

for i in range(len(a)):
    # c.unsqueeze(-1): (3,2,1)
    # d: (2,) broadcasts to (1,1,2)
    # => result (3,2,2) over (k,l,m)
    outer_step4[i, :, :, :] = a[i] * sum_b * c.unsqueeze(-1) * d

print("Step 4: Removed the 'k' loop (now 1 loop: i)")
print("  Max difference vs Step 0:", (outer_step4 - outer_step0).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 5: Remove the final 'i' loop by full broadcasting
#   Now we let all dimensions broadcast:
#   - a: (2,)      -> (2,1,1,1)
#   - c: (3,2)     -> (1,3,2,1)
#   - d: (2,)      -> (1,1,1,2)
#   And sum_b is a scalar.
# ---------------------------------------------------------------------
outer_step5 = (
    a[:, None, None, None] *   # shape (2,1,1,1)
    sum_b *                    # scalar
    c[None, :, :, None] *      # shape (1,3,2,1)
    d[None, None, None, :]     # shape (1,1,1,2)
)

print("Step 5: No Python loops, just broadcasting")
print("  Max difference vs Step 0:", (outer_step5 - outer_step0).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 6: Replace explicit broadcasting with a single einsum
#
# Recall the mathematical definition:
#   Result[i, k, l, m] = sum_j a[i] * b[j] * c[k, l] * d[m]
#
# Map to einsum indices:
#   a: (2,)    -> 'i'
#   b: (2,)    -> 'j'   (summed out)
#   c: (3,2)   -> 'kl'
#   d: (2,)    -> 'm'
#
# We want output indices i,k,l,m and to sum over j:
#   'i,j,kl,m -> iklm'
# ---------------------------------------------------------------------
einsum_result = torch.einsum('i,j,kl,m->iklm', a, b, c, d)

print("Step 6: Single einsum call")
print("  Max difference vs Step 0:", (einsum_result - outer_step0).abs().max().item())
print()
print("Final einsum result:")
print(einsum_result)

a: [1.0, 2.0]
b: [3.0, 4.0]  # This will be summed out in final result
c:
tensor([[2., 1.],
        [3., 2.],
        [4., 3.]])
d: [1.0, 2.0]

Step 0: Full 5-loop implementation (i, k, l, m, j)
  Result.shape = (2, 3, 2, 2)
tensor([[[[ 14.,  28.],
          [  7.,  14.]],

         [[ 21.,  42.],
          [ 14.,  28.]],

         [[ 28.,  56.],
          [ 21.,  42.]]],


        [[[ 28.,  56.],
          [ 14.,  28.]],

         [[ 42.,  84.],
          [ 28.,  56.]],

         [[ 56., 112.],
          [ 42.,  84.]]]])

Step 1: Removed the 'j' loop (now 4 loops: i, k, l, m)
  Max difference vs Step 0: 0.0

Step 2: Removed the 'm' loop (now 3 loops: i, k, l)
  Max difference vs Step 0: 0.0

Step 3: Removed the 'l' loop (now 2 loops: i, k)
  Max difference vs Step 0: 0.0

Step 4: Removed the 'k' loop (now 1 loop: i)
  Max difference vs Step 0: 0.0

Step 5: No Python loops, just broadcasting
  Max difference vs Step 0: 0.0

Step 6: Single einsum call
  Max difference vs Step 0: 0.0

Fi

In [None]:
import torch

# Now we'll start from the final einsum result and work backwards to the nested loops.
# Make one input a matrix (rank-2), and ignore one vector in the output (trace it out / sum over it).

# Math we want:
#   Result[i, k, l, m] = sum_j a[i] * b[j] * c[k, l] * d[m]
#
# Index mapping for einsum:
#   a: (2,)    -> 'i'
#   b: (2,)    -> 'j'
#   c: (3,2)   -> 'kl'
#   d: (2,)    -> 'm'
#
# We want output indices i,k,l,m and to sum over j:
#   'i,j,kl,m -> iklm'

# ---------------------------------------------------------------------
# Step 0: Fully vectorized einsum reminder (no Python loops)
# ---------------------------------------------------------------------
result_step0 = torch.einsum('i,j,kl,m->iklm', a, b, c, d)

print("Step 0: Full einsum with no Python loops")
print("  Result.shape =", tuple(result_step0.shape))
print(result_step0)
print()

# We'll treat this as our reference "ground truth" for all later steps.
ref = result_step0

# ---------------------------------------------------------------------
# Step 1: Add a loop over l (split c into columns)
#
# Idea:
#   - Keep c as (3,2) but handle each column l separately.
#   - For a fixed l, c[:, l] has shape (3,) and we treat it as a vector with index 'k'.
#   - Use einsum over a (i), b (j), c[:, l] (k), d (m):
#       'i,j,k,m -> ikm'
#   - Then place that into the l-th slice of the result.
# ---------------------------------------------------------------------
result_step1 = torch.zeros_like(ref)

for l in range(c.shape[1]):
    # einsum over: a[i], b[j], c[:,l] (k), d[m]
    # This sums out j and produces shape (i,k,m)
    einsum_res_l = torch.einsum('i,j,k,m->ikm', a, b, c[:, l], d)
    # Insert along l dimension
    result_step1[:, :, l, :] = einsum_res_l

print("Step 1: One Python loop (over l), everything else in einsum")
print("  Max difference vs Step 0:", (result_step1 - ref).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 2: Add loops over k and l
#
# Idea:
#   - Now we peel off both k and l as explicit loops.
#   - For each (k, l) pair, c[k, l] is a scalar.
#   - The rest is an einsum over a (i), b (j), d (m):
#       'i,j,m -> im'
#   - Then we scale by c[k, l] and write into [i, k, l, m].
# ---------------------------------------------------------------------
result_step2 = torch.zeros_like(ref)

for k in range(c.shape[0]):
    for l in range(c.shape[1]):
        # Base (i,m) tensor from a, b, d:
        #   sum_j a[i] * b[j] * d[m]  (j is summed out)
        base_im = torch.einsum('i,j,m->im', a, b, d)
        # Scale by the scalar c[k, l] and store at this (k,l)
        result_step2[:, k, l, :] = base_im * c[k, l]

print("Step 2: Two Python loops (over k, l), contractions in einsum")
print("  Max difference vs Step 0:", (result_step2 - ref).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 3: Add loops over i, k, l
#
# Idea:
#   - Now i, k, l are all handled in Python.
#   - For each fixed (i, k, l), we want the m-dimension:
#       Result[i, k, l, m] = a[i] * c[k, l] * sum_j b[j] * d[m]
#   - The part that depends on j and m is:
#       sum_j b[j] * d[m]   -> einsum 'j,m -> m'
#   - Then we just scale by a[i] and c[k, l].
# ---------------------------------------------------------------------
result_step3 = torch.zeros_like(ref)

for i in range(len(a)):
    for k in range(c.shape[0]):
        for l in range(c.shape[1]):
            # This vector has shape (m,):
            #   sum_j b[j] * d[m]
            base_m = torch.einsum('j,m->m', b, d)
            result_step3[i, k, l, :] = a[i] * c[k, l] * base_m

print("Step 3: Three Python loops (over i, k, l), einsum handles j,m")
print("  Max difference vs Step 0:", (result_step3 - ref).abs().max().item())
print()

# ---------------------------------------------------------------------
# Step 4: Add loops over i, k, l, m
#
# Idea:
#   - Now i, k, l, m are all explicit loops.
#   - For each fixed (i, k, l, m), we want a single scalar:
#       Result[i, k, l, m] = a[i] * c[k, l] * d[m] * sum_j b[j]
#   - The only contraction left is the sum over j:
#       sum_j b[j]           -> einsum 'j ->'
#   - We precompute this once (still with einsum) and reuse it.
# ---------------------------------------------------------------------
sum_b = torch.einsum('j->', b)  # scalar = sum_j b[j]

result_step4 = torch.zeros_like(ref)

for i in range(len(a)):
    for k in range(c.shape[0]):
        for l in range(c.shape[1]):
            for m in range(len(d)):
                result_step4[i, k, l, m] = a[i] * c[k, l] * d[m] * sum_b

print("Step 4: Four Python loops (over i, k, l, m), einsum only for sum over j")
print("  Max difference vs Step 0:", (result_step4 - ref).abs().max().item())
print()

# If we went one step further (5 loops), we'd explicitly loop over j as well and
# manually do the sum, which would just replicate what einsum already does internally.
# So we stop here.


In [49]:
import torch

# Matmul of a 4d tensor (X) with a 3d tensor (Y) that needs to be broadcast

# Let's define a 4d tensor X of shape (2, 3, 4, 5)
X = torch.arange(2*3*4*5, dtype=torch.float32).reshape(2, 3, 4, 5)

# And a 3d tensor Y of shape (1, 4, 5), which will be broadcast across the first dimension
Y = torch.arange(4*5, dtype=torch.float32).reshape(1, 4, 5)

print("X shape:", X.shape)
print("Y shape:", Y.shape)
print()

# We want a "batched matmul" over the last axis (size 5), for each X[i,j,k,:] and Y[m,k,:].
# Result index structure:
#   i in [0, 2)
#   j in [0, 3)
#   k in [0, 4)
#   m in [0, 1)
#
# Mathematically:
#   result[i, j, k, m] = sum_l X[i, j, k, l] * Y[m, k, l]


# ---------------------------------------------------------------------
# PART 1: From all loops down to no loops (NO einsum)
# ---------------------------------------------------------------------

# Step 0: Full explicit loops, including the inner l loop
full_loop = torch.zeros(X.shape[0], X.shape[1], X.shape[2], Y.shape[0])

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        for k in range(X.shape[2]):
            for m in range(Y.shape[0]):
                s = 0.0
                for l in range(X.shape[3]):  # inner dot-product over l
                    s += X[i, j, k, l] * Y[m, k, l]
                full_loop[i, j, k, m] = s

print("Step 0: Full 5-loop implementation (i, j, k, m, l)")
print("  Result shape:", full_loop.shape)
print(full_loop)
print()


# Step 1: Remove the l loop by using vectorized dot-product (still 4 loops)
step1 = torch.zeros_like(full_loop)

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        for k in range(X.shape[2]):
            for m in range(Y.shape[0]):
                # Now we let PyTorch do the l-sum:
                step1[i, j, k, m] = (X[i, j, k, :] * Y[m, k, :]).sum()

print("Step 1: 4 loops (i, j, k, m), l handled by .sum over the last dim")
print("  Max difference vs Step 0:", (step1 - full_loop).abs().max().item())
print()


# Step 2: Remove the m loop (remember Y has shape (1, 4, 5), so m is size 1)
#         We can just use Y[0, k, :] everywhere.
step2 = torch.zeros_like(full_loop)

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        for k in range(X.shape[2]):
            # Dot-product of X[i,j,k,:] with Y[0,k,:], then stored at m=0
            step2[i, j, k, 0] = (X[i, j, k, :] * Y[0, k, :]).sum()

print("Step 2: 3 loops (i, j, k), m absorbed by using Y[0, k, :]")
print("  Max difference vs Step 0:", (step2 - full_loop).abs().max().item())
print()


# Step 3: Remove the k loop by vectorizing over k as well
#         For fixed (i, j), we want all k:
#           result[i, j, k, 0] = sum_l X[i,j,k,l] * Y[0,k,l]
#         This is just a batched dot over the last dim.
step3 = torch.zeros_like(full_loop)

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        # X[i, j, :, :] has shape (4, 5)
        # Y[0, :, :] has shape (4, 5)
        # Elementwise multiply and sum over last dim -> shape (4,)
        step3[i, j, :, 0] = (X[i, j, :, :] * Y[0, :, :]).sum(dim=-1)

print("Step 3: 2 loops (i, j), k and l handled by tensor ops")
print("  Max difference vs Step 0:", (step3 - full_loop).abs().max().item())
print()


# Step 4: Remove the j loop by vectorizing over j
#         For fixed i, X[i] has shape (3, 4, 5), Y[0] has shape (4, 5).
#         We want:
#           result[i, j, k, 0] = sum_l X[i,j,k,l] * Y[0,k,l]
#         We can broadcast Y[0] onto X[i] and sum over the last dim.
step4 = torch.zeros_like(full_loop)

for i in range(X.shape[0]):
    # X[i] shape: (3, 4, 5)
    # Y[0] shape: (4, 5) -> broadcast to (1, 4, 5) then (3, 4, 5)
    prod = X[i] * Y[0]   # shape (3, 4, 5)
    # sum over l to get (3, 4)
    step4[i, :, :, 0] = prod.sum(dim=-1)

print("Step 4: 1 loop (i), j,k,l handled by tensor ops")
print("  Max difference vs Step 0:", (step4 - full_loop).abs().max().item())
print()


# Step 5: No loops at all: fully broadcasted tensor operation
#         X: (2, 3, 4, 5)
#         Y: (1, 4, 5) -> broadcast to (2, 1, 4, 5) -> (2, 3, 4, 5)
#         Then sum over last dim (l=5) and keep a singleton dim for m.
broadcast_result = (X * Y.unsqueeze(1)).sum(dim=-1, keepdim=True)
# Y.unsqueeze(1): (1,1,4,5) -> (2,3,4,5) by broadcasting

print("Step 5: 0 loops, pure broadcasting + sum")
print("  broadcast_result shape:", broadcast_result.shape)
print("  Max difference vs Step 0:", (broadcast_result - full_loop).abs().max().item())
print()


# ---------------------------------------------------------------------
# PART 2: Einsum version (also no loops)
# ---------------------------------------------------------------------
einsum_result = torch.einsum('ijkl,mkl->ijkm', X, Y)

print("Einsum version: einsum('ijkl,mkl->ijkm', X, Y)")
print("  einsum_result shape:", einsum_result.shape)
print(einsum_result)
print()
print("Difference between einsum and full_loop (should be 0):",
      (einsum_result - full_loop).abs().max().item())
print()


# ---------------------------------------------------------------------
# PART 3: Build back up using ONLY einsum (no manual .sum)
# ---------------------------------------------------------------------

ref = einsum_result  # treat einsum_result as the ground truth for this part


# Step E0: Reference, no loops
stepE0 = torch.einsum('ijkl,mkl->ijkm', X, Y)

print("Step E0: No loops, einsum only (same as ref)")
print("  Max difference vs ref:", (stepE0 - ref).abs().max().item())
print()


# Step E1: Add a loop over m, einsum handles i,j,k,l
#   For fixed m, we contract:
#     result[:, :, :, m] = sum_l X[i,j,k,l] * Y[m,k,l]
#   Einsum: 'ijkl,kl->ijk'
stepE1 = torch.zeros_like(ref)

for m in range(Y.shape[0]):
    stepE1[:, :, :, m] = torch.einsum('ijkl,kl->ijk', X, Y[m])

print("Step E1: 1 loop (m), einsum handles (i,j,k,l)")
print("  Max difference vs ref:", (stepE1 - ref).abs().max().item())
print()


# Step E2: Add loops over k and m, einsum handles i,j,l
#   For fixed (k, m):
#     result[:, :, k, m] = sum_l X[i,j,k,l] * Y[m,k,l]
#   Einsum: 'ijl,l->ij'
stepE2 = torch.zeros_like(ref)

for m in range(Y.shape[0]):
    for k in range(X.shape[2]):
        stepE2[:, :, k, m] = torch.einsum('ijl,l->ij', X[:, :, k, :], Y[m, k, :])

print("Step E2: 2 loops (k, m), einsum handles (i,j,l)")
print("  Max difference vs ref:", (stepE2 - ref).abs().max().item())
print()


# Step E3: Add loops over j, k, m, einsum handles i,l
#   For fixed (j, k, m):
#     result[:, j, k, m] = sum_l X[i,j,k,l] * Y[m,k,l]
#   Einsum: 'il,l->i'
stepE3 = torch.zeros_like(ref)

for m in range(Y.shape[0]):
    for k in range(X.shape[2]):
        for j in range(X.shape[1]):
            stepE3[:, j, k, m] = torch.einsum('il,l->i', X[:, j, k, :], Y[m, k, :])

print("Step E3: 3 loops (j, k, m), einsum handles (i,l)")
print("  Max difference vs ref:", (stepE3 - ref).abs().max().item())
print()


# Step E4: Add loops over i, j, k, m, einsum only does the final scalar dot over l
#   For fixed (i, j, k, m):
#     result[i, j, k, m] = sum_l X[i,j,k,l] * Y[m,k,l]
#   Einsum: 'l,l->'
stepE4 = torch.zeros_like(ref)

for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        for k in range(X.shape[2]):
            for m in range(Y.shape[0]):
                stepE4[i, j, k, m] = torch.einsum('l,l->', X[i, j, k, :], Y[m, k, :])

print("Step E4: 4 loops (i, j, k, m), einsum only does 1D dot 'l,l->'")
print("  Max difference vs ref:", (stepE4 - ref).abs().max().item())
print()

# If we went one step further (5 loops), we'd also manually loop over l and
# sum up X[i,j,k,l] * Y[m,k,l], which would replicate what einsum is doing
# internally in 'l,l->'. So we stop here.

X shape: torch.Size([2, 3, 4, 5])
Y shape: torch.Size([1, 4, 5])

Step 0: Full 5-loop implementation (i, j, k, m, l)
  Result shape: torch.Size([2, 3, 4, 1])
tensor([[[[  30.],
          [ 255.],
          [ 730.],
          [1455.]],

         [[ 230.],
          [ 955.],
          [1930.],
          [3155.]],

         [[ 430.],
          [1655.],
          [3130.],
          [4855.]]],


        [[[ 630.],
          [2355.],
          [4330.],
          [6555.]],

         [[ 830.],
          [3055.],
          [5530.],
          [8255.]],

         [[1030.],
          [3755.],
          [6730.],
          [9955.]]]])

Step 1: 4 loops (i, j, k, m), l handled by .sum over the last dim
  Max difference vs Step 0: 0.0

Step 2: 3 loops (i, j, k), m absorbed by using Y[0, k, :]
  Max difference vs Step 0: 0.0

Step 3: 2 loops (i, j), k and l handled by tensor ops
  Max difference vs Step 0: 0.0

Step 4: 1 loop (i), j,k,l handled by tensor ops
  Max difference vs Step 0: 0.0

Step 5: 0 lo

## The Einsum Rules

Now we can see the pattern:

| Rule | What happens |
|------|--------------|
| Index in output | Keep that dimension |
| Index NOT in output | Sum over it |
| Same index in multiple tensors | Those dimensions are "aligned" (multiplied element-wise) |
| Different indices | Creates all combinations (like nested loops) |

**Einsum = "for each combination of indices, multiply and optionally sum"**


## Batched Operations

In deep learning, we often have batches. Einsum handles this naturally.

In [None]:
# Batch of 2 matrices, each 2x3
batch_A = torch.tensor([
    [[1, 2, 3],    # batch 0
     [4, 5, 6]],
    [[7, 8, 9],    # batch 1
     [10, 11, 12]]
], dtype=torch.float32)

# Batch of 2 matrices, each 3x2
batch_B = torch.tensor([
    [[1, 0],       # batch 0
     [0, 1],
     [1, 1]],
    [[1, 1],       # batch 1
     [1, 0],
     [0, 1]]
], dtype=torch.float32)

print(f"batch_A shape: {batch_A.shape}  (batch, rows, cols)")
print(f"batch_B shape: {batch_B.shape}  (batch, rows, cols)")


In [None]:
# Batch matrix multiplication: for each batch, multiply corresponding matrices
# C[b, i, k] = Σⱼ A[b, i, j] × B[b, j, k]

# Loop version - notice how many loops!
B_size, I, J = batch_A.shape
_, J2, K = batch_B.shape

result_loop = torch.zeros(B_size, I, K)
for b in range(B_size):
    for i in range(I):
        for k in range(K):
            for j in range(J):
                result_loop[b, i, k] += batch_A[b, i, j] * batch_B[b, j, k]

print("Batch matmul with loops:")
print(result_loop)


In [None]:
# With einsum: just add 'b' to the front!
result_einsum = einsum('bij,bjk->bik', batch_A, batch_B)

print("Batch matmul with einsum('bij,bjk->bik'):")
print(result_einsum)
print(f"\nMatch: {torch.allclose(result_loop, result_einsum)}")

# Compare with torch.bmm
print(f"\ntorch.bmm match: {torch.allclose(torch.bmm(batch_A, batch_B), result_einsum)}")


## Real Example: Attention Scores

In transformers, we compute attention scores: `Q @ K.T` for queries and keys.

Let's see how einsum makes this clearer.


In [None]:
# Attention setup
# batch=2, seq_len=3, d_model=4
batch, seq_len, d_model = 2, 3, 4

torch.manual_seed(42)
Q = torch.randn(batch, seq_len, d_model)  # queries
K = torch.randn(batch, seq_len, d_model)  # keys

print(f"Q shape: {Q.shape}  (batch, seq_len, d_model)")
print(f"K shape: {K.shape}  (batch, seq_len, d_model)")


In [None]:
# Attention scores: for each query position, dot product with each key position
# scores[b, i, j] = Σ_d Q[b, i, d] * K[b, j, d]

# Standard PyTorch way - need to transpose K
scores_torch = torch.bmm(Q, K.transpose(1, 2))
print(f"scores shape: {scores_torch.shape}  (batch, query_pos, key_pos)")

# Einsum way - much clearer what's happening!
# b=batch, i=query position, j=key position, d=dimension (summed over)
scores_einsum = einsum('bid,bjd->bij', Q, K)

print(f"\nUsing torch.bmm(Q, K.transpose(1,2)):")
print(scores_torch[0])  # first batch

print(f"\nUsing einsum('bid,bjd->bij', Q, K):")
print(scores_einsum[0])

print(f"\nMatch: {torch.allclose(scores_torch, scores_einsum)}")


## The Next Problem: Reshape and Permute Are Confusing

Einsum handles computation. But what about just rearranging tensor shapes?

PyTorch has `reshape`, `permute`, `view`, `transpose`... and they're often combined in confusing ways.


In [None]:
# Example: we have an image batch in NHWC format (batch, height, width, channels)
# and need to convert to NCHW format (what PyTorch expects)

# Fake image batch: 2 images, 4x4 pixels, 3 channels
images_nhwc = torch.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3).float()
print(f"Original shape (NHWC): {images_nhwc.shape}")
print(f"  N={images_nhwc.shape[0]}, H={images_nhwc.shape[1]}, W={images_nhwc.shape[2]}, C={images_nhwc.shape[3]}")


In [None]:
# PyTorch way: permute with dimension indices
# NHWC -> NCHW means: dim 0 stays, dim 3 moves to 1, dim 1 moves to 2, dim 2 moves to 3
# So permute(0, 3, 1, 2)

images_nchw_torch = images_nhwc.permute(0, 3, 1, 2)
print(f"After permute(0, 3, 1, 2): {images_nchw_torch.shape}")

# Quick, what does permute(0, 2, 1, 3) do? 
# Hard to tell without thinking carefully about indices!


In [None]:
# einops way: use NAMES, not numbers!
images_nchw_einops = rearrange(images_nhwc, 'n h w c -> n c h w')
print(f"After rearrange('n h w c -> n c h w'): {images_nchw_einops.shape}")

print(f"\nMatch: {torch.allclose(images_nchw_torch, images_nchw_einops)}")
print("\n→ The einops version is SELF-DOCUMENTING!")


## Einops Power: Splitting and Merging Dimensions

Einops can also split one dimension into multiple, or merge multiple into one. This is incredibly useful for multi-head attention.


In [None]:
# Merging dimensions: flatten image to vector
# (batch, height, width, channels) -> (batch, height*width*channels)

images = torch.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3).float()
print(f"Original: {images.shape}")

# PyTorch way - need to calculate the size
flat_torch = images.reshape(2, -1)  # or images.view(2, 4*4*3)
print(f"PyTorch reshape: {flat_torch.shape}")

# Einops way - parentheses merge dimensions
flat_einops = rearrange(images, 'b h w c -> b (h w c)')
print(f"Einops 'b h w c -> b (h w c)': {flat_einops.shape}")

print(f"\nMatch: {torch.allclose(flat_torch, flat_einops)}")


In [None]:
# Splitting dimensions: for multi-head attention
# We have (batch, seq_len, d_model) and want (batch, heads, seq_len, d_head)
# where d_model = heads * d_head

batch, seq_len, d_model = 2, 5, 12
heads = 3  # so d_head = 12/3 = 4

x = torch.randn(batch, seq_len, d_model)
print(f"Original: {x.shape}  (batch, seq_len, d_model={d_model})")

# PyTorch way - reshape then transpose
x_torch = x.view(batch, seq_len, heads, d_model // heads).transpose(1, 2)
print(f"PyTorch view+transpose: {x_torch.shape}")

# Einops way - split d into (heads, d_head) using parentheses
x_einops = rearrange(x, 'b s (h d) -> b h s d', h=heads)
print(f"Einops 'b s (h d) -> b h s d': {x_einops.shape}")

print(f"\nMatch: {torch.allclose(x_torch, x_einops)}")


In [None]:
# And back: merge heads back together after attention
# (batch, heads, seq_len, d_head) -> (batch, seq_len, d_model)

# PyTorch way
back_torch = x_torch.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
print(f"PyTorch transpose+view: {back_torch.shape}")

# Einops way - just reverse the pattern!
back_einops = rearrange(x_einops, 'b h s d -> b s (h d)')
print(f"Einops 'b h s d -> b s (h d)': {back_einops.shape}")

print(f"\nMatch: {torch.allclose(back_torch, back_einops)}")


## Practical Example: Multi-Head Attention in One Cell

Let's put einsum and einops together for a complete multi-head attention implementation.


In [None]:
def multihead_attention(q, k, v, heads):
    """
    Multi-head attention using einsum and einops.
    
    q, k, v: (batch, seq_len, d_model)
    Returns: (batch, seq_len, d_model)
    """
    # 1. Split into heads: (batch, seq, d_model) -> (batch, heads, seq, d_head)
    q = rearrange(q, 'b s (h d) -> b h s d', h=heads)
    k = rearrange(k, 'b s (h d) -> b h s d', h=heads)
    v = rearrange(v, 'b s (h d) -> b h s d', h=heads)
    
    # 2. Attention scores: Q @ K.T for each head
    # (batch, heads, seq_q, d) @ (batch, heads, seq_k, d) -> (batch, heads, seq_q, seq_k)
    scores = einsum('bhqd,bhkd->bhqk', q, k)
    
    # 3. Scale and softmax
    d_head = q.shape[-1]
    scores = scores / (d_head ** 0.5)
    attn_weights = torch.softmax(scores, dim=-1)
    
    # 4. Apply attention to values
    # (batch, heads, seq_q, seq_k) @ (batch, heads, seq_k, d) -> (batch, heads, seq_q, d)
    out = einsum('bhqk,bhkd->bhqd', attn_weights, v)
    
    # 5. Merge heads back: (batch, heads, seq, d_head) -> (batch, seq, d_model)
    out = rearrange(out, 'b h s d -> b s (h d)')
    
    return out


In [None]:
# Test it!
batch, seq_len, d_model, heads = 2, 4, 12, 3

torch.manual_seed(42)
q = torch.randn(batch, seq_len, d_model)
k = torch.randn(batch, seq_len, d_model)
v = torch.randn(batch, seq_len, d_model)

out = multihead_attention(q, k, v, heads)
print(f"Input shape:  {q.shape}")
print(f"Output shape: {out.shape}")
print(f"\nOutput[0, 0, :6]: {out[0, 0, :6].tolist()}")


## Summary: The Einsum/Einops Mental Model

### Einsum
- Labels dimensions with letters
- Multiply elements where indices align
- Sum over indices that don't appear in output
- **"ij,jk->ik"** = "sum over j, keep i and k"

### Einops Rearrange  
- Names dimensions explicitly
- Parentheses merge: `(h w)` combines height and width
- Parentheses split: `(h d)` splits into heads and d_head
- **"b h w c -> b c h w"** = "move channels before height/width"

### When to use what:
| Task | Tool |
|------|------|
| Matrix multiply, dot products | einsum |
| Batch operations | einsum |
| Reshape, transpose | einops rearrange |
| Split/merge dimensions | einops rearrange |
| Self-documenting shapes | both! |


## Quick Reference: Common Patterns


In [None]:
# Common einsum patterns
A = torch.randn(3, 4)
B = torch.randn(4, 5)
v = torch.randn(4)

print("EINSUM PATTERNS:")
print(f"Sum all:          einsum('ij->')        = {einsum('ij->', A).shape}")
print(f"Sum rows:         einsum('ij->j')       = {einsum('ij->j', A).shape}")
print(f"Sum cols:         einsum('ij->i')       = {einsum('ij->i', A).shape}")
print(f"Transpose:        einsum('ij->ji')      = {einsum('ij->ji', A).shape}")
print(f"Matrix multiply:  einsum('ij,jk->ik')   = {einsum('ij,jk->ik', A, B).shape}")
print(f"Dot product:      einsum('i,i->')       = {einsum('i,i->', v, v).shape}")
print(f"Outer product:    einsum('i,j->ij')     = {einsum('i,j->ij', v, v).shape}")
print(f"Hadamard:         einsum('ij,ij->ij')   = {einsum('ij,ij->ij', A, A).shape}")
print(f"Trace:            einsum('ii->')        = {einsum('ii->', torch.randn(3,3)).shape}")


In [None]:
# Common einops patterns
x = torch.randn(2, 3, 4, 5)  # (batch, channels, height, width)

print("\nEINOPS REARRANGE PATTERNS:")
print(f"Original:         {x.shape}")
print(f"Transpose HW:     'b c h w -> b c w h'     = {rearrange(x, 'b c h w -> b c w h').shape}")
print(f"NCHW to NHWC:     'b c h w -> b h w c'     = {rearrange(x, 'b c h w -> b h w c').shape}")
print(f"Flatten spatial:  'b c h w -> b c (h w)'   = {rearrange(x, 'b c h w -> b c (h w)').shape}")
print(f"Flatten all:      'b c h w -> b (c h w)'   = {rearrange(x, 'b c h w -> b (c h w)').shape}")
print(f"Split channels:   'b (g c) h w -> b g c h w' = {rearrange(x, 'b (g c) h w -> b g c h w', g=1).shape}")
print(f"Merge batch+chan: 'b c h w -> (b c) h w'   = {rearrange(x, 'b c h w -> (b c) h w').shape}")


## Bonus: Image Patching for Vision Transformers

ViT needs to split images into patches. This is a perfect use case for einops!


In [None]:
# Split image into patches
# Input: (batch, channels, height, width)
# Output: (batch, num_patches, patch_size * patch_size * channels)

batch, channels, height, width = 2, 3, 8, 8
patch_size = 4

images = torch.randn(batch, channels, height, width)
print(f"Input images: {images.shape}")

# Split into patches and flatten each patch
patches = rearrange(
    images, 
    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 
    p1=patch_size, 
    p2=patch_size
)

num_patches = (height // patch_size) * (width // patch_size)
patch_dim = patch_size * patch_size * channels

print(f"Patches: {patches.shape}")
print(f"  num_patches = ({height}//{patch_size}) × ({width}//{patch_size}) = {num_patches}")
print(f"  patch_dim = {patch_size} × {patch_size} × {channels} = {patch_dim}")


In [None]:
# Visualize the patching
fig, axes = plt.subplots(1, 5, figsize=(15, 3))

# Create a simple test image with a pattern
test_img = torch.zeros(1, 1, 8, 8)
test_img[0, 0, :4, :4] = 1  # top-left = white
test_img[0, 0, :4, 4:] = 0.5  # top-right = gray
test_img[0, 0, 4:, :4] = 0.3  # bottom-left = dark gray
test_img[0, 0, 4:, 4:] = 0.7  # bottom-right = light gray

# Show original
axes[0].imshow(test_img[0, 0], cmap='gray', vmin=0, vmax=1)
axes[0].set_title('Original 8×8')
axes[0].axis('off')

# Extract patches
test_patches = rearrange(test_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)

# Show each patch
for i in range(4):
    patch = test_patches[0, i].reshape(4, 4)
    axes[i+1].imshow(patch, cmap='gray', vmin=0, vmax=1)
    axes[i+1].set_title(f'Patch {i}')
    axes[i+1].axis('off')

plt.suptitle("Image → Patches (for Vision Transformer)", fontsize=14)
plt.tight_layout()
plt.show()


## Key Takeaways

1. **Einsum** expresses tensor operations by labeling dimensions
   - Same index = align/multiply element-wise
   - Missing from output = sum over it
   - Makes complex operations like attention readable

2. **Einops rearrange** expresses shape transformations with named dimensions
   - No more cryptic `permute(0, 3, 1, 2)`
   - Parentheses split/merge dimensions
   - Self-documenting code

3. **Both tools** make your code:
   - Easier to read
   - Easier to debug (shapes are explicit)
   - Less error-prone (no off-by-one dimension mistakes)

**Practice**: Next time you write a `transpose`, `reshape`, or `bmm`, try expressing it with einsum/einops first!

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        # Create a learnable position embedding matrix of shape (n_ctx, d_model)
        # Each row corresponds to a position (up to max sequence length), each column to a dimension in the embedding
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        # Initialize the embedding weights with a normal distribution (mean=0, std=cfg.init_range)
        # This initialization helps with model stability at the start
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(
        self, tokens: Int[Tensor, "batch position"]
    ) -> Float[Tensor, "batch position d_model"]:
        # tokens: input tensor of token indices, shape (batch, position)
        # Goal: provide a position embedding per position per batch for further processing
        batch, seq_len = tokens.shape  # Unpack number of examples and sequence length
        # Select position embeddings for the required sequence length ([:seq_len] over positions)
        # Shape at this point: (seq_len, d_model)
        # Use einops.repeat to expand these embeddings along the batch dimension so each batch gets same position ids
        # Output shape: (batch, seq_len, d_model)
        return einops.repeat(
            self.W_pos[:seq_len],           # only positions up to sequence length
            "seq d_model -> batch seq d_model",  # instruct einops on reshaping: broadcast across batch
            batch=batch
        )

# Advanced: Let's build up the transformer as a bunch of einops operations

We'll use a batch size of 2, sequence length of 3, and embedding dimension of 4. We'll use a hidden dimension of 8 for the MLP. 

This is the original transformer but using a few tricks. 
- Self attention (encoder only)
- Multi head attention
- RoPE
- LayerNorm before attention
- KV cache to save on compute
- MLP with GELU activation
- Final layer norm and linear output

The goal here is not to focus on the details of the transformer, but rather to show how we could use only einops to create one.

1. Embeddings
2. RoPE
3. Layernorm
4. Attention (KV Cache, dot product, softmax, mask, keys, sum, add to residual)
5. MLP
6. Output
7. Multiple heads