diff --git a/src/axiomatic/pic_helpers.py b/src/axiomatic/pic_helpers.py index 9d3e56a..4ce530d 100644 --- a/src/axiomatic/pic_helpers.py +++ b/src/axiomatic/pic_helpers.py @@ -2,7 +2,7 @@ import numpy as np # type: ignore import iklayout # type: ignore import matplotlib.pyplot as plt # type: ignore -from ipywidgets import interactive, IntSlider # type: ignore +import plotly.graph_objects as go # type: ignore from typing import List, Optional, Tuple, Dict, Set from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation @@ -106,11 +106,10 @@ def plot_interactive_spectra( spectra: List[List[List[float]]], wavelengths: List[float], spectrum_labels: Optional[List[str]] = None, - slider_index: Optional[List[int]] = None, vlines: Optional[List[float]] = None, hlines: Optional[List[float]] = None, ): - """ + """" Creates an interactive plot of spectra with a slider to select different indices. Parameters: ----------- @@ -119,63 +118,118 @@ def plot_interactive_spectra( corresponding to the transmission of a single wavelength. wavelengths : list of float A list of wavelength values corresponding to the x-axis of the plot. - slider_index : list of int, optional - A list of indices for the slider. Defaults to range(len(spectra[0])). vlines : list of float, optional A list of x-values where vertical lines should be drawn. Defaults to an empty list. hlines : list of float, optional A list of y-values where horizontal lines should be drawn. Defaults to an empty list. - Returns: - -------- - ipywidgets.widgets.interaction.interactive - An interactive widget that allows the user to select different indices using a slider. - Notes: - ------ - - The function uses matplotlib for plotting and ipywidgets for creating the interactive - slider. - - The y-axis limits are fixed based on the global minimum and maximum values across all - spectra. - - Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines` - parameters. """ - # Calculate global y-limits across all arrays - y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectra) - y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectra) - if hlines: - y_min = min(hlines + [y_min])*0.95 - y_max = max(hlines + [y_max])*1.05 - - slider_index = slider_index or list(range(len(spectra[0]))) - spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectra))] - vlines = vlines or [] - hlines = hlines or [] - # Function to update the plot - def plot_array(index=0): - plt.close("all") - plt.figure(figsize=(8, 4)) - for i, array in enumerate(spectra): - plt.plot(wavelengths, array[index], lw=2, label=spectrum_labels[i]) - for x_val in vlines: - plt.axvline( - x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})" - ) # Add vertical line - for y_val in hlines: - plt.axhline( - y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})" - ) # Add vertical line - plt.title(f"Iteration: {index}") - plt.xlabel("X") - plt.ylabel("Y") - plt.ylim(y_min, y_max) # Fix the y-limits - plt.legend() - plt.grid(True) - plt.show() + # Defaults + if spectrum_labels is None: + spectrum_labels = [f"Spectrum {i}" for i in range(len(spectra))] + if vlines is None: + vlines = [] + if hlines is None: + hlines = [] + + # Adjust y-axis range + all_vals = [val for spec in spectra for iteration in spec for val in iteration] + y_min = min(all_vals) + y_max = max(all_vals) + if hlines: + y_min = min(hlines + [y_min]) * 0.95 + y_max = max(hlines + [y_max]) * 1.05 + + # Create hlines and vlines + shapes = [] + for xv in vlines: + shapes.append(dict( + type="line", + xref="x", x0=xv, x1=xv, + yref="paper", y0=0, y1=1, + line=dict(color="red", dash="dash") + )) + for yh in hlines: + shapes.append(dict( + type="line", + xref="paper", x0=0, x1=1, + yref="y", y0=yh, y1=yh, + line=dict(color="red", dash="dash") + )) + + + # Create frames for each index + slider_index = list(range(len(spectra[0]))) + fig = go.Figure() + + # Build initial figure for immediate display + init_idx = slider_index[0] + for i, spec in enumerate(spectra): + fig.add_trace( + go.Scatter( + x=wavelengths, + y=spec[init_idx], + mode="lines", + name=spectrum_labels[i] + ) + ) + # Build frames for animation + frames = [] + for idx in slider_index: + frame_data = [] + for i, spec in enumerate(spectra): + frame_data.append( + go.Scatter( + x=wavelengths, + y=spec[idx], + mode="lines", + name=spectrum_labels[i] + ) + ) + frames.append( + go.Frame( + data=frame_data, + name=str(idx), + ) + ) - slider = IntSlider( - value=0, min=0, max=len(spectra[0]) - 1, step=1, description="Index" + fig.frames = frames + + + # Create transition steps + steps = [] + for idx in slider_index: + steps.append(dict( + method="animate", + args=[ + [str(idx)], + { + "mode": "immediate", + "frame": {"duration": 0, "redraw": True}, + "transition": {"duration": 0} + } + ], + label=str(idx), + )) + + # Create the slider + sliders = [dict( + active=0, + currentvalue={"prefix": "Index: "}, + pad={"t": 50}, + steps=steps + )] + + # Create the layout + fig.update_layout( + xaxis_title="Wavelength", + yaxis_title="Transmission", + shapes=shapes, + sliders=sliders, + yaxis=dict(range=[y_min, y_max]), ) - return interactive(plot_array, index=slider) + + fig.show() def plot_parameter_history(parameters: List[Parameter], parameter_history: List[dict]):