# Examples

In this example we will be looking at the basic usage of `jaxnnls` to solve a non-negative least squares (NNLS) problem.  

```{warning}
While the algorithm can sometimes work with Jax's default 32-bit precision, it is recommended that you enable 64-bit precision.  The Cholesky decompositions used can become unstable at lower precision and lead to `nan` results.
```

## Basic usage

To begin we will write a function that randomly generates a non-trivial NNLS system.

In [6]:
import jax

# enable 64 bit mode
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jaxnnls

# adjust the print options to make easier to read
jnp.set_printoptions(suppress=True)


def generate_random_qp(key, nx):
    # split the random key
    key_q, key_mask, key_x, key_z = jax.random.split(key, 4)
    # make a positive definite Q matrix
    Q = jax.random.normal(key_q, (nx, nx))
    Q = Q.T @ Q
    # make the primal and dual variables (all positive)
    x = jnp.abs(jax.random.normal(key_x, (nx,)))
    z = jnp.abs(jax.random.normal(key_z, (nx,)))
    # mask out 50% of the values to zero
    mask = jax.random.choice(key_mask, jnp.array([True, False]), (nx,))
    x = jnp.where(mask, x, 0)
    z = jnp.where(mask, 0, z)
    # make the "observed" vector that has x as it's NNLS solution
    q = Q @ x - z
    return Q, q, x, z

Now let's make a Jax random key and generates an example system with a 5x5 `Q` matrix.

In [2]:
_key = jax.random.key(0)
key, _key = jax.random.split(_key)

Q, q, x, z = generate_random_qp(key, 5)

Next let's find the unconstrained solution using `jnp.linalg.solve`

In [64]:
print(jnp.linalg.solve(Q, q))

[-0.10993943  0.01985061  1.00367504  0.12005477 -0.06829775]


We can clearly see that this leads to negative values in the solution.  Now let's take the same system but use the NNLS solver.  If you are only interested in the primal solution (e.g. `x`) we can use `jaxnnls.solve_nnls_primal`.  If you want both the primal and dual solutions (along with some extra diagnostic information) you should use `jaxnnls.solve_nnls`.

In [80]:
jit_solve_nnls_primal = jax.jit(jaxnnls.solve_nnls_primal)
x_solve = jit_solve_nnls_primal(Q, q)
print(x_solve)

[0.         0.         0.97650555 0.10741558 0.        ]


Now we can see the solution being found is all positive as desired.  We can also check this against the known solution.

In [81]:
print(jnp.allclose(x, x_solve))

True


## Solving a batch of problems

The solver is full compatible with `vmap` for solving a system of problems at the same time.  First we will generate set of random problems with known solutions.

In [15]:
key, _key = jax.random.split(_key)

Qs, qs, xs, zs = jax.vmap(generate_random_qp, in_axes=(0, None))(jax.random.split(key, 20), 5)

Now we will `jit` and `vmap` the solver and apply it to our set of problems.

In [67]:
batch_nnls = jax.jit(jax.vmap(jaxnnls.solve_nnls_primal, in_axes=(0, 0)))
batch_xs = batch_nnls(Qs, qs)
print(batch_xs)

