# Task 2.4: Einsum Mastery - SOLUTIONS

**Module:** 2 - Python for AI/ML  

This notebook contains solutions to all exercises from the Einsum Mastery Lab.

---

In [None]:
import numpy as np

print("Solutions Notebook for Einsum Mastery")
print("=" * 50)

---

## Exercise 1: Batch Outer Product

**Task:** Compute the outer product for each pair of vectors in a batch.
- Input: Two tensors of shape `(batch, dim)`
- Output: Shape `(batch, dim, dim)`

In [None]:
# SOLUTION - Exercise 1
np.random.seed(42)
a = np.random.randn(16, 64)  # 16 vectors of dim 64
b = np.random.randn(16, 64)  # 16 vectors of dim 64

# Solution: 'bi,bj->bij'
# - a has indices (b, i)
# - b has indices (b, j)
# - No index is summed over (both appear in output)
# - Result has indices (b, i, j)
batch_outer = np.einsum('bi,bj->bij', a, b)

print(f"a shape: {a.shape}")
print(f"b shape: {b.shape}")
print(f"Result shape: {batch_outer.shape}")

# Verify against loop version
expected = np.array([np.outer(a[i], b[i]) for i in range(16)])
print(f"\nCorrect? {np.allclose(batch_outer, expected)}")

### Understanding the Solution

```
'bi,bj->bij'

a[b, i]  ×  b[b, j]  =  result[b, i, j]

For each batch element b:
  For each i and j:
    result[b, i, j] = a[b, i] * b[b, j]
```

This is exactly the outer product definition, applied to each batch element!

---

## Exercise 2: Bilinear Form

**Task:** Compute $x^T A y$ for batches of vectors and a shared matrix.
- x: shape `(batch, m)`
- A: shape `(m, n)` (shared across batch)
- y: shape `(batch, n)`
- Result: shape `(batch,)`

In [None]:
# SOLUTION - Exercise 2
np.random.seed(42)
x = np.random.randn(32, 64)   # (batch=32, m=64)
A = np.random.randn(64, 128)  # (m=64, n=128)
y = np.random.randn(32, 128)  # (batch=32, n=128)

# Solution: 'bm,mn,bn->b'
# - x has indices (b, m)
# - A has indices (m, n)
# - y has indices (b, n)
# - m and n are summed over (don't appear in output)
# - b is preserved
bilinear = np.einsum('bm,mn,bn->b', x, A, y)

print(f"x shape: {x.shape}")
print(f"A shape: {A.shape}")
print(f"y shape: {y.shape}")
print(f"Result shape: {bilinear.shape}")

# Verify against loop version
expected = np.array([x[i] @ A @ y[i] for i in range(32)])
print(f"\nCorrect? {np.allclose(bilinear, expected)}")

# Sample values
print(f"\nFirst 5 values: {bilinear[:5].round(2)}")

### Understanding the Solution

The bilinear form $x^T A y$ expands to:
$$\sum_m \sum_n x_m \cdot A_{mn} \cdot y_n$$

In einsum notation: `'bm,mn,bn->b'`

```
For each batch b:
  result[b] = sum over m and n of: x[b,m] * A[m,n] * y[b,n]
```

### Alternative: Two-Step Computation

In [None]:
# Alternative: Break into two steps

# Step 1: Compute A @ y for each batch element: (m, n) @ (b, n)^T -> (b, m)
# Using einsum: 'mn,bn->bm'
Ay = np.einsum('mn,bn->bm', A, y)
print(f"Ay shape: {Ay.shape}")

# Step 2: Compute x · (Ay) for each batch element
# Using einsum: 'bm,bm->b'
bilinear_v2 = np.einsum('bm,bm->b', x, Ay)

print(f"Same result? {np.allclose(bilinear, bilinear_v2)}")

---

## Bonus: Multi-Head Attention with Einsum

