## Cell state versus hidden state

The distinction between **hidden state** and **cell state** in an LSTM (Long Short-Term Memory) cell is central to why LSTMs are more effective than traditional RNNs at capturing **long-term dependencies**.

---

### Detailed Explanation:

#### 1. **Cell State ($C_t$) = Internal Memory**

* The **cell state** is like the "long-term memory" of the LSTM.
* It flows along the sequence **mostly unchanged**, only modulated by the **forget** and **input gates**.
* This makes it easier to retain information over long time spans.
* Analogy: a notebook you carry and write in (add/delete info) as you go along a sequence.

#### 2. **Hidden State ($h_t$) = Output at Time t**

* The **hidden state** is the **output** that gets passed to:

  * the **next time step**, and
  * the **next layer** (if stacked).
* It is computed by applying a `tanh` on the current cell state and multiplying it by the **output gate**.
* Analogy: the filtered view of your notebook that you share with others at each step.

---

### Why Both are Needed?

1. **Hidden state** alone (like in vanilla RNNs) is often **too volatile**. It gets updated at every time step and forgets things quickly.
2. **Cell state** allows for more **stable and controlled memory**:

   * You decide (via gates) what to forget, what to remember, and what to expose.
3. Having both lets the LSTM **decouple memory from output**:

   * The model can **store** useful information without **exposing** it immediately.

---

### Textual Summary:

| Component        | Symbol | Role                                                                                     |
| ---------------- | ------ | ---------------------------------------------------------------------------------------- |
| **Cell State**   | $C_t$  | **Memory** of the cell; stores long-term information across time.                        |
| **Hidden State** | $h_t$  | **Output** of the cell; controls what is exposed to the rest of the network at time $t$. |

---
### Visual Summary:

```
           +--------------+     
           |   Cell State |  <--- long-term memory (C_t)
           +--------------+
                 ↑   ↑
         forget ⊙    ⊕  input
                 ↓   ↓
             +---------+
             |   LSTM  |
             +---------+
                 ↓
           tanh(C_t) ⊙ output gate
                 ↓
           Hidden State h_t  ---> output to next time step/layer
```

## Cell state versus hidden state

In an LSTM:

> **The hidden state $h_t$ is the short-term memory.**
> **The cell state $C_t$ is the long-term memory.**

---

### Intuition Behind the Separation

| Component        | Symbol | Memory Type           | Purpose                                                                                        |
| ---------------- | ------ | --------------------- | ---------------------------------------------------------------------------------------------- |
| **Hidden State** | $h_t$  | **Short-term** memory | Captures recent, transient information; also serves as the **output** of the LSTM at time $t$. |
| **Cell State**   | $C_t$  | **Long-term** memory  | Captures persistent memory across time steps; modulated by gates.                              |

---

### Why does this matter?

* **$h_t$** gets **fully recomputed** at every time step. It reflects what the LSTM "wants to say now" — it’s influenced by the **current input** and recent context.
* **$C_t$** can persist information over **long sequences**, **even if it's not relevant to the immediate output**.

---

### Analogy

Think of an LSTM cell as a **human reading a paragraph**:

* **$C_t$** is your overall understanding of the story so far (the broader context you carry with you).
* **$h_t$** is what you’re thinking about *right now*, maybe just the meaning of the current sentence.

---

### What about in vanilla RNNs?

* In a simple RNN, there's **only $h_t$**, which tries to do **both jobs** (memory + output).
* This is why vanilla RNNs struggle with **long-term dependencies** — they can't separate what's important to **remember** vs. what needs to be **output**.

---


# Toy Example 1

In [4]:
import torch
import torch.nn.functional as F

# Inputs (as tensors)
x_t = torch.tensor([0.5, 0.1])             # input at time t
h_t_minus_1 = torch.tensor([0.4, -0.2])    # previous hidden state
C_t_minus_1 = torch.tensor([0.3, 0.0])     # previous cell state

# Concatenation of h_{t-1} and x_t
concat = torch.cat([h_t_minus_1, x_t])     # shape: (4,)

# Simulate fixed weights to reproduce the desired gate outputs
# We manually force values by inverting the activation functions

def simulate_gate(output_vals):
    # Sigmoid inverse: logit(y) = log(y / (1 - y))
    return torch.log(output_vals / (1 - output_vals))

