# Exercise 1: Tensor basics 
In this exercise you will learn the basics of tensor creation, manipulation, indexing, broadcasting, vectorization, einsum, and attention masking fundamentals. These basics are important for understanding any complex implementation later on so make sure you understand them well.

**To complete this exercise fill in all TODOs in the functions below.** 

Make sure to check the output of your function and whether or not it fulfills the requirements outlined in the function definition. Do NOT change the function signature or name since we will be running checks on your functions during grading.

### Shape legend used in this notebook
- `B`: batch size
- `T`: sequence length / time
- `D`: feature dimension
- `H`: number of attention heads
- `Dh`: per-head feature dimension

### Debugging tip: what to print
When you get a shape error, print:
- `x.shape`, `x.dtype`, `x.device`
- `x.is_contiguous()` (important for `view`)
For masks also print:
- `mask.shape`, `mask.dtype`, `mask.sum()` and a small slice like `mask[0, :10]`

### Reproducibility tip: seeding in PyTorch
Many operations in deep learning involve randomness (e.g., initializing model weights, shuffling data, dropout, random augmentations).
**Seeding** sets the starting state of PyTorch’s random number generator so that these random choices become **repeatable**.

- If you set the same seed and run the same code again, you should get the same *random* tensors / initial weights.
- If you don’t set a seed, results can vary between runs.

Common usage: `torch.manual_seed(seed)`

Note: even with fixed seeds, some GPU operations can still be non-deterministic due to performance optimizations. For this assignment, seeding is mainly to make debugging easier and to ensure everyone can reproduce the same intermediate results. If you are given a seed, make sure to use it when creating tensors or performing other operations.

## Tensor creation
This warmup exercise teaches you how to create tensors with different shapes and values. A few details about tensor creation that are good to know:
- `torch.tensor([...])` infers dtype from Python values (ints → integer tensor, floats → float tensor).
- `torch.arange(start, end)` is **end-exclusive**.
- `torch.linspace(start, end, steps)` is **end-inclusive**.

