# compute (maybe remove or add details)

## Introduction

In DESC, quantities of interest (such as force error `F`, magnetic field strength `|B|` and its vector components)   are computed from `Equilibrium` objects through the use of `compute_XXX` functions. The `Equilibrium` object has a `eq.compute(name="NAME OF QUANTITY TO COMPUTE",grid=Grid)` method which takes in the string `name` of the quantity to compute from the `Equilibrium`, and a `Grid` of coordinates $(\rho,\theta,\zeta)$ to evaluate the quantity at. This method then uses the `name` to find which `compute_XXX` function is needed to call in order to calculate that quantity.

These `compute_XXX` functions live inside of the `desc/compute` folder, and inside that folder the `data_index.py` file contains the list of all available quantities that `name` could specify to be computed, as well as information on that quantity and which `compute_XXX` function should be called in order to compute it.

The `compute_XXX` functions have function signatures that take in as arguments the necessary variables from `{R_lmn, Z_lmn, L_lmn, i_l, c_l, p_l,Psi}` required to calculate the quantity contained in a dict argument named `params`, as well as the `Transform` objects (in the `transforms` dict argument) and/or `Profile` objects (in the `profiles` dict argument) needed to evaluate those variables in `params` (which are spectral coefficients, except for `Psi`) at the points in real space specified by the desired `Grid` object (which is contained in the `Transform` objects passed).


An example compute function from `_field.py` is shown here for computing the magnetic field magnitude:
```python
def compute_magnetic_field_magnitude(
    params,
    transforms,
    profiles,
    data=None,
    **kwargs,
):
```
Every compute function in DESC has the same function signature:

 - `params` is a dict of the basic parameters needed to compute data, i.e. `{R_lmn, Z_lmn, L_lmn, i_l, c_l p_l,Psi}`
   - The possible params are: 
     - `R_lmn Z_lmn, L_lmn` the Fourier-Zernike spectral coeffiiencts describing the toroidal coordinates R and Z of the flux surfaces and the poloidal stream function $\lambda$ (`L_lmn`).
     - `i_l, c_l, p_l` the parameters (either spectral coefficients for a `PowerSeriesProfile` or spline values for a `SplineProfile` ) for the profiles of rotational transform (`i_l`), net enclosed toroidal current (`c_l`) and pressure (`p_l`). Note that only one of `i_l,c_l` are needed, and if both are passed the rotational transform `i_l` will be used.
     - `Psi` is the total enclosed toroidal flux, a scalar, in Wb.
 - `transforms` is a dict of the transforms (`Transform` objects) needed to transform the spectral coefficients from `params` to their values in real space.
 - `profiles` is a dict of the profiles (`Profile` objects) needed to evaluate the radial profiles of pressure, rotational transform and net enclosed toroidal current
 - `data` argument is an optional dictionary. The compute functions store the quantities they calculate in a `data` dictionary and return it, and if `data` is passed in, the quantities this function computes will be added to this dictionary. This way, a compute function can add on to the quantities already calculated by previous compute functions, or it can use other compute functions to calculate preliminary quantities to avoid duplicating code.

The first 3 arguments `R_lmn,Z_lmn,L_lmn` are the Fourier-Zernike spectral coefficients of the toroidal coordinates of the flux surfaces `R,Z`, and of the poloidal stream function $\lambda$. `i_l` is the coefficients of the rotational transform profile (typically a power series in `rho`). `Psi` is the enclosed toroidal flux.

The `_transform` arguments are `Transform` objects which transform the spectral coefficients to their values in real-space. `iota` is a `Profile` object which transforms `i_l` into its values in real-space.
As an example, in the `compute_magnetic_field_magnitude` function, the contravariant components of the magnetic field $B^i$ are required, along with the metric coefficients $g_ij$. To calculate these, the function calls:

```python
    data = compute_contravariant_magnetic_field(
        params,
        transforms,
        profiles,
        data=data,
        **kwargs,
    )
    data = compute_covariant_metric_coefficients(
        params,
        transforms,
        profiles,
        data=data,
        **kwargs,
    )
```

in order to populate `data` with these necessary preliminary quantities.


