Generating plots that explore the basic, static LOPC model. This notebook is for generating publication-ready plots, which will be saved directly to the LaTeX directory. To follow the process of exploring the data, look at the other notebooks, e.g. [here](../basic_LOPC.ipynb].

In [None]:
# computation
import lumapi
import numpy as np
import warnings
import xarray as xr
import pandas as pd
# import xyzpy as xyz
# from xyzpy.gen.combo_runner import multi_concat
from multilayer_simulator.lumerical_classes import LumericalOscillator, format_stackrt, format_stackfield
from multilayer_simulator.helpers.mixins import convert_wavelength_and_frequency
import dask
from functools import partial
from pathlib import Path
import sys
from tqdm import tqdm
# plotting
import hvplot.xarray
import hvplot.pandas
import holoviews as hv
from holoviews import dim, opts
import colorcet
import panel as pn
import panel.widgets as pnw
from bokeh.io import export_png, export_svg
from selenium.webdriver import Firefox
from selenium.webdriver.firefox.options import Options
from scipy.signal import find_peaks
from bokeh.models import PrintfTickFormatter

In [None]:
hv.extension("bokeh", inline=False, case_sensitive_completion=True)  # use matplotlib because rendering bokeh to svg is broken
pn.config.throttled = True  # don't update interactive plots until mouse is unclicked

# default_color_cycle = hv.Cycle("Colorblind")  # Ruth doesn't like the inclusion of yellow, which is fair enough
default_color_cycle = hv.Cycle(colorcet.glasbey_dark)
default_dash_cycle = hv.Cycle(["solid", "dashed", "dashdot", "dotted", "dotdash"])
universal_opts = dict(fontscale=2, title="")
matplotlib_opts = dict(fig_inches=5, aspect=2, fig_latex=True)
bokeh_opts = dict(width=700, height=300)
opts.defaults(opts.Curve(**universal_opts|bokeh_opts, color=default_color_cycle, line_width=1.5),
              opts.Scatter(**universal_opts|bokeh_opts, color=default_color_cycle),
              opts.Image(**universal_opts|bokeh_opts),
              opts.Slope(**universal_opts|bokeh_opts, color=default_color_cycle),
              opts.Area(**universal_opts|bokeh_opts, color=default_color_cycle),
              opts.Overlay(**universal_opts|bokeh_opts),
              opts.Layout(**universal_opts|bokeh_opts),
              opts.GridSpace(**universal_opts|bokeh_opts),
              )

xarray_engine='h5netcdf'

In [None]:
root = Path.cwd().parent.parent.parent  # depth of parents depends on if this is running in JupyterLab or Notebook

In [None]:
code_path = root / r"research"

In [None]:
data_path = code_path / r"notebooks/data"

In [None]:
archive_path = root / r"thesis/LaTeX/chapters/methods"

In [None]:
fig_path = archive_path / "fig_methods"

In [None]:
if not code_path in sys.path:
    sys.path.append(str(code_path))
from LOPC import LOPC
from LOPC.helpers import (
    assign_derived_attrs,
    restack,
    enhancement_factor,
    # combo_length,
    # estimate_combo_run_time,
    linewidth_calculator,
    lopc_data,
    spectrum,
    normalise_over_dim,
    integrate_da,
    sel_or_integrate,
    find_optimum_coords,
    plot_secondary,
    pre_process_for_plots,
    vlines,
    coordinate_string,
    plot_da,
    plot_var,
    plot_optimum_over_dim,
    plot_field,
    visualise_multilayer,
    complex_elements,
    indexer_from_dataset,
    fix_bin_labels,
    mean_and_std,
    max_min_pos,
)