In [2]:
from collections.abc import Sequence
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [25]:
def make_tensor(data, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """ Create a tensor from Python data (list/tuple/nested lists). """
    return torch.tensor(data, dtype=dtype, device=device)

x = make_tensor([[1, 2], [3, 4]], dtype=torch.float32)
x

tensor([[1., 2.],
        [3., 4.]])

In [26]:
def make_zeros(shape: Sequence[int], dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a tensor filled with zeros."""
    return torch.zeros(shape, dtype=dtype,device=device)

z = make_zeros((2, 3), dtype=torch.float64)
z

tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)

In [27]:
def make_ones_like(x: torch.Tensor) -> torch.Tensor:
    """Create a tensor of ones with the same shape, dtype, and device as x. """
    return torch.ones_like(x)

base = torch.randn(2, 3, dtype=torch.float32)
ones = make_ones_like(base)
ones

tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [28]:
def make_arange(start: int, end: int, step: int = 1, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a 1D tensor containing values [start, start+step, ..., < end]."""
    return torch.arange(start, end, step, dtype=dtype, device=device)

ar = make_arange(0, 5, 2, dtype=torch.int64)
ar

tensor([0, 2, 4])

In [29]:
def make_linspace(start: float, end: float, steps: int, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a 1D tensor with evenly spaced values from start to end (inclusive)."""
    return torch.linspace(start, end, steps, dtype=dtype, device=device)

ls = make_linspace(0.0, 1.0, steps=5, dtype=torch.float32)
ls

tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])

In [30]:
def make_randn(shape: Sequence[int], seed: int | None = None, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> torch.Tensor:
    """Create a tensor filled with values from a standard normal distribution."""
    return torch.randn(shape, generator=torch.Generator(device=device if device else 'cpu').manual_seed(seed), dtype=dtype, device=device)
a = make_randn((2, 3), seed=123, dtype=torch.float32)
a

tensor([[-0.1115,  0.1204, -0.3696],
        [-0.2404, -1.1969,  0.2093]])

In [37]:
def cast_dtype_and_move(x: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Convert tensor dtype and move to device."""
    return x.to(device=device, dtype=dtype)

casted = cast_dtype_and_move(torch.tensor([1, 2, 3]), torch.device("cpu"), torch.float32)
print(f"Casted: {casted}, dtype: {casted.dtype}")

Casted: tensor([1., 2., 3.]), dtype: torch.float32


## Shape manipulation
Now that we covered the basic tensor creation schemes, we want to focus on shape manipulation. Understanding the difference between these mechanisms is key for building larger systems and many people still get it wrong. 
The core ideas to understand are:
- **Contiguous tensors** store data in a single, row-major memory layout.
- Many ops (especially slicing like `x[:, ::2]`, `transpose`, `permute`) often create **non-contiguous** tensors (no copy but different strides).
- `view(...)` is **zero-copy** but typically requires **contiguous** memory → may throw an error.
- `reshape(...)` tries to return a view, but if the tensor is non-contiguous it will **allocate/copy**.
- `contiguous()` forces a contiguous copy when the tensor isn’t contiguous.

If you *need* a view after reordering dims: call `x = x.contiguous()` first (this makes a contiguous copy).

In [41]:
def reshape_tensor(x: torch.Tensor, new_shape: Sequence[int]) -> torch.Tensor:
    """Reshape tensor to new_shape (may return a view or a copy)."""
    return torch.reshape(x, new_shape)

x = torch.arange(6)
y = reshape_tensor(x, (2, 3))
print(f"Before reshaping: {x}, shape: {x.shape}")
print(f"After reshaping: {y}, shape: {y.shape}")

Before reshaping: tensor([0, 1, 2, 3, 4, 5]), shape: torch.Size([6])
After reshaping: tensor([[0, 1, 2],
        [3, 4, 5]]), shape: torch.Size([2, 3])


In [43]:
def view_tensor(x: torch.Tensor, new_shape: Sequence[int]) -> torch.Tensor:
    """View tensor as new_shape (requires contiguous memory and doesn't allocate new memory for the tensor data)."""
    return x.view(new_shape)

y_view = view_tensor(x, (2, 3))
print(f"Before view: {x}, shape: {x.shape}")
print(f"After view: {y_view}, shape: {y_view.shape}")

Before view: tensor([0, 1, 2, 3, 4, 5]), shape: torch.Size([6])
After view: tensor([[0, 1, 2],
        [3, 4, 5]]), shape: torch.Size([2, 3])


In [45]:
def flatten_from_dim(x: torch.Tensor, start_dim: int = 0) -> torch.Tensor:
    """Flatten a tensor starting from start_dim into a single dimension."""
    return x.flatten(start_dim)

x2 = torch.randn(2, 3, 4)
flat = flatten_from_dim(x2, start_dim=1)
print(f"Before flattening: {x2.shape}")
print(f"After flattening: {flat.shape}")

Before flattening: torch.Size([2, 3, 4])
After flattening: torch.Size([2, 12])


In [46]:
def add_singleton_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
    """Insert a size-1 dimension at position dim."""
    return x.unsqueeze(dim)

x3 = torch.randn(5, 7)
x3s = add_singleton_dim(x3, dim=1)
print(f"Before adding singleton dim: {x3.shape}")
print(f"After adding singleton dim at dim=1: {x3s.shape}")

Before adding singleton dim: torch.Size([5, 7])
After adding singleton dim at dim=1: torch.Size([5, 1, 7])


In [48]:
def remove_singleton_dims(x: torch.Tensor, dim: int | None = None) -> torch.Tensor:
    """Remove size-1 dimensions."""
    if dim is None:
        return x.squeeze()
    return x.squeeze(dim)

x4 = torch.randn(2, 1, 3)
x4s = remove_singleton_dims(x4)
print(f"Before removing singleton dim: {x4.shape}")
print(f"After removing singleton dim at dim=1: {x4s.shape}")

Before removing singleton dim: torch.Size([2, 1, 3])
After removing singleton dim at dim=1: torch.Size([2, 3])


In [49]:
def transpose_last_two(x: torch.Tensor) -> torch.Tensor:
    """Swap the last two dimensions of x."""
    return x.transpose(-2, -1)

x6 = torch.randn(2, 3, 4)
x6t = transpose_last_two(x6)
print(f"Before transpose: {x6.shape}")
print(f"After transpose: {x6t.shape}")

Before transpose: torch.Size([2, 3, 4])
After transpose: torch.Size([2, 4, 3])


In [50]:
def permute_bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
    """Convert (B, H, W, C) tensor into (B, C, H, W)."""
    return x.permute(0, 3, 1, 2)

x7 = torch.randn(8, 32, 32, 3)
x7p = permute_bhwc_to_bchw(x7)
print(f"Before permute: {x7.shape}")
print(f"After permute: {x7p.shape}")

Before permute: torch.Size([8, 32, 32, 3])
After permute: torch.Size([8, 3, 32, 32])


In [55]:
def make_contiguous(x: torch.Tensor) -> torch.Tensor:
    """Check if tensor is contiguous and if not make contiguous."""
    if not x.is_contiguous():
        return x.contiguous()
    return x

x8 = torch.randn(4, 6)[:, ::2]
print(f"Before making contiguous - is_contiguous: {x8.is_contiguous()}")
print(f"Shape: {x8.shape}")
x8c = make_contiguous(x8)
print(f"After making contiguous - is_contiguous: {x8c.is_contiguous()}")
print(f"Shape: {x8c.shape}")

Before making contiguous - is_contiguous: False
Shape: torch.Size([4, 3])
After making contiguous - is_contiguous: True
Shape: torch.Size([4, 3])


## Indexing
Now that we know how to create tensors and manipulate them we need to understand how we can extract certain components from them using indexing. 
- Basic slicing (`x[a:b]`) returns a view when possible.
- “Fancy” indexing (lists/tensors of indices) usually allocates a new tensor.
- In-place vs out-of-place matters: if a function says “return a copy, leave the input unchanged”, you need `clone()`.

In [57]:
def slice_rows(x: torch.Tensor, start: int, end: int) -> torch.Tensor:
    """Slice rows in a 2D tensor: x[start:end, :]."""
    return x[start:end]

x = torch.arange(12).reshape(4, 3)
rows = slice_rows(x, 1, 3)
print(f"Original tensor:\n{x}")
print(f"\nSliced rows [1:3]:\n{rows}")

Original tensor:
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

Sliced rows [1:3]:
tensor([[3, 4, 5],
        [6, 7, 8]])


In [60]:
def select_columns(x: torch.Tensor, cols: Sequence[int]) -> torch.Tensor:
    """Select specific columns from a 2D tensor."""
    return x[:, cols]

cols = select_columns(x, [0, 2])
print(f"Original tensor:\n{x}\n")
print(f"Selected columns [0, 2]:\n{cols}")

Original tensor:
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

Selected columns [0, 2]:
tensor([[ 0,  2],
        [ 3,  5],
        [ 6,  8],
        [ 9, 11]])


In [66]:
def get_diagonal(x: torch.Tensor) -> torch.Tensor:
    """Get the diagonal of a 2D tensor."""
    return x.diagonal()

d = get_diagonal(torch.tensor([[1, 2], [3, 4]]))
print(f"Original:\n{torch.tensor([[1, 2], [3, 4]])}\n")
print(f"Diagonal:\n{d}")

Original:
tensor([[1, 2],
        [3, 4]])

Diagonal:
tensor([1, 4])


In [69]:
def set_subtensor(x: torch.Tensor, row_idx: int, col_idx: int, value: float) -> torch.Tensor:
    """Return a copy of x where x[row_idx, col_idx] is set to value."""
    result = x.clone()
    result[row_idx, col_idx] = value
    return result

base = torch.zeros(2, 2)
out = set_subtensor(base, 0, 1, 5.0)
print(f"Original tensor:\n{base}\n")
print(f"After setting [0, 1] to 5.0:\n{out}")

Original tensor:
tensor([[0., 0.],
        [0., 0.]])

After setting [0, 1] to 5.0:
tensor([[0., 5.],
        [0., 0.]])


In [72]:
def gather_rows(x: torch.Tensor, row_indices: torch.Tensor) -> torch.Tensor:
    """Gather (concat) rows from x using row_indices."""
    return x[row_indices]

x2 = torch.tensor([[10, 11], [20, 21], [30, 31]])
idx = torch.tensor([2, 0])
gathered = gather_rows(x2, idx)
print(f"Original tensor:\n{x2}\n")
print(f"Row indices: {idx}\n")
print(f"Gathered rows:\n{gathered}")

Original tensor:
tensor([[10, 11],
        [20, 21],
        [30, 31]])

Row indices: tensor([2, 0])

Gathered rows:
tensor([[30, 31],
        [10, 11]])


## Broadcasting and reducing
Now we're covering a pytorch mechanism that lets you apply elementwise ops without using python loops. It's important to understand how it works to trace your shapes in complicated systems. The broadcasting rules to know are:
- Dimensions align from the **right**.
- A dimension can broadcast if it’s equal or one of them is **1**.

### Reduction ops and `keepdim`

When you reduce over a dimension (e.g. `sum`, `mean`, `max`), PyTorch can either:

- **remove** the reduced dimension (`keepdim=False`, default), or
- **keep** it as size 1 (`keepdim=True`)

Keeping the dimension is often helpful because it makes broadcasting back “just work”.

#### Shape diagram examples

Assume `x` has shape `(B, T, D)`:

**Sum over time**
- `x.sum(dim=1)` → shape `(B, D)`
- `x.sum(dim=1, keepdim=True)` → shape `(B, 1, D)`

**Mean over features**
- `x.mean(dim=2)` → shape `(B, T)`
- `x.mean(dim=2, keepdim=True)` → shape `(B, T, 1)`

#### Why `keepdim=True` helps with broadcasting

Example: center `x` by subtracting the mean over `T`

- If `m = x.mean(dim=1)` has shape `(B, D)`, then `x - m` **fails** (shapes `(B,T,D)` and `(B,D)` don't align).
- If `m = x.mean(dim=1, keepdim=True)` has shape `(B,1,D)`, then `x - m` **works** via broadcasting.

In [75]:
def sum_over_dim(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    """Sum tensor values along dimension dim."""
    return x.sum(dim=dim, keepdim=keepdim)

x = torch.ones(2, 3)
y = sum_over_dim(x, dim=1)
print("x:",x)
print(f"Sum over dim 1: {y}")
print(f"Shape: {y.shape}")

x: tensor([[1., 1., 1.],
        [1., 1., 1.]])
Sum over dim 1: tensor([3., 3.])
Shape: torch.Size([2])


In [76]:
def mean_over_dim(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    """Mean along dimension dim."""
    return x.mean(dim=dim, keepdim=keepdim)

x2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y2 = mean_over_dim(x2, dim=0)
print("x2:", x2)
print(f"Mean over dim 0: {y2}")
print(f"Shape: {y2.shape}")

x2: tensor([[1., 2.],
        [3., 4.]])
Mean over dim 0: tensor([2., 3.])
Shape: torch.Size([2])


In [77]:
def max_over_dim(x: torch.Tensor, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Max values and argmax indices along dimension dim."""
    values, indices = x.max(dim=dim)
    return values, indices

x3 = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
values, idx = max_over_dim(x3, dim=1)
print(f"Input tensor:\n{x3}")
print(f"Max values over dim=1: {values}")
print(f"Argmax indices over dim=1: {idx}")

Input tensor:
tensor([[1., 5.],
        [3., 2.]])
Max values over dim=1: tensor([5., 3.])
Argmax indices over dim=1: tensor([1, 0])


In [78]:
def argmax_over_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
    """Argmax indices along dimension dim."""
    return x.argmax(dim=dim)

idx2 = argmax_over_dim(x3, dim=1)
print(f"Input tensor for argmax_over_dim:\n{x3}")
print(f"Argmax indices over dim=1: {idx2}")

Input tensor for argmax_over_dim:
tensor([[1., 5.],
        [3., 2.]])
Argmax indices over dim=1: tensor([1, 0])


In [79]:
def broadcast_add_vector(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Add a vector v to each row of a 2D tensor x using broadcasting."""
    return x + v

x4 = torch.zeros(3, 2)
v = torch.tensor([10.0, 20.0])
y4 = broadcast_add_vector(x4, v)
print(f"Input tensor x4:\n{x4}")
print(f"Vector v: {v}")
print(f"Result of broadcast_add_vector(x4, v):\n{y4}")

Input tensor x4:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
Vector v: tensor([10., 20.])
Result of broadcast_add_vector(x4, v):
tensor([[10., 20.],
        [10., 20.],
        [10., 20.]])


## Vectorization
We want to avoid slow (due to per-iteration overhead) python loops as much as possible and pytorch gives us many tools to avoid it. We cover these basics:
- `cat` vs `stack` (concatenate existing dims vs create a new dim)
- `repeat` vs `expand`
- `scatter_add` / `index_add` for accumulation
- `where` for conditional selection

### `expand` vs `repeat`

- `repeat(...)` **copies** data → larger tensor with independent storage.
- `expand(...)` **does not copy** data → it creates a *view* with clever strides.

This has two important implications:

1) `expand` only works when expanding a **size-1 dimension** (broadcasting a singleton).
2) The expanded tensor may have **many positions pointing to the same memory**.  
   Modifying the expanded tensor can therefore produce surprising results (multiple rows change).

Rule of thumb:
- Use `expand` for read-only broadcasting.
- Use `repeat` if you truly need independent copies.


NOTE: We implore you to write your own quick checks from now on for calling the functions and checking their output. As before you are still required to fill in the TODOs in each function.

In [80]:
def concat_tensors(tensors: Sequence[torch.Tensor], dim: int = 0) -> torch.Tensor:
    """Concatenate tensors along dim. NOTE: This will always allocate new memory"""
    return torch.cat(list(tensors), dim=dim)

# Quick check for concat_tensors
c1 = torch.tensor([[1, 2], [3, 4]])
c2 = torch.tensor([[5, 6]])
concatenated = concat_tensors([c1, c2], dim=0)
print(f"First tensor c1:\n{c1}")
print(f"Second tensor c2:\n{c2}")
print(f"Concatenated along dim=0:\n{concatenated}")

First tensor c1:
tensor([[1, 2],
        [3, 4]])
Second tensor c2:
tensor([[5, 6]])
Concatenated along dim=0:
tensor([[1, 2],
        [3, 4],
        [5, 6]])


In [81]:
def stack_tensors(tensors: Sequence[torch.Tensor], dim: int = 0) -> torch.Tensor:
    """Stack tensors along a new dimension dim."""
    return torch.stack(list(tensors), dim=dim)

# Quick check for stack_tensors
s1 = torch.tensor([1, 2])
s2 = torch.tensor([3, 4])
stacked = stack_tensors([s1, s2], dim=0)
print(f"First tensor s1: {s1}")
print(f"Second tensor s2: {s2}")
print(f"Stacked along dim=0:\n{stacked}")

First tensor s1: tensor([1, 2])
Second tensor s2: tensor([3, 4])
Stacked along dim=0:
tensor([[1, 2],
        [3, 4]])


In [82]:
def repeat_tensor(x: torch.Tensor, repeats: Sequence[int]) -> torch.Tensor:
    """Repeat tensor along each dimension."""
    return x.repeat(*repeats)

# Quick check for repeat_tensor
r = torch.tensor([[1, 2]])
repeated = repeat_tensor(r, (2, 3))
print(f"Original tensor r:\n{r}")
print(f"Repeated with repeats=(2, 3):\n{repeated}")

Original tensor r:
tensor([[1, 2]])
Repeated with repeats=(2, 3):
tensor([[1, 2, 1, 2, 1, 2],
        [1, 2, 1, 2, 1, 2]])


In [83]:
def expand_tensor(x: torch.Tensor, *sizes: int) -> torch.Tensor:
    """Expand tensor to a larger size without copying data.(Sizes can be -1 to keep original dimension.)"""
    return x.expand(*sizes)

# Quick check for expand_tensor
e = torch.tensor([[1], [2], [3]])
expanded = expand_tensor(e, 3, 4)
print(f"Original tensor e:\n{e}")
print(f"Expanded to size (3, 4):\n{expanded}")



Original tensor e:
tensor([[1],
        [2],
        [3]])
Expanded to size (3, 4):
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])


In [84]:
def cumsum_over_dim(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """Cumulative sum along dim."""
    return x.cumsum(dim)

# Quick check for cumsum_over_dim
c = torch.tensor([[1, 2, 3], [4, 5, 6]])
cumsum_dim0 = cumsum_over_dim(c, dim=0)
cumsum_dim1 = cumsum_over_dim(c, dim=1)
print(f"Original tensor c:\n{c}")
print(f"Cumulative sum along dim=0:\n{cumsum_dim0}")
print(f"Cumulative sum along dim=1:\n{cumsum_dim1}")



Original tensor c:
tensor([[1, 2, 3],
        [4, 5, 6]])
Cumulative sum along dim=0:
tensor([[1, 2, 3],
        [5, 7, 9]])
Cumulative sum along dim=1:
tensor([[ 1,  3,  6],
        [ 4,  9, 15]])


In [87]:
def where_select(mask: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Elementwise select: return a where mask is True else b. mask must be broadcastable to a and b."""
    return torch.where(mask, a, b)

# Quick check for where_select
mask = torch.tensor([[True, False], [False, True]])
a = torch.tensor([[1, 2], [3, 4]])
b = torch.zeros_like(a)
result = where_select(mask, a, b)
print(f"Mask:\n{mask}")
print(f"Tensor a:\n{a}")
print(f"Tensor b:\n{b}")
print(f"where_select result:\n{result}")


Mask:
tensor([[ True, False],
        [False,  True]])
Tensor a:
tensor([[1, 2],
        [3, 4]])
Tensor b:
tensor([[0, 0],
        [0, 0]])
where_select result:
tensor([[1, 0],
        [0, 4]])


In [88]:
def one_hot(indices: torch.Tensor, num_classes: int, dtype: torch.dtype | None = None) -> torch.Tensor:
    """
    Create one-hot encodings.
    Output is a tensor of the same shape as indices with an added dimension of size num_classes at the end,
    where the value along that dimension is 1 if it matches the index and 0 otherwise.

    Shapes:
    - indices: (...,) integer tensor
    Return:
    - out: (..., num_classes)

    Requirements:
    - Must work for arbitrary leading shape.
    - No Python loops.
    """
    # Create a tensor of zeros with shape (*indices.shape, num_classes)
    out = torch.zeros(*indices.shape, num_classes, dtype=dtype if dtype is not None else torch.float32, device=indices.device)
    # Use scatter to set 1s at the appropriate positions
    # We need to add a dimension to indices to match the output shape
    out.scatter_(-1, indices.unsqueeze(-1), 1)
    return out

# Quick check for one_hot
indices = torch.tensor([0, 2, 1, 3])
one_hot_result = one_hot(indices, num_classes=4)
print(f"Indices: {indices}")
print(f"One-hot encoding:\n{one_hot_result}")

# Test with 2D indices
indices_2d = torch.tensor([[0, 1], [2, 0]])
one_hot_2d = one_hot(indices_2d, num_classes=3)
print(f"\n2D Indices:\n{indices_2d}")
print(f"One-hot encoding (2D):\n{one_hot_2d}")

Indices: tensor([0, 2, 1, 3])
One-hot encoding:
tensor([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.]])

2D Indices:
tensor([[0, 1],
        [2, 0]])
One-hot encoding (2D):
tensor([[[1., 0., 0.],
         [0., 1., 0.]],

        [[0., 0., 1.],
         [1., 0., 0.]]])


In [None]:
def scatter_add_1d(
    values: torch.Tensor, indices: torch.Tensor, size: int
) -> torch.Tensor:
    """
    Sum `values` into an output vector at positions `indices`.

    Shapes:
    - values: (N,)
    - indices: (N,) integer indices in [0, size)
    Return:
    - out: (size,) with same dtype and device as values

    Requirement:
    - no Python loops
    """
    # Create output tensor of zeros with the specified size
    out = torch.zeros(size, dtype=values.dtype, device=values.device)
    # Use scatter_add_ to accumulate values at the specified indices
    out.scatter_add_(0, indices, values)
    return out

# Quick check for scatter_add_1d
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
indices = torch.tensor([0, 2, 1, 2])
result = scatter_add_1d(values, indices, size=4)
print(f"Values: {values}")
print(f"Indices: {indices}")
print(f"scatter_add_1d result: {result}")
print(f"Expected: [1., 3., 6., 0.] (values at index 2 are summed: 2.0 + 4.0 = 6.0)")

Values: tensor([1., 2., 3., 4.])
Indices: tensor([0, 2, 1, 2])
scatter_add_1d result: tensor([1., 3., 6., 0.])
Expected: [1.0, 3.0, 6.0, 0.0] (values at index 2 are summed: 2.0 + 4.0 = 6.0)


In [93]:
def batched_token_histogram(tokens: torch.Tensor, vocab_size: int) -> torch.Tensor:
    """
    Count token occurrences per batch item.

    Shapes:
    - tokens: (B, T) int64
    Return:
    - counts: (B, vocab_size) where counts[b, v] = number of times token v appears in tokens[b]

    Requirements:
    - No Python loops over B or T.
    """
    B, T = tokens.shape
    # Create output tensor of zeros with shape (B, vocab_size)
    counts = torch.zeros(B, vocab_size, dtype=torch.long, device=tokens.device)

    # Create batch indices for scatter_add_
    # batch_indices: (B, T) where each row is [0, 0, ..., 0], [1, 1, ..., 1], etc.
    batch_indices = torch.arange(B, device=tokens.device).unsqueeze(1).expand(B, T)

    # Flatten everything for scatter_add_
    batch_flat = batch_indices.reshape(-1)  # (B*T,)
    tokens_flat = tokens.reshape(-1)  # (B*T,)
    ones = torch.ones(B * T, dtype=torch.long, device=tokens.device)

    # Use scatter_add_ to count occurrences
    # We need to convert 2D indices (batch, token) into 1D indices
    flat_indices = batch_flat * vocab_size + tokens_flat
    counts_flat = counts.reshape(-1)
    counts_flat.scatter_add_(0, flat_indices, ones)

    return counts.reshape(B, vocab_size)

# Quick check for batched_token_histogram
tokens = torch.tensor([[0, 1, 2, 1], [2, 2, 0, 1]])
vocab_size = 3
result = batched_token_histogram(tokens, vocab_size)
print(f"Tokens:\n{tokens}")
print(f"Vocab size: {vocab_size}")
print(f"Batched token histogram:\n{result}")
print(f"Expected: [[1, 2, 1], [1, 1, 2]] (batch 0: token 0 appears 1x, token 1 appears 2x, token 2 appears 1x)")

Tokens:
tensor([[0, 1, 2, 1],
        [2, 2, 0, 1]])
Vocab size: 3
Batched token histogram:
tensor([[1, 2, 1],
        [1, 1, 2]])
Expected: [[1, 2, 1], [1, 1, 2]] (batch 0: token 0 appears 1x, token 1 appears 2x, token 2 appears 1x)


In [94]:
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Mean over `dim` considering only mask==True entries.

    Convention:
    - mask: bool tensor broadcastable to x
    - mask==True means "keep this entry"

    Return: same shape as x.mean(dim=dim)

    Requirements:
    - Avoid division by zero: if all mask are False along `dim`, define mean as 0.
    """
    # Convert mask to float for multiplication
    mask_float = mask.float()

    # Zero out entries where mask is False
    masked_x = x * mask_float

    # Sum the masked values along the specified dimension
    masked_sum = masked_x.sum(dim=dim)

    # Count how many True entries along the dimension
    count = mask_float.sum(dim=dim)

    # Avoid division by zero: where count is 0, set it to 1 (result will be 0/1 = 0)
    count = torch.where(count == 0, torch.ones_like(count), count)

    # Compute the mean
    result = masked_sum / count

    return result

# Quick check for masked_mean
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[True, True, False], [True, False, False]])
result_dim1 = masked_mean(x, mask, dim=1)
print(f"Input x:\n{x}")
print(f"Mask:\n{mask}")
print(f"Masked mean (dim=1):\n{result_dim1}")
print(f"Expected: [1.5, 4.0] (batch 0: (1+2)/2=1.5, batch 1: 4/1=4.0)")

# Test with all False mask
mask_all_false = torch.tensor([[False, False, False], [True, True, True]])
result_all_false = masked_mean(x, mask_all_false, dim=1)
print(f"\nMasked mean with some all-False rows:\n{result_all_false}")
print(f"Expected: [0.0, 5.0] (batch 0: all masked out -> 0, batch 1: (4+5+6)/3=5.0)")

Input x:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
Mask:
tensor([[ True,  True, False],
        [ True, False, False]])
Masked mean (dim=1):
tensor([1.5000, 4.0000])
Expected: [1.5, 4.0] (batch 0: (1+2)/2=1.5, batch 1: 4/1=4.0)

Masked mean with some all-False rows:
tensor([0., 5.])
Expected: [0.0, 5.0] (batch 0: all masked out -> 0, batch 1: (4+5+6)/3=5.0)


## Einsum warmup
Now that you’re comfortable with shapes and broadcasting, we’ll introduce `torch.einsum`, a concise way to express tensor operations by explicitly naming axes and summing over repeated indices.

### The idea
You describe each input tensor by labeling its dimensions with letters, e.g.
- `x: (B, T, D)` → `"btd"`
- `W: (D, H)`    → `"dh"`

Then you tell einsum what output labels you want:
- `"btd,dh->bth"`

### Rules of einsum
1) **Same letter = same axis** (must match in size, except broadcastable size-1).
2) **Repeated letters are summed over** (a “contraction”).
3) **Letters that appear in the output are kept** (in that order).
4) You can **reorder axes** just by changing the output label order.

