In [1]:
from svetlanna.axes_math import tensor_dot, cast_tensor
from svetlanna.wavefront import mul
import svetlanna as sv
import torch

This example shows how tensor axes work.

Axes represent last axes names of the tensor.
Axes are stored in tuples of strings.
For example, tensor with shape `(..., N, M, L)` can have axes tuple `(a, b, c)`, where `a` dimension has `N` points, `b` has `M` points and `c` has `C` points.


In [2]:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
a_axes = ('a',)
a.shape

torch.Size([2, 3])

Any tensor can be casted to another axes, so the shape of the tensor will be changed

In [3]:
new_axes = ('b', 'a', 'c')
a_casted = cast_tensor(a, a_axes, new_axes)

a_casted.shape  # new axes was added to the second position and the end of the tensor

torch.Size([2, 1, 3, 1])

To perform tensor dot product  one can use `tensor_dot` function.
For example, if tensor `A` has axes `(a, b, c)` and tensor `B` has axes `(b,)`, the product result will have axes `(a, b, c)`.
The formula of such product is following:
$$
A \cdot B = \sum_{i_a, i_b, i_c} (A_{i_a, i_b, i_c} \cdot B_{i_b}) \vec{e_{i_a}}\vec{e_{i_b}}\vec{e_{i_c}}
$$
In other words, each tensor view `A[:,i,:]` was multiplied by `B[i]`.

The main rule is following: coincide axis will be multiplied.
 

In [4]:
A = torch.tensor([[[1, 2], [3, 4]]])
A_axes = ('a', 'b', 'c')

B = torch.tensor([10, 20])
B_axes = ('b', )

C, new_axes = tensor_dot(A, B, A_axes, B_axes)
C, new_axes

(tensor([[[10, 20],
          [60, 80]]]),
 ('a', 'b', 'c'))

In the case when `B` has axis not presented in `A` axes, the resulting tensor axes will be the union of `B` axes and `A` axes:

In [5]:
A = torch.tensor([[1, 2], [3, 4]])
A_axes = ('a', 'b')

B = torch.tensor([10, 20, 30])
B_axes = ('c', )

C, new_axes = tensor_dot(A, B, A_axes, B_axes)
C, new_axes

(tensor([[[ 10,  20,  30],
          [ 20,  40,  60]],
 
         [[ 30,  60,  90],
          [ 40,  80, 120]]]),
 ('a', 'b', 'c'))

The resulting tensor axes can be tested whether they coincides with `A` tensor axes or not, using `preserve_a_axis` argument.
If `preserve_a_axis=True`, the previous example will fail with `AssertionError` raised.

Let's define some simulation parameters

In [6]:
Nx = 50
Ny = 100
Nwl = 4

sim_params = sv.SimulationParameters(
    {
        'W': torch.linspace(-10, 10, Nx),
        'H': torch.linspace(-10, 10, Ny),
        'wavelength': torch.linspace(1, 5, Nwl),
    }
)

# one can see the axes sequence that is used during simulation
sim_params.axes.names

('wavelength', 'H', 'W')

Consider that during simulation one should perform multiplication of wavefront and some transmission function `T`, that is defined on (x,y) grid

In [7]:
wavefront = sv.Wavefront(torch.rand(sim_params.axes_size(sim_params.axes.names)))

T = torch.rand(Ny, Nx)
T_axis = ('H', 'W')

c1, c1_axis = tensor_dot(
    wavefront,
    T,
    sim_params.axes.names,
    T_axis
)
c1.shape

torch.Size([4, 100, 50])

To compute `c1` there is a shortcut available:

In [8]:
c2 = mul(
    wavefront,
    T,
    T_axis,
    sim_params
)

assert torch.allclose(c1, c2)

The `mul` function can be used in the `forward` method of an element.

The axes approach allow one not to think about wavefront axes provided by user during simulation and allow user to add custom wavefront axes in any order:

In [9]:
sim_params2 = sv.SimulationParameters(
    {
        'wavelength': torch.linspace(1, 5, Nwl),
        'H': torch.linspace(-10, 10, Ny),
        'polarization': torch.tensor([0, 1]),
        'W': torch.linspace(-10, 10, Nx),
    }
)

sim_params2.axes.names

('W', 'polarization', 'H', 'wavelength')

In [10]:
# new wavefront
new_wavefront = sv.Wavefront(torch.rand(sim_params2.axes_size(sim_params2.axes.names)))
new_wavefront.shape

torch.Size([50, 2, 100, 4])

The previous code still works with only change in wavefront and simulation parameters arguments:

In [11]:
c3 = mul(
    new_wavefront,
    T,
    T_axis,
    sim_params2
)