(linalg_tutorial)=
# Intro to the linear algebra module
Most of the linear algebra module are wrappers with very few lines and an API nearly equal to their numpy counterpart. In general, the only thing you need to do is pass the input DataArray and indicate which dimensions
correspond to the matrices. There are only a couple exceptions which have their own section.

In [1]:
from xarray_einstats import linalg, tutorial

We start by generating syntetic data to work with:

In [2]:
da = tutorial.generate_matrices_dataarray(7)
da

The data represents a collection of matrices. `dim` and `dim2` indicate the matrix dimensions, the whole array is 4d, with 30 matrices in total from 10 batches and 3 experiments. 

(linalg_tutorial/general)=
## General linalg functions
You can get the trace of all 30 matrices in a single line, you only need the input DataArray and the dimensions corresponding to the matrices:

In [3]:
linalg.trace(da, dims=["dim", "dim2"])

The main feature of the wrappers is that they know what is the expected shape of the output, you don't need to take care of it. See how the inverse which doesn't reduce the matrix dimension can be called with the exact same arguments.

In [4]:
linalg.inv(da, dims=["dim", "dim2"])

Even a qr decomposition which returns multiple matrices (which could even have different shapes) needs only these two arguments to work. (batched qr decomposition requires numpy>=1.22)

In [5]:
q, r = linalg.qr(da, dims=["dim", "dim2"])

In [6]:
q

In [7]:
r

:::{tip}
Do you always follow the same convention to name your matrix dimensions and feel that even having to repeat that is
unnecessary? Take a look at {func}`xarray_einstats.linalg.get_default_dims` to see how to modify the default dims used by the linalg wrappers
:::

(linalg_tutorial/matmul)=
## matmul: 1st exception
The general representation of a matrix multiplication is:

$$
\mathcal{M}_1^{N\times K} * \mathcal{M}_2^{K\times M} = \mathcal{M}^{N\times M}
$$ (eq:matmul)

There are conceptually 3 dimensions involved in the operation because the 2nd dimension of $\mathcal{M}_1$
needs to be the same as the 1st dimension of $\mathcal{M}_2$. Moreover, when working with square matrices, $N==M==K$
and there is only 1 dimension.

When working with xarray however, there can't be repeated dimension names, so as we have already seen, conceptually equivalent dimensions will have potentially different names, i.e. `dim` and `dim2`.

Taking all of this into account, `matmul`'s `dims` argument supports indicating the dimensions in 3 different ways. The following table summarizes the inputs `dims` accepts and how they are interpreted:

| `dims` |  dim_a1 | dim_a2 | dim_b1 | dim_b2 |
|--------|---------|--------|--------|--------|
| `[dim1, dim2]`| dim1 | dim2 | dim1 | dim2 |
| `[dim1, dim2, dim3]` | dim1 | dim2 | dim2 | dim3 |
| `[[dim_a1, dim_a2], [dim_b1, dim_b2]]` | dim_a1 | dim_a2 | dim_b1 | dim_b2 |

where `dim_a1, dim_a2` are the matrix dimensions of the first matrix, and `dim_b#` are the matrix dimensions
of the 2nd matrix. Like in {eq}`eq:matmul`, **the dimensions present in the output are `dim_a1, dim_b2`.**

### List of two elements

This first example uses square matrices, so when doing a matrix multiplication, the two dimensions are common in both inputs. You only need a list with two strings to indicate how to perform the multiplication:

In [8]:
linalg.matmul(da, da, dims=["dim", "dim2"])

### List of three elements

However, the input matrices for matrix multiplication might not be square or might not have the exact same dimension names. As we have seen, what is necessary if for the 2nd dimension of the 1st matrix to match with the 1st dimension of the 2nd matrix. This 3 element list of dimensions is arguable the most common way to specify matrix multiplications.

You could interpret the DataArray as a collection of matrices of dimension `batch, experiment`, or with `experiment, dim2` indicating the matrices. Those two collections of matrices are valid inputs for matrix multiplication. 

As there is still one that need to match, `matmul` can also take a list of 3 dimensions:

In [9]:
linalg.matmul(da, da, dims=["batch", "experiment", "dim2"], out_append="_bis")

