In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import numpy as np
import dask.array as da
import matplotlib.pyplot as plt

from nata.containers import GridDataset
from nata.containers import GridArray
from nata.containers import Axis
from nata.plugins.register import register_container_plugin

# Slice

The grid `.slice()` plugin returns a lower dimensionality slice of the original `GridArray` or `GridDataset`.

In [22]:
grid = GridArray.from_array(
    np.arange(96).reshape((8, 4, 3))
)
grid


| **GridArray** | |
| ---: | :--- |
| **name**  | unnamed |
| **label** | unlabeled |
| **unit**  | '' |
| **shape** | (8, 4, 3) |
| **dtype** | int64 |
| **time**  | 0.0 |
| **axes**  | Axis(axis0), Axis(axis1), Axis(axis2) |



Both the axis name or its index can be used to select the slicing direction.

In [14]:
sliced_grid = grid.slice(constant="axis0", value=0)
sliced_grid.shape

(4, 3)

In [15]:
sliced_grid = grid.slice(constant=0, value=0)
sliced_grid.shape

(4, 3)

When using axis indices, negative values are also supported.

In [16]:
sliced_grid = grid.slice(constant=-1, value=0)
sliced_grid.shape

(8, 4)

When slicing `GridDataset`, time dependence is always preserved.

In [23]:
grid = GridDataset.from_array(
    np.arange(96).reshape((8, 4, 3))
)
grid


| **GridDataset** | |
| ---: | :--- |
| **name**  | unnamed |
| **label** | unlabeled |
| **unit**  | '' |
| **shape** | (8, 4, 3) |
| **dtype** | int64 |
| **time**  | [0 1 2 3 4 5 6 7] |
| **axes**  | Axis(time), Axis(axis0), Axis(axis1) |



In [25]:
sliced_grid = grid.slice(constant="axis0", value=0)
sliced_grid


| **GridDataset** | |
| ---: | :--- |
| **name**  | unnamed |
| **label** | unlabeled |
| **unit**  | '' |
| **shape** | (8, 3) |
| **dtype** | int64 |
| **time**  | [0. 0. 0. 0. 0. 0. 0. 0.] |
| **axes**  | Axis(time), Axis(axis1) |



Slices are not allowed along the time axis.

In [30]:
try:
    grid.slice(constant="time", value=0)
except ValueError:
    print("slice along the time axis `time` is not supported")

slice along the time axis `time` is not supported


Slices are allowed along moving axis as long as the value at which they are taken exists in the corresponding axis for all times.

In [31]:
x = np.arange(10)
grid = GridDataset.from_array(
    np.tile(x, (5, 1)), 
    axes=[
        Axis(np.arange(5), name="t"), 
        Axis([x, x+1, x+2, x+3, x+4], name="x1"),
    ]
)
grid.slice(constant="x1", value=5)


| **GridDataset** | |
| ---: | :--- |
| **name**  | unnamed |
| **label** | unlabeled |
| **unit**  | '' |
| **shape** | (5,) |
| **dtype** | int64 |
| **time**  | [0. 0. 0. 0. 0.] |
| **axes**  | Axis(time) |



# Fast Fourier Transform (FFT)

The grid `.fft()` plugin returns the FFT of the original `GridDataset`.

If the `comp` argument is not provided, `.fft()` returns the absolute value of the FFT.

In [None]:
x = np.linspace(0, 10*np.pi, 101)

grid = GridDataset.from_array(
    np.sin(x),
    name="dataset",
    unit="A",
    indexable_axes=[
        Axis(x, name="x1", label="x_1", unit="a")
    ]
)
fft_grid = grid.fft()
fft_grid

The axes over which the FFT is computed are transformed such thatthe zero frequency bins are centered.

In [None]:
fig, ax = plt.subplots()

