## 0) Problem Setup and Symbols

We build a compact REINFORCE toy example with the same core notation as in `README.md`.

Main symbols:
- state embedding: $x$
- policy logits: $z$
- action probabilities: $\pi = \mathrm{softmax}(z)$
- sampled action: $a$
- log-probability: $\log \pi_\theta(a\mid s)$
- rewards: $r_t, r_{t+1}, \dots$
- discounted return: $G_t$

Discounted return definition:
$$
G_t = \sum_{k=0}^{T-t-1} \gamma^k r_{t+k+1}.
$$
For the two-step demonstration used below:
$$
G_t = r_t + \gamma r_{t+1}.
$$

REINFORCE objective form (from README):
$$
J(\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_t G_t \log \pi_\theta(a_t\mid s_t)\right]
$$
Gradient estimator:
$$
\nabla_\theta J(\theta) \approx \sum_t G_t \, \nabla_\theta \log \pi_\theta(a_t\mid s_t).
$$

The next code cell defines constants for our 2x2 toy state/action encoding.


In [59]:
import numpy as np

# 2x2 setup
N_FACES = 6
STICKERS_PER_FACE = 4
STATE_SIZE = N_FACES * STICKERS_PER_FACE  # 24
ACTION_DIM = 12

ACTION_NAMES = [
    'U+', 'U-', 'D+', 'D-', 'L+', 'L-', 'R+', 'R-', 'F+', 'F-', 'B+', 'B-'
]

## 1) Build 2x2 Observation and Action One-Hot Vectors

This cell constructs the RL input in the same style as the project pipeline:
- solved 2x2 color state (length 24),
- state one-hot matrix of shape $(24, 6)$,
- zero action-history vector of shape $(4\cdot 12)=48$,
- final observation vector $\mathrm{obs}$ of length $24\cdot 6 + 48 = 192$.

It also samples one action index and creates its one-hot representation.

This is only a deterministic encoding step; no optimization yet.


In [60]:
# Solved cube state for 2x2 (flat color IDs, length 24)
state = np.repeat(np.arange(N_FACES, dtype=np.int8), STICKERS_PER_FACE)
print('Solved state (6x4):')
print(state.reshape(N_FACES, STICKERS_PER_FACE))

# Sample one random action (0..11)
rng = np.random.default_rng()
action = int(rng.integers(0, ACTION_DIM))
print('\nSampled action:', action, ACTION_NAMES[action])

# One-hot state: (24, 6)
state_one_hot = np.zeros((STATE_SIZE, N_FACES), dtype=np.int8)
state_one_hot[np.arange(STATE_SIZE), state.astype(np.int64)] = 1

# History one-hot for last 4 actions: start with zeros -> shape (48,)
action_history_one_hot = np.zeros((4 * ACTION_DIM,), dtype=np.int8)

# Observation vector: state_one_hot.flatten + history_one_hot -> shape (192,)
obs = np.concatenate([state_one_hot.reshape(-1), action_history_one_hot], axis=0)

# Current action one-hot: shape (12,)
action_one_hot = np.zeros((ACTION_DIM,), dtype=np.int8)
action_one_hot[action] = 1

print('\nState one-hot shape:', state_one_hot.shape)
print(state_one_hot)
print('\nAction history one-hot shape:', action_history_one_hot.shape)
print(action_history_one_hot)
print('\nOBS shape:', obs.shape)
print(obs)
print('\nAction one-hot shape:', action_one_hot.shape)
print(action_one_hot)


Solved state (6x4):
[[0 0 0 0]
 [1 1 1 1]
 [2 2 2 2]
 [3 3 3 3]
 [4 4 4 4]
 [5 5 5 5]]

Sampled action: 8 F+

