# Introduction

In this notebook, we'll implement simple RNNs and LSTMs, then explore how gradients flow through these different networks.

This notebook does not require a Colab GPU. If it's enabled, you can turn it off through Runtime -> Change runtime type. (This will make it more likely for you to get Colab GPU access later in the REAL_RNN_LSTM.ipynb problem.)


# Imports

Note: the ipympl installation will require you to restart the colab runtime.


In [None]:
!pip install ipympl


In [None]:
from __future__ import annotations
import copy
from typing import Callable

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, widgets, Layout

try:
    from google.colab import output
    output.enable_custom_widget_manager()
    RUNNING_IN_COLAB = True
except ImportError:
    RUNNING_IN_COLAB = False

# === Reproducibility ===
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)


In [None]:
%matplotlib ipympl


# 1.A: Implementing a RNN Layer

Consider using PyTorch's [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear). You can implement this with either one Linear layer or two. If you use two, remember that you only need to include a bias term for one of the linear layers.

**Mathematical Formulation:**

$$h_t = \sigma(W^h h_{t-1} + W^x x_t + b)$$

where:
- $W^h$ is the hidden-to-hidden weight matrix (hidden_size × hidden_size)
- $W^x$ is the input-to-hidden weight matrix (input_size × hidden_size)  
- $b$ is the bias vector (hidden_size)
- $\sigma$ is the nonlinearity (e.g., tanh)

![RNN Unrolling](img/1a.png)

