In [1]:
import svetlanna as sv
from svetlanna import Wavefront, SimulationParameters
from svetlanna.parameters import OptimizableFloat, OptimizableTensor
from svetlanna.wavefront import mul
from svetlanna.units import ureg
import torch

# Simple element

Consider an optical element that operates as follows:
$$
f(u) = a\left(\hat{W}u\right)^b
$$
where $u$ represents the incident wavefront, and $a$ , $b$ are parameter. $W$ is a 2d mask in (x,y)-plane.
The product $\hat{W}x$ denotes an elementwise multiplication:
$$\left[\hat{W}u\right](x_i, y_i) = \hat{W}(x_i, y_i)u(x_i, y_i)$$

In [2]:
class MyElement(sv.elements.Element):
    def __init__(
        self,
        simulation_parameters: SimulationParameters,
        a: OptimizableFloat,
        b: int,
        W: OptimizableTensor
    ) -> None:
        super().__init__(simulation_parameters)  # this line is required

        self.a = self.process_parameter(
            'a', a
        )
        self.b = self.process_parameter(
            'b', b
        )
        self.W = self.process_parameter(
            'W', W
        )

    def forward(self, incident_wavefront: Wavefront) -> Wavefront:
        r = mul(
            incident_wavefront,
            self.W,
            ('H', 'W'),
            self.simulation_parameters
        )
        return self.a * r**self.b

Let's discuss each part of the code.
To highlight whether a value can be optimized, the `OptimizableFloat` and `OptimizableTensor` type aliases are available:
* `OptimizableFloat` is used for scalar values
* `OptimizableTensor` is used for tensor values, such as vectors and matrices

Next, parameters should be registered using the `process_parameter` method.
This method performs different actions based on the provided argument:
* if the argument requires gradient calculations (e.g., `torch.nn.Parameter`), it registers the parameter in the `torch.nn.Module` instance
* if the argument is a tensor, it registers it as a buffer
* otherwise, it does nothing.
**Best practice**: always use `process_parameter` method for any argument passed to init method.

The `forward` method must be implemented for any new element.
Apart from multiplication and exponentiation, it includes an elementwise product.
To perform an elementwise product between a wavefront and a tensor, use the `mul` function.
The axes of $W$ must be specified to perform the product.
In this example $W$ is a 2d-mask in (x,y)-plane, therefore its axes are `('H', 'W')`.
The order of the axes names should match the order of the tensor's axes when `.shape` is called.

The approach with `mul` method ensures compatibility with further changes in wavefront's axis order and changes in axes number (e.g., batch axis, physical properties like `polarization`, `wavelength`, etc.).

In [3]:
sim_params = SimulationParameters({
    'W': torch.linspace(-1, 1, 10) * ureg.cm,
    'H': torch.linspace(-1, 1, 10) * ureg.cm,
    'wavelength': torch.tensor([400, 500, 600]) * ureg.nm
})

wf = Wavefront(torch.rand((3, 10, 10)))

In [4]:
el1 = MyElement(
    sim_params,
    a=sv.ConstrainedParameter(2., min_value=0., max_value=5),
    b=2,
    W=torch.rand((10, 10))
)

print(el1(wf).shape)  # transmitted wavefront shape

torch.Size([3, 10, 10])


# Wavelength-depended properties

Consider an optical element with a nonlinear transmission function given by
$$
f(u) = u \frac{600}{600 + \lambda}
$$
where $\lambda$ is wavelength in nm.

In [5]:
class MyNonlinearElement(sv.elements.Element):
    def forward(self, incident_wavefront: Wavefront) -> Wavefront:
        lmbda = self.simulation_parameters.axes.wavelength / ureg.nm
        t = 600 / (600 + lmbda)
        return mul(
            incident_wavefront,
            t,
            'wavelength',
            self.simulation_parameters
        )

In this example, only the `forward` method is implemented.
To construct $f(u)$, the wavelength must be obtained from the simulation parameters.
The general approach to accessing any axis from the simulation parameters is: `self.simulation_parameters.axes.<axis name>`.
In this case, the wavelength is retrieved using:`self.simulation_parameters.axes.wavelength`.
Next, the `mul` function is used to perform the multiplication.
This ensures that the code remains functional in different scenarios, such as: when single wavelength is provided in the simulation parameters (`'wavelength': 500 * ureg.nm`) or when more axes are added.
The same code will work without any modifications in this cases.

In [6]:
el2 = MyNonlinearElement(
    sim_params
)

print(el2(wf).shape)  # transmitted wavefront shape

torch.Size([3, 10, 10])


# Wavelength-depended properties, improved

One can notice that `t` is recalculated every time the `forward` method is called.
To reduce computations this variable can be computed once during initialization and registered as a buffer.

In [7]:
class MyNonlinearElementImproved(sv.elements.Element):
    def __init__(self, simulation_parameters: SimulationParameters) -> None:
        super().__init__(simulation_parameters)

        lmbda = self.simulation_parameters.axes.wavelength / ureg.nm
        t = 600 / (600 + lmbda)

        if isinstance(lmbda, torch.Tensor):
            self.t = self.make_buffer('t', t)
        else:
            self.t = t

    def forward(self, incident_wavefront: Wavefront) -> Wavefront:
        return mul(
            incident_wavefront,
            self.t,
            'wavelength',
            self.simulation_parameters
        )

To store any tensor as a buffer, one should use the `make_buffer` method.
This method is only applicable to `torch.Tensors` objects, which is why the `if` condition is necessary.

Is it important to notice that if the expression depends on a parameter that requires gradient computation, the tensor should not be buffered.
It must be computed during each call to the forward method to ensure proper gradient tracking.

# Specs

In [8]:
from svetlanna.specs import ReprRepr, PrettyReprRepr, NpyFileRepr, ImageRepr
from svetlanna.specs import ParameterSpecs

Currently, there are four types of variable representations: `ReprRepr`, `PrettyReprRepr`, `NpyFileRepr`, `ImageRepr`.

To display specifications (and later save them), one should implement the `to_specs` method.
Representations require a `value` argument, which can be any relevant data, for example, `self.mask.abs()`.
The following example demonstrates this concept using random tensors.

In [9]:
class ElementWithSpecs(sv.elements.Element):
    def forward(self, incident_wavefront: Wavefront) -> Wavefront:
        ...

    def to_specs(self) -> list[ParameterSpecs]:
        return [
            ParameterSpecs(
                'a', (
                    ReprRepr(123),
                    PrettyReprRepr(123, units='cm'),
                )
            ),
            ParameterSpecs(
                'b', (
                    ImageRepr(
                        torch.rand((100, 100)).numpy(force=True)
                    ),
                    NpyFileRepr(
                        torch.rand((100, 100)).numpy(force=True)
                    ),
                    PrettyReprRepr(
                        torch.rand((100, 100)), units='nm'),
                )
            ),
            ParameterSpecs(
                'c', (
                    ImageRepr(
                        torch.rand((100, 100)).numpy(force=True)
                    ),
                )
            )
        ]

In [10]:
# specs can be displayed in the jupyter cell
ElementWithSpecs(sim_params)