In [1]:
%%html
<style>
    .cell-output-ipywidget-background {
        background-color: transparent !important;
    }
    
</style>

In [2]:
import svetlanna
import torch

In [3]:
simulation_parameters = svetlanna.SimulationParameters(
    {
        'W': torch.linspace(-1, 1, 100),
        'H': torch.linspace(-1, 1, 100),
        'wavelength': 2e-1
    }
)

In [4]:
svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100)))

In [34]:
system = svetlanna.LinearOpticalSetup([
    svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    svetlanna.elements.DiffractiveLayer(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100))),
    svetlanna.elements.Aperture(simulation_parameters=simulation_parameters, mask=torch.rand((100, 100))),
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),
])
system.show()

  warn(


LinearOpticalSetupWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '<div style="font-family:mon…

In [6]:
input_field = svetlanna.Wavefront.plane_wave(simulation_parameters)
system.show_stepwise_forward(input_field, simulation_parameters, types_to_plot=('I', 'phase', 'Re'))

LinearOpticalSetupStepwiseForwardWidget(elements=[{'index': 0, 'type': 'ThinLens', 'specs_html': '<div style="…

In [7]:
e = svetlanna.specs.specs_writer.write_specs(*system.elements, filename='specs.md')

In [39]:
system1 = svetlanna.LinearOpticalSetup([
    svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    svetlanna.elements.ThinLens(simulation_parameters=simulation_parameters, focal_length=1),
])

reservoir = svetlanna.elements.reservoir.SimpleReservoir(
    simulation_parameters,
    system1,
    # system1,
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    0.1,
    0.2,
    10
)
reservoir

In [40]:
system2 = svetlanna.LinearOpticalSetup([
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
    reservoir,
    svetlanna.elements.FreeSpace(simulation_parameters=simulation_parameters, distance=1, method='fresnel'),
])

In [41]:
from svetlanna.specs.specs_writer import _ElementsIterator, _ElementInTree
from svetlanna.specs import Specsable
from dataclasses import dataclass

from IPython.core.display import display_html
from jinja2 import Environment, FileSystemLoader, select_autoescape

jinja_env = Environment(
    loader=FileSystemLoader("templates"),
    autoescape=select_autoescape()
)


In [46]:


@dataclass(frozen=True, slots=True)
class ElementHTML:
    element_type: str | None
    html: str


def _widget_html_(
    index: int,
    name: str,
    element_type: str | None,
    subelements: list[ElementHTML]
) -> str:
    return jinja_env.get_template('default_widget.html.jinja').render(
        index=index, name=name, subelements=subelements
    )


def _ls_widget_html_(
    index: int,
    name: str,
    element_type: str | None,
    subelements: list[ElementHTML]
) -> str:
    return jinja_env.get_template('linear_setup_widget.html.jinja').render(
        index=index, name=name, subelements=subelements
    )


def _fs_widget_html_(
    index: int,
    name: str,
    element_type: str | None,
    subelements: list[ElementHTML]
) -> str:
    return jinja_env.get_template('free_space_widget.html.jinja').render(
        index=index, name=name, subelements=subelements
    )


def _rs_widget_html_(
    index: int,
    name: str,
    element_type: str | None,
    subelements: list[ElementHTML]
) -> str:
    return jinja_env.get_template('reservoir_widget.html.jinja').render(
        index=index, name=name, subelements=subelements
    )


def _l_widget_html_(
    index: int,
    name: str,
    element_type: str | None,
    subelements: list[ElementHTML]
) -> str:
    return jinja_env.get_template('lens_widget.html.jinja').render(
        index=index, name=name, subelements=subelements
    )


def _get_widget_html_method(element: Specsable):
    if hasattr(element, '_widget_html_'):
        widget_html_method = getattr(element, '_widget_html_')
    else:
        widget_html_method = _widget_html_

    if isinstance(element, svetlanna.LinearOpticalSetup):
        widget_html_method = _ls_widget_html_

    if isinstance(element, svetlanna.elements.FreeSpace):
        widget_html_method = _fs_widget_html_
    
    if isinstance(element, svetlanna.elements.SimpleReservoir):
        widget_html_method = _rs_widget_html_
    
    if isinstance(element, svetlanna.elements.ThinLens):
        widget_html_method = _l_widget_html_

    return widget_html_method


def _subelements_html(subelements: list[_ElementInTree]) -> list[ElementHTML]:
    res = []

    for subelement in subelements:
        widget_html_method = _get_widget_html_method(subelement.element)
        try:
            res.append(
                ElementHTML(
                    subelement.subelement_type,
                    html=widget_html_method(
                        index=subelement.element_index,
                        name=subelement.element.__class__.__name__,
                        element_type=subelement.subelement_type,
                        subelements=_subelements_html(subelement.children)
                    )
                )
            )
        except Exception as e:
            pass

    return res


elements = _ElementsIterator(system2, directory='')

for _, _, i in elements:
    for _ in i:
        pass

res = _subelements_html(elements.tree)


containered_html = f'<div style="overflow: auto;display: flex; flex-direction: column; align-items: flex-start; font-family: monospace; max-height: 25rem">{res[0].html}</div>'
display_html(containered_html, raw=True)

In [10]:
e = svetlanna.specs.specs_writer.write_specs(system2, filename='specs.md')

In [11]:
e.tree[0].children[1].children

[_ElementInTree(element=<svetlanna.setup.LinearOpticalSetup object at 0x1556fcbd0>, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Nonlinear element'),
 _ElementInTree(element=<svetlanna.setup.LinearOpticalSetup object at 0x1556fcbd0>, element_index=3, children=[_ElementInTree(element=ThinLens(), element_index=4, children=[], subelement_name='0'), _ElementInTree(element=FreeSpace(), element_index=5, children=[], subelement_name='1'), _ElementInTree(element=ThinLens(), element_index=6, children=[], subelement_name='2')], subelement_name='Delay element')]

In [12]:
raise Exception

Exception: 

In [None]:
e._tree[0].children[1].children[0].element_name

'Nonlinear element'

In [None]:
print('\n'.join([str(i) for i in e._tree]))

ElementInTree(element=<svetlanna.setup.LinearOpticalSetup object at 0x152b1bb10>, element_index=0, children=[ElementInTree(element=FreeSpace(), element_index=1, children=[]), ElementInTree(element=SimpleReservoir(), element_index=2, children=[ElementInTree(element=<svetlanna.setup.LinearOpticalSetup object at 0x13cad6490>, element_index=3, children=[ElementInTree(element=ThinLens(), element_index=4, children=[]), ElementInTree(element=FreeSpace(), element_index=5, children=[]), ElementInTree(element=ThinLens(), element_index=6, children=[])]), ElementInTree(element=<svetlanna.setup.LinearOpticalSetup object at 0x13cad6490>, element_index=7, children=[ElementInTree(element=ThinLens(), element_index=8, children=[]), ElementInTree(element=FreeSpace(), element_index=9, children=[]), ElementInTree(element=ThinLens(), element_index=10, children=[])])]), ElementInTree(element=FreeSpace(), element_index=11, children=[])])


In [None]:
# torch.set_default_dtype(torch.float32)
# Image.fromarray(torch.tensor(a).to(torch.float64).numpy(), mode='L').show()
# Image.fromarray(np.uint8(255*torch.tensor(a).numpy()), mode='L').show() # <- works

In [None]:
import torch

In [None]:
torch.tensor([[1, 1,], [1, 2]]).size() < torch.tensor([1,]).size()

False

In [None]:
import svetlanna
import svetlanna.elements


class A(svetlanna.elements.Element):
    def __init__(self, simulation_parameters: svetlanna.SimulationParameters) -> None:
        super().__init__(simulation_parameters)

        self.a = self.make_buffer('a', self.simulation_parameters.axes.W)

    def forward(self, input_field: svetlanna.Wavefront) -> svetlanna.Wavefront:
        pass

In [None]:
simulation_parameters = svetlanna.SimulationParameters(
    {
        'W': torch.linspace(-1, 1, 10),
        'H': torch.linspace(-1, 1, 10),
        'wavelength': 10
    }
)



In [None]:
a = A(simulation_parameters=simulation_parameters)

In [None]:
a.to('mps')

In [None]:
simulation_parameters.axes.W.device

device(type='cpu')

In [None]:
a.a.device

device(type='mps', index=0)