In [None]:
# Complete multi-head attention implementation
def softmax(x, axis=-1):
    x_max = x.max(axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / exp_x.sum(axis=axis, keepdims=True)

def multi_head_attention(Q, K, V, mask=None):
    """
    Multi-head scaled dot-product attention using einsum.
    
    Args:
        Q: Query tensor (batch, heads, seq_q, dim)
        K: Key tensor (batch, heads, seq_k, dim)
        V: Value tensor (batch, heads, seq_k, dim)
        mask: Optional attention mask (batch, heads, seq_q, seq_k)
    
    Returns:
        output: (batch, heads, seq_q, dim)
        attention_weights: (batch, heads, seq_q, seq_k)
    """
    d_k = Q.shape[-1]
    
    # Compute attention scores
    # Q: (b, h, sq, d) @ K^T: (b, h, d, sk) -> (b, h, sq, sk)
    scores = np.einsum('bhqd,bhkd->bhqk', Q, K) / np.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)
    
    # Softmax
    weights = softmax(scores, axis=-1)
    
    # Apply to values
    # weights: (b, h, sq, sk) @ V: (b, h, sk, d) -> (b, h, sq, d)
    output = np.einsum('bhqk,bhkd->bhqd', weights, V)
    
    return output, weights

# Test
np.random.seed(42)
batch, heads, seq_len, dim = 2, 4, 8, 16

Q = np.random.randn(batch, heads, seq_len, dim).astype(np.float32)
K = np.random.randn(batch, heads, seq_len, dim).astype(np.float32)
V = np.random.randn(batch, heads, seq_len, dim).astype(np.float32)

output, weights = multi_head_attention(Q, K, V)

print("Multi-Head Attention Results:")
print(f"  Q shape: {Q.shape}")
print(f"  K shape: {K.shape}")
print(f"  V shape: {V.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Weights shape: {weights.shape}")
print(f"  Weights sum to 1? {np.allclose(weights.sum(axis=-1), 1.0)}")

In [None]:
# With causal mask (for autoregressive models)
causal_mask = np.triu(np.ones((seq_len, seq_len)), k=1)
causal_mask = 1 - causal_mask  # 1 = attend, 0 = mask
causal_mask = causal_mask[np.newaxis, np.newaxis, :, :]  # Add batch and head dims

print("Causal mask (for token 5):")
print(causal_mask[0, 0, 4, :])  # Token 5 can attend to tokens 0-4

output_causal, weights_causal = multi_head_attention(Q, K, V, mask=causal_mask)

print(f"\nCausal attention weights for position 4:")
print(f"  {weights_causal[0, 0, 4, :].round(3)}")
print(f"  (positions 5-7 should be ~0)")

---

## Complete Einsum Reference

### Single Array Operations

| Operation | Einsum | NumPy |
|-----------|--------|-------|
| Sum all | `'ij->'` | `np.sum(A)` |
| Row sums | `'ij->i'` | `np.sum(A, axis=1)` |
| Column sums | `'ij->j'` | `np.sum(A, axis=0)` |
| Transpose | `'ij->ji'` | `A.T` |
| Diagonal | `'ii->i'` | `np.diag(A)` |
| Trace | `'ii->'` | `np.trace(A)` |

### Two Array Operations

| Operation | Einsum | NumPy |
|-----------|--------|-------|
| Dot product | `'i,i->'` | `np.dot(a, b)` |
| Outer product | `'i,j->ij'` | `np.outer(a, b)` |
| Matrix multiply | `'ik,kj->ij'` | `A @ B` |
| Matrix-vector | `'ij,j->i'` | `A @ v` |
| Element-wise | `'ij,ij->ij'` | `A * B` |

### Batch Operations

| Operation | Einsum |
|-----------|--------|
| Batch matmul | `'bik,bkj->bij'` |
| Batch outer | `'bi,bj->bij'` |
| Batch dot | `'bi,bi->b'` |
| Attention scores | `'bhsd,bhtd->bhst'` |
| Attention apply | `'bhst,bhtd->bhsd'` |

---

**End of Solutions**