In [2]:
#!pip install femwell

In [1]:
from collections import OrderedDict

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import shapely
import shapely.affinity
from scipy.constants import epsilon_0, speed_of_light
from shapely.ops import clip_by_rect
import skfem
from skfem import Basis, ElementTriP0
from skfem.io.meshio import from_meshio

from femwell.maxwell.waveguide import compute_modes
from femwell.mesh import mesh_from_OrderedDict
from femwell.visualization import plot_domains

import ipywidgets as widgets


%matplotlib widget
# set default figure size
plt.rcParams["figure.figsize"] = (10, 6)


In [2]:
class Geom:
    pass


def gen_mesh(sim):
    geom = Geom()
    wg_width = sim.width_set.value
    wg_thickness = sim.height_set.value
    slab_thickness = sim.slab_thickness_set.value
    sim_width = sim.sim_width_set.value
    sim_height = sim.sim_height_set.value

    geom.core = shapely.geometry.box(
        -wg_width / 2, -wg_thickness / 2, +wg_width / 2, wg_thickness / 2
    )

    geom.env = shapely.geometry.box(
        -sim_width / 2, -sim_height / 2, sim_width / 2, sim_height / 2
    )
    # env = shapely.affinity.scale(core.buffer(1, resolution=50), xfact=1.)

    geom.polygons = OrderedDict(
        core=geom.core,
        box=clip_by_rect(geom.env, -np.inf, -np.inf, np.inf, -wg_thickness / 2),
        clad=clip_by_rect(
            geom.env, -np.inf, -wg_thickness / 2 + slab_thickness, np.inf, np.inf
        ),
    )

    if slab_thickness > 0:
        geom.polygons["slab"] = clip_by_rect(
            geom.env,
            -np.inf,
            -wg_thickness / 2,
            np.inf,
            -wg_thickness / 2 + slab_thickness,
        )

    geom.resolutions = dict(
        core={"resolution": sim.res_core_set.value, "distance": 1},
        box={"resolution": sim.res_clad_set.value, "distance": 1},
        clad={"resolution": sim.res_clad_set.value, "distance": 1},
        slab={"resolution": sim.res_slab_set.value, "distance": 1},
    )

    geom.mesh = from_meshio(
        mesh_from_OrderedDict(
            geom.polygons, geom.resolutions, default_resolution_max=10
        )
    )

    fig, axs = plt.subplots(1, 2, figsize=(12, 4), subplot_kw={"aspect": "equal"})
    geom.mesh.draw(ax=axs[0])

    geom.basis0 = Basis(geom.mesh, ElementTriP0())

    geom.subdomains = list(geom.mesh.subdomains.keys() - {"gmsh:bounding_entities"})
    subdomain_colors = geom.basis0.zeros() * np.NaN
    for i, subdomain in enumerate(geom.subdomains):
        subdomain_colors[geom.basis0.get_dofs(elements=subdomain)] = i

    norm = matplotlib.colors.BoundaryNorm(np.arange(i + 2) - 0.5, ncolors=256)
    geom.basis0.plot(
        subdomain_colors, plot_kwargs={"norm": norm}, ax=axs[0], cmap="rainbow"
    )
    fig.colorbar(axs[0].collections[-1], ticks=list(range(i + 1))).ax.set_yticklabels(
        geom.subdomains
    )

    index_dict = {
        "core": sim.core_n_set.value,
        "box": sim.bott_clad_n_set.value,
        "clad": sim.top_clad_n_set.value,
    }
    if slab_thickness > 0:
        index_dict["slab"] = sim.core_n_set.value

    geom.epsilon = geom.basis0.zeros()
    geom.index = geom.basis0.zeros()

    for subdomain, n in index_dict.items():
        geom.index[geom.basis0.get_dofs(elements=subdomain)] = n
        geom.epsilon[geom.basis0.get_dofs(elements=subdomain)] = n**2
    geom.basis0.plot(geom.index, colorbar=True, ax=axs[1])
    axs[0].set_xlabel("x ($\mu$m)")
    axs[0].set_ylabel("y ($\mu$m)")
    axs[1].set_xlabel("x ($\mu$m)")
    axs[1].set_ylabel("y ($\mu$m)")
    # axs[0].set_aspect("equal")
    # axs[1].set_aspect("equal")
    # retrieve colorbar
    # fig.axes[-1].set_title('n')
    axs[0].set_title("Mesh")
    axs[1].set_title("Refractive index")
    fig.canvas.header_visible = False

    fig.tight_layout()
    return geom, fig, axs

