# Exercise 2: PyTorch core

In this exercise you’ll build core PyTorch “muscle memory” that you’ll reuse in basically every model you write:

- **Autograd**: how gradients are created, how they accumulate, and how to compute gradients for one or multiple inputs.
- **Dataloading**: writing small `Dataset`s, using `DataLoader`, and custom `collate_fn`.
- **Optimizers**: implementing **AdamW** updates from scratch (state, bias correction, weight decay).
- **Training basics**: a clean single training step.
- **Initialization**: fan-in/out and common initializers (Xavier / Kaiming), plus a helper to init `nn.Linear`.

As before: fill in all `TODO`s without changing function names or signatures.
When debugging, print shapes/dtypes/devices, and write tiny sanity checks (e.g. compare to PyTorch’s built-ins).


In [1]:
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn

  cpu = _conversion_method_template(device=torch.device("cpu"))


## Autograd fundamentals

PyTorch builds a computation graph when you apply operations to tensors with `requires_grad=True`.
Calling `backward()` (or `torch.autograd.grad`) computes gradients by traversing that graph.

### Key concepts
- **Leaf tensor**: a tensor created by you (not the result of an operation) with `requires_grad=True`. Leaf tensors can store gradients in `.grad`.
- **Gradient accumulation**: calling `backward()` adds into `.grad` (it does not overwrite). You must reset gradients between steps/calls.
- **`torch.autograd.grad` vs `.backward()`**
  - `torch.autograd.grad(f, x)` returns `df/dx` directly and does not write into `x.grad` unless you explicitly do so.
  - `f.backward()` writes gradients into `.grad` of leaf tensors.

In the next functions you’ll compute gradients for a simple scalar function such as `f(x) = sum(x^2)` using both APIs.

### `torch.no_grad()`
Wrap inference-only code to avoid tracking gradients and building graphs:
- saves memory
- speeds up evaluation

### `detach()`
`y = x.detach()` returns a tensor that shares data with `x` but is **not connected** to the autograd graph.
This is useful when you want to treat something as a constant target.

### `model.train()` vs `model.eval()`
- `train()` enables training behavior (e.g. dropout active, batchnorm updates running stats).
- `eval()` enables inference behavior (e.g. dropout off, batchnorm uses running stats).

In [7]:
def grad_with_autograd_grad(x: torch.Tensor) -> torch.Tensor:
    """
    Compute gradient of f(x) = sum(x^2) using torch.autograd.grad

    Requirements:
    - Do not call .backward().
    - x should require grad inside the function (don't assume it does).
    - Must return df/dx
    """
    # Ensure x requires grad
    if not x.requires_grad:
        x.requires_grad_(True)

    # Define the function f(x) = sum(x^2)
    f = torch.sum(x ** 2)

    # Compute the gradient df/dx using torch.autograd.grad
    grad = torch.autograd.grad(f, x, create_graph=True)[0]

    return grad


# Test grad_with_autograd_grad
x_test = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
grad_result = grad_with_autograd_grad(x_test)
print("Testing grad_with_autograd_grad:")
print(f"Input x: {x_test}")
print(f"Gradient df/dx: {grad_result}")
print(f"Expected: {2 * x_test.detach()}")
print()



Testing grad_with_autograd_grad:
Input x: tensor([1., 2., 3.], requires_grad=True)
Gradient df/dx: tensor([2., 4., 6.], grad_fn=<MulBackward0>)
Expected: tensor([2., 4., 6.])



In [8]:

def grad_with_backward(x: torch.Tensor) -> torch.Tensor:
    """
    Compute gradient of f(x) = sum(x^2) using .backward().

    Requirements:
    - Must return df/dx
    - Must not leak gradients across calls (watch x.grad accumulation)
    """
    # Ensure x requires grad
    if not x.requires_grad:
        x.requires_grad_(True)

    # Define the function f(x) = sum(x^2)
    f = torch.sum(x ** 2)

    # Compute the gradient df/dx using torch.backward()
    f.backward()
    # Get the gradient from x.grad
    grad = x.grad.clone()  # Clone to avoid in-place modification
    # Zero out the gradients to prevent accumulation
    x.grad.zero_()
    return grad

# Test grad_with_backward
x_test = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
grad_result = grad_with_backward(x_test)
print("Testing grad_with_backward:")
print(f"Input x: {x_test}")
print(f"Gradient df/dx: {grad_result}")
print(f"Expected: {2 * x_test.detach()}")


