# Working with JAX numpy and calculating perplexity

Normally we would import `numpy` and name it as the alis `np`. 

However in this notebook we will notice that this convention has been changed. 

Now standard `numpy` is not aliased and `trax.fastmath.numpy` is renamed as `np`. 

The rationale behind this change is that we will be using Trax's numpy (which is compatible with JAX) far more often. Trax's numpy supports most of the same functions as the regular numpy so the change won't be noticeable in most cases.

In [5]:
%pip install trax

Collecting trax
  Using cached trax-1.4.1-py2.py3-none-any.whl (637 kB)
Collecting absl-py (from trax)
  Obtaining dependency information for absl-py from https://files.pythonhosted.org/packages/01/e4/dc0a1dcc4e74e08d7abedab278c795eef54a224363bb18f5692f416d834f/absl_py-2.0.0-py3-none-any.whl.metadata
  Using cached absl_py-2.0.0-py3-none-any.whl.metadata (2.3 kB)
Collecting gym (from trax)
  Using cached gym-0.26.2-py3-none-any.whl
Collecting jax (from trax)
  Obtaining dependency information for jax from https://files.pythonhosted.org/packages/b5/5b/5131520dd9a384a640399e5efe4324fdee9e8a48685a33d08eb47140ccc3/jax-0.4.18-py3-none-any.whl.metadata
  Using cached jax-0.4.18-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib (from trax)
  Obtaining dependency information for jaxlib from https://files.pythonhosted.org/packages/4d/af/22bf25b1b9c56a774d34eeac8f6d70c2e5d0a9d8b33b374e39517f830902/jaxlib-0.4.18-cp311-cp311-win_amd64.whl.metadata
  Using cached jaxlib-0.4.18-cp311-cp311-win_amd

ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device



In [6]:
import numpy
import trax
import trax.fastmath.numpy as np

# Setting random seeds
numpy.random.seed(32)

ModuleNotFoundError: No module named 'trax'

One important change to take into consideration is that the types of the resulting objects will be different depending on the version of numpy. With regular old numpy we get `numpy.ndarray` as our multi-dimensional data structure (array) but with Trax's numpy we will get `jax.interpreters.xla.DeviceArray`. These two types map to each other. 

In [None]:
numpy_array = numpy.random.random((5,10))
print(f"The regular numpy array looks like this:\n\n {numpy_array}\n")
print(f"It is of type: {type(numpy_array)}")

The regular numpy array looks like this:

 [[0.72158098 0.36299476 0.15039771 0.89004238 0.71484224 0.65245173
  0.59168053 0.63502934 0.37356814 0.60504975]
 [0.24328694 0.23866972 0.93230853 0.26613939 0.86000716 0.76622879
  0.50854193 0.61018048 0.94483917 0.12428304]
 [0.43527633 0.18559947 0.83212291 0.98959454 0.86460191 0.74405856
  0.72858069 0.38019823 0.43452783 0.65735066]
 [0.47901676 0.42314845 0.2657922  0.69784179 0.58958402 0.72223054
  0.35940943 0.10315196 0.24230629 0.69583213]
 [0.69357502 0.3075907  0.1184685  0.31449128 0.37929997 0.35752695
  0.34253852 0.09281963 0.16274971 0.49312257]]

It is of type: <class 'numpy.ndarray'>


We can easily cast regular numpy arrays or lists into trax numpy arrays using the `trax.fastmath.numpy.array()` function:

In [None]:
trax_numpy_array = np.array(numpy_array)
print(f"The trax numpy array looks like this:\n\n {trax_numpy_array}\n")
print(f"It is of type: {type(trax_numpy_array)}")

NameError: name 'np' is not defined

Now we will see how to calculate the perplexity of a trained model.


## Calculating Perplexity

The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: 

$$P(W) = \sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{i-1})}}$$

As an implementation hack, we would usually take the log of that formula (so the computation is less prone to underflow problems). 
We would also need to take care of the padding, since we do not want to include the padding when calculating the perplexity (to avoid an artificially good metric).

After taking the logarithm of $P(W)$ we have:

$$log P(W) = {\log\left(\sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{i-1})}}\right)}$$


$$ = \log\left(\left(\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{i-1})}\right)^{\frac{1}{N}}\right)$$

$$ = \log\left(\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{i-1})}}\right)^{-\frac{1}{N}}\right)$$

$$ = -\frac{1}{N}{\log\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{i-1})}}\right)} $$

$$ = -\frac{1}{N}{{\sum_{i=1}^{N}{\log P(w_i| w_1,...,w_{i-1})}}} $$


Now we will work with an example is made up of:
   - `predictions` : log probabilities for each element in the vocabulary for 32 sequences with 64 elements (after padding).
   - `targets` : 32 observed sequences of 64 elements (after padding).

In [None]:
from trax import layers as tl

# Load from .npy files
predictions = numpy.load('predictions.npy')
targets = numpy.load('targets.npy')

# Cast to jax.interpreters.xla.DeviceArray
predictions = np.array(predictions)
targets = np.array(targets)

# Print shapes
print(f'predictions has shape: {predictions.shape}')
print(f'targets has shape: {targets.shape}')

>The predictions have an extra dimension with the same length as the size of the vocabulary used.

- Because of this we will reshape `targets` to match this shape. For this you can use `trax.layers.one_hot()`.

In [None]:
#trax's one_hot function takes the input as one_hot(x, n_categories, dtype=optional)
reshaped_targets = tl.one_hot(targets, predictions.shape[-1])
print(f'reshaped_targets has shape: {reshaped_targets.shape}')

By calculating the product of the predictions and the reshaped targets and summing across the last dimension, the total log propbability of each observed element within the sequences can be computed:

In [None]:
log_p = np.sum(predictions * reshaped_targets, axis= -1)

Now we will need to account for the padding so this metric is not artificially deflated (since a lower perplexity means a better model). For identifying which elements are padding and which are not, you can use `np.equal()` and get a tensor with `1s` in the positions of actual values and `0s` where there are paddings.

In [None]:
non_pad = 1.0 - np.equal(targets, 0)
print(f'non_pad has shape: {non_pad.shape}\n')
print(f'non_pad looks like this: \n\n {non_pad}')

By computing the product of the log probabilities and the non_pad tensor we remove the effect of padding on the metric:

In [None]:
real_log_p = log_p * non_pad
print(f'real log probabilities still have shape: {real_log_p.shape}')

We can now check the effect of filtering out the padding by looking at the two log probabilities tensors:

In [None]:
print(f'log probabilities before filtering padding: \n\n {log_p}\n')
print(f'log probabilities after filtering padding: \n\n {real_log_p}')

Finally, to get the average log perplexity of the model across all sequences in the batch, we sum the log probabilities in each sequence and divide by the number of non padding elements (which will give us the negative log perplexity per sequence).

After that, we get the mean of the log perplexity across all sequences in the batch.

In [None]:
log_ppx = np.sum(real_log_p, axis=1) / np.sum(non_pad, axis=1)
log_ppx = np.mean(-log_ppx)
print(f'The log perplexity and perplexity of the model are respectively: {log_ppx} and {np.exp(log_ppx)}')