# RL for LLMs

## Background: REINFORCE

Loss $l(f_\theta(x)) = \left(10 - \lvert f_\theta(x)\rvert\right)^2$ is non-differentiable, $x$ is a sentence, $\theta$ are model weights. Objective:

$$\min_\theta l(f_\theta(x))$$

Now we seek another objective that's similar

\begin{align}
&\min_\mu\int l(f_\theta(x))\ p(\theta;\mu)\ d\theta\\
=&E_{p(\theta;\mu)} [f_\theta(x)]
\end{align}

$p(\theta;\mu)$ could be something like $\mathcal{N}(\mu,1)$

To optimize for $\mu$:

\begin{align}
&\nabla_\mu\int l(f_\theta(x))\ p(\theta;\mu)\ d\theta\\
=&\int \underbrace{l(f_\theta(x))}_\text{net non diff.}\ \nabla_\mu p(\theta;\mu)\ d\theta\\
\end{align}

---

\begin{align}
\nabla_x\ln(x) =& \frac{1}{x}\nabla_xx\\
\nabla_\mu\ln(p(\theta;\mu)) =& \frac{1}{p(\theta;\mu)}\nabla_\mu p(\theta;\mu)\quad\vert \cdot p(\theta;\mu)\\
\nabla_\mu\ln(p(\theta;\mu)) \cdot p(\theta;\mu) =& \nabla_\mu p(\theta;\mu)\\
\end{align}

---

\begin{align}
&\int l(f_\theta(x))\ \nabla_\mu p(\theta;\mu)\ d\theta\\
=&\int l(f_\theta(x))\ \nabla_\mu\ln(p(\theta;\mu)) \cdot p(\theta;\mu)\ d\theta\\
\approx&\frac{1}{n}\sum_i^n l(f_{\theta_i}(x))\ \nabla_\mu\ln(p(\theta_i;\mu))\quad\theta_i\sim p(\theta;\mu)
\end{align}

Does REINFORCE work for LLMs? Concern is that sampling from $p(\theta_i;\mu)$ may not work for a $\theta$ which is high dim (e.g., 70M or 175B in the real world).

## Background: PPO

Importance sampling: 

\begin{align}
&\int f(x)p(x)\ dx = E_p[f(x)]\\
=&\int f(x)p(x) \cdot 1\ dx\\
=&\int f(x)p(x) \cdot \frac{q(x)}{q(x)}\ dx\\
=&\int f(x)q(x) \cdot \frac{p(x)}{q(x)}\ dx = E_q\left[f(x)\frac{p(x)}{q(x)}\right]\\
\end{align}

Rewrite after looking into `grpo_loss` in [this file](https://github.com/aburkov/theLMbook/blob/main/GRPO.py).

| Symbol | Meaning |
| ---:| --- |
| $\phi$ | model weights |
| $x,y$ | model input (prompt) and sampled outputs |
| $r(x,y)$ | reward |
| $\pi_\phi^\text{RL}$ | policy (model) |

What we want to find / our objective:

\begin{align}
&\displaystyle \nabla_\phi E_{\pi_\phi^\text{RL}(x,y)}\left[r(x,y)\right]\\
=&\nabla_\phi\int r(x,y)\ \pi_\phi^\text{RL}(x,y)\ dxdy\\
=&\int r(x,y)\ \nabla_\phi \pi_\phi^\text{RL}(x,y)\ dxdy\\
=&\int r(x,y)\ \nabla_\phi \ln \pi_\phi^\text{RL}(x,y) \cdot \pi_\phi^\text{RL}(x,y)\ dxdy\\
\approx&\frac{1}{N}\sum_n^N r(x,y)\  \nabla_\phi \ln \pi_\phi^\text{RL}(x_n,y_n) & \text{with } (x_n,y_n) \sim \pi_\phi^\text{RL}(x,y)\\
\end{align}