# Exercise 3: Neural networks in PyTorch

In this exercise you’ll implement small neural-network building blocks from scratch and use them to train a simple classifier.

You’ll cover:
- **Basic layers**: Linear, Embedding, Dropout
- **Normalization**: LayerNorm and RMSNorm
- **MLPs + residual**: composing layers into deeper networks
- **Classification**: generating a learnable dataset, implementing cross-entropy from logits, and writing a minimal training loop

As before: fill in all `TODO`s without changing function names or signatures.
Use small sanity checks and compare to PyTorch reference implementations when useful.

In [1]:
from __future__ import annotations

import torch
from torch import nn

## Basic layers

In this section you’ll implement a few core layers that appear everywhere:

### `Linear`
A fully-connected layer that follows nn.Linear conventions:  
`y = x @ Wᵀ + b`

Important details:
- Parameters should be registered as `nn.Parameter`
- Store weight as (out_features, in_features) like nn.Linear.
- The forward pass should support leading batch dimensions: `x` can be shape `(..., in_features)`

### `Embedding`
An embedding table maps integer ids to vectors:
- input: token ids `idx` of shape `(...,)`
- output: vectors of shape `(..., embedding_dim)`

This is essentially a learnable lookup table.

### `Dropout`
Dropout randomly zeroes activations during training to reduce overfitting.
Implementation details:
- Only active in `model.train()` mode
- In training: drop with probability `p` and scale the kept values by `1/(1-p)` so the expected value stays the same
- In eval: return the input unchanged

## Instructions
- Do not use PyTorch reference modules for the parts you implement (e.g. don’t call nn.Linear inside your Linear).
- You may use standard tensor ops that you learned before (matmul, sum, mean, rsqrt, indexing, etc.).
- Use a parameter initialization method of your choice. We recommend something like Xavier-uniform.


