Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 106 additions & 52 deletions src/axiomatic/pic_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np # type: ignore
import iklayout # type: ignore
import matplotlib.pyplot as plt # type: ignore
from ipywidgets import interactive, IntSlider # type: ignore
import plotly.graph_objects as go # type: ignore
from typing import List, Optional, Tuple, Dict, Set

from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
Expand Down Expand Up @@ -106,11 +106,10 @@ def plot_interactive_spectra(
spectra: List[List[List[float]]],
wavelengths: List[float],
spectrum_labels: Optional[List[str]] = None,
slider_index: Optional[List[int]] = None,
vlines: Optional[List[float]] = None,
hlines: Optional[List[float]] = None,
):
"""
""""
Creates an interactive plot of spectra with a slider to select different indices.
Parameters:
-----------
Expand All @@ -119,63 +118,118 @@ def plot_interactive_spectra(
corresponding to the transmission of a single wavelength.
wavelengths : list of float
A list of wavelength values corresponding to the x-axis of the plot.
slider_index : list of int, optional
A list of indices for the slider. Defaults to range(len(spectra[0])).
vlines : list of float, optional
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
hlines : list of float, optional
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
Returns:
--------
ipywidgets.widgets.interaction.interactive
An interactive widget that allows the user to select different indices using a slider.
Notes:
------
- The function uses matplotlib for plotting and ipywidgets for creating the interactive
slider.
- The y-axis limits are fixed based on the global minimum and maximum values across all
spectra.
- Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines`
parameters.
"""
# Calculate global y-limits across all arrays
y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectra)
y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectra)
if hlines:
y_min = min(hlines + [y_min])*0.95
y_max = max(hlines + [y_max])*1.05

slider_index = slider_index or list(range(len(spectra[0])))
spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectra))]
vlines = vlines or []
hlines = hlines or []

# Function to update the plot
def plot_array(index=0):
plt.close("all")
plt.figure(figsize=(8, 4))
for i, array in enumerate(spectra):
plt.plot(wavelengths, array[index], lw=2, label=spectrum_labels[i])
for x_val in vlines:
plt.axvline(
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
) # Add vertical line
for y_val in hlines:
plt.axhline(
y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
) # Add vertical line
plt.title(f"Iteration: {index}")
plt.xlabel("X")
plt.ylabel("Y")
plt.ylim(y_min, y_max) # Fix the y-limits
plt.legend()
plt.grid(True)
plt.show()
# Defaults
if spectrum_labels is None:
spectrum_labels = [f"Spectrum {i}" for i in range(len(spectra))]
if vlines is None:
vlines = []
if hlines is None:
hlines = []

# Adjust y-axis range
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
y_min = min(all_vals)
y_max = max(all_vals)
if hlines:
y_min = min(hlines + [y_min]) * 0.95
y_max = max(hlines + [y_max]) * 1.05

# Create hlines and vlines
shapes = []
for xv in vlines:
shapes.append(dict(
type="line",
xref="x", x0=xv, x1=xv,
yref="paper", y0=0, y1=1,
line=dict(color="red", dash="dash")
))
for yh in hlines:
shapes.append(dict(
type="line",
xref="paper", x0=0, x1=1,
yref="y", y0=yh, y1=yh,
line=dict(color="red", dash="dash")
))


# Create frames for each index
slider_index = list(range(len(spectra[0])))
fig = go.Figure()

# Build initial figure for immediate display
init_idx = slider_index[0]
for i, spec in enumerate(spectra):
fig.add_trace(
go.Scatter(
x=wavelengths,
y=spec[init_idx],
mode="lines",
name=spectrum_labels[i]
)
)
# Build frames for animation
frames = []
for idx in slider_index:
frame_data = []
for i, spec in enumerate(spectra):
frame_data.append(
go.Scatter(
x=wavelengths,
y=spec[idx],
mode="lines",
name=spectrum_labels[i]
)
)
frames.append(
go.Frame(
data=frame_data,
name=str(idx),
)
)

slider = IntSlider(
value=0, min=0, max=len(spectra[0]) - 1, step=1, description="Index"
fig.frames = frames


# Create transition steps
steps = []
for idx in slider_index:
steps.append(dict(
method="animate",
args=[
[str(idx)],
{
"mode": "immediate",
"frame": {"duration": 0, "redraw": True},
"transition": {"duration": 0}
}
],
label=str(idx),
))

# Create the slider
sliders = [dict(
active=0,
currentvalue={"prefix": "Index: "},
pad={"t": 50},
steps=steps
)]

# Create the layout
fig.update_layout(
xaxis_title="Wavelength",
yaxis_title="Transmission",
shapes=shapes,
sliders=sliders,
yaxis=dict(range=[y_min, y_max]),
)
return interactive(plot_array, index=slider)

fig.show()


def plot_parameter_history(parameters: List[Parameter], parameter_history: List[dict]):
Expand Down
Loading