*Figure: RNN unrolling across timesteps. Source: [Colah's Blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)*


In [None]:
class RNNLayer(nn.Module):
    """Single-layer Elman RNN implementation.
    
    Implements the recurrence relation:
        h_t = σ(W_hh @ h_{t-1} + W_ih @ x_t + b)
    
    Learning Objectives:
        - Understand how hidden state carries temporal information
        - See the recurrence relation implemented in code
        - Observe gradient flow through time via retain_grad()
    
    Attributes:
        input_size: Dimension of input features at each timestep.
        hidden_size: Dimension of hidden state (also output dimension).
        nonlinearity: Activation function applied after linear transformation.
    
    References:
        - Elman, J. L. (1990). "Finding structure in time"
        - PyTorch nn.RNN for production implementation
    """
    
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        nonlinearity: Callable[[torch.Tensor], torch.Tensor] = torch.tanh
    ) -> None:
        """Initialize a single RNN layer.
        
        Args:
            input_size: Data input feature dimension.
            hidden_size: RNN hidden state size (also the output feature dimension).
            nonlinearity: Nonlinearity applied to the RNN output. Default: tanh.
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        
        ##############################################################################
        # TODO: Initialize any parameters your class needs.                          #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """RNN forward pass implementing h_t = σ(W_hh @ h_{t-1} + W_ih @ x_t + b).
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_size).
        
        Returns:
            all_h: All hidden states of shape (batch_size, seq_len, hidden_size).
                   Contains [h_1, h_2, ..., h_T] for each sequence.
            last_h: Final hidden state of shape (batch_size, hidden_size).
                    This is h_T, useful for sequence classification.
        """
        hidden_states = []  # Will store [h_1, h_2, ..., h_T]
        
        ##############################################################################
        # TODO: Implement the RNN forward step                                       #
        # 1. Initialize h_0 with zeros: shape (batch_size, hidden_size)              #
        # 2. Loop over timesteps t = 0, 1, ..., seq_len-1:                           #
        #    - Compute h_{t+1} = σ(W_hh @ h_t + W_ih @ x_t + b)                       #
        #    - Append h_{t+1} to hidden_states list                                  #
        # 3. Set last_h = h_T (the final hidden state)                               #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################

        # Store hidden states for gradient analysis (used in visualization)
        # hidden_states should contain T tensors, each of shape (batch_size, hidden_size)
        self.store_hidden_states_for_grad(hidden_states)
        
        # Stack list of tensors into single tensor
        # Shape transformation: List[T × (B, H)] → (B, T, H)
        all_h = torch.stack(hidden_states, dim=1)  # (batch_size, seq_len, hidden_size)
        return all_h, last_h

    def store_hidden_states_for_grad(self, hidden_states: list[torch.Tensor]) -> None:
        """Store hidden states and enable gradient computation for analysis.
        
        Args:
            hidden_states: List of hidden state tensors [h_1, ..., h_T].
        """
        for h in hidden_states:
            h.retain_grad()
        self.h_list = hidden_states


### Test Cases

If your implementation is correct, you should expect to see errors of less than 1e-4.


In [None]:
# === Test 1: Single input, single timestep ===
rnn = RNNLayer(input_size=1, hidden_size=1)

rnn.load_state_dict({k: v * 0 + 0.1 for k, v in rnn.state_dict().items()})

data = torch.ones((1, 1, 1))
expected_out = torch.FloatTensor([[[0.1973753273487091]]])

all_h, last_h = rnn(data)

assert all_h.shape == expected_out.shape, (
    f"Shape mismatch: expected {expected_out.shape}, got {all_h.shape}"
)
assert torch.all(torch.isclose(all_h, last_h)), (
    "For seq_len=1, all_h and last_h should be identical"
)
print(f"Expected: {expected_out.item():.6f}")
print(f"Got:      {last_h.item():.6f}")
print(f"Max error: {torch.max(torch.abs(expected_out - last_h)).item():.2e}")


In [None]:
# === Test 2: Multiple inputs, multiple timesteps, linear activation ===
rnn = RNNLayer(
    input_size=2,
    hidden_size=3,
    nonlinearity=lambda x: x
)

# Verify parameter count

num_params = sum(p.numel() for p in rnn.parameters())
assert num_params == 18, (
    f"Expected 18 parameters: W_ih(2×3=6) + W_hh(3×3=9) + b(3) = 18, "
    f"but found {num_params}"
)
print(f"Parameter count correct: {num_params}")

rnn.load_state_dict({k: v * 0 - 0.1 for k, v in rnn.state_dict().items()})

data = torch.FloatTensor([
    [[0.1, 0.15], [0.2, 0.25], [0.3, 0.35], [0.4, 0.45]],
    [[-0.1, -1.5], [-0.2, -2.5], [-0.3, -3.5], [-0.4, -0.45]]
])

expected_all_h = torch.FloatTensor([
    [[-0.1250, -0.1250, -0.1250],
     [-0.1075, -0.1075, -0.1075],
     [-0.1328, -0.1328, -0.1328],
     [-0.1452, -0.1452, -0.1452]],
    [[ 0.0600,  0.0600,  0.0600],
     [ 0.1520,  0.1520,  0.1520],
     [ 0.2344,  0.2344,  0.2344],
     [-0.0853, -0.0853, -0.0853]]
])

expected_last_h = torch.FloatTensor([
    [-0.1452, -0.1452, -0.1452],
    [-0.0853, -0.0853, -0.0853]
])

all_h, last_h = rnn(data)

assert all_h.shape == expected_all_h.shape, (
    f"all_h shape mismatch: expected {expected_all_h.shape}, got {all_h.shape}"
)
assert last_h.shape == expected_last_h.shape, (
    f"last_h shape mismatch: expected {expected_last_h.shape}, got {last_h.shape}"
)

print(f"Max error all_h:  {torch.max(torch.abs(expected_all_h - all_h)).item():.2e}")
print(f"Max error last_h: {torch.max(torch.abs(expected_last_h - last_h)).item():.2e}")

# 1.B: Implementing a RNN Regression Model

Now we'll use the RNN layer in a regression model by adding a final linear layer on top:

$$\hat{y}_t = W^f h_t + b^f$$

This transforms the hidden state at each timestep into a prediction.


In [None]:
class RecurrentRegressionModel(nn.Module):
    """RNN-based regression model with linear output layer.
    
    Architecture: Input → RNN/LSTM → Linear → Output
    Computes: h_t = RNN(x_t, h_{t-1}), then ŷ_t = W^f @ h_t + b^f
    
    Attributes:
        recurrent_net: The underlying RNN or LSTM module.
        output_dim: Dimension of the output predictions.
    """
    
    def __init__(
        self,
        recurrent_net: nn.Module,
        output_dim: int = 1
    ) -> None:
        """Initialize a simple RNN regression model.
        
        Args:
            recurrent_net: An RNN or LSTM module (single or multi-layer).
            output_dim: Feature dimension of the output predictions.
        """
        super().__init__()
        self.recurrent_net = recurrent_net
        self.output_dim = output_dim
        
        ##############################################################################
        # TODO: Initialize any parameters you need                                   #
        # HINT: Use recurrent_net.hidden_size to find the hidden state size          #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through RNN and output layer.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_size).
        
        Returns:
            out: Predictions of shape (batch_size, seq_len, output_dim).
            all_h: Hidden states of shape (batch_size, seq_len, hidden_size).
        """
        ##############################################################################
        # TODO: Implement the forward step.                                          #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################
        return out, all_h