In [21]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        weight = torch.zeros(size=(out_features, in_features))
        torch.nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight)
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros([out_features]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (..., in_features)
        return: (..., out_features)
        """
        result = x@self.weight.T
        if self.bias is not None:
            result += self.bias
        return result
        
Linear(5,4).forward(torch.ones(1,5))


tensor([[ 1.1575, -0.6848,  0.0579,  0.2306]], grad_fn=<AddBackward0>)

In [23]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(num_embeddings, embedding_dim))
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """
        idx: (...,) int64
        return: (..., embedding_dim)
        """
        return self.weight[idx]

Embedding(10,3).forward(torch.tensor([[[0,1,1],[0,1,2]],[[0,1,1],[0,1,2]]]))

tensor([[[[-0.5734,  0.0258, -0.3688],
          [ 0.4198,  0.5900,  0.0508],
          [ 0.4198,  0.5900,  0.0508]],

         [[-0.5734,  0.0258, -0.3688],
          [ 0.4198,  0.5900,  0.0508],
          [ 0.4743,  0.0624,  0.2419]]],


        [[[-0.5734,  0.0258, -0.3688],
          [ 0.4198,  0.5900,  0.0508],
          [ 0.4198,  0.5900,  0.0508]],

         [[-0.5734,  0.0258, -0.3688],
          [ 0.4198,  0.5900,  0.0508],
          [ 0.4743,  0.0624,  0.2419]]]], grad_fn=<IndexBackward0>)

In [43]:
class Dropout(nn.Module):
    def __init__(self, p: float):
        super().__init__()
        self.p = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        In train mode: drop with prob p and scale by 1/(1-p).
        In eval mode: return x unchanged.
        """
        if not self.training:
            return x
        prob = torch.zeros_like(x)
        prob.uniform_(0,1)
        return x.where(prob > self.p, 0) * (1/(1-self.p))

dropout = Dropout(0.3)
dropout.train()
dropout.forward(torch.ones(1,5,5))
        

tensor([[[1.4286, 1.4286, 1.4286, 1.4286, 0.0000],
         [0.0000, 1.4286, 1.4286, 1.4286, 0.0000],
         [0.0000, 1.4286, 0.0000, 1.4286, 0.0000],
         [1.4286, 1.4286, 1.4286, 0.0000, 1.4286],
         [1.4286, 1.4286, 1.4286, 1.4286, 1.4286]]])

## Normalization

Normalization layers help stabilize training by controlling activation statistics.

### LayerNorm
LayerNorm normalizes each example across its **feature dimension** (the last dimension):

- compute mean and variance over the last dimension
- normalize: `(x - mean) / sqrt(var + eps)`
- apply learnable per-feature scale and shift (`weight`, `bias`)

**In this exercise, assume `elementwise_affine=True` (always include `weight` and `bias`).**  
`weight` and `bias` each have shape `(D,)`.

LayerNorm is widely used in transformers because it does not depend on batch statistics.

### RMSNorm
RMSNorm is similar to LayerNorm but normalizes using only the root-mean-square:
- `x / sqrt(mean(x^2) + eps)` over the last dimension
- usually includes a learnable scale (`weight`)
- no mean subtraction

RMSNorm is popular in modern LLMs because it's faster.


In [54]:
class LayerNorm(nn.Module):
    def __init__(
        self, normalized_shape: int, eps: float = 1e-5, elementwise_affine: bool = True):
        super().__init__()
        self.eps = eps
        self.normalized_shape = normalized_shape
        self.weight = nn.Parameter(torch.ones([normalized_shape]))
        self.bias = nn.Parameter(torch.zeros([normalized_shape]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Normalize over the last dimension.
        x: (..., D)
        """
        layer_mean = torch.mean(x, dim=-1)
        layer_std = torch.std(x, dim=-1, correction=0)
        print(layer_mean)
        result = (x - layer_mean) / (layer_std + self.eps)
        return result*self.weight + self.bias
        
LayerNorm(10).forward(torch.randn(10,10))

tensor([-0.0989, -0.3268,  0.2099,  0.3278,  0.3323, -0.3406,  0.1972,  0.1293,
        -0.1042,  0.1071])


tensor([[-5.1504e-01, -1.2176e+00,  2.5555e+00, -3.8171e-01, -3.2971e+00,
         -5.6900e-01,  6.3473e-01, -1.0426e+00,  1.2167e+00, -2.3417e-03],
        [ 5.1640e-01, -1.5591e+00, -2.6643e-01, -7.4787e-01, -1.0611e+00,
          2.2523e-01, -8.6268e-01, -2.1266e-01,  1.2820e-01, -1.3549e+00],
        [-1.0260e+00, -3.1809e-01, -2.5542e-01,  7.1209e-01,  2.0921e+00,
          1.0925e-01, -8.1726e-02, -1.1531e-01,  7.7773e-01,  3.4758e-01],
        [ 9.4283e-01,  2.5441e+00,  1.3201e+00, -1.5489e+00, -2.0292e+00,
          1.7472e+00, -1.2504e+00, -1.2474e-01,  1.2491e-01,  2.1027e+00],
        [ 3.3295e-01,  3.7805e+00,  3.7085e-01,  1.1326e-02, -7.7049e-01,
          3.9304e-01, -3.1336e-01, -1.3047e+00,  9.5967e-01,  8.4602e-01],
        [-8.3052e-02, -4.4530e+00,  4.5396e-01,  6.6868e-01, -1.0176e+00,
         -1.0889e+00,  6.5812e-01,  7.8234e-02, -6.7013e-01, -9.0464e-01],
        [-6.1597e-01, -7.6302e-01, -1.9380e-01,  1.4401e-01, -1.6850e+00,
          1.2994e+00,  1.8872e+0

In [52]:
class RMSNorm(nn.Module):
    def __init__(self, normalized_shape: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        self.normalized_shape = normalized_shape
        self.weight = nn.Parameter(torch.ones([normalized_shape]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        RMSNorm: x / sqrt(mean(x^2) + eps) * weight
        over the last dimension.
        """
        return (x / (torch.mean(x**2, dim=-1, keepdim=True)**0.5 + self.eps)) * self.weight
RMSNorm(10).forward(torch.randn(1,10))

tensor([[ 1.2734, -1.1069,  0.0739,  0.2700,  1.6816,  0.3672, -0.5372,  0.8667,
         -1.7350, -0.2491]], grad_fn=<MulBackward0>)

## MLPs and residual networks

Now you’ll build larger networks by composing layers.

### MLP
An MLP is a stack of `depth` Linear layers with non-linear activations (use GELU) in between.
In this exercise you’ll support:
- configurable depth
- a hidden dimension
- optional LayerNorm between layers (a common stabilization trick)

A key skill is building networks using `nn.ModuleList` / `nn.Sequential` while keeping shapes consistent.

### Transformer-style FeedForward (FFN)
A transformer block contains a position-wise feedforward network:
- `D -> 4D -> D` (by default)
- activation is typically **GELU**

This is essentially an MLP applied independently at each token position.

### Residual wrapper
Residual connections are the simplest form of “skip connection”:
- output is `x + fn(x)`

They improve gradient flow and allow training deeper networks more reliably.

In [None]:
class MLP(nn.Module):
    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
        out_dim: int,
        depth: int,
        use_layernorm: bool = False,
    ):
        super().__init__()
        # TODO: build modules (list of Linear + activation)
        # Optionally insert LayerNorm between layers.
        raise NotImplementedError

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        raise NotImplementedError

In [None]:
class FeedForward(nn.Module):
    """
    Transformer-style FFN: D -> 4D -> D (default)
    """

    def __init__(self, d_model: int, d_ff: int | None = None):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        # TODO: create two Linear layers and choose an activation (GELU)
        raise NotImplementedError

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: implement
        raise NotImplementedError

In [None]:
class Residual(nn.Module):
    def __init__(self, fn: nn.Module):
        super().__init__()
        # TODO: implement
        raise NotImplementedError

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        # TODO: return x + fn(x, ...)
        raise NotImplementedError

## Classification problem

In this section you’ll put everything together in a minimal MNIST classification experiment.

You will:
1) download and load the MNIST dataset
2) implement cross-entropy from logits (stable, using log-softmax)
3) build a simple MLP-based classifier (flatten MNIST images first)
4) write a minimal training loop
5) report train loss curve and final accuracy