- maybe include example of how to make your own (let's say for a simple thing like B_theta * B_zeta)





## Calculating Quantities

Inside the compute function, every quantity is stored inside the `data` dictionary under the key of the name of the quantity. 
As an example, `data['|B|']` contains the magnetic field magnitude.
This quantity is stored as a JAX array of size `(num_nodes,)` (if the quantity is NOT a vector), or `(num_nodes,3)` (if the quantity is a vector, i.e. `data['B']`, which contains all three components $[B_R, B_{\phi},B_{Z}]$ of $B$ at each node) (`num_nodes` is the number of nodes in the `Grid` object that the quantity was computed on. Can be accessed by `grid.num_nodes`). 
This array is flattened, so if the grid used has 10 equispaced gridpoints in $(\rho,\theta,\zeta)$, the grid will have $10^3 = 1000$ nodes, and any quantity calculated on that grid will be returned as an array of size `(num_nodes,)` if not a vector and `(num_nodes,3)` if a vector quantity.
### Scalar Algebra
Storing the quantities in arrays like this enables for easy computation using these quantities. For example, if one wants to calculate the magnitude of the pressure gradient $|\nabla p(\rho)| = \sqrt{p'(\rho)^2}|\nabla\rho|$, one simply [writes out the expression](https://github.com/PlasmaControl/DESC/blob/6d03cb015701b27d651bf804d36032c35119c536/desc/compute/_equil.py#L114) after calling the necessary compute functions:

```python
data["|grad(p)|"] = jnp.sqrt(data["p_r"] ** 2) * data["|grad(rho)|"]
```

### Vector Algebra
If calculating a quantity which involves vector algebra, the format of these arrays makes it simple to write out as well. As an example, if calculating the contravariant radial basis vector $\mathbf{e}^{\rho} = \frac{\mathbf{e}^{\theta} \times \mathbf{e}^{\zeta}}{\sqrt{g}}$, one [writes](https://github.com/PlasmaControl/DESC/blob/6d03cb015701b27d651bf804d36032c35119c536/desc/compute/_core.py#L426):
```python
data["e^rho"] = (cross(data["e_theta"], data["e_zeta"]).T / data["sqrt(g)"]).T
```
Note here that once the quantities are crossed, they are transposed. This is done to ensure that the result retains the desired shape of `(num_nodes,3)`.

### Be Mindful of Shapes

It is important to keep in mind the shapes of the quantities being manipulated to ensure the desired operation is carried out. As another example, the gradient of the magnetic toroidal flux $\nabla \psi = \frac{d\psi}{d\rho}\nabla \rho$ is [calculated as](https://github.com/PlasmaControl/DESC/blob/94d7e43542613b1c901fcd655502312f3e567c26/desc/compute/_core.py#L701):
```python
data["grad(psi)"] = (data["psi_r"] * data["e^rho"].T).T
```
The desired operation here is to multiply `data["psi_r"]`, which is a scalar quantity at each grid point and so is of shape `(num_nodes,)` with `data["e^rho"]`, a vector quantity and so is of shape `(num_nodes,3)`. 
We want the result to be of shape `(num_nodes,3)`. 
In order to do so, we first must transpose the vector quantity to be shape `(3,num_nodes)`, so that when multiplied together with the scalar quantity of shape `(num_nodes,3)`, the result is broadcast to an array of shape `(3,num_nodes)`. 
If the transpose did not occur, the two shapes `(num_nodes,)` and `(num_nodes,3)` would be incompatible with eachother.
The second transpose after the multiplication is to ensure that the result is in the shape `(num_nodes,3)`, as is the convention expected in the code.

## What `check_derivs()` does

Basically, this function ensures that the transforms passed to the compute function have the necessary derivatives of $R,Z,L,p,\iota$ to calculate the quantity contained in the if statement. 
If yes, it returns `True` and the quantity in the logival is computed. If not, it returns `False` and that quantitiy is not calculated. 
This allows us to call a function to get a specific quantity which may not need high order derivatives, and avoid needing to compute those derivatives anyways just to have the function call not throw an error that the necessary derivatives do not exist for a quantitiy we are not asking for but which needs higher order derivatives to compute

## `__init__.py`

`arg_order` is defined in this file. This tuple is used in other parts of the code to determine how to parse the state vector `x` into the various arguments that make it up, and also for making the derivatives of functions of these arguments, such as inside of `_set_derivatives` method of `_Objective` in `objective_funs.py`.

why does arg_order exist again? It is so we can check if things have the necessary arguments?

we need canonical ordering of things so when we combine all the args into x and all the constraints into A everything lines up correctly. We also use it in some places for a shorthand of all the args that could be used by any objective, but i think in those cases we only ever need to know about args that are taken by the objectives at hand, so we could just use that

## `compute/utils.py`

 - dot
 - cross
 custom vector algebra fxns