## Tests


In [None]:
rnn = RecurrentRegressionModel(RNNLayer(2, 3), 4)

num_params = sum(p.numel() for p in rnn.parameters())
assert num_params == 34, f'expected 34 parameters but found {num_params}'

rnn.load_state_dict({k: v * 0 - 0.1 for k, v in rnn.state_dict().items()})
data = torch.FloatTensor([
    [[0.1, 0.15], [0.2, 0.25], [0.3, 0.35], [0.4, 0.45]],
    [[-0.1, -1.5], [-0.2, -2.5], [-0.3, -3.5], [-0.4, -0.45]]
])
expected_preds = torch.FloatTensor([
    [[-0.0627, -0.0627, -0.0627, -0.0627],
     [-0.0678, -0.0678, -0.0678, -0.0678],
     [-0.0604, -0.0604, -0.0604, -0.0604],
     [-0.0567, -0.0567, -0.0567, -0.0567]],
    [[-0.1180, -0.1180, -0.1180, -0.1180],
     [-0.1453, -0.1453, -0.1453, -0.1453],
     [-0.1692, -0.1692, -0.1692, -0.1692],
     [-0.0748, -0.0748, -0.0748, -0.0748]]
])
expected_all_h = torch.FloatTensor([
    [[-0.1244, -0.1244, -0.1244],
     [-0.1073, -0.1073, -0.1073],
     [-0.1320, -0.1320, -0.1320],
     [-0.1444, -0.1444, -0.1444]],
    [[ 0.0599,  0.0599,  0.0599],
     [ 0.1509,  0.1509,  0.1509],
     [ 0.2305,  0.2305,  0.2305],
     [-0.0840, -0.0840, -0.0840]]
])
preds, all_h = rnn(data)
assert all_h.shape == expected_all_h.shape
assert preds.shape == expected_preds.shape
print(f'Max error all_h: {torch.max(torch.abs(expected_all_h - all_h)).item()}')
print(f'Max error preds: {torch.max(torch.abs(expected_preds - preds)).item()}')


# Problem 1.C: Dataset and Loss Function


## 1.C.i: Understanding the Dataset (no implementation needed)

Inspect the code and plots below to visualize the dataset.

![RNN Prediction Types](img/1c.png)

*Figure: Different RNN prediction patterns - we focus on many-to-many and many-to-one.*