Testing grad_with_backward:
Input x: tensor([1., 2., 3.], requires_grad=True)
Gradient df/dx: tensor([2., 4., 6.])
Expected: tensor([2., 4., 6.])


In [13]:
def grad_wrt_multiple_inputs(
    a: torch.Tensor, b: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute gradients w.r.t. multiple inputs. The function is f(a, b) = sum(a^2 + ab).

    Return:
        (df/da, df/db)

    Requirements:
    - Use torch.autograd.grad
    - Ensure both a and b require grad in this function.
    """
    # Ensure both a and b require grad
    if not a.requires_grad:
        a.requires_grad_(True)
    if not b.requires_grad:
        b.requires_grad_(True)

    # Define the function f(a, b) = sum(a^2 + ab)
    f = torch.sum(a ** 2 + a * b)

    # Compute the gradients df/da and df/db using torch.autograd.grad
    grad_a, grad_b = torch.autograd.grad(f, (a, b), create_graph=True)

    return grad_a, grad_b


# Test grad_wrt_multiple_inputs
a_test = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
b_test = torch.tensor([0.5, 1.0, 1.5], requires_grad=True)

grad_result = grad_wrt_multiple_inputs(a_test, b_test)
print("Testing grad_wrt_multiple_inputs:")
print(f"Input a: {a_test}")
print(f"Input b: {b_test}")
print()
print(f"Gradient df/da: {grad_result[0]}")
print(f"Gradient df/db: {grad_result[1]}")
print()
print(f"Expected df/da: {2 * a_test.detach() + b_test.detach()}")
print(f"Expected df/db: {a_test.detach()}")


Testing grad_wrt_multiple_inputs:
Input a: tensor([1., 2., 3.], requires_grad=True)
Input b: tensor([0.5000, 1.0000, 1.5000], requires_grad=True)

Gradient df/da: tensor([2.5000, 5.0000, 7.5000], grad_fn=<AddBackward0>)
Gradient df/db: tensor([1., 2., 3.], grad_fn=<MulBackward0>)

Expected df/da: tensor([2.5000, 5.0000, 7.5000])
Expected df/db: tensor([1., 2., 3.])


## Dataloading

In PyTorch, a `Dataset` defines how to fetch a *single* training example, and a `DataLoader` handles:
- batching
- shuffling
- parallel workers
- optional custom batching logic via `collate_fn`

### `Dataset` in one sentence
A `Dataset` only needs:
- `__len__`: number of items
- `__getitem__`: return one item (e.g. `(x, y)`)

### Why `collate_fn` matters
The default DataLoader collation stacks items along a new batch dimension.
That works for fixed-size tensors, but it breaks for **variable-length sequences**.

So we’ll implement padding ourselves:
- Convert a list of 1D token sequences into a padded tensor `(B, T_max)`
- Track `lengths` and a `padding_mask`

### Mask convention for padding
For padding masks in this exercise:
- `padding_mask[b, t] == True` means **this is padding / invalid**
- `padding_mask[b, t] == False` means **this is a real token**

In [14]:
from torch.utils.data import DataLoader, Dataset

In [17]:
class TensorPairDataset(Dataset):
    """
    Minimal dataset wrapping (x, y).

    x: (N, ...)
    y: (N, ...)

    N is the number of samples. The dataset should return tuples of (x[i], y[i]).
    """

    def __init__(self, x: torch.Tensor, y: torch.Tensor):
        assert len(x) == len(y), "x and y must have the same number of samples"
        self.x = x
        self.y = y

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        return self.x[idx], self.y[idx]

In [None]:
class NextTokenDataset(Dataset):
    """
    Next-token prediction dataset.

    Given tokens of shape (N, T), produce:
      input_ids  = tokens[:, :-1]
      target_ids = tokens[:, 1:]

    Return per item:
      (input_ids, target_ids)

    Notes:
    - Returned tensors should be 1D of length (T-1).
    - dtype should remain integer.
    """

    def __init__(self, tokens: torch.Tensor):
        self.tokens = tokens

    def __len__(self) -> int:
        return len(self.tokens)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        sequence = self.tokens[idx]
        input_ids = sequence[:-1]
        target_ids = sequence[1:]
        return input_ids, target_ids

# Test NextTokenDataset
tokens_test = torch.tensor([[1,2,3,4,5,6], [7,8,9,10,11,12]])
dataset = NextTokenDataset(tokens_test)
print("Testing NextTokenDataset:")
for i in range(len(dataset)):
    input_ids, target_ids = dataset[i]
    print(f"Sample {i}:")
    print(f"Input IDs: {input_ids}")
    print(f"Target IDs: {target_ids}")
    print()

Testing NextTokenDataset:
Sample 0:
Input IDs: tensor([1, 2, 3, 4, 5])
Target IDs: tensor([2, 3, 4, 5, 6])

Sample 1:
Input IDs: tensor([ 7,  8,  9, 10, 11])
Target IDs: tensor([ 8,  9, 10, 11, 12])



In [None]:

class RandomCropSequenceDataset(Dataset):
    """
    Sequence dataset that returns random crops of fixed length.

    tokens: (N, T_total)
    crop_len: L

    For each __getitem__:
      - sample a start index s so that s+L <= T_total
      - return tokens[idx, s:s+L]

    Requirements:
    - Use a torch.Generator for deterministic behavior if seed is provided.
    - Do NOT use Python's random module.
    """

    def __init__(self, tokens: torch.Tensor, crop_len: int, seed: int | None = None):
        self.tokens = tokens
        self.crop_len = crop_len
        self.generator = torch.Generator()
        if seed is not None:
            self.generator.manual_seed(seed)


    def __len__(self) -> int:
        return len(self.tokens)

    def __getitem__(self, idx: int) -> torch.Tensor:
        sequence = self.tokens[idx]
        T_total = sequence.size(0)
        if T_total <= self.crop_len:
            raise ValueError(f"Crop length {self.crop_len} must be less than sequence length {T_total}")
        max_start = T_total - self.crop_len
        start_idx = torch.randint(0, max_start + 1, (1,), generator=self.generator).item()
        crop = sequence[start_idx:start_idx + self.crop_len]
        return crop

# Test RandomCropSequenceDataset
tokens_test = torch.tensor([[1,2,3,4,5,6], [7,8,9,10,11,12]])
crop_len = 3
dataset = RandomCropSequenceDataset(tokens_test, crop_len, seed=42)
print("Testing RandomCropSequenceDataset:")
for i in range(len(dataset)):
    crop = dataset[i]
    print(f"Sample {i}: Crop: {crop}")


Testing RandomCropSequenceDataset:
Sample 0: Crop: tensor([3, 4, 5])
Sample 1: Crop: tensor([10, 11, 12])


In [22]:


@dataclass(frozen=True)
class PaddedBatch:
    """
    A padded batch for variable-length sequences.

    tokens: LongTensor (B, T_max)
    lengths: LongTensor (B,)
    padding_mask: BoolTensor (B, T_max) where True means "this is padding"
    """

    tokens: torch.Tensor
    lengths: torch.Tensor
    padding_mask: torch.Tensor


def pad_1d_sequences(seqs: list[torch.Tensor], pad_value: int = 0) -> PaddedBatch:
    """
    Pad a list of 1D integer tensors to the same length.

    Requirements:
    - Return PaddedBatch(tokens, lengths, padding_mask)
    - padding_mask[b, t] == True iff t >= lengths[b]
    - tokens should be dtype long, if not cast them
    """
    batch_size = len(seqs)
    lengths = torch.tensor([len(seq) for seq in seqs], dtype=torch.long)
    max_length = lengths.max().item()

    # Initialize padded tokens and padding mask
    padded_tokens = torch.full((batch_size, max_length), pad_value, dtype=torch.long)
    padding_mask = torch.ones((batch_size, max_length), dtype=torch.bool)

    for i, seq in enumerate(seqs):
        seq_len = lengths[i].item()
        padded_tokens[i, :seq_len] = seq
        padding_mask[i, :seq_len] = False  # Mark non-padding positions

    return PaddedBatch(tokens=padded_tokens, lengths=lengths, padding_mask=padding_mask)

# Test pad_1d_sequences
seqs_test = [torch.tensor([1,2,3]), torch.tensor([4,5]), torch.tensor([6])]
padded_batch = pad_1d_sequences(seqs_test, pad_value=0)
print("Testing pad_1d_sequences:")
print(f"Tokens:\n{padded_batch.tokens}")
print(f"Lengths:\n{padded_batch.lengths}")
print(f"Padding Mask:\n{padded_batch.padding_mask}")


Testing pad_1d_sequences:
Tokens:
tensor([[1, 2, 3],
        [4, 5, 0],
        [6, 0, 0]])
Lengths:
tensor([3, 2, 1])
Padding Mask:
tensor([[False, False, False],
        [False, False,  True],
        [False,  True,  True]])


In [23]:
def collate_next_token_batch(
    batch: list[tuple[torch.Tensor, torch.Tensor]], pad_value: int = 0
) -> dict[str, torch.Tensor]:
    """
    Collate for NextTokenDataset samples that may have variable lengths.

    batch: list of (input_ids, target_ids), each 1D

    Return dict with:
      - input_ids: (B, T_max)
      - target_ids: (B, T_max)
      - attention_mask: (B, T_max) where True means "keep" (NOT padding)
      - padding_mask: (B, T_max) where True means "padding"

    Requirements:
    - pad input_ids and target_ids consistently
    - attention_mask is the logical NOT of padding_mask
    """
    # Separate input_ids and target_ids from the batch
    input_ids_list = [item[0] for item in batch]
    target_ids_list = [item[1] for item in batch]

    # Pad both input_ids and target_ids using the pad_1d_sequences function
    padded_inputs = pad_1d_sequences(input_ids_list, pad_value=pad_value)
    padded_targets = pad_1d_sequences(target_ids_list, pad_value=pad_value)

    # Create attention_mask as the logical NOT of padding_mask
    attention_mask = ~padded_inputs.padding_mask

    return {
        "input_ids": padded_inputs.tokens,
        "target_ids": padded_targets.tokens,
        "attention_mask": attention_mask,
        "padding_mask": padded_inputs.padding_mask,
    }

# Test collate_next_token_batch
batch_test = [
    (torch.tensor([1,2,3]), torch.tensor([2,3,4])),
    (torch.tensor([4,5]), torch.tensor([5,6])),
    (torch.tensor([6]), torch.tensor([7])),
]
collated = collate_next_token_batch(batch_test, pad_value=0)
print("Testing collate_next_token_batch:")
print(f"Input IDs:\n{collated['input_ids']}")
print(f"Target IDs:\n{collated['target_ids']}")
print(f"Attention Mask:\n{collated['attention_mask']}")
print(f"Padding Mask:\n{collated['padding_mask']}")


Testing collate_next_token_batch:
Input IDs:
tensor([[1, 2, 3],
        [4, 5, 0],
        [6, 0, 0]])
Target IDs:
tensor([[2, 3, 4],
        [5, 6, 0],
        [7, 0, 0]])
Attention Mask:
tensor([[ True,  True,  True],
        [ True,  True, False],
        [ True, False, False]])
Padding Mask:
tensor([[False, False, False],
        [False, False,  True],
        [False,  True,  True]])


In [None]:
def make_dataloader(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool = True,
    drop_last: bool = False,
    collate_fn=None,
    num_workers: int = 0,
) -> DataLoader:
    """
    Create a DataLoader with optional collate_fn.
    """
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

# Test make_dataloader with NextTokenDataset and collate_next_token_batch
# Create a list of 1D tensors
tokens_test = [torch.tensor([1,2,3,4]), torch.tensor([5,6,7]), torch.tensor([8,9])]
dataset = NextTokenDataset(tokens_test)
dataloader = make_dataloader(dataset, batch_size=2, shuffle=False, collate_fn=collate_next_token_batch)
print("Testing make_dataloader with NextTokenDataset:")
for batch in dataloader:
    print(batch)


Testing make_dataloader with NextTokenDataset:
{'input_ids': tensor([[1, 2, 3],
        [5, 6, 0]]), 'target_ids': tensor([[2, 3, 4],
        [6, 7, 0]]), 'attention_mask': tensor([[ True,  True,  True],
        [ True,  True, False]]), 'padding_mask': tensor([[False, False, False],
        [False, False,  True]])}
{'input_ids': tensor([[8]]), 'target_ids': tensor([[9]]), 'attention_mask': tensor([[True]]), 'padding_mask': tensor([[False]])}


## Optimizers (AdamW from scratch)

PyTorch optimizers keep **state** for each parameter (e.g. moment estimates in Adam).
In this section you’ll implement **AdamW**, which is Adam + *decoupled* weight decay.

### AdamW state
For each parameter tensor `p` we store:
- `m`: first moment (EMA of gradients)
- `v`: second moment (EMA of squared gradients)
- `t`: step counter

### Update overview (high level)
1) Update moments `m, v`
2) Bias-correct them (`m_hat, v_hat`)
3) Apply parameter update:
   `p -= lr * ( m_hat / (sqrt(v_hat) + eps) + weight_decay * p )`

Notes:
- This update is **in-place** (mutates `p`).
- Gradients should not be modified.
- State tensors must match parameter shape/device/dtype.

In [29]:
from this import d


@dataclass
class AdamWState:
    """
    Per-parameter AdamW state.

    m: first moment
    v: second moment
    t: step count
    """

    m: torch.Tensor
    v: torch.Tensor
    t: int


def init_adamw_state(p: torch.Tensor) -> AdamWState:
    """
    Initialize AdamW state tensors for a parameter tensor p.

    What to create:
    - m: zeros like p, same shape/device/dtype
    - v: zeros like p, same shape/device/dtype
    - t: step counter starting at 0

    Notes / requirements:
    - Use torch.zeros_like(p) for m and v.
    - Do NOT attach gradients to the state (initialize under torch.no_grad()).
    - t starts at 0. In adamw_step_, increment t to 1 on the first update *before*
      computing bias correction terms (1 - beta1^t) and (1 - beta2^t).
    - State tensors must live on the same device as p (CPU vs GPU) and have the
      same dtype as p.
    """
    with torch.no_grad():
        m = torch.zeros_like(p, device=p.device, dtype=p.dtype)
        v = torch.zeros_like(p, device=p.device, dtype=p.dtype)
    t = 0
    return AdamWState(m=m, v=v, t=t)

The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!


In [30]:
def adamw_step_(
    p: torch.Tensor,
    grad: torch.Tensor,
    state: AdamWState,
    lr: float,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.01,
) -> AdamWState:
    """
    In-place AdamW parameter update (updates p).

    Algorithm (AdamW):
      m = beta1*m + (1-beta1)*grad
      v = beta2*v + (1-beta2)*grad^2
      m_hat = m / (1 - beta1^t)
      v_hat = v / (1 - beta2^t)
      p = p - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * p)

    Requirements:
    - Update p in-place.
    - Return updated state (with incremented t).
    - Do not modify grad.
    - Should work for any tensor shape.
    """
    t= state.t + 1  # Increment step count
    m = state.m
    v = state.v
    beta1, beta2 = betas

    # Update first moment
    m.mul_(beta1).add_(grad, alpha=1 - beta1)
    # Update second moment
    v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

    # Compute bias-corrected moments
    m_hat = m / (1 - beta1 ** t)
    v_hat = v / (1 - beta2 ** t)

    # Update parameters in-place
    p -= lr * (m_hat / (torch.sqrt(v_hat) + eps) + weight_decay * p)

    return AdamWState(m=m, v=v, t=t)


In [31]:
def adamw_step_many_(
    params: list[torch.Tensor],
    grads: list[torch.Tensor],
    states: list[AdamWState],
    lr: float,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.01,
) -> list[AdamWState]:
    """
    Apply AdamW to many parameters.

    Requirements:
    - len(params) == len(grads) == len(states)
    - Update each param in-place.
    - Return the list of updated states.
    """
    updated_states = []
    for p, grad, state in zip(params, grads, states):
        updated_state = adamw_step_(p, grad, state, lr, betas, eps, weight_decay)
        updated_states.append(updated_state)
    return updated_states


## Training basics

A minimal training step follows the same pattern almost everywhere:

1) set model to train mode
2) reset gradients
3) forward pass
4) compute loss
5) backward pass
6) step optimizer

In this exercise you’ll implement a single MSE training step using a standard PyTorch optimizer.
Return a Python float loss value.

In [32]:
def train_step_mse(
    model: nn.Module,
    batch: tuple[torch.Tensor, torch.Tensor],
    optimizer: torch.optim.Optimizer,
) -> float:
    """
    One MSE train step using standard torch optimizer.
    """
    model.train()
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = nn.functional.mse_loss(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()




## Parameter initialization

Initialization matters because it controls signal and gradient scales at the start of training.

### Fan-in / fan-out
- `fan_in`: number of input connections to a unit
- `fan_out`: number of output connections from a unit

For a Linear layer weight of shape `(out_features, in_features)`:
- `fan_in = in_features`
- `fan_out = out_features`

### Common schemes
- **Xavier / Glorot** (often good for tanh / linear-ish nets):
  keeps variance stable across layers when activations are roughly symmetric.
- **Kaiming / He** (often good for ReLU-like nets):
  accounts for the fact that ReLU zeroes out about half the inputs.

In this section you’ll implement Xavier uniform and Kaiming uniform and use them to initialize `nn.Linear`.
We also always zero the bias unless explicitly told otherwise.

In [33]:
def fan_in_fan_out(weight: torch.Tensor) -> tuple[int, int]:
    """Compute (fan_in, fan_out) for a weight tensor."""
    if weight.ndim < 2:
        raise ValueError("Weight tensor must have at least 2 dimensions")
    fan_in = weight.size(1) * weight[0][0].numel()  # Product of dimensions except the first
    fan_out = weight.size(0) * weight[0][0].numel()  # Product of dimensions except the second
    return fan_in, fan_out

# Test fan_in_fan_out
weight_test = torch.randn(64, 128)  # Example weight tensor
fan_in, fan_out = fan_in_fan_out(weight_test)
print("Testing fan_in_fan_out:")
print(f"Weight shape: {weight_test.shape}")
print(f"Fan-in: {fan_in}")
print(f"Fan-out: {fan_out}")



Testing fan_in_fan_out:
Weight shape: torch.Size([64, 128])
Fan-in: 128
Fan-out: 64


In [35]:
def xavier_uniform_(weight: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
    """
    In-place Xavier/Glorot uniform init:
      bound = gain * sqrt(6 / (fan_in + fan_out))
      U(-bound, bound)
    """
    fan_in, fan_out = fan_in_fan_out(weight)
    bound = gain * (6 / (fan_in + fan_out)) ** 0.5
    with torch.no_grad():
        return weight.uniform_(-bound, bound)

# Test xavier_uniform_
weight_test = torch.empty(64, 128)  # Example weight tensor
xavier_uniform_(weight_test, gain=1.0)
print("Testing xavier_uniform_:")
print(f"Weight shape: {weight_test.shape}")
print(f"Weight stats: mean={weight_test.mean().item():.4f}, std={weight_test.std().item():.4f}")


Testing xavier_uniform_:
Weight shape: torch.Size([64, 128])
Weight stats: mean=0.0004, std=0.1027


In [37]:
def kaiming_uniform_(weight: torch.Tensor, nonlinearity: str = "relu") -> torch.Tensor:
    """
    In-place Kaiming/He uniform init.

    Follow this common choice:
      gain = sqrt(2) for ReLU
      std = gain / sqrt(fan_in)
      bound = sqrt(3) * std
      U(-bound, bound)
    """
    fan_in, fan_out = fan_in_fan_out(weight)
    gain = 2 ** 0.5 if nonlinearity == "relu" else 1.0
    std = gain / (fan_in ** 0.5)
    bound = (3 ** 0.5) * std
    with torch.no_grad():
        return weight.uniform_(-bound, bound)

# Test kaiming_uniform_
weight_test = torch.empty(64, 128)  # Example weight tensor
kaiming_uniform_(weight_test, nonlinearity= "relu")
print("Testing kaiming_uniform_:")
print(f"Weight shape: {weight_test.shape}")
print(f"Weight stats: mean={weight_test.mean().item():.4f}, std={weight_test.std().item():.4f}")


Testing kaiming_uniform_:
Weight shape: torch.Size([64, 128])
Weight stats: mean=0.0014, std=0.1247


In [38]:
def init_linear_(layer: nn.Linear, scheme: str = "xavier") -> nn.Linear:
    """
    Initialize an nn.Linear in-place.

    scheme:
      - "xavier"
      - "kaiming_relu"
      - "zero" (weights and bias = 0)
    """
    if scheme == "xavier":
        xavier_uniform_(layer.weight, gain=1.0)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    elif scheme == "kaiming_relu":
        kaiming_uniform_(layer.weight, nonlinearity="relu")
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    elif scheme == "zero":
        nn.init.zeros_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    else:
        raise ValueError(f"Unknown initialization scheme: {scheme}")
    return layer

# Test init_linear_
linear_test = nn.Linear(128, 64)
init_linear_(linear_test, scheme="xavier")
print("Testing init_linear_ with Xavier:")
print(f"Weight stats: mean={linear_test.weight.mean().item():.4f}, std={linear_test.weight.std().item():.4f}")
init_linear_(linear_test, scheme="kaiming_relu")
print("Testing init_linear_ with Kaiming ReLU:")
print(f"Weight stats: mean={linear_test.weight.mean().item():.4f}, std={linear_test.weight.std().item():.4f}")
init_linear_(linear_test, scheme="zero")
print("Testing init_linear_ with Zero:")
print(f"Weight stats: mean={linear_test.weight.mean().item():.4f}, std={linear_test.weight.std().item():.4f}")


Testing init_linear_ with Xavier:
Weight stats: mean=-0.0003, std=0.1028
Testing init_linear_ with Kaiming ReLU:
Weight stats: mean=0.0025, std=0.1245
Testing init_linear_ with Zero:
Weight stats: mean=0.0000, std=0.0000