In [11]:
def plot_small_mode(mode):
    plt.ioff()
    dpi = plt.rcParams["figure.dpi"]
    fig, ax = plt.subplots(figsize=(200 / dpi, 150 / dpi), subplot_kw=dict(aspect=1))

    basis = mode.basis
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    (et, et_basis), (ez, ez_basis) = basis.split(mode.E.real)

    plot_basis = et_basis.with_element(
        skfem.ElementVector(skfem.ElementDG(skfem.ElementTriP1()))
    )
    et_xy = plot_basis.project(et_basis.interpolate(et))
    (et_x, et_x_basis), (et_y, et_y_basis) = plot_basis.split(et_xy)

    basis.mesh.draw(ax=ax, boundaries_only=True)
    for subdomain in basis.mesh.subdomains.keys() - {"gmsh:bounding_entities"}:
        basis.mesh.restrict(subdomain).draw(ax=ax, boundaries_only=True)

    et_x_basis.plot(np.sqrt(et_x**2 + et_y**2), shading="gouraud", ax=ax)
    TE_frac = mode.te_fraction * 100
    TM_frac = 100 - TE_frac
    if TE_frac > 50:
        mode_title = f"TE ({TE_frac:.0f}%)"
    else:
        mode_title = f"TM ({TM_frac:.0f}%)"
    ax.set_title(
        mode_title + " $n_{eff}$ =" + f"{np.real(mode.n_eff):.6g}", fontsize=10
    )
    # ax.set_xticks([])
    # ax.set_yticks([])
    ax.axis("off")
    fig.canvas.header_visible = False
    fig.canvas.toolbar_visible = False
    ax.format_coord = lambda x, y: ""
    fig.subplots_adjust(left=0, right=1, top=0.8, bottom=0)
    # fig.tight_layout()

    return fig, ax


def gen_row_modes(modes):
    plt.close("all")
    fig_panel_list = []
    button_style = {"button_color": "lightgray"}
    for ii, mode in enumerate(modes):
        fig_output = widgets.Output(
            layout={"width": "200px", "border": "None", "justify_items": "center"}
        )

        with fig_output:
            fig_small, ax_small = plot_small_mode(mode)
            # disable displaying mouse coordinates

            plt.show(fig_small)

        # fig_panel_list.append(widgets.VBox([widgets.Button(description=f"Mode {ii}", layout = {"justify_content": "center"}),
        #                                     fig_output], layout = {"align_items": "center"}))
        btn = widgets.Button(
            description=f"Mode {ii+1}", layout={"justify_content": "center"}, style = button_style,
        )

        btn_and_fig = widgets.VBox(
            [btn, fig_output],
            layout={
                "min_width": "200px",
                "align_items": "center",
                "border": "None",
                "overflow": "hidden",
            },
        )

        fig_panel_list.append(btn_and_fig)

    row = widgets.HBox(
        fig_panel_list,
        layout={
            "width": "1000px",
            "object_fit": "contain",
            "overflow": "scroll hidden",
            "flex_flow": "row",
            "display": "flex",
        },
    )

    return row


def plot_main_mode(mode):
    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(aspect=1))

    basis = mode.basis
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    (et, et_basis), (ez, ez_basis) = basis.split(mode.E.real)

    plot_basis = et_basis.with_element(
        skfem.ElementVector(skfem.ElementDG(skfem.ElementTriP1()))
    )
    et_xy = plot_basis.project(et_basis.interpolate(et))
    (et_x, et_x_basis), (et_y, et_y_basis) = plot_basis.split(et_xy)

    basis.mesh.draw(ax=ax, boundaries_only=True)
    for subdomain in basis.mesh.subdomains.keys() - {"gmsh:bounding_entities"}:
        basis.mesh.restrict(subdomain).draw(ax=ax, boundaries_only=True)

    et_x_basis.plot(np.sqrt(et_x**2 + et_y**2), shading="gouraud", ax=ax)
    TE_frac = mode.te_fraction * 100
    TM_frac = 100 - TE_frac
    if TE_frac > 50:
        mode_title = f"TE ({TE_frac:.0f}%)"
    else:
        mode_title = f"TM ({TM_frac:.0f}%)"
    ax.set_title(mode_title + " $n_{eff}$ =" + f"{np.real(mode.n_eff):.6g}")
    colorbar = plt.colorbar(ax.collections[-1])
    colorbar.ax.set_title("$|E|$")
    ax.set_xlabel("x ($\mu m$)")
    ax.set_ylabel("y ($\mu m$)")
    fig.canvas.header_visible = False

    fig.tight_layout()

    return fig, ax


