### Explanation in RNN Context

In an RNN, the hidden state at each time step is computed using both the previous hidden state and the current input. The equations can be represented as follows:

$ h_t = \tanh(W_x x_t + W_h h_{t-1} + b) $

In the provided code snippet:

- `i2h` represents $ W_x x_t $, the transformation of the input by the input weight matrix.
- `h2h` represents $ W_h h_{t-1} $, the transformation of the previous hidden state by the hidden state weight matrix.
- The sum `i2h + h2h` combines these transformations to form the input to the activation function.
- `next_h.tanh()` applies the $\tanh$ activation function to get the next hidden state.

This process is a core part of how RNNs maintain and update their hidden states over sequences of data, allowing them to capture temporal dependencies and patterns.

In [12]:
import torch

x = torch.randn(1,10, requires_grad=True)
prev_h = torch.randn(1,20, requires_grad=True)
W_h = torch.randn(20,20, requires_grad=True)
W_x = torch.randn(20,10, requires_grad=True)

# torch.mm matrix multiplication
# .t is transfer
i2h = torch.mm(W_x, x.t())
h2h = torch.mm(W_h, prev_h.t())

next_h = i2h + h2h
next_h = next_h.tanh()
loss = next_h.sum()
loss.backward()

# print(f"Gradient of x: {x.grad}")
# print(f"Gradient of prev_h: {prev_h.grad}")
# print(f"Gradient of W_h: {W_h.grad}")
# print(f"Gradient of W_x: {W_x.grad}")
print(loss)

tensor(1.7539, grad_fn=<SumBackward0>)