In [None]:
def generate_batch(
    seq_len: int = 10,
    batch_size: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate synthetic time series data for cumulative average prediction.
    
    Args:
        seq_len: Length of each sequence.
        batch_size: Number of sequences per batch.
    
    Returns:
        data: Random values of shape (batch_size, seq_len, 1).
        target: Cumulative averages of shape (batch_size, seq_len, 1).
    """
    data = torch.randn(size=(batch_size, seq_len, 1))
    sums = torch.cumsum(data, dim=1)
    div = (torch.arange(seq_len) + 1).unsqueeze(0).unsqueeze(2)
    target = sums / div
    return data, target


In [None]:
x, y = generate_batch(seq_len=10, batch_size=4)

for i in range(4):
    fig, ax1 = plt.subplots(1)
    ax1.plot(x[i, :, 0])
    ax1.plot(y[i, :, 0])
    ax1.legend(['x', 'y'])
    plt.title('Targets at all timesteps')
    plt.show()

for i in range(4):
    fig, ax1 = plt.subplots(1)
    ax1.plot(x[i, :, 0])
    ax1.plot(np.arange(10), [y[i, -1].item()] * 10)
    ax1.legend(['x', 'y'])
    plt.title('Predict only at the last timestep')
    plt.show()


## 1.C.ii: Implement the Loss Function


In [None]:
def loss_fn(
    pred: torch.Tensor,
    y: torch.Tensor,
    last_timestep_only: bool = False
) -> torch.Tensor:
    """Compute MSE loss for sequence prediction.
    
    Args:
        pred: Model predictions of size (batch, seq_len, 1).
        y: Targets of size (batch, seq_len, 1).
        last_timestep_only: If True, compute loss only at final timestep.
    
    Returns:
        loss: Scalar MSE loss between pred and true labels.
    """
    ##############################################################################
    # TODO: Implement the loss (HINT: look for pytorch's MSELoss function)       #
    ##############################################################################
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return loss


### Tests

You should see errors < 1e-4


In [None]:
pred = torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
y = torch.FloatTensor([[-1.1, -1.2, -1.3], [-1.4, -1.5, -1.6]])
loss_all = loss_fn(pred, y, last_timestep_only=False)
loss_last = loss_fn(pred, y, last_timestep_only=True)
assert loss_all.shape == loss_last.shape == torch.Size([])
print(f'Max error loss_all: {torch.abs(loss_all - torch.tensor(3.0067)).item()}')
print(f'Max error loss_last: {torch.abs(loss_last - torch.tensor(3.7)).item()}')


# 1.D: Analyzing RNN Gradients

**Key Insight:** When backpropagating through many timesteps, gradients can vanish or explode.

$$\frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial h_T} \frac{\partial h_T}{\partial W} + \frac{\partial \mathcal{L}}{\partial h_{T-1}} \frac{\partial h_{T-1}}{\partial W} + \ldots + \frac{\partial \mathcal{L}}{\partial h_1} \frac{\partial h_1}{\partial W}$$

The later terms in this sum often end up with either very small magnitude (vanishing) or very large magnitude (exploding).


You do not need to understand the details of the GradientVisualizer class in order to complete this problem.


In [None]:
def biggest_eig_magnitude(matrix: torch.Tensor) -> float:
    """Compute the magnitude of the largest eigenvalue.
    
    Important because: |λ_max| > 1 → explosion, |λ_max| < 1 → vanishing
    
    Args:
        matrix: A square matrix (n × n).
    
    Returns:
        The magnitude of the eigenvalue with largest absolute value.
    """
    h, w = matrix.shape
    assert h == w, f'Matrix has shape {matrix.shape}, but eigenvalues can only be computed for square matrices'
    eigs = torch.linalg.eigvals(matrix)
    eig_magnitude = eigs.abs()
    eigs_sorted = sorted([i.item() for i in eig_magnitude], reverse=True)
    first_eig_magnitude = eigs_sorted[0]
    return first_eig_magnitude


In [None]:
class GradientVisualizer:
    """Interactive visualization for RNN gradient flow analysis.
    
    Creates an interactive plot showing:
    1. Hidden state magnitudes at each timestep
    2. Gradient magnitudes (∂L/∂h_t) flowing backward through time
    """

    def __init__(self, rnn: RNNLayer, last_timestep_only: bool) -> None:
        """Initialize the gradient visualizer.
        
        Args:
            rnn: RNN module to visualize.
            last_timestep_only: If True, compute loss only at final timestep.
        """
        self.rnn = rnn
        self.last_timestep_only = last_timestep_only
        self.model = RecurrentRegressionModel(rnn)
        self.original_weights = copy.deepcopy(rnn.state_dict())

        # Generate a single batch to be used repeatedly
        self.x, self.y = generate_batch(seq_len=10)
        print(f'Data point: x={np.round(self.x[0, :, 0].detach().cpu().numpy(), 2)}, y={np.round(self.y.squeeze().detach().cpu().numpy(), 2)}')

    def plot_visuals(self):
        """Generate plots which will be updated in realtime."""
        fig, (ax1, ax2) = plt.subplots(1, 2)
        ax1.set_title('RNN Outputs')
        ax1.set_xlabel('Unroll Timestep')
        ax1.set_ylabel('Hidden State Norm')
        ax1.set_ylim(-1, 5)
        plt_1 = ax1.plot(np.arange(1, 11), np.zeros(10) + 1)
        plt_1 = plt_1[0]

        ax2.set_title('Gradients')
        ax2.set_xlabel('Unroll Timestep')
        ax2.set_ylabel('RNN dLoss/d a_t Gradient Magitude')
        ax2.set_ylim((10**-6, 1e5))
        ax2.set_yscale('log')
        ax2.set_xticks(np.arange(10), np.arange(10, 0, -1))
        plt_2 = ax2.plot(np.arange(10), np.arange(10) + 1)
        plt_2 = plt_2[0]
        self.fig = fig
        self.plots = [plt_1, plt_2]
        return plt_1, plt_2, fig

    def update_plots(self, weight_val: float = 0, bias_val: float = 0) -> None:
        """Update visualization with scaled weights."""
        # Scale the original RNN weights by a constant
        w_dict = copy.deepcopy(self.original_weights)
        ##############################################################################
        # TODO: Scale all W matrixes by weight_val, and all bias matrices by bias_val#
        # If you're using PyTorch nn.Linear layers, you don't need to modify the code#
        # provided, but if you're using custom layers, modify this block.            #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################
        self.rnn.load_state_dict(w_dict)

        # Don't compute for LSTMs, which don't have behavior dependent on a single eigenvalue
        if isinstance(self.rnn, RNNLayer):
            ##############################################################################
            # TODO: Set W = the weight which most affects exploding/vanishing gradients  #
            # Hint: Call module.weight or module.bias on the module you want to use      #
            # If you used a single Linear layer, slice a square matrix from it.          #
            ##############################################################################
            
            ##############################################################################
            #                               END OF YOUR CODE                             #
            ##############################################################################
            biggest_eig = biggest_eig_magnitude(W)
            print(f' Biggest eigenvalue magnitude: {biggest_eig:.3}')

        # Run model
        pred, h = self.model(self.x)
        loss = loss_fn(pred, self.y, self.last_timestep_only)
        n_steps = len(h[0])

        plt_1, plt_2 = self.plots

        # Plot the hidden state magnitude
        max_h = torch.linalg.norm(h[0], dim=-1).detach().cpu().numpy()
        print('Max H', ' '.join([f'{num:.3}' for num in max_h]))
        plt_1.set_data(np.arange(1, n_steps + 1), np.array(max_h))
        
        # Compute the gradient for the loss wrt the stored hidden states
        grads = [torch.linalg.norm(num).item() for num in torch.autograd.grad(loss, self.rnn.h_list)][::-1]
        print('gradients d Loss/d h_t', ' '.join([f'{num:.3}' for num in grads]))
        plt_2.set_data(np.arange(n_steps), np.array(grads) + 1e-6)
        self.fig.canvas.draw_idle()

    def create_visualization(self):
        """Create interactive widget with sliders."""
        self.plot_visuals()
        ip = interactive(
            self.update_plots,
            weight_val=widgets.FloatSlider(value=0, min=-5, max=5, step=0.05, description="weight_scale", layout=Layout(width='100%')),
            bias_val=widgets.FloatSlider(value=0, min=-5, max=5, step=0.05, description="bias_scale", layout=Layout(width='100%')),
        )
        return ip


Adjust the sliders to rescale the weight and bias parameters in the RNN. Observe the effect on exploding and vanishing gradients.

**Parameters to try varying:**
- `nonlinearity`: Try `lambda x: x` (identity), `F.relu`, or `torch.tanh`
- `last_target_only`: Compare `True` vs `False`


In [None]:
hidden_size = 16
nonlinearity = lambda x: x  # options: lambda x: x, F.relu, torch.tanh
last_target_only = True

rnn = RNNLayer(1, hidden_size, nonlinearity=nonlinearity)
gv = GradientVisualizer(rnn, last_target_only)
gv.create_visualization()

# If the slider doesn't work, try calling gv.update_plots with various values


# Problem 1.K: Making a Multi-Layer RNN

![Multi-layer RNN](img/1h.png)

*Figure: Multi-layer RNN architecture. Depth (vertical) vs Time (horizontal).*


## 1.K.i: Implementing Multi-Layer Models


In [None]:
class RNN(nn.Module):
    """Multi-layer RNN implementation.
    
    Attributes:
        input_size: Dimension of input features.
        hidden_size: Dimension of hidden state (same for all layers).
        num_layers: Number of stacked RNN layers.
    """
    
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int
    ) -> None:
        """Initialize a multilayer RNN.
        
        Args:
            input_size: Data input feature dimension.
            hidden_size: Hidden state size (also the output feature dimension).
            num_layers: Number of layers.
        """
        super().__init__()
        assert num_layers >= 1
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        ##############################################################################
        # TODO: Initialize any parameters your class needs.                          #
        # Consider using nn.ModuleList or nn.ModuleDict.                             #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Multilayer RNN forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_size).
        
        Returns:
            last_layer_h: Outputs from the last layer (batch_size, seq_len, hidden_size).
            last_step_h: All hidden states from the last step (num_layers, batch_size, hidden_size).
        """
        ##############################################################################
        # TODO: Implement the RNN forward step                                       #
        ##############################################################################
        
        ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################
        return last_layer_h, last_step_h


### Test Cases


In [None]:
rnn = RNN(2, 3, 1)
rnn.load_state_dict({k: v * 0 - 0.1 for k, v in rnn.state_dict().items()})
data = torch.FloatTensor([[[0.1, 0.15], [0.2, 0.25], [0.3, 0.35], [0.4, 0.45]], [[-0.1, -1.5], [-0.2, -2.5], [-0.3, -3.5], [-0.4, -0.45]]])
expected_all_h = torch.FloatTensor([[[-0.1244, -0.1244, -0.1244],
         [-0.1073, -0.1073, -0.1073],
         [-0.1320, -0.1320, -0.1320],
         [-0.1444, -0.1444, -0.1444]],
        [[ 0.0599,  0.0599,  0.0599],
         [ 0.1509,  0.1509,  0.1509],
         [ 0.2305,  0.2305,  0.2305],
         [-0.0840, -0.0840, -0.0840]]])
expected_last_h = torch.FloatTensor([[[-0.1444, -0.1444, -0.1444],
         [-0.0840, -0.0840, -0.0840]]])
all_h, last_h = rnn(data)
assert all_h.shape == expected_all_h.shape
assert last_h.shape == expected_last_h.shape
print(f'Max error all_h: {torch.max(torch.abs(expected_all_h - all_h)).item()}')
print(f'Max error last_h: {torch.max(torch.abs(expected_last_h - last_h)).item()}')

rnn = RNN(2, 3, 2)
rnn.load_state_dict({k: v * 0 - 0.1 for k, v in rnn.state_dict().items()})
expected_all_h = torch.FloatTensor([[[-0.0626, -0.0626, -0.0626],
         [-0.0490, -0.0490, -0.0490],
         [-0.0457, -0.0457, -0.0457],
         [-0.0430, -0.0430, -0.0430]],
        [[-0.1174, -0.1174, -0.1174],
         [-0.1096, -0.1096, -0.1096],
         [-0.1354, -0.1354, -0.1354],
         [-0.0342, -0.0342, -0.0342]]])
expected_last_h = torch.FloatTensor([[[-0.1444, -0.1444, -0.1444],
         [-0.0840, -0.0840, -0.0840]],
        [[-0.0430, -0.0430, -0.0430],
         [-0.0342, -0.0342, -0.0342]]])
all_h, last_h = rnn(data)
assert all_h.shape == (2, 4, 3)
assert last_h.shape == (2, 2, 3)
print(f'Max error all_h: {torch.max(torch.abs(expected_all_h - all_h)).item()}')
print(f'Max error last_h: {torch.max(torch.abs(expected_last_h - last_h)).item()}')


## 1.K.ii: Training Your Model


In [None]:
def train(
    model: nn.Module,
    optimizer: optim.Optimizer,
    num_batches: int,
    last_timestep_only: bool,
    seq_len: int = 10,
    batch_size: int = 32
) -> list[float]:
    """Train the RNN model on the running average task.
    
    Args:
        model: The RecurrentRegressionModel to train.
        optimizer: PyTorch optimizer.
        num_batches: Number of training iterations.
        last_timestep_only: Whether to compute loss only at final timestep.
        seq_len: Length of generated sequences.
        batch_size: Number of sequences per batch.
    
    Returns:
        List of loss values for each batch.
    """
    model.train()
    losses = []
    
    from tqdm import tqdm
    t = tqdm(range(0, num_batches))
    for i in t:
        data, labels = generate_batch(seq_len=seq_len, batch_size=batch_size)
        pred, h = model(data)
        loss = loss_fn(pred, labels, last_timestep_only)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            t.set_description(f"Batch: {i} Loss: {np.mean(losses[-10:])}")
    return losses


In [None]:
def train_all(
    hidden_size: int,
    lr: float,
    num_batches: int,
    last_timestep_only: bool
) -> tuple[list[nn.Module], list[list[float]]]:
    """Train and compare 1-layer and 2-layer RNN models."""
    input_size = 1
    rnn_1_layer = RecurrentRegressionModel(RNN(input_size, hidden_size, 1))
    rnn_2_layer = RecurrentRegressionModel(RNN(input_size, hidden_size, 2))
    models = [rnn_1_layer, rnn_2_layer]
    model_names = ['rnn_1_layer', 'rnn_2_layer']

    losses = []
    for model in models:
        optimizer = optim.Adam(model.parameters(), lr=lr)
        loss = train(model, optimizer, num_batches, last_timestep_only)
        losses.append(loss)

    # Visualize the results
    fig, ax1 = plt.subplots(1)
    for loss in losses:
        ax1.plot(loss)
    ax1.legend(model_names)
    plt.show()

    batch_size = 4
    x, y = generate_batch(seq_len=10, batch_size=batch_size)
    preds_list = [model(x)[0] for model in models]
    for i in range(batch_size):
        fig, ax1 = plt.subplots(1)
        ax1.plot(x[i, :, 0])
        if last_timestep_only:
            ax1.plot(np.arange(10), [y[i, -1].item()] * 10, 'bo')
        else:
            ax1.plot(y[i, :, 0], 'bo')
        for pred in preds_list:
            if last_timestep_only:
                ax1.plot(np.arange(10), [pred[i, -1, 0].detach().cpu().numpy()] * 10)
            else:
                ax1.plot(pred[i, :, 0].detach().cpu().numpy())
        ax1.legend(['x', 'y'] + model_names)
        plt.show()
    return models, losses


In [None]:
HIDDEN_SIZE = 32
LEARNING_RATE = 1e-4
NUM_BATCHES = 5000

torch.manual_seed(0)

# Train with last_timestep_only=False (predict at all timesteps)
last_timestep_only = False
predict_all_models, predict_all_losses = train_all(HIDDEN_SIZE, LEARNING_RATE, NUM_BATCHES, last_timestep_only)

# Train with last_timestep_only=True (predict only at last timestep)
last_timestep_only = True
predict_one_models, predict_one_losses = train_all(HIDDEN_SIZE, LEARNING_RATE, NUM_BATCHES, last_timestep_only)