def mode_labels(modes):
    labels = []
    for ii, mode in enumerate(new_modes):
        TE_frac = mode.te_fraction * 100
        TM_frac = 100 - TE_frac

        if TE_frac > 50:
            labels.append(f"Mode {ii}: TE ({TE_frac:.0f}%)")
        else:
            labels.append(f"Mode {ii}: TM ({TM_frac:.0f}%)")
    return labels

In [12]:
class Sim:
    """Where I insert all the simulation parameters"""

    pass


plt.close("all")

sim = Sim()
style_input = {"description_width": "initial"}
layout_col = widgets.Layout(width="160px")
sim.width_set = widgets.BoundedFloatText(
    value=0.5,
    min=0.001,
    max=20.0,
    step=0.001,
    description="WG width",
    layout=layout_col,
)
sim.height_set = widgets.BoundedFloatText(
    value=0.22,
    min=0.001,
    max=20.0,
    step=0.001,
    description="WG height",
    layout=layout_col,
)
sim.slab_thickness_set = widgets.BoundedFloatText(
    value=0.0,
    min=0.0,
    max=20.0,
    step=0.001,
    description="Slab thickness",
    layout=layout_col,
    style=style_input,
)

sim.core_n_set = widgets.BoundedFloatText(
    value=3.477, min=1, max=5, step=0.001, description="Core n", layout=layout_col
)
sim.bott_clad_n_set = widgets.BoundedFloatText(
    value=1.444, min=1, max=5, step=0.0, description="Bott clad n", layout=layout_col
)
sim.top_clad_n_set = widgets.BoundedFloatText(
    value=1.444, min=1, max=5, step=0.0, description="Top clad n", layout=layout_col
)

sim.res_core_set = widgets.BoundedFloatText(
    value=0.03, min=0.0001, max=20.0, description="Core res", layout=layout_col
)
sim.res_clad_set = widgets.BoundedFloatText(
    value=0.1, min=0.0001, max=20.0, description="Clad res", layout=layout_col
)
sim.res_slab_set = widgets.BoundedFloatText(
    value=0.05, min=0.0001, max=20.0, description="Slab res", layout=layout_col
)

sim.sim_width_set = widgets.BoundedFloatText(
    value=3.0, min=0.1, max=30.0, step=0.001, description="Sim width", layout=layout_col
)
sim.sim_height_set = widgets.BoundedFloatText(
    value=2.0,
    min=0.1,
    max=30.0,
    step=0.001,
    description="Sim height",
    layout=layout_col,
)


cols = []
cols += [widgets.VBox([sim.width_set, sim.height_set, sim.slab_thickness_set])]
cols += [widgets.VBox([sim.core_n_set, sim.bott_clad_n_set, sim.top_clad_n_set])]
cols += [widgets.VBox([sim.sim_width_set, sim.sim_height_set])]
cols += [widgets.VBox([sim.res_core_set, sim.res_clad_set, sim.res_slab_set])]

button_style = {"button_color": "lightgray"}

ctrl_panel = widgets.HBox(cols)
display(ctrl_panel)
mesh_btn = widgets.Button(description="Calc mesh", style = button_style)
display(mesh_btn)

mesh_fig_panel = widgets.Output()
display(mesh_fig_panel)


wavel_set = widgets.BoundedFloatText(
    value=1.55,
    min=0.001,
    max=20.0,
    step=0.001,
    description="Wavelength",
    layout=layout_col,
)
Nmodes_set = widgets.BoundedIntText(
    value=2, min=1, max=100, description="Num. modes", layout=layout_col
)
calc_btn = widgets.Button(description="Calc modes", style = button_style)
calc_label = widgets.Label("Calculating modes...")