State one-hot shape: (24, 6)
[[1 0 0 0 0 0]
 [1 0 0 0 0 0]
 [1 0 0 0 0 0]
 [1 0 0 0 0 0]
 [0 1 0 0 0 0]
 [0 1 0 0 0 0]
 [0 1 0 0 0 0]
 [0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 0 1 0 0 0]
 [0 0 1 0 0 0]
 [0 0 1 0 0 0]
 [0 0 0 1 0 0]
 [0 0 0 1 0 0]
 [0 0 0 1 0 0]
 [0 0 0 1 0 0]
 [0 0 0 0 1 0]
 [0 0 0 0 1 0]
 [0 0 0 0 1 0]
 [0 0 0 0 1 0]
 [0 0 0 0 0 1]
 [0 0 0 0 0 1]
 [0 0 0 0 0 1]
 [0 0 0 0 0 1]]

Action history one-hot shape: (48,)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]

OBS shape: (192,)
[1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0
 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0
 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0
 0 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

## 2) Torch Forward Pass, Action Sampling, Environment Step, Reward, and Discounted Return

We now perform the stochastic policy step in PyTorch:
$$
h_{\text{pre}} = xW_1 + b_1, \quad h = \mathrm{ELU}(h_{\text{pre}}), \quad z = hW_2 + b_2
$$
$$
\pi = \mathrm{softmax}(z).
$$
Sampling and log-probability are computed with:
```python
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
```

Then we apply one toy cube step, compute immediate reward $r_t$, then one more step for $r_{t+1}$, and build discounted return:
$$
G_t = r_t + \gamma r_{t+1}.
$$

We also store initial weights/logits/log-prob for exact Torch-vs-NumPy comparison later.


In [61]:
# Step 1: Torch forward + sample action, apply one cube step, compute discounted return
import numpy as np
import torch
import torch.nn as nn

torch.manual_seed(123)  # fixed seed for reproducibility
rng = np.random.default_rng(123)

in_dim = 5
hidden_dim = 4
action_dim = 12
lr = 1e-2
gamma = 0.95

# Same init for torch and numpy (numpy -> torch copy)
x_np = rng.normal(size=(in_dim,)).astype(np.float64)
W1_init_np = rng.normal(scale=0.1, size=(in_dim, hidden_dim)).astype(np.float64)
b1_init_np = rng.normal(scale=0.01, size=(hidden_dim,)).astype(np.float64)
W2_init_np = rng.normal(scale=0.1, size=(hidden_dim, action_dim)).astype(np.float64)
b2_init_np = rng.normal(scale=0.01, size=(action_dim,)).astype(np.float64)

dtype = torch.float64
linear1 = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
elu_layer = nn.ELU(alpha=1.0)
linear2 = nn.Linear(hidden_dim, action_dim, bias=True, dtype=dtype)

with torch.no_grad():
    linear1.weight.copy_(torch.from_numpy(W1_init_np.T))
    linear1.bias.copy_(torch.from_numpy(b1_init_np))
    linear2.weight.copy_(torch.from_numpy(W2_init_np.T))
    linear2.bias.copy_(torch.from_numpy(b2_init_np))

# Forward + sampling
x_t = torch.from_numpy(x_np).to(dtype=dtype).unsqueeze(0)
h_pre_t = linear1(x_t)
h_t = elu_layer(h_pre_t)
logits_t = linear2(h_t).squeeze(0)
dist = torch.distributions.Categorical(logits=logits_t)
action_t = dist.sample()
log_prob_t = dist.log_prob(action_t)
sampled_action = int(action_t.item())

# Save reference tensors/values BEFORE update
logits_before_torch_np = logits_t.detach().cpu().numpy().copy()
log_prob_before_torch = float(log_prob_t.item())

W1_before_torch_np = linear1.weight.detach().cpu().numpy().T.copy()
b1_before_torch_np = linear1.bias.detach().cpu().numpy().copy()
W2_before_torch_np = linear2.weight.detach().cpu().numpy().T.copy()
b2_before_torch_np = linear2.bias.detach().cpu().numpy().copy()

# Apply one action to the 2x2 state we created earlier (toy transition for notebook)
state_after_step = np.roll(state, sampled_action + 1)
is_solved_after_step = bool(np.array_equal(state_after_step, state))
reward_t = -1.0 if not is_solved_after_step else 0.0

# One extra lookahead step for discounted return demo
state_after_step_2 = np.roll(state_after_step, sampled_action + 1)
is_solved_after_step_2 = bool(np.array_equal(state_after_step_2, state))
reward_t_plus_1 = -1.0 if not is_solved_after_step_2 else 0.0

# Discounted return used by REINFORCE weight
G_t = reward_t + gamma * reward_t_plus_1

print('sampled_action =', sampled_action, ACTION_NAMES[sampled_action])
print('log_prob_t(action) =', log_prob_before_torch)
print('is_solved_after_step =', is_solved_after_step)
print('reward_t =', reward_t)
print('reward_t_plus_1 =', reward_t_plus_1)
print('gamma =', gamma)
print('G_t = reward_t + gamma * reward_t_plus_1 =', G_t)


sampled_action = 4 L+
log_prob_t(action) = -2.514844358797807
is_solved_after_step = False
reward_t = -1.0
reward_t_plus_1 = -1.0
gamma = 0.95
G_t = reward_t + gamma * reward_t_plus_1 = -1.95


## 3) Torch REINFORCE Update Step

For one sampled transition, we optimize:
$$
\mathcal{L} = -G_t \, \log \pi_\theta(a\mid s),
$$
where $G_t$ already includes discount factor $\gamma$.

With SGD step:
$$
\theta \leftarrow \theta - \eta \, \nabla_\theta \mathcal{L}
= \theta + \eta \, G_t \, \nabla_\theta \log \pi_\theta(a\mid s).
$$
So this matches the REINFORCE direction exactly.

This cell performs the torch backward pass and one optimizer step, then stores final torch weights.


In [62]:
# Step 2: Torch REINFORCE-style gradient step with discounted return G_t
import torch

# loss = - G_t * log_prob(a|s), where G_t includes gamma
optimizer = torch.optim.SGD(list(linear1.parameters()) + list(linear2.parameters()), lr=lr)
optimizer.zero_grad()
loss_t = -(G_t * log_prob_t)
loss_t.backward()
optimizer.step()

W1_after_torch_np = linear1.weight.detach().cpu().numpy().T.copy()
b1_after_torch_np = linear1.bias.detach().cpu().numpy().copy()
W2_after_torch_np = linear2.weight.detach().cpu().numpy().T.copy()
b2_after_torch_np = linear2.bias.detach().cpu().numpy().copy()

print('Torch loss =', float(loss_t.item()))
print('Torch SGD step done with lr =', lr)
print('Discounted return G_t =', G_t)


Torch loss = -4.903946499655723
Torch SGD step done with lr = 0.01
Discounted return G_t = -1.95


## 4) Manual NumPy Gradient and One-Step Update (Same Action, Same Discounted Return)

Now we reproduce the same update in pure NumPy, using **the same** sampled action $a$ and the same discounted return $G_t=r_t+\gamma r_{t+1}$.

From the softmax-logprob derivative:
$$
\delta = \frac{\partial \log \pi_\theta(a\mid s)}{\partial z} = e_a - \pi
$$
Layer gradients:
$$
\frac{\partial \log \pi}{\partial W_2}=h^\top\delta, \qquad
\frac{\partial \log \pi}{\partial b_2}=\delta
$$
$$
g_h=W_2\delta, \qquad g_{\text{pre}}=g_h\odot \mathrm{ELU}'(h_{\text{pre}})
$$
$$
\frac{\partial \log \pi}{\partial W_1}=x^\top g_{\text{pre}}, \qquad
\frac{\partial \log \pi}{\partial b_1}=g_{\text{pre}}.
$$

Then we apply:
$$
\theta \leftarrow \theta + \eta \, G_t \, \nabla_\theta \log \pi_\theta(a\mid s),
$$
and compare final NumPy weights with final Torch weights element-wise.

If everything is consistent, differences should be near numerical precision.


In [63]:
# Step 3: NumPy mirror with the SAME action and discounted return, then compare final weights
import numpy as np

def linear(x, W, b):
    return x @ W + b

def elu(x, alpha=1.0):
    return np.where(x > 0.0, x, alpha * (np.exp(x) - 1.0))

def elu_prime(x, alpha=1.0):
    return np.where(x > 0.0, 1.0, alpha * np.exp(x))

def softmax(z):
    z_shifted = z - np.max(z)
    e = np.exp(z_shifted)
    return e / np.sum(e)

def one_hot(i, n):
    v = np.zeros((n,), dtype=np.float64)
    v[int(i)] = 1.0
    return v

# Start from exactly the same initial parameters
W1_np = W1_init_np.copy()
b1_np = b1_init_np.copy()
W2_np = W2_init_np.copy()
b2_np = b2_init_np.copy()

# Forward (NumPy)
h_pre_np = linear(x_np, W1_np, b1_np)
h_np = elu(h_pre_np)
logits_np = linear(h_np, W2_np, b2_np)
pi_np = softmax(logits_np)

# Same sampled action from torch
a = sampled_action
log_prob_np = float(np.log(pi_np[a]))

# grad log pi(a|s)
delta = one_hot(a, action_dim) - pi_np
dW2_logp = np.outer(h_np, delta)
db2_logp = delta.copy()
g_h = W2_np @ delta
g_pre = g_h * elu_prime(h_pre_np)
dW1_logp = np.outer(x_np, g_pre)
db1_logp = g_pre.copy()

# REINFORCE update: theta += lr * G_t * grad_log_prob, where G_t uses gamma
W2_np = W2_np + lr * G_t * dW2_logp
b2_np = b2_np + lr * G_t * db2_logp
W1_np = W1_np + lr * G_t * dW1_logp
b1_np = b1_np + lr * G_t * db1_logp

print('Pre-update comparison:')
print('max|logits_torch - logits_numpy| =', float(np.max(np.abs(logits_before_torch_np - logits_np))))
print('|log_prob_torch - log_prob_numpy| =', float(abs(log_prob_before_torch - log_prob_np)))

print('\nPost-update weight comparison:')
print('max|W1_torch - W1_numpy| =', float(np.max(np.abs(W1_after_torch_np - W1_np))))
print('max|b1_torch - b1_numpy| =', float(np.max(np.abs(b1_after_torch_np - b1_np))))
print('max|W2_torch - W2_numpy| =', float(np.max(np.abs(W2_after_torch_np - W2_np))))
print('max|b2_torch - b2_numpy| =', float(np.max(np.abs(b2_after_torch_np - b2_np))))

tol = 1e-10
print('\nAll close?')
print('W1 close:', np.allclose(W1_after_torch_np, W1_np, atol=tol, rtol=0.0))
print('b1 close:', np.allclose(b1_after_torch_np, b1_np, atol=tol, rtol=0.0))
print('W2 close:', np.allclose(W2_after_torch_np, W2_np, atol=tol, rtol=0.0))
print('b2 close:', np.allclose(b2_after_torch_np, b2_np, atol=tol, rtol=0.0))


Pre-update comparison:
max|logits_torch - logits_numpy| = 2.0816681711721685e-17
|log_prob_torch - log_prob_numpy| = 4.440892098500626e-16

Post-update weight comparison:
max|W1_torch - W1_numpy| = 1.3877787807814457e-17
max|b1_torch - b1_numpy| = 1.734723475976807e-18
max|W2_torch - W2_numpy| = 4.336808689942018e-19
max|b2_torch - b2_numpy| = 3.469446951953614e-18

All close?
W1 close: True
b1 close: True
W2 close: True
b2 close: True


## 5) Batch REINFORCE Example (Exactly as in `trainer.py`)

Here we reproduce the trainer batch logic with **two parallel trajectories**:
- Env 0: terminates after **1** step
- Env 1: terminates after **2** steps

We use tensors with shape `[T, B]` and the same formulas as in the trainer:

- discounted returns (backward over time):
$$
R_t = r_t + \gamma R_{t+1}
$$
- masked policy loss:
$$
\mathcal{L} = -\frac{\sum_{t,b} \log \pi(a_{t,b}|s_{t,b})\,R_{t,b}\,m_{t,b}}{\sum_{t,b} m_{t,b}}
$$
where `m_{t,b} = active_mask[t,b]` and inactive timesteps do not contribute.


In [64]:
# Step 4: Minimal torch batch demo with masks and loss (same as trainer.py)
import torch

# Two trajectories in a batch: B=2, max horizon T=2
# Env0 length=1, Env1 length=2
T, B = 2, 2
gamma_batch = 0.9

# reward_total[t, b]
# t=0: both envs active and take a step -> -1 each
# t=1: env0 already done (inactive, reward 0), env1 still active -> -1
rewards_total = torch.tensor([
    [-1.0, -1.0],
    [ 0.0, -1.0],
], dtype=torch.float32)

# active_mask[t, b] exactly as trainer uses it
# Env0: [1, 0], Env1: [1, 1]
active_mask = torch.tensor([
    [1.0, 1.0],
    [0.0, 1.0],
], dtype=torch.float32)

# Example log-probabilities for sampled actions at each [t, b]
# Value at inactive step (t=1,b=0) is arbitrary and must be ignored by mask.
log_probs = torch.tensor([
    [-0.20, -0.50],
    [-9.99, -0.70],
], dtype=torch.float32)

# 1) Discounted returns per env over time (same backward loop as trainer.py)
returns = torch.zeros_like(rewards_total)
running = torch.zeros((B,), dtype=torch.float32)
for t in range(T - 1, -1, -1):
    running = rewards_total[t] + gamma_batch * running
    returns[t] = running

# 2) Masked loss (exact trainer formula)
valid_count = torch.clamp(active_mask.sum(), min=1.0)
weighted_terms = log_probs * returns * active_mask
loss = -(weighted_terms.sum() / valid_count)

print('gamma =', gamma_batch)
print('rewards_total [T,B]:\n', rewards_total)
print('active_mask [T,B]:\n', active_mask)
print('log_probs [T,B]:\n', log_probs)
print('returns [T,B]:\n', returns)
print('weighted_terms = log_probs * returns * active_mask:\n', weighted_terms)
print('valid_count =', float(valid_count.item()))
print('loss =', float(loss.item()))

# Sanity checks by trajectory
print('\nPer-env active lengths:')
print('env0 active steps =', int(active_mask[:, 0].sum().item()))
print('env1 active steps =', int(active_mask[:, 1].sum().item()))

print('\nContributions used in loss:')
for t in range(T):
    for b in range(B):
        if active_mask[t, b] > 0:
            print(
                f't={t}, env={b}, logp={log_probs[t,b].item():.3f}, '
                f'return={returns[t,b].item():.3f}, product={weighted_terms[t,b].item():.3f}'
            )


gamma = 0.9
rewards_total [T,B]:
 tensor([[-1., -1.],
        [ 0., -1.]])
active_mask [T,B]:
 tensor([[1., 1.],
        [0., 1.]])
log_probs [T,B]:
 tensor([[-0.2000, -0.5000],
        [-9.9900, -0.7000]])
returns [T,B]:
 tensor([[-1.0000, -1.9000],
        [ 0.0000, -1.0000]])
weighted_terms = log_probs * returns * active_mask:
 tensor([[0.2000, 0.9500],
        [-0.0000, 0.7000]])
valid_count = 3.0
loss = -0.6166666150093079

Per-env active lengths:
env0 active steps = 1
env1 active steps = 2

Contributions used in loss:
t=0, env=0, logp=-0.200, return=-1.000, product=0.200
t=0, env=1, logp=-0.500, return=-1.900, product=0.950
t=1, env=1, logp=-0.700, return=-1.000, product=0.700