### Tiny cheat sheet
- Sum over an axis: `"btd->bt"` (sums over `d`)
- Transpose: `"ij->ji"`
- Dot product: `"d,d->"` or batched `"btd,btd->bt"`
- Matrix multiply: `"ik,kj->ij"`
- Batched matmul: `"bij,bjk->bik"`
- Outer product: `"i,j->ij"`

### How to derive an einsum (recommended workflow)
1) Write down shapes with named axes (e.g. `q: b h t d`, `k: b h s d`).
2) Decide which axes you want to **sum over** (give them the same letter in both inputs).
3) Decide which axes you want to **keep** in the output (write them after `->`).

In this section, you’ll use einsum to implement building blocks that show up in attention:
- linear projections (`x @ W`)
- dot products
- attention score matrices (`QKᵀ`)
- applying attention weights (`softmax(scores) @ V`)

NOTE: For these exercises you are required to use `torch.einsum` not `matmul` (we check). You are also not required to understand the attention mechanism at this point and the exercises are sovable without. It is good however, to remember the implementations in this exercise for future implementations.

In [97]:
def einsum_linear_btd_dh_to_bth(x: torch.Tensor, W: torch.Tensor) -> torch.Tensor:
    """
    Linear projection using einsum.

    Shapes:
    - x: (B, T, D)
    - W: (D, H)
    Return:
    - y: (B, T, H)
    """
    # x has shape (B, T, D) -> label as "btd"
    # W has shape (D, H) -> label as "dh"
    # We want to sum over D (the shared dimension)
    # Output should be (B, T, H) -> label as "bth"
    y = torch.einsum("btd,dh->bth", x, W)

    return y