Here, `batch` and `dim2` were matrix dimensions in one of the matrices and batch dimensions in the other. While this
might not be very common, `xarray-einstats` check for dimensions that would end up being duplicated in the output and renames them if necessary using `out_append` to avoid collisions.

A similar thing happens when both dim1 and dim3 have the same name:

In [10]:
linalg.matmul(da, da, dims=["batch", "experiment", "batch"])

### List of 2 element lists
The 3rd option is the more verbose and explicit, but still necessary to avoid the need for manual renamings before being able to multiply some matrices. 

To see how it works, you'll need a `db` object, with the same shape but different dimension names:

In [11]:
db = da.rename(dim="different_dim", dim2="different_dim2")
db

Now `da` and `db` are compatible and you might want to multiply them, after all, it's the same operation we did in the first `matmul` example (you can check the result if running the notebook). But given the name mismatch it wasn't possible to use the first nor second option:

In [12]:
linalg.matmul(da, db, dims=[["dim", "dim2"], ["different_dim", "different_dim2"]])

Whenever the dimension being multiplied/reduced doesn't have the same name in both matrices, you'll need to use this 2+2 dims specification. Like in the list of 3 elements case, `matmul` avoids name clashes:

In [13]:
dc = da.rename(batch="batch_bis")
linalg.matmul(da, dc, dims=[["experiment", "batch"], ["batch_bis", "experiment"]])

(linalg_tutorial/einsum)=
## einsum: 2nd and most notable exception
`einsum` is a such a flexible function that it can even be intimidating. It can cover from `sum` operations, to {func}`xarray.dot` reductions and obviously some operations similar to `einops` which after all is inspired in einsum. 

The goal of this page is not to be an extensive nor in depth guide on einsum but to act as a small ladder from simple operations that you can do without einsum until reaching operations that are only possible with einsum. This will give you a good look into `xarray_einstats` unique version of `einsum` that works with named dimensions, you'll see how most einsum operations translate to our syntax. 