The goal here is not to reach state-of-the-art accuracy, but to understand the full pipeline:
data → model → logits → loss → gradients → parameter update.

### Model notes
- We want you to combine the MLP we implemented above with the classification head we define below into one model 

### MNIST notes
- MNIST images are `28×28` grayscale.
- After `ToTensor()`, each image has shape `(1, 28, 28)` and values in `[0, 1]`.
- For an MLP classifier, we flatten to a vector of length `784`.

## Deliverables
- Include a plot of your train loss curve in the video submission as well as a final accuracy. 
- **NOTE** Here we don't grade on model performance but we expect you to achieve at least 70% accuracy to confirm a correct model implementation.

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
transform = transforms.ToTensor()  # -> float32 in [0,1], shape (1, 28, 28)

train_ds = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# TODO: define the dataloaders

In [None]:
def cross_entropy_from_logits(
    logits: torch.Tensor,
    targets: torch.Tensor,
) -> torch.Tensor:
    """
    Compute mean cross-entropy loss from logits.

    logits: (B, C)
    targets: (B,) int64

    Requirements:
    - Use log-softmax for stability (do not use torch.nn.CrossEntropyLoss, we check this in the autograder).
    """
    # TODO: implement
    raise NotImplementedError

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, d_in: int, num_classes: int):
        super().__init__()
        # TODO: implement
        raise NotImplementedError

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (..., d_in)
        return: (..., num_classes) logits
        """
        # TODO: implement
        raise NotImplementedError

In [None]:
def accuracy(loader):
    # TODO: You can use this function to evaluate your model accuracy.
    raise NotImplementedError

In [None]:
def train_classifier(
    model: nn.Module,
    train_data_loader: DataLoader,
    test_data_loader: DataLoader,
    lr: float,
    epochs: int,
    seed: int = 0,
) -> list[float]:
    """
    Minimal training loop for MNIST classification.

    Steps:
    - define optimizer
    - for each epoch:
        - sample minibatches
        - forward -> cross-entropy -> backward -> optimizer step
      - compute test accuracy at the end of each epoch
    - return list of training losses (one per update step)

    Requirements:
    - call model.train() during training and model.eval() during evaluation
    - do not use torch.nn.CrossEntropyLoss (use your cross_entropy_from_logits)
    """
    # TODO: implement
    raise NotImplementedError