[[0.         0.         0.38382659 0.15292546 0.        ]
 [1.92070267 1.5044806  0.72937339 0.1729202  0.10214948]
 [0.         0.         0.         0.25669574 1.20627346]
 [0.17402674 0.         0.20067439 0.         1.98034549]
 [0.         0.         0.         0.         0.        ]
 [0.38442547 0.         0.73927296 0.         0.14899361]
 [0.         0.         1.09022179 0.28157659 0.        ]
 [0.53804846 0.55245586 0.5648163  0.41864657 1.42616834]
 [0.2257546  0.         0.         0.24268287 1.05370187]
 [1.19462445 0.         0.         0.         1.31461407]
 [0.05825319 0.         0.         0.025179   0.        ]
 [0.99083519 0.21995431 0.21457618 0.         1.63248801]
 [2.42041109 0.         0.78802546 0.         0.        ]
 [0.97545337 0.78935263 1.41030076 0.9074219  0.        ]
 [1.1603355  0.         0.75510333 0.         0.56652302]
 [0.         0.         0.54660379 0.         0.        ]
 [0.         1.52725322 0.         0.         0.09822676]
 [0.92809165 0

We see that all the solutions are indeed position as expected.  Now let's check if they match the known solutions.

In [68]:
print(jnp.allclose(xs, batch_xs))

True


## Differentiating a NNLS

If we are only looking at the primal solution with `jaxnnls.solve_nnls_primal` we can use automatic differentiation.  For this example we will set up a simple loss function and calculated the gradients of that loss with respect to both `Q` and `q`.

In [82]:
def loss(Q, q, target_kappa=0):
    x = jaxnnls.solve_nnls_primal(Q, q, target_kappa=target_kappa)
    x_bar = jnp.ones_like(x)
    residual = x - x_bar
    return jnp.dot(residual, residual)


loss_and_grad = jax.jit(jax.value_and_grad(loss, argnums=(0, 1)))

l, (dl_dQ, dl_dq) = loss_and_grad(Q, q)

In [69]:
print(l)

3.797258930506071


In [70]:
print(dl_dQ)

[[0.00001912 0.00020742 0.00406618 0.00102431 0.00027023]
 [0.00020742 0.00110802 0.07179544 0.01015142 0.00138725]
 [0.00406618 0.07179544 0.19343259 0.24109751 0.09490007]
 [0.00102431 0.01015142 0.24109751 0.05406517 0.01317802]
 [0.00027023 0.00138725 0.09490007 0.01317802 0.00173121]]


In [71]:
print(dl_dq)

[-0.00780425 -0.14497006 -0.19736384 -0.46876957 -0.19184031]


In the above example we set `target_kappa=0`.  This means no smoothing will be applied to the gradients.  In general, when dealing with constrained solvers like this, the gradients can be discontinuous.  In this example we see that we only have non-zero gradient values for the elements of `x` that are non-zero when solved.

If we were aiming to minimize our loss using a gradient decent method, we would only be able to move a subset of our parameters at a time because of this.  By increasing the `target_kappa` value these discontinuities will be smoothed out, providing more useful information for gradient based optimizers.

In [52]:
l_kappa, (dl_dQ_kappa, dl_dq_kappa) = loss_and_grad(Q, q, target_kappa=1e-3)

In [72]:
print(l_kappa)

3.797258930506071


In [73]:
print(dl_dQ_kappa)

[[0.00001912 0.00020742 0.00406618 0.00102431 0.00027023]
 [0.00020742 0.00110802 0.07179544 0.01015142 0.00138725]
 [0.00406618 0.07179544 0.19343259 0.24109751 0.09490007]
 [0.00102431 0.01015142 0.24109751 0.05406517 0.01317802]
 [0.00027023 0.00138725 0.09490007 0.01317802 0.00173121]]


In [74]:
print(dl_dq_kappa)

[-0.00780425 -0.14497006 -0.19736384 -0.46876957 -0.19184031]


We can see that the loss value has not changed as the smoothing is only applied to the gradients.  As for the two gradients, all the values have become non-zero.  Now if gradient decent was applied **all** the value would move rather than just a subset of them.

For more information about the smoothing process please refer to the [qpax paper](https://arxiv.org/abs/2406.11749).

## Diagnostic information

In all the examples above we used `jaxnnls.solve_nnls_primal` as we were only interested in the primal solution.  If you want the dual solution or more diagnostic information the `jaxnnls.solve_nnls` function is available.

In [58]:
x, s, z, converged, number_iterations = jaxnnls.solve_nnls(Q, q)

The outputs are:
- `x`: the primal solution
- `s`: the slack variable (will be the same as `x` if the algorithm converged)
- `z`: the dual solution
- `converged`: flag that is `1` if the algorithm converged and `0` otherwise
- `number_iterations`: the number of steps the algorithm took to converged

```{note}
The code will run a maximum of 50 steps before stopping and reporting it did not converge.
```

```{note}
Automatic differentiation is only available for `jaxnnls.solve_nnls_primal` not this version of the function.
```

In [75]:
print(x)

[0.         0.         0.97650555 0.10741558 0.        ]


In [76]:
print(s)

[0.         0.         0.97650555 0.10741558 0.        ]


In [77]:
print(z)

[0.38611367 0.11502944 0.         0.         0.0997703 ]


In [78]:
print(converged)

1


In [79]:
print(number_iterations)

10