If you want to master einsum however, we direct you to {func}`numpy.einsum` documentation and the [einops](https://einops.rocks/) package. To ease a little bit your ability to follow the tutorial without needing to understand einsum beforehand, we provide the equivalent in non-einsum functions (which is often multiple operations) inside of toggle-able note boxes. But keep in mind that the goal of this section is not teaching how to use `einsum` but showing how to use
`xarray_einstats` to perform einsum operations with named dimension names.

In [14]:
from xarray_einstats import raw_einsum, einsum
import xarray as xr

Start reducing the `experiment` dimension. Any ellipsis, broadcasting and transposition is handled by xarray and xarray-einstats. You only need to care about the dimensions you want to operate on. Use `[]` to indicate you want to reduce the dimension (or `->` in `raw_` syntax):

In [15]:
einsum([["experiment"], []], da)
raw_einsum("experiment->", da)

The same can be dome with multiple dimensions.

In [16]:
einsum([["batch", "experiment"], []], da)
raw_einsum("batch experiment->", da)

:::{note}
:class: dropdown

These two calls are respectively equivalent to 

```
da.sum("experiment")
da.sum(("batch", "experiment"))
```
:::

`einsum` also takes multiple outputs. In those cases, if there are repeated dimensions in the expressions
corresponding to different inputs and we want to reduce all of them, the output expression can be omitted, just like
you'd do with `numpy.einsum`.

In [17]:
einsum([["experiment"], ["experiment"]], da, da)
raw_einsum("experiment,experiment", da, da)

:::{note}
:class: dropdown

This call combines a product and a summation, and has two equivalents. One using `xarray.dot` (also quite einsum-like), another in simple mathematical operations:

```
xr.dot(da, da, dims"experiment")
(da * da).sum("experiment")
```
:::

When there are no repeated indexes between inputs, then the results of _implicit_ and _explicit_ mode are different, again, just like in `numpy.einsum`. After all, `xarray_einstats.einsum` is an interface to it that uses
dimension names and needs no ellipsis.

**Implicit mode:**

In [18]:
einsum([["experiment"], ["batch"]], da, da)
raw_einsum("experiment,batch", da, da)

This call no longer has a single operation equivalent. Here we are performing multiple summations
and multiplications. And also reordering the dimensions. The first time they are encountered, dimensions are mapped to a single letter (the input accepted by einsum) in reverse alphabetical order. After that, the saved mapping is used. Therefore, following the `xarray.apply_ufunc` convention, the default order of the dimensions is the following:
1. All the ommitted dimensions _in the order they appear in the inputs_
2. All dimensions present in the expressions _in the **inverse** order they appear in the expression_ for the first time

Thus, the output has `dim` and `dim2` first as they are not present in the expressions, then comes `batch` and finally `experiment` in the exact inverted order they appear in the expression.

:::{note}
:class: dropdown

Even though this computation no longer has a single function/method equivalent, it does have a multiple
operation equivalent:

```python
(da.sum("experiment") * da.sum("batch")).transpose(..., "batch", "experiment")
```
:::

**Explicit mode:**

In [19]:
einsum([["experiment"], ["batch"], []], da, da)
raw_einsum("experiment,batch->", da, da)

:::{note}
:class: dropdown

Which again has no single operation equivalent but a multiple operation one:

```python
(da.sum("experiment") * da.sum("batch")).sum(("batch", "experiment"))
```
:::

**Relation to {func}`xarray.dot`**

`xarray.dot` is also a wrapper on `numpy.einsum`, but it takes a single list of dimensions to operate on. This means that none of the two computations above can be reproduced with it. See for yourself:

In [20]:
xr.dot(da, da, dims=["experiment", "batch"])

`xarray.dot` is operating on both dimensions for both outputs. Therefore, to reproduce its results we need to tell `einsum` to operate on both dimensions for both inputs. This is similar to what we did a couple of examples back but now with two dimensions.

In [21]:
einsum([["batch", "experiment"], ["batch", "experiment"], []], da, da)
raw_einsum("batch experiment,batch experiment->", da, da)

:::{note}
:class: dropdown

Similarly to the example using `dot` before, the equivalent here is a product followed by a sum over the two provided axis.

```python
(da * da).sum(("batch", "experiment"))
```
:::

**`keep_dims` argument**

`einsum` also has an argument to indicate dimensions that are present in multiple inputs but should be "kept". That is, instead of treating the dimension as the same for all inputs, its occurrence in multiple inputs should be preserved, as if they were actually different dimensions. `einsum` with then rename the repeated dimension names
using the `out_append` argument.

Back to using our DataArray as a collection of matrices, we might want to do a matrix multiplication. That would mean reducing the `dim2` dimensions and _keeping_ both `dim` to get the resulting collection of matrices. You can see how the default `einsum` behaviour only keeps one occurence of `dim`:

In [22]:
einsum([["dim2"], ["dim2"]], da, da)
raw_einsum("dim2,dim2", da, da)

We need to use the `keep_dims` argument to keep the `dim` dimension of the first DataArray as `dim` and `dim` of the 2nd DataArray as the _new_ `dim2`:

In [23]:
einsum([["dim2"], ["dim2"]], da, da, keep_dims={"dim"}, out_append="_auto{i}")
raw_einsum("dim2,dim2", da, da, keep_dims={"dim"}, out_append="_auto{i}")

Note that here we had `dim, dim2` and `dim, dim2` and we have reduced `dim2` in both arguments. Therefore, we haven't done the matrix multiplication between the two arguments, but the matrix multiplication between the first argument and the _transpose_ of the second. For seamless matrix multiplication, use {func}`xarray_einstats.matmul`.

:::{important}
`xarray_einstats.einsum` does not support combining dimensions of different names.
:::

The `keep_dims` argument can also be used to perform outer products. In a pure outer product, we don't want to reduce any dimension, so we give empty lists as input "expression" and we pass the dimension we want to perform the outer product on as `keep_dims`:

In [24]:
einsum([[], []], da, da, keep_dims={"batch"})

In [25]:
%load_ext watermark
%watermark -n -u -v -iv -w -p numpy

Last updated: Fri Jan 20 2023

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.8.0

numpy: 1.24.0

xarray_einstats: 0.5.1
xarray         : 2022.12.0

Watermark: 2.3.1