Turn on auto-archiving of cells and Holoviews outputs. See the user guide [here](https://holoviews.org/user_guide/Exporting_and_Archiving.html).

Might need to install `ipympl`.

EDIT: This does not work but I'm leaving this here so a future researcher can avoid the rabbithole I fell down.

In [None]:
# # This is the idiomatic way to record all generated figures with holoviews
# # This does NOT work in JupyterLab: see https://github.com/holoviz/holoviews/issues/3570
# # This also does not work in Jupyter Notebook
# # It's just utterly broken

# hv.archive.auto(root=str(archive_path), export_name="fig_chapter_2") 

In [None]:
# options = Options()
# options.add_argument('-headless')
# web_driver = Firefox(
#     options=options,
#     # firefox_binary=str(Path(r'C:\Users\xv18766\Anaconda3\envs\multilayer_simulator\Library\bin\firefox')),
#     # executable_path=str(Path(r"C:\Users\xv18766\Anaconda3\envs\multilayer_simulator\Scripts\geckodriver"))
# )

# Load/define datasets

## Load LOPC dataset

In [None]:
# # chunks for per-angle plots
# chunks = {
#     "frequency": 256,
#     "excitonic_layer_thickness": 16,
#     "passive_layer_thickness": 32,
#     "theta": 1,
#     "num_periods": 16,
# }

In [None]:
# chunks for plotting or integrating over angle
chunks = {
    "frequency": 256,
    "excitonic_layer_thickness": 16,
    "passive_layer_thickness": 32,
    "theta": 16,
    "num_periods": 1,
}

In [None]:
run_number = 2

ds = xr.open_mfdataset(
    data_path / f"run_{run_number}/LOPC.nc",
    engine=xarray_engine,
    lock=False,
    chunks=chunks,
)

# add derived attrs
ds = assign_derived_attrs(ds, per_oscillator=["Rs", "Rp", "R", "Ts", "Tp", "T", "As", "Ap", "A"])

## Load reference slab dataset

In [None]:
# useful variables
total_excitonic_thicknesses = np.unique(ds.total_excitonic_thickness)
total_passive_thicknesses = np.unique(ds.total_passive_thickness)
total_thicknesses = np.unique(ds.total_thickness)

In [None]:
ref = xr.open_mfdataset(
    data_path / f"run_{run_number}/ref.nc",
    engine=xarray_engine,
    lock=False,
)

Note: `period=False` is an important option because otherwise it ends up a coordinate of `total_excitonic_thickness` and causes a conflict after binary operations with `ds`.

In [None]:
ref = assign_derived_attrs(ref, period=False, total_excitonic_thickness=False, total_passive_thickness=False, total_thickness=False)

In [None]:
# compressed reference slab without passive layer
crs_1 = (
    ref.sel(
        remove_last_layer=1,
        passive_layer_thickness=0,
        excitonic_layer_thickness=total_excitonic_thicknesses,
        drop=True,
    )
    .squeeze(drop=True)
    .rename(excitonic_layer_thickness="total_excitonic_thickness")
)

In [None]:
# filled reference slab
frs_1 = (
    ref.sel(
        remove_last_layer=1,
        passive_layer_thickness=0,
        excitonic_layer_thickness=total_thicknesses,
        drop=True,
    )
    .squeeze(drop=True)
    .rename(excitonic_layer_thickness="total_thickness")
)

### Load derived variables

In [None]:
polarised_attrs = ['Rs', 'Rp', 'Ts', 'Tp', 'As', 'Ap']

In [None]:
norm_1 = xr.open_mfdataset(
    data_path / f"run_{run_number}/norm_1.nc",
    chunks=chunks,
)

In [None]:
norm_2 = xr.open_mfdataset(
    data_path / f"run_{run_number}/norm_2.nc",
    chunks=chunks,
    engine=xarray_engine,
    lock=False,
)

In [None]:
norm_2 = assign_derived_attrs(norm_2, absorption=False, unpolarised=False, per_oscillator=False)

In [None]:
gb_tet = ds[polarised_attrs].groupby('total_excitonic_thickness')

# 'biology' absorptance enhancement factor: normalise by reference slab type 1: compressed reference slab w/o passive layer
diff_1 = gb_tet - crs_1

diff_1 = assign_derived_attrs(
    dataset=diff_1,
    unpolarised=True,
    absorption=False,
    period=False,
    total_excitonic_thickness=False,
    total_passive_thickness=False,
    total_thickness=False,
    N_tot=False,
    per_oscillator=["Rs", "Rp", "R", "Ts", "Tp", "T", "As", "Ap", "A"],
)

In [None]:
gb_tt = ds[polarised_attrs].groupby("total_thickness")

# 'stuffed' difference factor: difference with reference slab type 2: filled reference slab
diff_2 = gb_tt - frs_1

diff_2 = assign_derived_attrs(
    dataset=diff_2,
    unpolarised=True,
    absorption=False,
    period=True,  # reset period to only depend on two dims
    total_excitonic_thickness=False,
    total_passive_thickness=False,
    total_thickness=False,
    N_tot=False,
    per_oscillator=["Rs", "Rp", "R", "Ts", "Tp", "T", "As", "Ap", "A"],
)

In [None]:
ds_flat_spectrum = xr.open_dataset(
    data_path / f"run_{run_number}/LOPC_flat_spectrum.nc",
    engine=xarray_engine,
    lock=False,
)
ds_flat_spectrum = assign_derived_attrs(ds_flat_spectrum)

In [None]:
ref_flat_spectrum = xr.open_dataset(
    data_path / f"run_{run_number}/ref_flat_spectrum.nc",
    engine=xarray_engine,
    lock=False,
)
ref_flat_spectrum = assign_derived_attrs(
    ref_flat_spectrum,
    period=False,
    total_excitonic_thickness=False,
    total_passive_thickness=False,
    total_thickness=False,
)

In [None]:
# Compressed reference slab without passive layer
crs_1_flat_spectrum = (
    ref_flat_spectrum.sel(
        remove_last_layer=1,
        passive_layer_thickness=0,
        excitonic_layer_thickness=total_excitonic_thicknesses,
        drop=True,
    )
    .squeeze(drop=True)
    .rename(excitonic_layer_thickness="total_excitonic_thickness")
)

In [None]:
norm_flat_spectrum = (ds_flat_spectrum.groupby('total_excitonic_thickness')/crs_1_flat_spectrum)#.drop_sel(excitonic_layer_thickness=0)

In [None]:
enhancement_factor?

In [None]:
norm_flat_spectrum = enhancement_factor(ds=ds_flat_spectrum, ref=crs_1_flat_spectrum, common_dim="total_excitonic_thickness")

## Restacking passive layer thickness to period

In [None]:
restack_plt_to_period = partial(
    restack,
    start_idxs=["passive_layer_thickness", "excitonic_layer_thickness"],
    end_idxs=["period", "excitonic_layer_thickness"],
)

In [None]:
restacked_ds = restack_plt_to_period(ds)

In [None]:
restacked_norm_1 = restack_plt_to_period(norm_1)

In [None]:
restacked_norm_2 = restack_plt_to_period(norm_2)

In [None]:
restacked_diff_1 = restack_plt_to_period(diff_1)

In [None]:
restacked_diff_2 = restack_plt_to_period(diff_2)

In [None]:
restacked_ds_flat_spectrum = restack_plt_to_period(ds_flat_spectrum)

In [None]:
restacked_norm_flat_spectrum = restack_plt_to_period(norm_flat_spectrum)

# Plots

## Pre-processing

In [None]:
blue = hv.Cycle.default_cycles['default_colors'][0]
red = hv.Cycle.default_cycles['default_colors'][1]
yellow = hv.Cycle.default_cycles['default_colors'][2]
green = hv.Cycle.default_cycles['default_colors'][3]

In [None]:
wavelengths_in_nanometres = np.linspace(480, 880, 256)
wavelengths = wavelengths_in_nanometres * 1e-9
frequencies = convert_wavelength_and_frequency(wavelengths)
angles = np.linspace(0, 86, 64)

In [None]:
default_oscillator_params = {
    "N": 1e26,
    "permittivity": 2.2,
    "lorentz_resonance_wavelength": 680,
    "lorentz_linewidth": 7.5e13,
}

In [None]:
unpolarised_RTA = ["R", "T", "A"]
s_polarised_RTA = ["Rs", "Ts", "As"]
p_polarised_RTA = ["Rp", "Tp", "Ap"]
reflectance = ["Rs", "Rp", "R"]
transmittance = ["Ts", "Tp", "T"]
absorptance = ["As", "Ap", "A"]
per_oscillator_RTA = ["R_per_oscillator", "T_per_oscillator", "A_per_oscillator"]

In [None]:
# WARNING: all these datasets will be fundamentally changed after this cell, to the extent that it can't be run twice
# For consistency, keep important calculations in the preceding section!
ds = pre_process_for_plots(ds)
restacked_ds = pre_process_for_plots(restacked_ds)
ref = pre_process_for_plots(ref)
crs_1 = pre_process_for_plots(crs_1)
frs_1 = pre_process_for_plots(frs_1)
norm_1 = pre_process_for_plots(norm_1)
restacked_norm_1 = pre_process_for_plots(restacked_norm_1)
norm_2 = pre_process_for_plots(norm_2)
restacked_norm_2 = pre_process_for_plots(restacked_norm_2)
diff_1 = pre_process_for_plots(diff_1)
restacked_diff_1 = pre_process_for_plots(restacked_diff_1)
ds_flat_spectrum = pre_process_for_plots(ds_flat_spectrum)
restacked_ds_flat_spectrum = pre_process_for_plots(restacked_ds_flat_spectrum)
ref_flat_spectrum = pre_process_for_plots(ref_flat_spectrum)
crs_1_flat_spectrum = pre_process_for_plots(crs_1_flat_spectrum)
norm_flat_spectrum = pre_process_for_plots(norm_flat_spectrum)
restacked_norm_flat_spectrum = pre_process_for_plots(restacked_norm_flat_spectrum)

In [None]:
period_dim = hv.Dimension("period", label="Λ", unit="nm")
wavelength_dim = hv.Dimension("wavelength", label="λ", unit="nm")
real_index_dim = hv.Dimension("n")
imag_index_dim = hv.Dimension("k", label="ϰ")

## Useful functions

### Lorentz lines

I want some sort of metric for 'near the resonance' and 'far from the resonance'. The natural unit of distance in this instance is the linewidth. The linewidth is given in rad/s so there need to be some conversions to get the equivalent lines in the plots by wavelength, but they are roughly symmetrical around the peak wavelength.

Based on the plots of the refractive index below, I think I will consider 'near' to be 'within two linewidths', and 'far' to be 'at least four linewidths away'.

In [None]:
# resonance_line = hv.VLine(680, label='LO resonance wavelength').opts(line_dash='dotted')

# Convert from rad/s to Hz
lorentz_linewidth_frequency = default_oscillator_params["lorentz_linewidth"] / (2*np.pi)

In [None]:
def linewidth_calculator_factory(centre, linewidth):
    return partial(linewidth_calculator, centre=centre, linewidth=linewidth)

In [None]:
lorentz_line_frequency = linewidth_calculator_factory(convert_wavelength_and_frequency(680e-9), lorentz_linewidth_frequency)

def lorentz_line_wavelength(x=None):
    x = -x if x is not None else x
    return convert_wavelength_and_frequency(lorentz_line_frequency(x))

In [None]:
def lorentz_vlines(x=0, scale=1, mode='wavelength', **kwargs):
    if mode == 'wavelength':
        line_func = lorentz_line_wavelength
    elif mode == 'frequency':
        line_func = lorentz_line_frequency
    else:
        raise TypeError(f"mode should be 'wavelength' or 'frequency', not {mode}")
        
    match x:
        case [*xs]:
            line_pos = [line_func(x)/scale for x in xs]
        case x:
            line_pos = line_func(x)/scale
            
    return vlines(line_pos, **kwargs)

### Plotting functions

#### Select a wavelength or wavelength range based on the distance from the resonance in linewidths.

In [None]:
def select_lorentz_line(da, lorentz_line=0, window_radius=0):
    if window_radius == 0:
        wavelength = lorentz_line_wavelength(lorentz_line) * 1e9
        wavelength_sel_method = "nearest"
    else:
        wavelength = slice(
            lorentz_line_wavelength(lorentz_line - window_radius) * 1e9,
            lorentz_line_wavelength(lorentz_line + window_radius) * 1e9,
        )
        wavelength_sel_method = None
    da = da.sel(wavelength=wavelength, method=wavelength_sel_method)
    
    return da

#### Plot a comparison of the reflectance and absorptance of the LOPC with that of the reference slab.

In [None]:
opts_R = [opts.Curve(color=blue, ylim=(0,1)), opts.Image(cmap='viridis', clim=(0,1)), opts.QuadMesh(cmap='viridis', clim=(0,1))]

def plot_R(variable="R", dataset=None, label_field="long_name", label_append=None, **hvplot_kwargs):
    plot = plot_var(variable, dataset, label_field, label_append, **hvplot_kwargs)
    plot.opts(*opts_R)
    return plot

# # test
# plot_R(dataset=restacked_ds.sel(period=200, excitonic_layer_thickness=20, num_periods=10).squeeze(), x="wavelength", y="theta").opts(cmap="cividis", clim=(None, None))

In [None]:
opts_T = [opts.Curve(color=yellow, ylim=(0,1)), opts.Image(cmap='cividis', clim=(0,1)), opts.QuadMesh(cmap='cividis', clim=(0,1))]

def plot_T(variable="T", dataset=None, label_field="long_name", label_append=None, **hvplot_kwargs):
    plot = plot_var(variable, dataset, label_field, label_append, **hvplot_kwargs)
    plot.opts(*opts_T)
    return plot

In [None]:
opts_A = [opts.Curve(color=red, ylim=(0,1)), opts.Image(cmap='inferno', clim=(0,1)), opts.QuadMesh(cmap='inferno', clim=(0,1))]

def plot_A(variable="A", dataset=None, label_field="long_name", label_append=None, **hvplot_kwargs):
    plot = plot_var(variable, dataset, label_field, label_append, **hvplot_kwargs)
    plot.opts(*opts_A)
    return plot

In [None]:
def plot_vars_to_funcs(plot_vars):
    var_func_mapping = {
        "R": plot_R,
        "T": plot_T,
        "A": plot_A,
        "Rs": partial(plot_R, variable="Rs"),
        "Ts": partial(plot_T, variable="Ts"),
        "As": partial(plot_A, variable="As"),
        "Rp": partial(plot_R, variable="Rp"),
        "Tp": partial(plot_T, variable="Tp"),
        "Ap": partial(plot_A, variable="Ap"),
        "R_per_oscillator": partial(plot_R, variable="R_per_oscillator"),
        "T_per_oscillator": partial(plot_T, variable="T_per_oscillator"),
        "A_per_oscillator": partial(plot_A, variable="A_per_oscillator"),
        "Rs_per_oscillator": partial(plot_R, variable="Rs_per_oscillator"),
        "Ts_per_oscillator": partial(plot_T, variable="Ts_per_oscillator"),
        "As_per_oscillator": partial(plot_A, variable="As_per_oscillator"),
        "Rp_per_oscillator": partial(plot_R, variable="Rp_per_oscillator"),
        "Tp_per_oscillator": partial(plot_T, variable="Tp_per_oscillator"),
        "Ap_per_oscillator": partial(plot_A, variable="Ap_per_oscillator"),
    }

    plot_funcs = []
    for var in plot_vars:
        try:
            func = var_func_mapping[var]
        except KeyError:
            func = partial(plot_var, variable=var)
        plot_funcs.append(func)

    return plot_funcs

In [None]:
# new version
def plot_RTA(
    period,
    excitonic_layer_thickness,
    num_periods,
    theta,
    title="",
    include=["LOPC", "CRS_1"],
    plot_vars=["R", "T", "A"],
    label_override=None,
    label_append=None,
):
    label_field = None  # for debugging
    label_append = "" if label_append is None else label_append

    P = period
    t = excitonic_layer_thickness
    N = num_periods

    plot_funcs = [
        partial(func, x="wavelength", label_field=label_field)
        for func in plot_vars_to_funcs(plot_vars)
    ]
    curves = []
    if "LOPC" in include:
        lopc_label = " (LOPC)" if label_override is None else label_override
        lopc_label += label_append
        lopc_sel = restacked_ds.sel(
            period=P, excitonic_layer_thickness=t, num_periods=N
        ).squeeze()
        lopc_sel = sel_or_integrate(lopc_sel, "theta", theta, normalisation=1)
        lopc_curves = [
            plot_func(dataset=lopc_sel, label_append=lopc_label).opts(line_dash="solid")
            for plot_func in plot_funcs
        ]
        curves += lopc_curves
    if "CRS_1" in include:
        crs_1_label = " (CRS)" if label_override is None else label_override
        crs_1_label += label_append
        crs_1_sel = crs_1.sel(total_excitonic_thickness=t * N).squeeze()
        crs_1_sel = sel_or_integrate(crs_1_sel, "theta", theta, normalisation=1)
        crs_1_curves = [
            plot_func(dataset=crs_1_sel, label_append=crs_1_label).opts(
                line_dash="dashed"
            )
            for plot_func in plot_funcs
        ]
        curves += crs_1_curves
    if "FRS_1" in include:
        frs_1_label = " (FRS)" if label_override is None else label_override
        frs_1_label += label_append
        frs_1_sel = frs_1.sel(total_thickness=(P + t) * N).squeeze()
        frs_1_sel = sel_or_integrate(frs_1_sel, "theta", theta, normalisation=1)
        frs_1_curves = [
            plot_func(dataset=frs_1_sel, label_append=frs_1_label).opts(
                line_dash="dotted"
            )
            for plot_func in plot_funcs
        ]
        curves += frs_1_curves

    overlay = hv.Overlay(curves).opts(
        opts.Curve(
            ylim=(0, 1),
            ylabel="Intensity",
            title=f"{title}{coordinate_string(period=P, excitonic_layer_thickness=t, num_periods=N, theta=theta)}",
        ),
    )

    return overlay


# # test
# display(
#     plot_RTA(200, 40, 20, 0, "test\n", include=["LOPC", "CRS_1", "FRS_1"]).opts(
#         legend_position="right"
#     )
# )

# display(
#     plot_RTA(
#         200,
#         40,
#         20,
#         (10, 50),
#         "test RA only\n",
#         include=["LOPC", "CRS_1", "FRS_1"],
#         plot_vars=["R", "A"],
#     ).opts(opts.Overlay(legend_position="right"))
# )

# display(
#     plot_RTA(
#         200,
#         40,
#         20,
#         75,
#         "test\n",
#         include=["LOPC"],
#         plot_vars=["R_per_oscillator", "A_per_oscillator"],
#         label_append=" test",
#         label_override="OVERRIDDEN",
#     ).opts(opts.Curve(ylim=(None, None)), opts.Overlay(legend_position="right"))
# )

#### Plot a comparison of normal incidence to integrated

In [None]:
def plot_comparison(*comparison_params: tuple[dict, list["opts"]], plot_func=plot_RTA, **shared_params):
    param_opts = [(shared_params|comp_params, comp_opts) for comp_params, comp_opts in comparison_params]
    plots = [plot_func(**comp_params).opts(*comp_opts) for comp_params, comp_opts in param_opts]
    return plots

In [None]:
def compare_RTA(*args, opts_cycle=None, plot_func=plot_RTA, **shared_params):
    default_opts = [[opts.Curve(line_dash=style)] for style in ["solid", "dashed", "dotted", "dotdash", "dashdot"]]
    opts_cycle = default_opts if opts_cycle is None else opts_cycle
    
    # comparison_params = list(zip(args, opts_cycle))
    
    plots = plot_comparison(*zip(args, opts_cycle), plot_func=plot_func, **shared_params)
    overlay = hv.Overlay(plots).opts(opts.Overlay(legend_position="right"))
                                     
    return overlay

# # test
# shared_params = {
#     "period": 250,
#     "excitonic_layer_thickness": 70,
#     "num_periods": 30,
#     "include": ["LOPC"],
# }
# compare_RTA({"theta": (0, 75), "label_override": " (integrated)"}, {"theta": 0, "label_override": " (θ = 0)"}, **shared_params)

In [None]:
compare_RTA_normal_vs_integrated = partial(compare_RTA, {"theta": (0, 45), "label_override": " (integrated)"}, {"theta": 0, "label_override": " (θ = 0)"}, include= ["LOPC"],)

# # test
# shared_params = {
#     "period": 250,
#     "excitonic_layer_thickness": 70,
#     "num_periods": 30,
# }
# compare_RTA_normal_vs_integrated(**shared_params)

#### Plot the RTA of the structures in 2D

In [None]:
def plot_RTA_2D(
    period,
    excitonic_layer_thickness,
    num_periods,
    theta=(0, 75),
    title="",
    include=["LOPC", "CRS_1"],
):
    P = period
    t = excitonic_layer_thickness
    N = num_periods

    plots = []
    if "LOPC" in include:
        lopc_sel = restacked_ds.sel(
            period=P, excitonic_layer_thickness=t, num_periods=N
        ).squeeze()
        lopc_sel = lopc_sel.sel(theta=slice(*theta))
        plots.append(
            lopc_sel["R"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Reflectance (LOPC)")
            .opts(opts.Image(cmap="viridis"))
        )
        plots.append(
            lopc_sel["T"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Transmittance (LOPC)")
            .opts(opts.Image(cmap="cividis"))
        )
        plots.append(
            lopc_sel["A"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Absorptance (LOPC)")
            .opts(opts.Image(cmap="inferno"))
        )
    if "CRS_1" in include:
        crs_1_sel = crs_1.sel(total_excitonic_thickness=t * N).squeeze()
        crs_1_sel = crs_1_sel.sel(theta=slice(*theta))
        plots.append(
            crs_1_sel["R"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Reflectance (CRS)")
            .opts(opts.Image(cmap="viridis"))
        )
        plots.append(
            crs_1_sel["T"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Transmittance (CRS)")
            .opts(opts.Image(cmap="cividis"))
        )
        plots.append(
            crs_1_sel["A"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Absorptance (CRS)")
            .opts(opts.Image(cmap="inferno"))
        )
    if "FRS_1" in include:
        frs_1_sel = frs_1.sel(total_thickness=(P + t) * N).squeeze()
        frs_1_sel = frs_1_sel.sel(theta=slice(*theta))
        plots.append(
            frs_1_sel["R"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Reflectance (FRS)")
            .opts(opts.Image(cmap="viridis"))
        )
        plots.append(
            frs_1_sel["T"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Transmittance (FRS)")
            .opts(opts.Image(cmap="cividis"))
        )
        plots.append(
            frs_1_sel["A"]
            .hvplot(kind="image", x="wavelength", y="theta", title="Absorptance (FRS)")
            .opts(opts.Image(cmap="inferno"))
        )

    layout = hv.Layout(plots).opts(
        opts.Image(
            clim=(0, 1),
            clabel="Intensity",
        ),
        opts.Layout(
            title=f"{title}{coordinate_string(period=P, excitonic_layer_thickness=t, num_periods=N)}",
        ),
    )

    return layout

# # test
# display(plot_RTA_2D(200, 40, 20, (0, 90), "test\n", include=["LOPC", "CRS_1", "FRS_1"]).opts(opts.Image(frame_width=200)).cols(3))

# display(plot_RTA_2D(200, 40, 20, (10, 50), "test\n", include=["LOPC", "CRS_1", "FRS_1"]).opts(opts.Image(frame_width=200)).cols(3))

#### Plot an enhancement factor.

In [None]:
def plot_ef(
    variable,
    dataset,
    sel=None,
    sel_method=None,
    title="",
    *,
    x="wavelength",
    y=None,
):
    sel = {} if sel is None else sel
    da = dataset[variable].sel(**sel, method=sel_method).squeeze()
    if y is None:
        plot = da.hvplot(x=x, label=f"{variable} enhancement factor")
        plot *= hv.HLine(1).opts(line_dash="dotted")
    else:
        plot = da.hvplot(
            kind="image",
            x=x,
            y=y,
            label=f"{variable} enhancement factor",
            clim=(0.5, 1.5),
        )
    plot = plot.opts(
        opts.Curve(
            title=f"{title}{coordinate_string(**sel)}",
        ),
        opts.Overlay(
            title=f"{title}{coordinate_string(**sel)}",
        ),
    )

    return plot


# # test
# sel_1 = {"period": 200, "excitonic_layer_thickness": 40, "num_periods": 10, "theta": 30}
# sel_2 = {"period": 200, "excitonic_layer_thickness": 40, "num_periods": 10, "theta": 0}
# sel_3 = {"period": 200, "excitonic_layer_thickness": 40, "num_periods": 10}
# display(
#     (
#         plot_ef("As", restacked_norm_1, sel_1, "nearest", "test\n")
#         + plot_ef("As", restacked_norm_2, sel_2, title="test2\n")
#     ).cols(1)
# )
# display(
#     plot_ef("As", restacked_norm_1, sel_3, title="test3\n", x="theta", y="wavelength").opts(clim=(0, 2), cmap="RdBu_r")
# )

#### Test plot_optimum_over_dim

In [None]:
# foo, bar = plot_optimum_over_dim(restacked_ds.A.sel(theta=0, wavelength=660, method="nearest"), "period", "excitonic_layer_thickness", "num_periods", "max")

In [None]:
# foo, bar = plot_optimum_over_dim(integrate_da(restacked_ds.A, "theta", normalisation=1).sel(wavelength=660, method="nearest"), "period", "excitonic_layer_thickness", "num_periods", "max")

#### Find and plot the min or max over any dimension.

In [None]:
def wrapped_2D_plot(
    variable,
    dataset,
    optimise="max",
    lorentz_line=0,
    window_radius=0,
    theta=0,
    cmap="viridis",
    period_start=None,
    period_stop=None,
    integrate_angle=None,
    extra_plots=["RTA_normal", "RTA_int", "norm_1_normal", "norm_1_int"],
    dim=None,  # automatically assign if dataset recognised
):
    plots = []

    if str(dataset) == "restacked_ds":
        # the drop_sel is important for avoiding the most common degenerate cases
        dataset = restacked_ds.drop_sel({"excitonic_layer_thickness": 0})
        dim = "period"

    if str(dataset) == "restacked_norm_1":
        # the drop_sel is important for avoiding the most common degenerate cases
        dataset = restacked_norm_1.drop_sel({"excitonic_layer_thickness": 0})
        dim = "period"

    if str(dataset) == "restacked_diff_1":
        # the drop_sel is important for avoiding the most common degenerate cases
        dataset = restacked_diff_1.drop_sel({"excitonic_layer_thickness": 0})
        dim = "period"

    da = dataset[variable]

    if not integrate_angle:
        da = da.sel(theta=theta, method="nearest")
    else:  # integrate_angle must be a float, so that (theta, integrate_angle) is a slice syntax
        da = da.sel(theta=slice(theta, integrate_angle))
        da = integrate_da(da, "theta", weighting=1, normalisation=1)

    if period_start < period_stop:
        da = da.sel(period=slice(period_start, period_stop))
    else:  # otherwise no data is selected and everything breaks
        da = da.sel(period=slice(period_start, None))
    da = select_lorentz_line(da, lorentz_line=lorentz_line, window_radius=window_radius)

    vline_locs = [0]

    if window_radius == 0:
        wavelength = float(da.wavelength)
        title = f"{optimise.capitalize()}imum {variable} at {wavelength:.0f} nm"
        if lorentz_line != 0:  # don't put two lines over the resonance
            vline_locs.append(lorentz_line)
    else:
        wavelength_start = float(da.wavelength[0])
        wavelength_stop = float(da.wavelength[-1])
        # make it easier to compare values
        da = integrate_da(
            da, "wavelength", weighting=1, normalisation=1
        )  # replaces the below two lines
        # da = normalise_over_dim(da, "wavelength", 1)
        # da = da.integrate("wavelength")
        title = f"{optimise.capitalize()}imum integrated {variable} between {wavelength_start:.0f} and {wavelength_stop:.0f} nm"
        vline_locs.append(lorentz_line - window_radius)
        vline_locs.append(lorentz_line + window_radius)

    plot_1, optimum_coords = plot_optimum_over_dim(
        da,
        dim=dim,
        x="excitonic_layer_thickness",
        y="num_periods",
        optimise=optimise,
    )

    P = float(optimum_coords["period"])
    t = float(optimum_coords["excitonic_layer_thickness"])
    N = float(optimum_coords["num_periods"])
    try:  # this should work if not integrating over theta
        th = float(optimum_coords["theta"])
    except:  # probably the problem is that theta doesn't exist because I integrated over it already
        th = (theta, integrate_angle)
    lorentz_lines = lorentz_vlines(vline_locs, scale=1e-9, mode="wavelength").opts(
        opts.VLine(line_color=green, line_dash="dotted"),
    )

    # give the resonance line a special colour
    lorentz_lines.VLine.I.opts(opts.VLine(line_color=yellow))

    plot_1.opts(
        opts.QuadMesh(cmap=cmap),
        opts.Points(color="red"),
        opts.Overlay(title=f"{title}\nOptimal period: {P:.0f}"),
    )

    plots.append(plot_1)

    if "RTA_normal" in extra_plots:  # plot RTA at theta=0
        new_plot = plot_RTA(
            period=P, excitonic_layer_thickness=t, num_periods=N, theta=0
        )
        new_plot *= lorentz_lines

        plots.append(new_plot)

    if "RTA_int" in extra_plots:  # plot RTA at theta OR integrating over theta
        new_plot = plot_RTA(
            period=P, excitonic_layer_thickness=t, num_periods=N, theta=th
        )
        new_plot *= lorentz_lines
        plots.append(new_plot)

    if "norm_1_normal" in extra_plots:  # plot enhancement factor at theta=0
        sel = {
            "period": P,
            "excitonic_layer_thickness": t,
            "num_periods": N,
            "theta": 0,
        }
        new_plot = plot_ef(variable="A", dataset=restacked_norm_1, sel=sel)
        new_plot *= lorentz_lines

        plots.append(new_plot)

    if (
        "norm_1_int" in extra_plots
    ):  # plot enhancement factor at theta OR integrating over theta
        try:  # this should work if not integrating over theta
            sel = {
                "period": P,
                "excitonic_layer_thickness": t,
                "num_periods": N,
                "theta": th,
            }
            new_plot = plot_ef(variable="A", dataset=restacked_norm_1, sel=sel)
            new_plot *= lorentz_lines
        except:  # if integrating, we need to do the integral *before* normalising
            ds_int = sel_or_integrate(ds, dim="theta", val=th)
            crs_1_int = sel_or_integrate(crs_1, dim="theta", val=th)
            norm = enhancement_factor(
                ds_int,
                ref=crs_1_int,
                common_dim="total_excitonic_thickness",
                method="groupby",
            )
            restacked_norm = restack_plt_to_period(norm)
            # replaces the lines below
            #             # this should all get separated out into its own function
            #             crs_1_like_ds = crs_1.sel(
            #                 total_excitonic_thickness=ds.total_excitonic_thickness
            #             )

            #             ds_int = sel_or_integrate(ds, dim="theta", val=th)
            #             crs_1_int = sel_or_integrate(crs_1_like_ds, dim="theta", val=th)
            #             norm = ds_int / crs_1_int
            #             restacked_norm = norm.stack(multiperiod=['passive_layer_thickness', 'excitonic_layer_thickness']).set_index(multiperiod=['period', 'excitonic_layer_thickness']).unstack()

            sel = {"period": P, "excitonic_layer_thickness": t, "num_periods": N}
            new_plot = plot_ef(variable="A", dataset=restacked_norm, sel=sel)
            new_plot *= lorentz_lines
            sel["theta"] = th
            new_plot = new_plot.opts(opts.Overlay(title=f"{coordinate_string(**sel)}"))
        plots.append(new_plot)

    return hv.Layout(plots).cols(1)

#### Plot the E-field, overlayed with the refractive index profile and layer boundaries.

In [None]:
# # sometimes this errors on the first call for some reason
# fdtd = lumapi.FDTD()

# oscillator = LumericalOscillator(fdtd)

# plot_field(
#     680,
#     lumerical_session=fdtd,
#     oscillator=oscillator,
#     ri_lower=1.35,
#     ri_upper=1.6,
#     excitonic_layer_thickness=30,
#     passive_layer_thickness=210,
#     num_periods=10,
# ).opts(opts.VSpan(color='gray'))

# def wrap_plot_field(
#     wavelength, excitonic_layer_thickness, passive_layer_thickness, num_periods
# ):
#     coords = {
#         "λ": wavelength,
#         "Excitonic layer thickness": excitonic_layer_thickness,
#         "Passive layer thickness": passive_layer_thickness,
#         "Number of periods": num_periods,
#     }

#     title = f"{coordinate_string(**coords)}"

#     return plot_field(
#         wavelength=wavelength,
#         lumerical_session=fdtd,
#         oscillator=oscillator,
#         ri_lower=1.35,
#         ri_upper=1.6,
#         excitonic_layer_thickness=excitonic_layer_thickness,
#         passive_layer_thickness=passive_layer_thickness,
#         num_periods=num_periods,
#     ).opts(opts.Curve(title=title, ylim=(0,None)), opts.VSpan(color="gray"))

In [None]:
# # an example of what this can do
# pn.interact(
#     wrap_plot_field,
#     wavelength=(480, 880),
#     excitonic_layer_thickness=(10, 200),
#     passive_layer_thickness=(0, 300),
#     num_periods=(1, 50),
# )

## Refractive index

In [None]:
with lumapi.FDTD() as fdtd:
    oscillator = LumericalOscillator(session=fdtd)
    oscillator_index = LOPC.LOPC(lumerical_session=fdtd, oscillator=oscillator, **default_oscillator_params).oscillator.index(frequencies)

### Basic plots

High and low wavelength values of the refractive index

In [None]:
print(oscillator_index[0], oscillator_index[-1])

Plot the refractive index of the Lorentz oscillator against wavelength.

In [None]:
(
    (
        hv.Layout(
            complex_elements(
                (wavelengths_in_nanometres, oscillator_index),
                wavelength_dim,
                element=hv.Curve,
                auto_label="group",
                label="Lorentz Oscillator",
            ).values()
        )
        * lorentz_vlines([-4, -2, -1, 0, 1, 2, 4], scale=1e-9)
    )
    .redim("Curve.Real", y=real_index_dim)
    .redim("Curve.Imaginary", y=imag_index_dim)
    .opts(opts.Curve(width=600), opts.VLine(line_dash="dotted"))
    .cols(1)
)

Plot the refractive index of the Lorentz oscillator against frequency.

In [None]:
(
    (
        hv.Layout(
            complex_elements(
                (frequencies, oscillator_index),
                "f (Hz)",
                element=hv.Curve,
                auto_label="group",
                label="Lorentz Oscillator",
            ).values()
        )
        * lorentz_vlines([-4, -2, -1, 0, 1, 2, 4], mode='frequency')
    )
    .redim("Curve.Real", y=real_index_dim)
    .redim("Curve.Imaginary", y=imag_index_dim)
    .opts(opts.Curve(width=600), opts.VLine(line_dash="dotted"))
    .cols(1)
)

Plot the refractive index of the Lorentz oscillator against wavelength on one axis.

In [None]:
(
    hv.Overlay(
        complex_elements(
            (wavelengths_in_nanometres, oscillator_index),
            wavelength_dim,
            element=hv.Curve,
            auto_label="label",
            group="Lorentz Oscillator",
        ).values()
    )
    .redim('Curve.Lorentz_Oscillator.Real', y=real_index_dim)
    .redim('Curve.Lorentz_Oscillator.Imag', y=imag_index_dim)
    .opts(opts.Curve(width=600))
)

Plot the refractive index of the Lorentz oscillator against wavelength on two axes.

In [None]:
(
    hv.Overlay(
        complex_elements(
            (wavelengths_in_nanometres, oscillator_index),
            wavelength_dim,
            element=hv.Curve,
            auto_label="label",
            group="Lorentz Oscillator",
        ).values()
    )
    .redim('Curve.Lorentz_Oscillator.Real', y=real_index_dim)
    .redim('Curve.Lorentz_Oscillator.Imag', y=imag_index_dim)
    .opts(opts.Curve(width=600), opts.Curve('Imaginary', hooks=[plot_secondary]))
)

#### Measured thylakoid membrane data

In [None]:
# Get real and imaginary refractive index data of thylakoid membrance
thyl_real_data = np.loadtxt(data_path / "Chl_real.dat", delimiter="\t", dtype="f")
thyl_imag_data = np.loadtxt(data_path / "Chl_img.dat", delimiter="\t", dtype="f")

In [None]:
osc_n_curve = hv.Curve(
    (wavelengths_in_nanometres, oscillator_index.real),
    kdims=[wavelength_dim],
    vdims=[real_index_dim],
    # group="Lorentz Oscillator",
    # label="Lorentz Oscillator, Real",
    label="Lorentz Oscillator, n",
)
osc_k_curve = hv.Curve(
    (wavelengths_in_nanometres, oscillator_index.imag),
    kdims=[wavelength_dim],
    vdims=[imag_index_dim],
    # group="Lorentz Oscillator",
    label="Lorentz Oscillator, ϰ",
)#.opts(hooks=[plot_secondary])

In [None]:
thyl_n_curve = hv.Curve(
    (thyl_real_data[:, 0], thyl_real_data[:, 1]),
    kdims=[wavelength_dim],
    vdims=[real_index_dim],
    # group="Thylakoid Membrane",
    label="Thylakoid Membrane, n",
)
thyl_k_curve = hv.Curve(
    (thyl_imag_data[:, 0], thyl_imag_data[:, 1]),
    kdims=[wavelength_dim],
    vdims=[imag_index_dim],
    # group="Thylakoid Membrane",
    label="Thylakoid Membrane, ϰ",
).opts(hooks=[plot_secondary])

In [None]:
n_curves = (osc_n_curve * thyl_n_curve)
k_curves = (osc_k_curve * thyl_k_curve).opts(opts.Curve(line_dash='dashed', hooks=[plot_secondary]))

In [None]:
# Even though these symbols aren't used, the options are applied to the underlying curves!
lo_curves = (osc_n_curve * osc_k_curve).opts(opts.Curve(color=blue))

In [None]:
tm_curves = (thyl_n_curve * thyl_k_curve).opts(opts.Curve(color=red))

In [None]:
tm_curves.opts(
    opts.Curve(width=800),
    opts.Overlay(
        legend_position="top_right", legend_opts={"background_fill_alpha": 0.5}
    ),
    clone=True,
)

In [None]:
fig = (n_curves * k_curves*lorentz_vlines([-1, 1], scale=1e-9)).opts(
    opts.Curve(width=1000),
    opts.VLine(line_color=green, line_dash="dotted", line_width=2),
    opts.Overlay(
        width=800,
        legend_position="right",
        legend_opts={"background_fill_alpha": 0.9},
        fontscale=1,  # <-- this is because secondary axes in holoviews are bad. Try changing it and see!
        legend_labels={"Thylakoid Membrane, Real": r"\(n_{TM}\)"},
    ),
    clone=True,
)
fig

In [None]:
hv.save(fig, filename=fig_path/"LO_TM_compare", fmt="png", toolbar=None)

#### Interpolated thylakoid membrane data

In [None]:
# Interpolate the refractive index data
thyl_real = np.interp(wavelengths_in_nanometres, thyl_real_data[:,0], thyl_real_data[:,1])
thyl_imag = np.interp(wavelengths_in_nanometres, thyl_imag_data[:,0], thyl_imag_data[:,1])

# Combine into complex refractive index of the membrane
n_M = thyl_real + 1j*thyl_imag

In [None]:
interp_curves = complex_elements(
    (wavelengths_in_nanometres, n_M),
    auto_label="group",
    label="Thylakoid membrane (interpolated)",
)

(
    interp_curves["Real"]
    * hv.Scatter(thyl_n_curve)
    * (
        interp_curves["Imaginary"].opts(hooks=[plot_secondary])
        * hv.Scatter(thyl_k_curve).opts(hooks=[plot_secondary])
    )
).opts(
    opts.Curve(width=800, color=blue),
    opts.Scatter(color=red),
    opts.Overlay(legend_position="right"),
)

In [None]:
# Interpolate using InterpolatedIndex class
from multilayer_simulator.material import InterpolatedIndex

In [None]:
from scipy import interpolate

In [None]:
real_data_frequencies = convert_wavelength_and_frequency(thyl_real_data[::-1][:,0]*1e-9)
real_data_indexes = thyl_real_data[::-1][:,1]

thyl_real_index = InterpolatedIndex(real_data_frequencies, real_data_indexes, interpolate.interp1d, interp_kwargs={'bounds_error': False, 'fill_value': (real_data_indexes[0], real_data_indexes[-1])})

In [None]:
hv.Curve((thyl_real_data[:,0], thyl_real_index._index_function_real(real_data_frequencies[::-1])))#.opts(ylim=(0, 1.65))

In [None]:
imag_data_frequencies = convert_wavelength_and_frequency(thyl_imag_data[::-1][:,0]*1e-9)
imag_data_indexes = thyl_imag_data[::-1][:,1]

thyl_imag_index = InterpolatedIndex.from_scipy_method(imag_data_frequencies, imag_data_indexes, 'interp1d', interp_kwargs={'bounds_error': False, 'fill_value': (imag_data_indexes[0], imag_data_indexes[-1])})

In [None]:
hv.Curve((thyl_imag_data[:,0], thyl_imag_index._index_function_real(imag_data_frequencies[::-1])))#.opts(ylim=(0, 1.65))

In [None]:
def thyl_complex_index(frequencies, component, **kwargs):
    return thyl_real_index._index_function_real(frequencies) + 1j*thyl_imag_index._index_function_real(frequencies)

In [None]:
from multilayer_simulator.material import CallableIndex

thyl_index = CallableIndex(thyl_complex_index)

In [None]:
(
    hv.Overlay(
        complex_elements(
            (wavelengths_in_nanometres, thyl_index.index(frequencies)),
            auto_label="label"
        ).values()
    ).opts(opts.Curve(width=600))
)

In [None]:
index_plot = hv.NdLayout(
    complex_elements(
        (wavelengths_in_nanometres, thyl_index.index(frequencies)), label="Interpolated"
    )
) * hv.NdLayout(
    {
        "Real": hv.Scatter(
            (thyl_real_data[:, 0], thyl_real_data[:, 1]), label="Measured"
        ),
        "Imaginary": hv.Scatter(
            (thyl_imag_data[:, 0], thyl_imag_data[:, 1]), label="Measured"
        ),
    }
)

In [None]:
index_plot.opts(opts.Curve(width=600), opts.Scatter(color=red), opts.Overlay(legend_position='right')).cols(1)

## Calculate omega_0

In [None]:
print(f"{2*np.pi * convert_wavelength_and_frequency(680e-9):e}")

# Cleanup

In [None]:
# hv.archive.export()