# Example usage
B, T, D, H = 2, 3, 4, 5
x = torch.randn(B, T, D)
W = torch.randn(D, H)
y = einsum_linear_btd_dh_to_bth(x, W)
print(f"Input x:\n{x}")
print(f"Weight W:\n{W}")
print(f"Output y:\n{y}")


Input x:
tensor([[[-1.2894, -0.6417,  1.0130,  1.5455],
         [-0.4208,  1.7069, -0.8938,  0.5079],
         [ 1.4580,  1.1857,  1.1367,  0.5306]],

        [[-0.5157, -0.1709, -2.5697,  1.0639],
         [ 0.0086,  1.5308, -0.0896, -0.6444],
         [-0.2858,  1.1921, -0.0117, -1.0445]]])
Weight W:
tensor([[ 0.3487,  0.2342,  0.2083,  1.3587, -0.0191],
        [ 0.1549,  0.1116, -2.0587,  0.5292, -0.2444],
        [ 0.3462,  0.0850, -0.7461,  2.7192,  0.5329],
        [-0.8640, -0.4612,  1.9042,  0.9572,  1.3263]])
Output y:
tensor([[[-1.5336, -1.0003,  3.2398,  2.1423,  2.7710],
         [-0.6306, -0.2183, -1.9676, -1.6126, -0.2118],
         [ 0.6272,  0.3257, -1.9750,  6.2074,  0.9919]],

        [[-2.0152, -0.8490,  4.1877, -6.7604,  0.0933],
         [ 0.7659,  0.4624, -4.3099, -0.0388, -1.2767],
         [ 0.9834,  0.5469, -4.4941, -0.7890, -1.6774]]])