# Target gate values as in the HTML animation
f_t = torch.sigmoid(simulate_gate(torch.tensor([0.8, 0.2])))
i_t = torch.sigmoid(simulate_gate(torch.tensor([0.6, 0.4])))
o_t = torch.sigmoid(simulate_gate(torch.tensor([0.9, 0.3])))

# Candidate vector C̃ₜ (simulated via tanh inverse)
c_tilde = torch.tanh(torch.atanh(torch.tensor([0.7, -0.1])))

# Step 4: new cell state
C_t = f_t * C_t_minus_1 + i_t * c_tilde

# Step 6: new hidden state
h_t = o_t * torch.tanh(C_t)

# Print results with consistent formatting (2 decimal places)
def format_vector(label, vec):
    values = [f"{v.item():.2f}" for v in vec]
    print(f"{label:<30}: [{', '.join(values)}]")

format_vector("Concat [h_{t-1}, x_t]", concat)
format_vector("f_t (forget gate)", f_t)
format_vector("i_t (input gate)", i_t)
format_vector("C̃_t (candidate vector)", c_tilde)
format_vector("C_t (new cell state)", C_t)
format_vector("o_t (output gate)", o_t)
format_vector("h_t (new hidden state)", h_t)


Concat [h_{t-1}, x_t]         : [0.40, -0.20, 0.50, 0.10]
f_t (forget gate)             : [0.80, 0.20]
i_t (input gate)              : [0.60, 0.40]
C̃_t (candidate vector)       : [0.70, -0.10]
C_t (new cell state)          : [0.66, -0.04]
o_t (output gate)             : [0.90, 0.30]
h_t (new hidden state)        : [0.52, -0.01]


# Toy Example 2

In [1]:
import torch
import torch.nn as nn

# Set seed for reproducibility
torch.manual_seed(0)

# Input dimensions
seq_len = 4      # number of time steps
input_dim = 3    # features per time step
hidden_dim = 2   # LSTM hidden state size
batch_size = 1   # for simplicity

# Dummy input (sequence of vectors)
x = torch.randn(batch_size, seq_len, input_dim)

# Define LSTM layer (1 layer, unidirectional)
lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

# Initial hidden and cell states (h0, c0)
h0 = torch.zeros(1, batch_size, hidden_dim)
c0 = torch.zeros(1, batch_size, hidden_dim)

# Forward pass
output, (hn, cn) = lstm(x, (h0, c0))

# Print outputs
print("Input sequence:\n", x.squeeze())
print("\nOutput at each time step:\n", output.squeeze())
print("\nFinal hidden state:\n", hn.squeeze())
print("\nFinal cell state:\n", cn.squeeze())


Input sequence:
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])

Output at each time step:
 tensor([[-0.1491,  0.0378],
        [-0.2774,  0.0892],
        [-0.0008,  0.2280],
        [ 0.0942,  0.1251]], grad_fn=<SqueezeBackward0>)

Final hidden state:
 tensor([0.0942, 0.1251], grad_fn=<SqueezeBackward0>)

Final cell state:
 tensor([0.2080, 0.3955], grad_fn=<SqueezeBackward0>)


In [2]:
import torch
import torch.nn as nn

# Config
input_size = 5     # number of features per time step
hidden_size = 8    # size of hidden state and cell state
seq_len = 4        # length of the input sequence
batch_size = 2     # number of sequences in a batch

# Dummy input: shape = (batch_size, seq_len, input_size)
x = torch.randn(batch_size, seq_len, input_size)

# Define LSTM layer
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)

# Initial hidden state and cell state (num_layers=1 by default)
h0 = torch.zeros(1, batch_size, hidden_size)  # shape: (num_layers, batch_size, hidden_size)
c0 = torch.zeros(1, batch_size, hidden_size)

# Forward pass
output, (hn, cn) = lstm(x, (h0, c0))

# Print shapes
print("Input shape:", x.shape)
print("Output shape (all time steps):", output.shape)
print("Final hidden state shape:", hn.shape)
print("Final cell state shape:", cn.shape)


Input shape: torch.Size([2, 4, 5])
Output shape (all time steps): torch.Size([2, 4, 8])
Final hidden state shape: torch.Size([1, 2, 8])
Final cell state shape: torch.Size([1, 2, 8])