ax.plot(fft_grid.axes[0].as_dask(), fft_grid.as_dask())
ax.set_xlim((-3, 3))
ax.set_xlabel(f"${fft_grid.axes[0].label}$ [${fft_grid.axes[0].unit}$]")
ax.set_ylabel(f"${fft_grid.label}$ [${fft_grid.unit}$]")
ax.axvline(+1, c="k", ls="--")
ax.axvline(-1, c="k", ls="--")
fig.show()

If `comp` is provided, `.fft()` returns the corresponding component of the FFT.

In [None]:
fft_grid = grid.fft(comp="imag")
fft_grid

In [None]:
fig, ax = plt.subplots()

ax.plot(fft_grid.axes[0].as_dask(), fft_grid.as_dask())
ax.set_xlim((-3, 3))
ax.set_xlabel(f"${fft_grid.axes[0].label}$ [${fft_grid.axes[0].unit}$]")
ax.set_ylabel(f"${fft_grid.label}$ [${fft_grid.unit}$]")
ax.axvline(+1, c="k", ls="--")
ax.axvline(-1, c="k", ls="--")
fig.show()

When applied to multidimensional grids, the `axes` argument controls the axes over which the FFT is computed. If `axes` is not provided, the FFT is computed over all available grid axes.

In [None]:
x = np.linspace(0, 10*np.pi, 101)
y = np.linspace(0, 10*np.pi, 201)

X,Y = np.meshgrid(x, y, indexing="ij")

grid = GridDataset.from_array(
    np.sin(X) + np.sin(2*Y),
    name="dataset",
    indexable_axes=[
        Axis(x, name="x1", label="x_1", unit="a"),
        Axis(y, name="x2", label="x_2", unit="b")
    ]
)
fft_grid = grid.fft()
fft_grid

In [None]:
fig, ax = plt.subplots()

im = ax.imshow(
    np.transpose(fft_grid.as_numpy()), 
    extent=(
        np.min(fft_grid.axes[0].as_dask()),
        np.max(fft_grid.axes[0].as_dask()),
        np.min(fft_grid.axes[1].as_dask()),
        np.max(fft_grid.axes[1].as_dask())
    ),
    origin="lower",
)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

ax.set_xlabel(f"${fft_grid.axes[0].label}$ [${fft_grid.axes[0].unit}$]")
ax.set_ylabel(f"${fft_grid.axes[1].label}$ [${fft_grid.axes[1].unit}$]")

fig.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(fft_grid.axes[1].as_dask(), fft_grid.as_dask()[50, :])
ax.set_xlim((-3, 3))
ax.axvline(+2, c="k", ls="--")
ax.axvline(-2, c="k", ls="--")
fig.show()

In [2]:
fig, ax = plt.subplots()

ax.plot(fft_grid.axes[0].as_dask(), fft_grid.as_dask()[:, 100])
ax.set_xlim((-3, 3))
ax.axvline(+1, c="k", ls="--")
ax.axvline(-1, c="k", ls="--")
fig.show()

NameError: name 'plt' is not defined

If `axes` is provided, the FFT is only computed over the identified grid axes. Both strings (corresponding to axes names) and integers can be provided in `axes`.

In [None]:
fft_grid = grid.fft(axes=["x2"])
fft_grid

In [3]:
fig, ax = plt.subplots()

im = ax.imshow(
    np.transpose(fft_grid.as_numpy()), 
    extent=(
        np.min(fft_grid.axes[0].as_dask()),
        np.max(fft_grid.axes[0].as_dask()),
        np.min(fft_grid.axes[1].as_dask()),
        np.max(fft_grid.axes[1].as_dask())
    ),
    origin="lower",
)
ax.set_ylim(-3, 3)

ax.set_xlabel(f"${fft_grid.axes[0].label}$ [${fft_grid.axes[0].unit}$]")
ax.set_ylabel(f"${fft_grid.axes[1].label}$ [${fft_grid.axes[1].unit}$]")

fig.show()

NameError: name 'plt' is not defined