In [100]:
def einsum_pairwise_dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Pairwise dot product between x and y.

    Shapes:
    - x: (B, T, D)
    - y: (B, T, D)
    Return:
    - dots: (B, T) where dots[b,t] = dot(x[b,t], y[b,t])
    """
    # x has shape (B, T, D) -> label as "btd"
    # y has shape (B, T, D) -> label as "btd"
    # We want to sum over D (the shared dimension for dot product)
    # Output should be (B, T) -> label as "bt"
    dots = torch.einsum("btd,btd->bt", x, y)

    return dots

# Example
B, T, D = 2, 3, 4
x = torch.randn(B, T, D)
y = torch.randn(B, T, D)
dots = einsum_pairwise_dot(x, y)
print(f"Input x shape: {x.shape}")
print(f"Input x : \n{x}\n")
print(f"Input y shape: {y.shape}")
print(f"Input y: \n{y}\n")
print(f"Output dots shape: {dots.shape}")
print(f"Output dots:\n{dots}")

Input x shape: torch.Size([2, 3, 4])
Input x : 
tensor([[[ 0.9832,  0.2583,  0.7991, -0.4400],
         [-0.4836,  1.7729, -1.6043,  0.6841],
         [-1.3693, -0.2251,  0.1797, -1.8506]],

        [[-0.7692, -1.0373, -0.7382, -0.0743],
         [ 0.3102, -0.9968,  0.9680, -0.6660],
         [-0.4834, -0.5480, -0.4567, -0.4177]]])

Input y shape: torch.Size([2, 3, 4])
Input y: 
tensor([[[ 1.3701, -0.4849,  0.2198,  0.0501],
         [ 0.1318,  2.3606, -0.4911, -0.2938],
         [ 0.8714, -0.8666,  0.5537, -0.5355]],

        [[ 1.1904, -0.3813, -1.4756, -0.7867],
         [ 0.6286,  0.7151, -0.9377, -0.2933],
         [ 1.2918, -0.1248, -0.7089,  0.3701]]])

Output dots shape: torch.Size([2, 3])
Output dots:
tensor([[ 1.3754,  4.7083,  0.0923],
        [ 0.6275, -1.2302, -0.3869]])


In [101]:
def einsum_qk_scores(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    """
    Compute attention scores QK^T using einsum.

    Shapes:
    - q: (B, H, T, Dh)
    - k: (B, H, T, Dh)
    Return:
    - scores: (B, H, T, T) where scores[b,h,i,j] = dot(q[b,h,i], k[b,h,j])
    """
    # q has shape (B, H, T, Dh) -> label as "bhtd"
    # k has shape (B, H, T, Dh) -> label as "bhsd" (use 's' for the second T dimension)
    # We want to compute dot product over Dh dimension
    # Output should be (B, H, T, T) -> label as "bhts"
    scores = torch.einsum("bhtd,bhsd->bhts", q, k)

    return scores

# Example usage
B, H, T, Dh = 2, 3, 4, 5
q = torch.randn(B, H, T, Dh)
k = torch.randn(B, H, T, Dh)
scores = einsum_qk_scores(q, k)
print(f"Query q shape: {q.shape}")
print(f"Key k shape: {k.shape}")
print(f"Scores shape: {scores.shape}")
print(f"Scores:\n{scores}")
print(f"\nVerification - manual dot product for first batch, first head, first query:")
manual_score = torch.sum(q[0, 0, 0, :] * k[0, 0, 0, :])
print(f"Manual: {manual_score}")
print(f"Einsum: {scores[0, 0, 0, 0]}")

Query q shape: torch.Size([2, 3, 4, 5])
Key k shape: torch.Size([2, 3, 4, 5])
Scores shape: torch.Size([2, 3, 4, 4])
Scores:
tensor([[[[-2.1125, -1.2437, -1.3071,  0.7863],
          [-1.2806,  2.0429, -2.7935, -2.8358],
          [ 1.9617, -3.2367,  1.3090, -0.4283],
          [-4.0495, -0.6375, -2.5105,  5.6523]],

         [[ 1.2895,  0.0360,  2.2279, -0.6961],
          [ 0.6300, -1.7143,  2.7321,  0.4183],
          [-1.7585, -0.3796, -2.4269, -1.2258],
          [ 0.6286,  1.4669,  1.3121, -0.2500]],

         [[-0.8792,  2.6554,  2.1585, -3.6343],
          [-0.4860, -4.2922, -3.5201, -1.7220],
          [-0.2678, -2.5024, -0.9727, -2.7100],
          [ 1.6127,  4.1782,  2.2001,  2.7633]]],


        [[[ 0.4462, -0.8451,  1.9368, -4.6323],
          [ 3.3634,  0.1512,  2.2604, -0.7003],
          [ 2.2317,  0.3510, -0.2277,  2.8992],
          [ 1.4594, -2.1479,  0.7099,  1.3697]],

         [[ 2.8657, -0.2721,  0.9616,  0.9559],
          [ 0.6156, -0.8614, -1.1999,  0.8588],
 

In [None]:
def einsum_apply_attention(weights: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """
    Apply attention weights to values using einsum.

    Shapes:
    - weights: (B, H, T, T)
    - v:       (B, H, T, Dh)
    Return:
    - out:     (B, H, T, Dh) where out[b,h,i] = sum_j weights[b,h,i,j] * v[b,h,j]
    """
    # TODO: implement
    raise NotImplementedError

## Attention Fundamentals
This exercise introduces some building blocks of the attention mechanism which we will encounter extensively throughout the course. It's not yet required for you to fully understand the mechanism to implement the exercises. However, it's good to remember these building blocks for the future. 

To complete the exercises you should familiarize yourself with these topics:
- Stable softmax read: https://jaykmody.com/blog/stable-softmax/
- Masking: typically this means setting masked logits to -inf *before* softmax.
- For attention: causal masks are upper-triangular (no attending to the future).

In [None]:
def stable_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Numerically stable softmax along `dim`.

    Requirements:
    - Must not overflow for large values in x.
    - Output sums to 1 along `dim`.
    """
    # TODO: implement
    raise NotImplementedError

