# Lecture 3: Recurrent and Masking-Based Autoregressive Models

## Recap: Neural Autoregressive Models

**Factorization of joint distribution:**

$$
P(x_1, x_2, \ldots, x_T) = \prod_{t=1}^T P(x_t \mid x_{<t})
$$

**How to compute conditionals with neural nets?**
1. Process context  
   - RNNs (sequential)  
   - Masking-based models (parallel)  
2. Generate probability distribution for next token  

**Example:**
"What do pigs __"

Predict:
$$
P(\text{eat} \mid \text{"What do pigs"})
$$


## Recurrent Neural Networks (RNNs)


![RNN](https://aiml.com/wp-content/uploads/2023/10/RNN-Language-model1-1024x704.png)

**Hidden state recurrence:**

$$
h_t = f(W_{xh}x_t + W_{hh}h_{t-1} + b_h)
$$
$$
P(x_{t+1} \mid x_{\le t}) = \text{softmax}(W_{ho}h_t + b_o)
$$

- $h_t$: hidden state at time $t$  
- $f$: nonlinearity (tanh, ReLU)  

**Strengths**
- Fits sequential data naturally  
- Parameter sharing across time  
- Compact memory footprint  

**Weaknesses**
- Sequential training (slow)  
- Vanishing/exploding gradients  
- Poor long-range memory  


The vanishing and exploding gradient problem is typically analyzed through the hidden-state gradients  

$$
\delta_t = \frac{\partial L_T}{\partial h_t},
$$

since they are the recursive quantities that accumulate products of Jacobians across time.  
The parameter gradients for the recurrent and input weights are directly proportional to these signals.

---

#### Derivation

Hidden pre-activation:
$$
a_t = W_{xh}x_t + W_{hh}h_{t-1} + b_h, \quad h_t = f(a_t).
$$

Gradient w.r.t. $W_{hh}$:
$$
\frac{\partial L_T}{\partial W_{hh}}
= \sum_{t=1}^T \frac{\partial L_T}{\partial a_t}\;\frac{\partial a_t}{\partial W_{hh}}.
$$

- **First factor $\frac{\partial L_T}{\partial a_t}$:**  
  Recall that $h_t = f(a_t)$. By the chain rule,  
  $$
  \frac{\partial L_T}{\partial a_t}
  \;=\;  \frac{\partial L_T}{\partial h_t}\frac{\partial h_t}{\partial a_t}
  \;=\; \delta_t^{\top}D_t,
  $$
  where
  $$
  D_t = \mathrm{diag}\!\big(f'(a_t)\big)
  $$
  is the diagonal Jacobian of the elementwise nonlinearity $f$.  


---

- **Second factor $\frac{\partial a_t}{\partial W_{hh}}$:**  
  The pre-activation is
  $$
  a_t = W_{xh}x_t + W_{hh}h_{t-1} + b_h.
  $$
$$
\frac{\partial L_T}{\partial W_{hh}}
= \sum_{t=1}^T (D_t \delta_t)\, h_{t-1}^\top.
$$

Similarly, for the input weights $W_{xh}$,
$$
\frac{\partial L_T}{\partial W_{xh}}
= \sum_{t=1}^T (D_t \delta_t)\, x_t^\top.
$$

---

#### Interpretation

- If $\delta_t$ **vanishes**, the parameter gradients also vanish → the RNN fails to learn long-term dependencies.  
- If $\delta_t$ **explodes**, the parameter gradients explode → training becomes unstable.  
- The output-layer gradients (e.g. for $W_{ho}$) do not involve long Jacobian chains and remain well-behaved.  

Hence, analyzing $\delta_t$ suffices to understand how vanishing and exploding gradients propagate to recurrent parameter updates.

### Why the result is an outer product (tensor → matrix via contraction)

Let  
$$
a_t = W_{xh}x_t + W_{hh}h_{t-1} + b_h \in \mathbb{R}^H,\quad
h_{t-1}\in\mathbb{R}^H.
$$  
We want $\frac{\partial L_T}{\partial W_{hh}}\in\mathbb{R}^{H\times H}$.

---

#### Elementwise derivation

Write components:  
$$
(a_t)_i = \sum_{j=1}^H (W_{hh})_{ij}\,(h_{t-1})_j + \cdots
$$  
Then  
$$
\frac{\partial (a_t)_i}{\partial (W_{hh})_{pq}}
= \mathbf{1}_{\{i=p\}}\,(h_{t-1})_q.
$$  
Chain rule:  
$$
\frac{\partial L_T}{\partial (W_{hh})_{pq}}
= \sum_{i=1}^H \frac{\partial L_T}{\partial (a_t)_i}\,
   \frac{\partial (a_t)_i}{\partial (W_{hh})_{pq}}
= \left(\frac{\partial L_T}{\partial a_t}\right)_p\,(h_{t-1})_q.
$$  
Using $\frac{\partial L_T}{\partial a_t}=D_t\delta_t$, we get  
$$
\frac{\partial L_T}{\partial (W_{hh})_{pq}}
= (D_t\delta_t)_p\,(h_{t-1})_q.
$$  
Stacking all $(p,q)$ yields the **outer product**:  
$$
\boxed{\;
\frac{\partial L_T}{\partial W_{hh}}
= (D_t\delta_t)\,h_{t-1}^\top
\;}
$$  
Summing over time steps gives  
$$
\boxed{\;
\frac{\partial L_T}{\partial W_{hh}}
= \sum_{t=1}^T (D_t\delta_t)\,h_{t-1}^\top
\;}
$$  

---

**Key idea:** $\frac{\partial a_t}{\partial W_{hh}}$ is a tensor, but in reverse-mode AD it is **immediately contracted** with the upstream vector $ \frac{\partial L_T}{\partial a_t} $, producing the matrix outer product $(D_t\delta_t)\,h_{t-1}^\top$.


## Backpropagation Through Time (BPTT)

**Training objective:**

$$
{L} = - \frac{1}{T}\sum_{t=1}^T \log P(x_t \mid x_{<t})
$$

**Unroll RNN across time:**
- Looks like a deep feedforward net of depth \(T\).  
- Apply backprop through the unrolled graph.  

**Challenges**
- Gradients flow across many steps.  
- Leads to vanishing/exploding gradients.  

**Tricks**
- Gradient clipping (for exploding)  
- Truncated BPTT (limit to last \(k\) steps)  
- Gated RNNs (LSTM/GRU)  


## Vanishing & Exploding Gradients

### Conventions
- **Gradient (column vector):** for a scalar $f:\mathbb{R}^n\!\to\!\mathbb{R}$,
  $$
  \nabla_x f \;\equiv\; \frac{\partial f}{\partial x}\in\mathbb{R}^{n}.
  $$
- **Jacobian (matrix):** for a vector $y:\mathbb{R}^n\!\to\!\mathbb{R}^m$,
  $$
  J_y(x)\;\equiv\;\frac{\partial y}{\partial x}\in\mathbb{R}^{m\times n}.
  $$
We will **only** call something a Jacobian if its output is a **matrix**. All derivatives of a scalar loss are **gradients (vectors)**.

---

## Setup

**Hidden state recurrence**
$$
h_t \;=\; f\!\big(W_{xh}x_t + W_{hh}h_{t-1} + b_h\big)
$$

**Softmax head**
$$
P(x_{t+1}\mid x_{\le t}) \;=\; \mathrm{softmax}(W_{ho}h_t+b_o)
$$

Total loss up to time $T$:
$$
L_T \;=\; \frac{1}{T}\sum_{k=1}^T \ell_k,\qquad
\ell_k=\mathrm{CE}\!\big(\mathrm{softmax}(W_{ho}h_k+b_o),\;x_{k+1}\big).
$$

**Backprop signal (gradient)**
$$
\delta_t \;\equiv\; \frac{\partial L_T}{\partial h_t}\in\mathbb{R}^{H}\quad\text{(a column gradient vector)}.
$$

---

## Local quantities

Let
$$
a_t \;\equiv\; W_{xh}x_t + W_{hh}h_{t-1} + b_h, \quad h_t=f(a_t).
$$

- **Jacobian of the transition** (matrix):
  $$
  J_t \;\equiv\; \frac{\partial h_t}{\partial h_{t-1}}
  \;=\; \frac{\partial f(a_t)}{\partial a_t}\,\frac{\partial a_t}{\partial h_{t-1}}
  \;=\; D_t\,W_{hh}\;\in\;\mathbb{R}^{H\times H},
  $$
  where $D_t=\mathrm{diag}\!\big(f'(a_t)\big)$.

- **Local loss gradient w.r.t. $h_t$** (vector, **not** a Jacobian):
  $$
  g_t \;\equiv\; \frac{\partial \ell_t}{\partial h_t}
  \;=\; W_{ho}^\top\!\big(p_t - y_{t+1}\big)\;\in\;\mathbb{R}^{H},
  \quad p_t=\mathrm{softmax}(W_{ho}h_t+b_o).
  $$



# Expanding the Local Loss Gradient w.r.t. $h_t$

## 1. Softmax output

Define the logits at time $t$:
$$
z_t = W_{ho} h_t + b_o \in \mathbb{R}^V,
$$
where $V$ is the vocabulary size.

The softmax distribution is
$$
p_t = \text{softmax}(z_t),
\qquad
p_{t,i} = \frac{e^{z_{t,i}}}{\sum_{j=1}^V e^{z_{t,j}}}.
$$

---

## 2. Cross-entropy loss

Given the one-hot target vector $y_{t+1} \in \{0,1\}^V$,
$$
\ell_t = -\sum_{i=1}^V y_{t+1,i}\, \log p_{t,i}.
$$

---

## 3. Gradient wrt logits $z_t$

Differentiate $\ell_t$ wrt $z_t$:

$$
\frac{\partial \ell_t}{\partial z_t} \in \mathbb{R}^V.
$$

Using the well-known **softmax + cross-entropy gradient identity**:
$$
\frac{\partial \ell_t}{\partial z_t} = p_t - y_{t+1}.
$$

Derivation of the above equality
---

### Step 1: Cross-entropy loss
For one time step,
$$
\ell_t = -\sum_{i=1}^V y_{t+1,i} \log p_{t,i}.
$$

Since $y_{t+1}$ is one-hot, $\ell_t = -\log p_{t,c}$ where $c$ is the correct class index.

---

### Step 2: Derivative wrt probabilities $p_{t,i}$

The gradient of $\ell_t$ wrt $p_{t,i}$ is
$$
\frac{\partial \ell_t}{\partial p_{t,i}} = -\frac{y_{t+1,i}}{p_{t,i}}.
$$

---

### Step 3: Softmax derivative wrt logits $z_t$

The softmax function is
$$
p_{t,i} = \frac{e^{z_{t,i}}}{\sum_{j=1}^V e^{z_{t,j}}}.
$$

Its derivative wrt $z_{t,j}$ is
$$
\frac{\partial p_{t,i}}{\partial z_{t,j}}
= p_{t,i}(\delta_{ij} - p_{t,j}),
$$
where $\delta_{ij}$ is the Kronecker delta.

---

### Step 4: Chain rule

By chain rule,
$$
\frac{\partial \ell_t}{\partial z_{t,j}}
= \sum_{i=1}^V \frac{\partial \ell_t}{\partial p_{t,i}} \cdot \frac{\partial p_{t,i}}{\partial z_{t,j}}.
$$

Substitute the formulas:
$$
\frac{\partial \ell_t}{\partial z_{t,j}}
= \sum_{i=1}^V \left(-\frac{y_{t+1,i}}{p_{t,i}}\right) \cdot p_{t,i}(\delta_{ij} - p_{t,j}).
$$

Simplify:
$$
\frac{\partial \ell_t}{\partial z_{t,j}}
= \sum_{i=1}^V \big(-y_{t+1,i}\delta_{ij} + y_{t+1,i}p_{t,j}\big).
$$

---

### Step 5: Collapse the sums

- The first term gives $-y_{t+1,j}$ (since only $i=j$ contributes).
- The second term gives $p_{t,j}\sum_{i=1}^V y_{t+1,i}$.

Since $y_{t+1}$ is one-hot, $\sum_{i=1}^V y_{t+1,i} = 1$.

So,
$$
\frac{\partial \ell_t}{\partial z_{t,j}} = p_{t,j} - y_{t+1,j}.
$$

---

### Final vector form

Stacking over all $j=1,\dots,V$,
$$
\boxed{\;\frac{\partial \ell_t}{\partial z_t} = p_t - y_{t+1}\;}
$$

---


- This is a **gradient vector** in $\mathbb{R}^V$, *not* a Jacobian.  
- Each entry is simply the difference between predicted probability and the target indicator.

---

## 4. Gradient wrt hidden state $h_t$

Since $z_t = W_{ho} h_t + b_o$, applying the chain rule:
$$
\frac{\partial \ell_t}{\partial h_t}
= \left(\frac{\partial z_t}{\partial h_t}\right)^\top
\frac{\partial \ell_t}{\partial z_t}.
$$

Here,
$$
\frac{\partial z_t}{\partial h_t} = W_{ho},
$$
so
$$
\boxed{\;\frac{\partial \ell_t}{\partial h_t}
= W_{ho}^\top \big(p_t - y_{t+1}\big).\;}
$$

---




## Chain rule at time $t$

Differentiate $L_T$ w.r.t. $h_t$:
$$
\frac{\partial L_T}{\partial h_t}
= \underbrace{\frac{\partial \ell_t}{\partial h_t}}_{g_t}
+ \sum_{k=t+1}^T \frac{\partial \ell_k}{\partial h_k}\;
      \frac{\partial h_k}{\partial h_{k-1}}\cdots
      \frac{\partial h_{t+2}}{\partial h_{t+1}}\;
      \frac{\partial h_{t+1}}{\partial h_t}.
$$

Group the future part by factoring the **next** step:
$$
\sum_{k=t+1}^T \frac{\partial \ell_k}{\partial h_k}
      \frac{\partial h_k}{\partial h_{k-1}}\cdots
      \frac{\partial h_{t+2}}{\partial h_{t+1}}
= \underbrace{\frac{\partial L_T}{\partial h_{t+1}}}_{\delta_{t+1}}.
$$

Thus
$$
\frac{\partial L_T}{\partial h_t}
= g_t + \left(\frac{\partial h_{t+1}}{\partial h_t}\right)^\top \frac{\partial L_T}{\partial h_{t+1}}.
$$

Using the shorthand $ \delta_t = \frac{\partial L_T}{\partial h_t} $ and $ J_{t+1} = \frac{\partial h_{t+1}}{\partial h_t} $,
$$
\boxed{\;\delta_t = g_t + J_{t+1}^\top \delta_{t+1}\;}
$$
with terminal condition $\delta_{T+1}=0$.


# Backpropagation Through Time (BPTT): Vanishing & Exploding Gradients

---

## General recursion (loss at every time step)

By the chain rule under the **column-gradient convention**,
$$
\boxed{\;\delta_t = g_t + J_{t+1}^\top\,\delta_{t+1}\;}, \qquad t=T,T-1,\dots,1,
$$
with terminal condition
$$
\delta_{T+1} = 0.
$$

Here:
- $\delta_t = \dfrac{\partial L_T}{\partial h_t} \in \mathbb{R}^H$ is the backprop signal,
- $g_t = \dfrac{\partial \ell_t}{\partial h_t} \in \mathbb{R}^H$ is the local gradient at time $t$,
- $J_{t+1} = \dfrac{\partial h_{t+1}}{\partial h_t} \in \mathbb{R}^{H\times H}$ is the Jacobian of the recurrence.

---

## Unrolled form

Expanding the recursion step by step yields:
$$
\delta_t = g_t
+ J_{t+1}^\top g_{t+1}
+ J_{t+1}^\top J_{t+2}^\top g_{t+2}
+ \cdots
+ J_{t+1}^\top J_{t+2}^\top \cdots J_T^\top g_T.
$$

Compactly:
$$
\boxed{\;\delta_t = \sum_{k=t}^T \left( \Big(\prod_{j=t+1}^k J_j^\top\Big) g_k \right)\;}
$$

- If $k=t$, the product is empty $\Rightarrow I$, giving $g_t$.  
- Each future $g_k$ is transported back through all Jacobians from step $k$ down to step $t$.

---

# Why the recursion $\delta_t = g_t + J_{t+1}^\top \delta_{t+1}$ unrolls into a sum

---

## Step 1. Recursion formula

We start from the backpropagation-through-time recurrence:
$$
\delta_t = g_t + J_{t+1}^\top \,\delta_{t+1}.
$$

Here:
- $\delta_t = \dfrac{\partial L_T}{\partial h_t} \in \mathbb{R}^H$ (gradient vector),
- $g_t = \dfrac{\partial \ell_t}{\partial h_t} \in \mathbb{R}^H$ (local gradient at time $t$),
- $J_{t+1} = \dfrac{\partial h_{t+1}}{\partial h_t} \in \mathbb{R}^{H\times H}$ (Jacobian of the hidden recurrence).

---

## Step 2. Substitute one step ahead

Insert the definition of $\delta_{t+1}$:
$$
\delta_{t+1} = g_{t+1} + J_{t+2}^\top \delta_{t+2}.
$$

So,
$$
\delta_t = g_t + J_{t+1}^\top g_{t+1} + J_{t+1}^\top J_{t+2}^\top \delta_{t+2}.
$$

---

## Step 3. Substitute further

Now expand $\delta_{t+2}$:
$$
\delta_{t+2} = g_{t+2} + J_{t+3}^\top \delta_{t+3}.
$$

Plugging this in:
$$
\delta_t = g_t
+ J_{t+1}^\top g_{t+1}
+ J_{t+1}^\top J_{t+2}^\top g_{t+2}
+ J_{t+1}^\top J_{t+2}^\top J_{t+3}^\top \delta_{t+3}.
$$

---

## Step 4. Continue until the end

Repeating this expansion up to time $T$ (where $\delta_{T+1}=0$), we collect all contributions:
$$
\delta_t = g_t
+ J_{t+1}^\top g_{t+1}
+ J_{t+1}^\top J_{t+2}^\top g_{t+2}
+ \cdots
+ J_{t+1}^\top J_{t+2}^\top \cdots J_T^\top g_T.
$$

---

## Step 5. Compact summation notation

We can write the whole expansion in a single summation:
$$
\delta_t = \sum_{k=t}^{T}
\left( \prod_{j=t+1}^k J_j^\top \right) g_k.
$$

- For $k=t$: the product $\prod_{j=t+1}^t$ is **empty**, so by convention it equals the identity $I$, giving just $g_t$.
- For $k=t+1$: the term is $J_{t+1}^\top g_{t+1}$.
- For $k=t+2$: the term is $J_{t+1}^\top J_{t+2}^\top g_{t+2}$.
- …
- For $k=T$: the term is $J_{t+1}^\top J_{t+2}^\top \cdots J_T^\top g_T$.

---


## Special case: loss only at the final step

If $\ell_t = 0$ for all $t < T$, then $g_k=0$ for $k<T$, and
$$
\delta_t = \delta_T \prod_{j=t+1}^T J_j^\top
= \big(J_T J_{T-1}\cdots J_{t+1}\big)^\top \delta_T.
$$

This is the **gradient recurrence** when the loss depends only on $h_T$.

---

## Vanishing and exploding gradients

Using the spectral norm $\|\cdot\|_2$ (note $\|A^\top\|_2=\|A\|_2$),
$$
\|\delta_t\|
\;\le\; \sum_{k=t}^T \|g_k\| \prod_{j=t+1}^k \|J_j^\top\|
\;=\; \sum_{k=t}^T \|g_k\| \prod_{j=t+1}^k \|J_j\|
\;\le\; \sum_{k=t}^T \|g_k\| \prod_{j=t+1}^k \|D_j\|\,\|W_{hh}\|.
$$

- For $f$ is ReLU in normal regimes, $\|D_j\|\le 1$ (often $<1$).
- If $\rho(W_{hh}) < 1$ (spectral radius), products decay exponentially ⇒ **vanishing gradients**.  
- If $\rho(W_{hh}) > 1$, products grow exponentially ⇒ **exploding gradients**.

---

## Simple linear example

Let $f$ be the identity ($D_t=I$, $J_t=W_{hh}$).  
With loss only at the final step:
$$
\delta_t = \big(W_{hh}^{\,T-t}\big)^\top \delta_T.
$$

If $W_{hh}=\alpha I$,
$$
\|\delta_t\| = |\alpha|^{\,T-t}\,\|\delta_T\|.
$$

- **Vanishing:** $\alpha=0.5 \;\Rightarrow\; (0.5)^{20}\approx 10^{-6}$.  
- **Exploding:** $\alpha=1.5 \;\Rightarrow\; (1.5)^{20}\approx 3.3\times 10^3$.

---




## Deep RNNs

**Motivation**
- Single-layer RNN has limited capacity.  
- Stacking multiple layers improves expressiveness.  

**Equations (2-layer RNN):**

$$
h_t^{(1)} = f(W_{xh}x_t + W_{hh}^{(1)}h_{t-1}^{(1)})
$$

$$
h_t^{(2)} = f(W_{h^{(1)}h^{(2)}}h_t^{(1)} + W_{hh}^{(2)}h_{t-1}^{(2)})
$$

$$
y_t = \text{softmax}(W_{hy}h_t^{(2)})
$$

**Pros**
- Capture hierarchical representations.  

**Cons**
- Even worse vanishing/exploding issues.

![Deep RNN](https://i.imgur.com/J3DwxSF.png)


In [None]:
import torch
import torch.nn as nn

# ------------------------
# Single-layer RNN
# ------------------------
rnn_single = nn.RNN(input_size=5, hidden_size=8, num_layers=1, batch_first=True)

# Example input: batch=2, seq_len=4, input_dim=5
x = torch.randn(2, 4, 5)

# Initial hidden state: [num_layers, batch, hidden_size]
h0_single = torch.zeros(1, 2, 8)

# Forward pass
output_single, hn_single = rnn_single(x, h0_single)

print("=== Single-layer RNN ===")
print("Input shape:", x.shape)             # [2, 4, 5]
print("Output shape:", output_single.shape) # [2, 4, 8]
print("Last hidden state shape:", hn_single.shape) # [1, 2, 8]

# ------------------------
# Deep RNN (3 layers stacked)
# ------------------------
rnn_deep = nn.RNN(input_size=5, hidden_size=8, num_layers=3, batch_first=True)

# Initial hidden state: [num_layers, batch, hidden_size]
h0_deep = torch.zeros(3, 2, 8)

# Forward pass
output_deep, hn_deep = rnn_deep(x, h0_deep)

print("\n=== Deep RNN (3 layers) ===")
print("Input shape:", x.shape)              # [2, 4, 5]
print("Output shape:", output_deep.shape)   # [2, 4, 8]
print("Last hidden state shape:", hn_deep.shape) # [3, 2, 8]

## Long Short-Term Memory (LSTM)

**Key idea:** Cell state with gates → stable long-range memory.  

![LSTM](https://external-preview.redd.it/jL97dVEQcqoX79lFAI5J4fomXXU_hBHtlzyAs-7xQ-Q.png?format=pjpg&auto=webp&s=6400685c6a57df685241f8d6d456d4aae3c74105)

**Equivalent Equations:**
$$
\begin{aligned}
i_t &= \sigma(W_{xi}x_t + W_{hi}h_{t-1} + b_i) \quad &\text{(input gate)} \\
f_t &= \sigma(W_{xf}x_t + W_{hf}h_{t-1} + b_f) \quad &\text{(forget gate)} \\
o_t &= \sigma(W_{xo}x_t + W_{ho}h_{t-1} + b_o) \quad &\text{(output gate)} \\
\tilde{c}_t &= \tanh(W_{xc}x_t + W_{hc}h_{t-1} + b_c) \quad &\text{(candidate)} \\
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \quad &\text{(cell state)} \\
h_t &= o_t \odot \tanh(c_t) \quad &\text{(hidden state)} \\
\end{aligned}
$$



In [None]:
import torch
import torch.nn as nn

# ------------------------
# Single-layer LSTM
# ------------------------
lstm_single = nn.LSTM(input_size=5, hidden_size=8, num_layers=1, batch_first=True)

# Example input: batch=2, seq_len=4, feature_dim=5
x = torch.randn(2, 4, 5)

# Initial hidden state (h0) and cell state (c0): [num_layers, batch, hidden_size]
h0_single = torch.zeros(1, 2, 8)
c0_single = torch.zeros(1, 2, 8)

# Forward pass
output_single, (hn_single, cn_single) = lstm_single(x, (h0_single, c0_single))

print("=== Single-layer LSTM ===")
print("Input shape:", x.shape)              # [2, 4, 5]
print("Output shape:", output_single.shape) # [2, 4, 8]
print("hn shape:", hn_single.shape)         # [1, 2, 8]
print("cn shape:", cn_single.shape)         # [1, 2, 8]


# ------------------------
# Deep LSTM (3 layers)
# ------------------------
lstm_deep = nn.LSTM(input_size=5, hidden_size=8, num_layers=3, batch_first=True)

# Initial hidden & cell states: [num_layers, batch, hidden_size]
h0_deep = torch.zeros(3, 2, 8)
c0_deep = torch.zeros(3, 2, 8)

# Forward pass
output_deep, (hn_deep, cn_deep) = lstm_deep(x, (h0_deep, c0_deep))

print("\n=== Deep LSTM (3 layers) ===")
print("Input shape:", x.shape)               # [2, 4, 5]
print("Output shape:", output_deep.shape)    # [2, 4, 8]
print("hn shape:", hn_deep.shape)            # [3, 2, 8] (one hidden state per layer)
print("cn shape:", cn_deep.shape)            # [3, 2, 8] (one cell state per layer)


![GRU](https://miro.medium.com/1*DwL2ygleKXtRbYeVi8Qb_g.png)

## Gated Recurrent Unit (GRU)

For input $x_t$ and previous hidden state $h_{t-1}$:

$$
\begin{aligned}
z_t &= \sigma(W_z x_t + U_z h_{t-1} + b_z) &\quad& \text{(update gate)} \\
r_t &= \sigma(W_r x_t + U_r h_{t-1} + b_r) &\quad& \text{(reset gate)} \\
\tilde{h}_t &= \tanh(W_h x_t + U_h (r_t \odot h_{t-1}) + b_h) &\quad& \text{(candidate state)} \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t &\quad& \text{(new hidden state)} \\
\end{aligned}
$$

---

### Intuition
- **Update gate $z_t$:** how much of the old state to keep.  
- **Reset gate $r_t$:** how much past information to forget.  
- **Candidate $\tilde{h}_t$:** proposed new state (mix of input + reset history).  
- **Final hidden state $h_t$:** interpolation between old state and candidate.  



In [None]:
import torch
import torch.nn as nn

# ------------------------
# Single-layer GRU
# ------------------------
gru = nn.GRU(input_size=5, hidden_size=8, num_layers=1, batch_first=True)

# Example input: batch=2, seq_len=4, feature_dim=5
x = torch.randn(2, 4, 5)

# Initial hidden state: [num_layers, batch, hidden_size]
h0 = torch.zeros(1, 2, 8)

# Forward pass
output, hn = gru(x, h0)

print("Input shape:", x.shape)        # [2, 4, 5]
print("Output shape:", output.shape)  # [2, 4, 8] (hidden states for all timesteps)
print("hn shape:", hn.shape)          # [1, 2, 8] (last hidden state for each sequence)

# Optional: map hidden states to logits (e.g., for vocab size = 10)
head = nn.Linear(8, 10)
logits = head(output)  # [2, 4, 10]
print("Logits shape:", logits.shape)

# ------------------------
# (Optional) Deep GRU: just change num_layers
# ------------------------
# gru_deep = nn.GRU(input_size=5, hidden_size=8, num_layers=3, batch_first=True)
# h0_deep = torch.zeros(3, 2, 8)
# output_deep, hn_deep = gru_deep(x, h0_deep)
# print("Deep GRU hn shape:", hn_deep.shape)  # [3, 2, 8]


## Masking-Based Autoregressive Models

**Key property:**  
- All conditionals computed **in parallel**.  
- Enforced by **causal masking**.  

$$
P(x_1, \ldots, x_T) = \prod_{t=1}^T P(x_t \mid x_{<t})
$$

**Examples**
- Masked MLP (MADE)  
- Masked convolutions (PixelCNN)  
- Masked self-attention (Transformers, next lecture)  

**Benefits**
- Parallelizable training  
- Still autoregressive  
- Parameter sharing across time  


# Masked Autoregressive Models

We want to model the joint distribution of a sequence or vector:

$$
P(x_1, x_2, \dots, x_D) = \prod_{i=1}^D P(x_i \mid x_{<i})
$$

- Each conditional $P(x_i \mid x_{<i})$ should depend **only** on earlier variables.
- This is the **autoregressive property**.

---

## Problem
- A standard feedforward neural network (MLP, CNN, Transformer) connects **all inputs to all outputs**.
- Without constraints, the output for $x_i$ might depend on "future" inputs (like $x_j$ with $j > i$).
- That breaks the autoregressive property.

---



## Solution: Masking
- Introduce a **binary mask** (matrix of 0s and 1s) applied to the network weights.
- The mask zeroes out connections that would allow information flow from "future" variables.

**Rule:**
$$
W_{jk} = 0 \quad \text{if input index } k \geq \text{ output index } j
$$

- Each output neuron is only connected to valid past inputs.

---
## Example 1: MADE (Masked Autoencoder for Distribution Estimation)

Suppose we want to model a simple 3D vector:

$$
x = (x_1, x_2, x_3)
$$

with autoregressive factorization:

$$
P(x) = P(x_1) \cdot P(x_2 \mid x_1) \cdot P(x_3 \mid x_1, x_2).
$$

---

### Without a mask
- A normal MLP would connect **every input to every output**.
- That means the output for $x_2$ could depend on $x_3$, which **breaks the autoregressive rule**.

---

### With a mask
- We add a binary mask to the MLP weights so that:
  - **Output for $x_1$**: no inputs allowed → only a bias term.
  - **Output for $x_2$**: can only depend on $x_1$.
  - **Output for $x_3$**: can depend on both $x_1$ and $x_2$.

**Masked connections table:**

| Output | Visible inputs | Mask vector |
|--------|----------------|-------------|
| o1     | —              | (0, 0, 0)   |
| o2     | x1             | (1, 0, 0)   |
| o3     | x1, x2         | (1, 1, 0)   |

---

### Intuition
- Think of the mask as a **set of rules** saying:
  - "Who can look at whom?"
- MADE enforces that each output **only looks left** (earlier variables).


✅ In practice: one forward pass of MADE gives you **all conditionals in parallel**, respecting autoregressive order.

---


## Why the First Row of the Mask is All Zeros in MADE

### 1. What the rows mean
In a **masked linear layer**:

- **Columns** = inputs (e.g., $x_1, x_2, x_3$)  
- **Rows** = outputs (e.g., logits for $o_1, o_2, o_3$)  

So, each row of the mask tells us:  
> Which inputs this output is allowed to see.

---

### 2. Why row 1 (for $o_1$) is all zeros
- Output $o_1$ corresponds to $P(x_1)$.  
- By definition of autoregressive factorization:

$$
P(x_1) \quad \text{has no conditioning variables.}
$$

- This means it **cannot depend** on:
  - $x_1$ itself (future),  
  - or $x_2, x_3$ (future).  

✅ Therefore, row 1 of the mask must be **all zeros**.  
The only thing that influences $o_1$ is the **bias term**, which acts like a learnable unconditional prior.

---

### 3. Contrast with other rows
- **Row 2 (output $o_2$):**  
  Models $P(x_2 \mid x_1)$.  
  - Allowed to see $x_1$.  
  - Not allowed to see $x_2, x_3$.  

- **Row 3 (output $o_3$):**  
  Models $P(x_3 \mid x_1, x_2)$.  
  - Allowed to see $x_1, x_2$.  
  - Not allowed to see $x_3$.  

---

### 4. Big picture
- Row 1 (all zeros): "first variable has no context."  
- Row 2 (partially filled): "second variable conditioned on earlier ones."  
- Row 3 (more filled): "third variable conditioned on first two."  

This is how MADE enforces **causality** with a feedforward mask.

---

**Summary:**  
The first row of the mask is all zeros because $P(x_1)$ is unconditional.  
The model should only learn its marginal distribution, not cheat by looking at the inputs themselves or any future variables.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Custom masked linear layer
class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, mask):
        super().__init__(in_features, out_features, bias=True)
        self.register_buffer("mask", mask)

    def forward(self, x):
        return F.linear(x, self.weight * self.mask, self.bias)

# Suppose input vector is (x1, x2, x3)
D = 3
# Create a mask so that o1 depends on nothing, o2 depends only on x1, o3 depends on x1,x2
mask = torch.tensor([[0,0,0],   # o1 ← no inputs
                     [1,0,0],   # o2 ← x1
                     [1,1,0]], dtype=torch.float)

layer = MaskedLinear(D, D, mask)

# Example batch of inputs (2 samples, each of dimension 3)
x = torch.tensor([[1., 2., 3.],
                  [0.5, -1., 2.]])

output = layer(x)

print("Input:\n", x)
print("Masked Weight Matrix:\n", layer.weight * layer.mask)
print("Output:\n", output)

## Refresh: Convolution

**Operation:**  
- Slide kernel across input.  
- Compute weighted sum at each location.  

![CNN](https://miro.medium.com/v2/resize:fit:1400/0*LeK_gmCf3DfO3gj_.jpeg)


![Convolution](https://viso.ai/wp-content/uploads/2024/04/Illustrating-the-first-5-steps-of-convolution-operation-1.jpg)

![2D Convolution](https://miro.medium.com/v2/resize:fit:1400/0*H_6KDnWyFj_JDstS)
$$
y[i] = \sum_k w[k] \cdot x[i+k]
$$

**Key property:** Translation equivariance  
- Input shift → Output shift  
- Same filter applies everywhere  

**Why useful?**  
- Strong inductive bias for natural signals  
- Local patterns + weight sharing + efficiency  

## PixelCNN: Convolution with Causal Masks

![Masked CNN](https://velog.velcdn.com/images/2712qwer/post/b776ad05-fdd8-43b1-b6d5-a7fd60dc09ee/image.png)

**Goal.** Model the joint distribution of an image $x \in \mathbb{R}^{H \times W \times C}$ as an autoregressive factorization in **raster order** (left→right, top→bottom):

$$
P(x) \;=\; \prod_{i=1}^{H}\prod_{j=1}^{W}\prod_{c=1}^{C}
P\!\big(x_{i,j,c}\,\big|\,x_{<i,j,:},\,x_{i,j,<c}\big).
$$

- At pixel $(i,j)$ and channel $c$, the model may only “see” previously generated pixels ($x_{<i,j,:}$) and **earlier channels of the same pixel** ($x_{i,j,<c}$).

---

### Causal masking for convolutions

A standard $k\times k$ convolution at $(i,j)$ would read from a $(k\times k)$ patch centered at $(i,j)$, which includes **future** pixels.  
**PixelCNN** enforces causality by **masking** the kernel weights (elementwise multiply with a 0/1 matrix) so the conv never reads from future positions.

For a $3\times 3$ kernel (raster order):

- **Mask A (first conv layer):** forbid the center as well as the future.





![Masked CNN](https://velog.velcdn.com/images/2712qwer/post/b776ad05-fdd8-43b1-b6d5-a7fd60dc09ee/image.png)

## PixelCNN: Convolution with Causal Masks

**Goal.** Model the joint distribution of an image $x \in \mathbb{R}^{H \times W \times C}$ as an autoregressive factorization in **raster order** (left→right, top→bottom):

$$
P(x) \;=\; \prod_{i=1}^{H}\prod_{j=1}^{W}\prod_{c=1}^{C}
P\!\big(x_{i,j,c}\,\big|\,x_{<i,j,:}\,x_{i,j,<c}\big).
$$

- At pixel $(i,j)$ and channel $c$, the model may only “see” previously generated pixels ($x_{<i,j,:}$) and **earlier channels of the same pixel** ($x_{i,j,<c}$).

* * *

### Causal masking for convolutions

A standard $k\times k$ convolution at $(i,j)$ would read from a $(k\times k)$ patch centered at $(i,j)$, which includes **future** pixels.\
**PixelCNN** enforces causality by **masking** the kernel weights (elementwise multiply with a 0/1 matrix) so the conv never reads from future positions.

For a $3\times 3$ kernel (raster order):

- **Mask A (first conv layer):** forbid the center as well as the future.

$$
\begin{bmatrix}
1 & 1 & 1 \\
1 & 0 & 0 \\
0 & 0 & 0
\end{bmatrix}
$$

- **Mask B (subsequent conv layers):** allow the center (information from the current location that is already causal through previous layers), still forbid the future.

$$
\begin{bmatrix}
1 & 1 & 1 \\
1 & 1 & 0 \\
0 & 0 & 0
\end{bmatrix}
$$


- “1” = allowed connection, “0” = masked (blocked).

**Why two masks?**  
The first layer must not look at the current pixel’s raw value; deeper layers can read the current-location **features** as long as those features are themselves built from causal context.

---

### Multi-channel (RGB) masking

Within a pixel, channels are ordered (e.g., R → G → B). The conditional becomes:

$$
P(x) \;=\; \prod_{i,j} P(x_{i,j,R}\mid x_{<i,j,:})\;
P(x_{i,j,G}\mid x_{<i,j,:},x_{i,j,R})\;
P(x_{i,j,B}\mid x_{<i,j,:},x_{i,j,R},x_{i,j,G}).
$$

Practically, the mask is extended so that for channel $c$ at $(i,j)$ the conv can only see:
- all channels at **past** pixels, and
- **earlier** channels at the **same** pixel.

---

### Training objective (discretized pixels)

For 8-bit images, each channel has 256 classes. The model outputs a categorical distribution (or a mixture of logistics as in PixelCNN++).  
Training is **fully parallel** (teacher forcing): feed the whole image, predict all conditionals at once, minimize cross-entropy:

$$
{L} \;=\; - \sum_{i,j,c} \log P_\theta\!\big(x_{i,j,c}\,\big|\,x_{<i,j,:},x_{i,j,<c}\big).
$$

---

### Generation (sampling) procedure

1. For $i=1..H$:  
2. &nbsp;&nbsp;for $j=1..W$:  
3. &nbsp;&nbsp;&nbsp;&nbsp;for $c=1..C$:  
4. &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;sample $x_{i,j,c} \sim P_\theta(\cdot \mid x_{<i,j,:},x_{i,j,<c})$  

This is **sequential** over pixels/channels (slow at sample time), but **parallel** during training.

---

### Architectural notes

- Deep stacks of **masked** $3\times 3$ convolutions (often with residual connections) grow the **receptive field**, letting each conditional see a larger causal neighborhood.
- Variants:
  - **Gated PixelCNN**: replaces ReLU with gated activations for better modeling.
  - **PixelCNN++**: discretized logistic mixture likelihood, downsampling, and other refinements.
- Common speedups:
  - Caching intermediate features during generation.
  - Coarse-to-fine pyramids or parallelizable approximations.

---



