In [1]:
import torch as t
import torch.nn.functional as F
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, random
from jax.scipy.special import logsumexp

In [2]:
key = random.PRNGKey(0)
rng = np.random.default_rng()



# A different way of calculating softmax

Standard way of calculating softmax is 

$$
p_k = \frac{e^{h_k}}{\sum e^{h_k}} \\
l_k = log(p_k) \\
$$

Where $\mathbf h$ is the vector of logits.

In [3]:
logits = np.array([1, 2, 3], dtype=np.float32)

In [4]:
p = np.exp(logits) / np.sum(np.exp(logits))
l = np.log(p)
print(f"probs = {p}")
print(f"logprobs = {l}")

probs = [0.09003058 0.24472846 0.66524094]
logprobs = [-2.407606   -1.407606   -0.40760598]


Lets verify this with using PyTorch's softmax to ensure that I have not made any errors in my manual calculation.

In [5]:
t.log(F.softmax(t.from_numpy(logits), dim=0))

tensor([-2.4076, -1.4076, -0.4076])

The other way of calculating softmax is - 

$$
\begin{align}
l_k &= log(p_k) \\
&= log\left(\frac{e^{h_k}}{\sum e^{h_k}}\right) \\
&= log\left(e^{h_k}\right) - log\left(\sum e^{h_k}\right) \\
&= h_k - log \left( \sum e^{h_k} \right) \\
\end{align}
$$

In [6]:
logits - np.log(np.sum(np.exp(logits)))

array([-2.407606 , -1.4076059, -0.4076059], dtype=float32)

## Define the net weights

In [7]:
w_key, b_key = random.split(key)
scale = 0.1
n_targets = 10

In [8]:
# W \in in_features \times out_features

w1 = scale * random.normal(w_key, (784, 512))
b1 = scale * random.normal(b_key, (512,))

w2 = scale * random.normal(w_key, (512, 256))
b2 = scale * random.normal(b_key, (256,))

w3 = scale * random.normal(w_key, (256, 10))
b3 = scale * random.normal(b_key, (10,))

## Define the forward function

In [9]:
def relu(x):
    return jnp.maximum(0, x)

In [10]:
x = random.normal(key, (28*28,))

In [11]:
z1 = jnp.dot(w1.T, x) + b1
a1 = relu(z1)
a1.shape

(512,)

In [12]:
z2 = jnp.dot(w2.T, a1) + b2
a2 = relu(z2)
a2.shape

(256,)

In [13]:
logits = jnp.dot(w3.T, a2) + b3
logits.shape

(10,)

In [14]:
logprobs = logits - logsumexp(logits)
logprobs

DeviceArray([-5.9046612e+00, -8.3773727e+00, -1.5425657e+01,
             -8.2383070e+00, -1.9056673e+01, -1.2999908e+01,
             -5.3692966e+00, -9.7851343e+00, -8.2426071e-03,
             -8.2100334e+00], dtype=float32)

In [15]:
def forward(x):
    z1 = jnp.dot(w1.T, x) + b1
    a1 = relu(z1)
    
    z2 = jnp.dot(w2.T, a1) + b2
    a2 = relu(z2)
    
    logits = jnp.dot(w3.T, a2) + b3
    logprobs = logits - logsumexp(logits)
    return logprobs

In [16]:
img = random.normal(key, (28*28,))
y_hat = forward(img)
y_hat

DeviceArray([-5.9046612e+00, -8.3773727e+00, -1.5425657e+01,
             -8.2383070e+00, -1.9056673e+01, -1.2999908e+01,
             -5.3692966e+00, -9.7851343e+00, -8.2426071e-03,
             -8.2100334e+00], dtype=float32)

In [17]:
imgs = random.normal(key, (3, 28*28))
imgs.shape

(3, 784)

In [19]:
try:
    forward(imgs)
except TypeError as te:
    print(te)

Incompatible shapes for dot: got (512, 784) and (3, 784).


In [20]:
forward_batch = vmap(forward, in_axes=(0,))

In [21]:
y_hat = forward_batch(imgs)
y_hat.shape

(3, 10)

## Define loss function

In [25]:
y = jnp.array(rng.integers(low=0, high=10, size=3))
y

DeviceArray([0, 2, 6], dtype=int32)

In [27]:
jnp.arange(10)

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [28]:
y[:, None] == jnp.arange(10)

DeviceArray([[ True, False, False, False, False, False, False, False,
              False, False],
             [False, False,  True, False, False, False, False, False,
              False, False],
             [False, False, False, False, False, False,  True, False,
              False, False]], dtype=bool)

In [32]:
y = rng.integers(low=0, high=10, size=3).reshape(-1, 1)
y

array([[9],
       [4],
       [4]])

In [33]:
y == np.arange(10)

array([[False, False, False, False, False, False, False, False, False,
         True],
       [False, False, False, False,  True, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False]])

In [34]:
tp = np.array([
    np.arange(10),
    np.arange(10),
    np.arange(10)
])

In [35]:
tp

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [36]:
y == tp

array([[False, False, False, False, False, False, False, False, False,
         True],
       [False, False, False, False,  True, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False]])

In [37]:
tp == y

array([[False, False, False, False, False, False, False, False, False,
         True],
       [False, False, False, False,  True, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False]])