In [None]:
def masked_fill_tensor(x: torch.Tensor, mask: torch.Tensor, value: float) -> torch.Tensor:
    """
    Return a copy of x where positions with mask == True are replaced by `value`.

    Requirements:
    - mask must be broadcastable to x.
    - do NOT modify x in-place.
    """
    # TODO: implement
    raise NotImplementedError

In [None]:
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Softmax over x with a boolean mask.

    Convention:
    - mask == True means "invalid and must receive probability 0".
    - Do masking before softmax (i.e., set invalid logits to a large negative).”

    Requirements:
    - Must be numerically stable.
    - Output must be exactly 0 where mask==True.
    - If all entries are masked along `dim`, return all zeros along `dim`.
    - You may reuse functions you implemented above.
    """
    # TODO: implement using masked_fill_tensor + stable_softmax
    raise NotImplementedError


In [None]:
def make_causal_mask(T: int, device: torch.device | str | None = None) -> torch.Tensor:
    """
    Create a causal (future-masking) boolean mask of shape (T, T).

    Convention:
    - mask[i, j] == True  => position (i attends to j) is NOT allowed (j is in the future)
    - mask[i, j] == False => allowed

    So this is an upper-triangular mask above the diagonal.

    Return:
    - mask: boolean tensor on the specified device

    Example (T=4):
        [[F, T, T, T],
         [F, F, T, T],
         [F, F, F, T],
         [F, F, F, F]]
    """
    # TODO: implement
    raise NotImplementedError

In [None]:
def apply_causal_mask(attn_logits: torch.Tensor, value: float = -1e9) -> torch.Tensor:
    """
    Apply a causal mask to attention logits.

    Expected shapes:
    - attn_logits: (..., T, T)

    Returns:
    - masked logits (same shape) where masked positions have been set to `value`.

    Notes:
    - Create a causal mask for the final two dims.
    - Broadcast it across leading dims.
    - You may reuse functions declared above.
    """
    # TODO: implement using make_causal_mask + masked_fill_tensor
    raise NotImplementedError