In [None]:
import torch
from torch import nn

In [None]:
# Fixing nomenclature
nn.LogSoftArgMax = nn.LogSoftmax
torch.softargmax = torch.softmax

In [None]:
# Get a random input
torch.manual_seed(0)
x = torch.randn(1, 2)
print(f'{x = }')

In [None]:
# Generate random target
torch.manual_seed(2)
y = torch.zeros(5, dtype=torch.long)
c = torch.randint(5, (1,))
y[c] = 1
print(f'{c = }, {y = }')

Model definition:

$$\begin{eqnarray*}
h &=&
f(\boldsymbol{W_h x} + \boldsymbol{b_h}) \\
\boldsymbol{s} &=&
a(\boldsymbol{h}) = \boldsymbol{W_y h} + \boldsymbol{b_y} \\
\boldsymbol{o} &=&
g(\boldsymbol{s}) \\
f &=&
(\cdot)^+ \\
g &=&
\operatorname{logsoftargmax} \\
D(\boldsymbol{y}, \boldsymbol{o}) &=&
- \boldsymbol{y}^\top \boldsymbol{o}
\end{eqnarray*}$$

In [None]:
# Define model
torch.manual_seed(1)
predictor = nn.Sequential(
    nn.Linear(2, 7),
    nn.ReLU(),
)
a = nn.Linear(7, 5)
g = nn.LogSoftArgMax(dim=-1)

D = nn.NLLLoss()

In [None]:
# Generate intermediate and final output
h = predictor(x)
s = a(h)
o = g(s)

s.retain_grad()
o.retain_grad()

print(f'{s = },\n{s.retains_grad = },\n{s.grad = }\n')
print(f'{o = },\n{o.retains_grad = },\n{o.grad = }')

In [None]:
# Compute cost, energy, and loss
L = F = D(o, c)
print(f'{L = }')

In [None]:
# Run back-propagation & grad accumulation
L.backward()

In [None]:
# Show computed gradients
print(f'{o.grad = }\n{s.grad = }')

In [None]:
# Check for correctness
torch.softargmax(s.detach(), dim=-1) - y

What about the affine transformation?

$$
\begin{gather}
a: \mathbb{R}^d \to \mathbb{R}^K, \quad \boldsymbol{h} \mapsto a(\boldsymbol{h}) = \boldsymbol{s}\\
\boldsymbol{s} = \boldsymbol{W_y h} + \boldsymbol{b_y} =
\boldsymbol{w}_1 h_1 +
\boldsymbol{w}_2 h_2 + \cdots +
\boldsymbol{w}_d h_d + 
\boldsymbol{b_y} \\
\boldsymbol{W_y} =
[\boldsymbol{w}_1\; \boldsymbol{w}_2\; \cdots\; \boldsymbol{w}_d]
\in \mathbb{R}^{K \times d}, \quad \boldsymbol{b_y} \in \mathbb{R}^K\\
\Rightarrow
{\partial \mathcal{L} \over \partial \boldsymbol{b_y}} = \cdots, \quad
{\partial \mathcal{L} \over \partial \boldsymbol{W_y}} = \cdots
\end{gather}
$$

In [None]:
# Check gradBias
print(f'{a.bias.grad = }')

In [None]:
# Check sizes
print(f'''{h.size() = }\n{s.grad.size() = }
{a.weight.size() = }\n{a.weight.grad.size() = }''')

In [None]:
# Compute gradWeight by hand
s.grad.t() @ h.detach()

In [None]:
# Verify what PyTorch computed
print(f'a.weight.grad =\n{a.weight.grad}')