second_row = widgets.HBox([wavel_set, Nmodes_set, calc_btn, calc_label])

calc_label.layout.display = "none"

display(second_row)

row_modes = widgets.HBox([])
main_plot_panel = widgets.Output(layout={"width": "1000px"})



vert_panel = widgets.VBox([row_modes, main_plot_panel])
display(vert_panel)

curr_geom = None


def mesh_btn_press(wdgt):
    global curr_geom, mesh_fig_panel, sim
    mesh_fig_panel.clear_output()
    geom, fig, axs = gen_mesh(sim)
    curr_geom = geom

    with mesh_fig_panel:
        plt.show(fig)


mesh_btn.on_click(mesh_btn_press)

mesh_btn_press(None)


def calc_btn_press(wdgt):
    global vert_panel, sim, main_plot_panel, row_modes
    # enable visibility of calc_label
    calc_label.layout.display = "initial"

    sim.modes = compute_modes(
        curr_geom.basis0,
        curr_geom.epsilon,
        wavelength=wavel_set.value,
        num_modes=Nmodes_set.value,
        order=2,
    )

    for mode in sim.modes:
        print(f"Effective refractive index: {mode.n_eff:.4f}")

    row_modes = gen_row_modes(sim.modes)
    vert_panel.children = [row_modes, main_plot_panel]

    for ii, wdgt in enumerate(row_modes.children):
        wdgt.children[0].on_click(partial(fig_click, mode_index=ii))

    fig_click(None, 0)
    calc_label.layout.display = "none"


calc_btn.on_click(calc_btn_press)

from functools import partial


def fig_click(wdgt, mode_index):
    global vert_panel, sim, main_plot_panel
    main_plot_panel.clear_output()
    with main_plot_panel:
        fig_main, ax_main = plot_main_mode(sim.modes[mode_index])
        plt.show(fig_main)

HBox(children=(VBox(children=(BoundedFloatText(value=0.5, description='WG width', layout=Layout(width='160px')…

Button(description='Calc mesh', style=ButtonStyle(button_color='lightgray'))

Output()

HBox(children=(BoundedFloatText(value=1.55, description='Wavelength', layout=Layout(width='160px'), max=20.0, …

VBox(children=(HBox(), Output(layout=Layout(width='1000px'))))

Effective refractive index: 2.4465+0.0000j
Effective refractive index: 1.7710+0.0000j


Effective refractive index: 2.4465+0.0000j
Effective refractive index: 1.7710+0.0000j
Effective refractive index: 1.4930+0.0000j


Effective refractive index: 2.4129+0.0000j
Effective refractive index: 1.7568+0.0000j
Effective refractive index: 1.4718+0.0000j


Effective refractive index: 2.6788+0.0000j
Effective refractive index: 2.5244+0.0000j
Effective refractive index: 2.5155+0.0000j


Effective refractive index: 2.6788+0.0000j
Effective refractive index: 2.5244+0.0000j
Effective refractive index: 2.5155+0.0000j
Effective refractive index: 2.4740+0.0000j
Effective refractive index: 2.4430+0.0000j


Effective refractive index: 2.6788+0.0000j
Effective refractive index: 2.5244+0.0000j
Effective refractive index: 2.5155+0.0000j
Effective refractive index: 2.4740+0.0000j
Effective refractive index: 2.4430+0.0000j
Effective refractive index: 2.3827+0.0000j
Effective refractive index: 2.3236+0.0000j
Effective refractive index: 2.2437+0.0000j
Effective refractive index: 2.1530+0.0000j
Effective refractive index: 2.0478+0.0000j


Effective refractive index: 2.4949+0.0000j
Effective refractive index: 1.8898+0.0000j
Effective refractive index: 1.7628+0.0000j
Effective refractive index: 1.7419+0.0000j
Effective refractive index: 1.6328+0.0000j


Effective refractive index: 2.8030+0.0000j
Effective refractive index: 2.6617+0.0000j
Effective refractive index: 2.4109+0.0000j
Effective refractive index: 2.0027+0.0000j
Effective refractive index: 2.0234+0.0000j


Output(layout=Layout(width='1000px'))