In [56]:
import torch
import numpy as np
import math
from ssm.parallel_scan import parallel_scan_naive

We define the following matrices:

$$
\bar{A}_0 =
\begin{bmatrix}
1 & 0 & 0 \\
0 & 1 & 0 \\
0 & 0 & 1
\end{bmatrix}
\quad
\bar{B}_0 =
\begin{bmatrix}
6 \\
1 \\
2
\end{bmatrix}
\quad
C_0 =
\begin{bmatrix}
1 & 2 & 3
\end{bmatrix}
\quad
u_0 =
\begin{bmatrix}
5
\end{bmatrix}
\\[1em]
\bar{A}_1 =
\begin{bmatrix}
3 & 0 & 0 \\
0 & 1 & 0 \\
0 & 0 & 2
\end{bmatrix}
\quad
\bar{B}_1 =
\begin{bmatrix}
9 \\
8 \\
3
\end{bmatrix}
\quad
C_1 =
\begin{bmatrix}
4 & 5 & 7
\end{bmatrix}
\quad
u_1 =
\begin{bmatrix}
8
\end{bmatrix}
\\[1em]
\bar{A}_2 =
\begin{bmatrix}
5 & 0 & 0 \\
0 & 1 & 0 \\
0 & 0 & 1
\end{bmatrix}
\quad
\bar{B}_2 =
\begin{bmatrix}
3 \\
4 \\
6
\end{bmatrix}
\quad
C_2 =
\begin{bmatrix}
1 & 2 & 6
\end{bmatrix}
\quad
u_2 =
\begin{bmatrix}
3
\end{bmatrix}
$$


In [19]:
#torch
A_bar = torch.tensor([[3,1,2], [5,1,1]], dtype=torch.float32, requires_grad=True)
B_bar = torch.tensor([[6,1,2], [9,8,3], [3,4,6]], dtype=torch.float32, requires_grad=True)
C = torch.tensor([[1,2,3], [4,5,7], [1,2,6]], dtype=torch.float32, requires_grad=True)
u = torch.tensor([5,8,3], dtype=torch.float32, requires_grad=True)

#numpy
A_bar_np = np.array([[1,1,1],[3, 1, 2], [5, 1, 1]], dtype=np.float32)
B_bar_np = np.array([[6, 1, 2], [9, 8, 3], [3, 4, 6]], dtype=np.float32)
C_np = np.array([[1, 2, 3], [4, 5, 7], [1, 2, 6]], dtype=np.float32)
u_np = np.array([5, 8, 3], dtype=np.float32)


Let us manually calculate the forward pass! we start with the recurrence of the hidden states:
$$
\begin{align*}
  x_k = \begin{cases}
          \bar{B}_0 u_0 \quad & k=0 \\
          (\bar{A}_k x_{k-1}) + \bar{B}_k u_k \quad &0 < k \leq L
        \end{cases}
\end{align*}
$$

In [28]:
def x_k(A_bar, B_bar, u):
  x_states = []
  for k in range(len(u)):
    if k == 0:
      x_states.append(B_bar[k] * u[k])
    else:
      x_states.append((A_bar[k] * x_states[k-1]) + B_bar[k] * u[k])
  return x_states

hidden_states = x_k(A_bar_np, B_bar_np, u_np)
print(f"x_k = {hidden_states}")
print(f"y_k = {(C_np * hidden_states).sum(axis=1)}")

x_k = [array([30.,  5., 10.], dtype=float32), array([162.,  69.,  44.], dtype=float32), array([819.,  81.,  62.], dtype=float32)]
y_k = [  70. 1301. 1353.]


Let us now dive into the backward pass!
We first neeed to define a loss. Let us use an easy one such as
$$
\mathcal{L} (\{y_k\}_{\mathbb{N}_L}) = \sum_{k=0}^L y_k
$$
With this simple loss, we have $\frac{\partial \mathcal{L}}{\partial y_k} = 1$. And by extension the direct path from $x_k \rightarrow \mathcal{L}$ is $\frac{\partial \mathcal{L}}{\partial y_k}\frac{\partial y_k}{\partial x_k} = c_k$

We want to onbtain $\nabla_{x_k} \mathcal{L}$, $\nabla_{\bar{A}_k} \mathcal{L}$ and $\nabla_{\bar{B}_k} \mathcal{L}$. They are:
$$
\begin{align*}
  \nabla_{x_k} \mathcal{L} &=\frac{\partial \mathcal{L}}{\partial x_k} \bar{A}_{k+1} \nabla_{x_{k+1}} \mathcal{L}\\
  \nabla_{\bar{A}_k}  \mathcal{L} &= x_{k-1} \cdot \nabla_{x_k} \mathcal{L}\\
  \nabla_{\bar{B}_k}  \mathcal{L} &= u_{k} \nabla_{x_k} \mathcal{L}
\end{align*}
$$

In [53]:
def nabla_x_k(A_bar, dl_dx, x_k):
  grad_x = np.zeros_like(x_k)
  for k in reversed(range(len(dl_dx))):
    if k == len(dl_dx) - 1:
      grad_x[k] =dl_dx[k]
    else:
      grad_x[k] =A_bar[k+1] * grad_x[k+1] + dl_dx[k]
  return grad_x

def nabla_a_k(x_k, grad_x, A_bar):
  grad_A_bar = []
  for k in range(len(x_k)):
    if k == 0:
      grad_A_bar.append(np.zeros_like(A_bar[0]))
    else:
      grad_A_bar.append(x_k[k-1] * grad_x[k])
  return grad_A_bar

def nabla_b_k(u_k, grad_x):
  grad_B_bar = []
  for k in range(len(u_k)):
      grad_B_bar.append(u_k[k] * grad_x[k])
  return grad_B_bar

grad_x = nabla_x_k(A_bar_np, C_np, hidden_states)
grad_A_bar = nabla_a_k(hidden_states, grad_x, A_bar_np)
grad_B_bar = nabla_b_k(u_np, grad_x)

print(f"nabla_x L = \n{grad_x}")
print(f"nabla_a_bar L =")
for row in grad_A_bar:
    print("  ", row)
print(f"nabla_b_bar L =")
for row in grad_B_bar:
    print("  ", row)

nabla_x L = 
[[28.  9. 29.]
 [ 9.  7. 13.]
 [ 1.  2.  6.]]
nabla_a_bar L =
   [0. 0. 0.]
   [270.  35. 130.]
   [162. 138. 264.]
nabla_b_bar L =
   [140.  45. 145.]
   [ 72.  56. 104.]
   [ 3.  6. 18.]


Let us see if our implemented torch autograd function is correct

In [55]:
# Reset gradients before backward
A_bar.grad = None
B_bar.grad = None
u.grad = None
C.grad = None

result = parallel_scan_naive.apply(A_bar, B_bar, u, C)
loss = result.sum()
loss.backward()
print("Forward pass:")
print(result)
print("Gradient wrt A_bar:")
print(A_bar.grad)

print("Gradient wrt B_bar:")
print(B_bar.grad)


Forward pass:
tensor([  70., 1301., 1353.], grad_fn=<parallel_scan_naiveBackward>)
Gradient wrt A_bar:
tensor([[270.,  35., 130.],
        [162., 138., 264.]])
Gradient wrt B_bar:
tensor([[140.,  45., 145.],
        [ 72.,  56., 104.],
        [  3.,   6.,  18.]])
