In [22]:
%matplotlib widget

from pathlib import Path
import dataclasses

import matplotlib.pyplot as plt
import numpy as np
import numpy.random
import scipy.signal, scipy.fft
import h5py
import ipywidgets as widgets


In [23]:
@dataclasses.dataclass(eq=False)
class BES_ELM_Labeling_App:
    data_hdf5_file: str|Path = '/home/smithdr/ml/elm_data/step_4_shot_partial_data/data_v1.hdf5'
    input_csv_file: str|Path = 'input.csv'
    output_csv_file: str|Path = 'output.csv'

    def __post_init__(self):

        self.data_hdf5_file = Path(self.data_hdf5_file)
        self.input_csv_file = Path(self.input_csv_file)
        self.output_csv_file = Path(self.output_csv_file)

        # prepare MPL figure
        with plt.ioff():
            self.fig, self.axes = plt.subplots(nrows=5, ncols=1, sharex=True, figsize=(8.5,6.5))
        self.axes[-1].set_xlabel('Time (ms)')
        self.axes[-1].set_ylabel('Frequency (kHz)')
        self.fig.suptitle('Shot')
        self.fig.tight_layout(pad=0.5)
        self.canvas = self.fig.canvas
        self.canvas.header_visible = False
        self.canvas.footer_visible = False
        self.canvas.toolbar_visible = False
        self.canvas.mpl_connect('button_press_event', self.on_mouse_click_callback)
        self.canvas.mpl_connect('motion_notify_event', self.on_mouse_move_callback)
        self.toolbar = self.canvas.toolbar

        self.shot = None
        self.next_shot = self._yield_shot()
        self.nfft = 128
        self.mouse_lines = []
        self.marker_lines = []
        self.click_start = None
        self.click_end = None
        self.vspans = [ [] for ax in self.axes ]

        # prepare ipywidgets
        def_layout = {'width':'80%', 'margin':'5px'}
        self.next_shot_button = widgets.Button(description='Load new shot', layout=def_layout)
        self.next_shot_button.on_click(self.regenerate_figure)
        self.reset_button = widgets.Button(description='Reset view', layout=def_layout)
        self.reset_button.on_click(self.toolbar.home)
        self.back_button = widgets.Button(description='Previous view', layout=def_layout)
        self.back_button.on_click(self.toolbar.back)
        self.zoom_in_button = widgets.Button(description='Zoom in', layout=def_layout)
        self.zoom_in_button.on_click(self.zoom_in_callback)
        self.zoom_out_button = widgets.Button(description='Zoom out', layout=def_layout)
        self.zoom_out_button.on_click(self.zoom_out_callback)
        self.pan_button = widgets.Button(description='Pan view', layout=def_layout)
        self.pan_button.on_click(self.pan_callback)
        self.pan_left_button = widgets.Button(description='Pan left', layout=def_layout)
        self.pan_left_button.on_click(self.pan_left_callback)
        self.pan_right_button = widgets.Button(description='Pan right', layout=def_layout)
        self.pan_right_button.on_click(self.pan_right_callback)
        self.zoom_button = widgets.Button(description='Zoom selection', layout=def_layout)
        self.zoom_button.on_click(self.zoom_callback)
        self.autoy_button = widgets.Button(description='Autoscale y', layout=def_layout)
        self.autoy_button.on_click(self.autoy_callback)
        self.undo_button = widgets.Button(
            description='Undo selection', 
            button_style='danger',
            layout={'width':'80%', 'margin':'5px 5px 15px 5px'},
        )
        self.undo_button.on_click(self.undo_callback)
        self.mode_selection = widgets.RadioButtons(
            options=['Manual', 'Auto'],
            description='Selection mode:',
            layout={'width': '80%'}
        )
        self.mode_selection.observe(self.selection_mode_callback, names='index')
        self.segment_button = widgets.Button(
            description='Segment ELM(s)', 
            layout=def_layout,
            disabled=True,
        )
        self.segment_button.on_click(self.segment_callback)
        self.status_label = widgets.Label(value='State: GUI launched')
        self.controls = widgets.VBox(
            layout = {'justify_content':'center', 'align_items':'center'},
            children = [
                self.next_shot_button,
                self.reset_button,
                self.back_button,
                self.pan_button,
                self.pan_left_button,
                self.pan_right_button,
                self.zoom_button,
                self.zoom_in_button,
                self.zoom_out_button,
                self.autoy_button,
                self.undo_button,
                self.mode_selection,
                self.segment_button,
            ]
        )

        # app states
        self.is_this_click_elm_cycle_start = True
        self.is_pan_active = False
        self.is_zoom_active = False

        # load first shot
        self.regenerate_figure()

    def _yield_shot(self):
        with h5py.File(self.data_hdf5_file) as h5root:
            shots = [int(group_name) for group_name in h5root if not group_name.startswith('config')]
        numpy.random.default_rng().shuffle(shots)
        for shot in shots:
            yield shot

    def pan_callback(self, *_):
        self.toolbar.pan()
        self.status()

    def zoom_callback(self, *_):
        self.toolbar.zoom()
        self.status()

    def selection_mode_callback(self, change):
        self.segment_button.button_style = 'info' if change.new else ''
        self.segment_button.disabled = not bool(change.new)

    def zoom_in_callback(self, *_):
        xlim = self.axes[-1].get_xlim()
        x_middle, x_range = np.mean(xlim), xlim[1]-xlim[0]
        x_range /= 2
        self.axes[-1].set_xlim(np.array([-1,1])*x_range/2 + x_middle)

    def zoom_out_callback(self, *_):
        xlim = self.axes[-1].get_xlim()
        x_middle, x_range = np.mean(xlim), xlim[1]-xlim[0]
        x_range *= 2
        self.axes[-1].set_xlim(np.array([-1,1])*x_range/2 + x_middle)

    def pan_right_callback(self, *_):
        xlim = self.axes[-1].get_xlim()
        x_middle, x_range = np.mean(xlim), xlim[1]-xlim[0]
        x_middle += x_range/2
        self.axes[-1].set_xlim(np.array([-1,1])*x_range/2 + x_middle)

    def pan_left_callback(self, *_):
        xlim = self.axes[-1].get_xlim()
        x_middle, x_range = np.mean(xlim), xlim[1]-xlim[0]
        x_middle -= x_range/2
        self.axes[-1].set_xlim(np.array([-1,1])*x_range/2 + x_middle)

    def autoy_callback(self, b):
        for ax in self.axes[:-1]:
            ax.relim(visible_only=True)
            ax.autoscale(axis='y', enable=True)
        self.axes[-1].set_ylim(0,100)

    def segment_callback(self, b):
        pass

    def on_mouse_click_callback(self, mouse_event):
        if self.toolbar.mode.startswith(('pan','zoom')) or len(self.mouse_lines)==0:
            return
        if self.is_this_click_elm_cycle_start:
            self.click_start = mouse_event.xdata
            for line in self.marker_lines:
                line.set_xdata([mouse_event.xdata, mouse_event.xdata])
        else:
            self.click_end = mouse_event.xdata
            for line in self.marker_lines:
                line.set_xdata([np.nan,np.nan])
            for i, ax in enumerate(self.axes):
                self.vspans[i].append(
                    ax.axvspan(xmin=self.click_start, xmax=self.click_end, alpha=0.1, color='m')
                )
            self.click_start = self.click_end = None
        self.is_this_click_elm_cycle_start = not self.is_this_click_elm_cycle_start
        self.status()

    def undo_callback(self, *args):
        for line, vspans in zip(self.marker_lines, self.vspans):
            if np.nan in line.get_xdata():
                vspan = vspans.pop(-1)
                vspan.remove()
            else:
                line.set_xdata([np.nan,np.nan])
        self.click_start = self.click_end = None
        self.is_this_click_elm_cycle_start = True
        self.status()

    def on_mouse_move_callback(self, mouse_event):
        if self.toolbar.mode.startswith(('pan','zoom')):
            return
        for line in self.mouse_lines:
            line.set_xdata([mouse_event.xdata, mouse_event.xdata])

    def status(self):
        self.zoom_button.button_style = 'info' if self.toolbar.mode.startswith('zoom') else ''
        self.pan_button.button_style = 'info' if self.toolbar.mode.startswith('pan') else ''
        if self.toolbar.mode.startswith('pan'):
            self.status_label.value = 'State: pan mode'
        elif self.toolbar.mode.startswith('zoom'):
            self.status_label.value = 'State: zoom mode'
        elif self.is_this_click_elm_cycle_start:
            self.status_label.value = 'State: click ELM cycle start'
        else:
            self.status_label.value = 'State: click ELM cycle end'

    def print_group(self):
        with h5py.File(self.data_hdf5_file) as h5root:
            group = h5root[str(self.shot)]
            for key, value in group.attrs.items():
                if isinstance(value, np.ndarray):
                    print(f'  Attribute {key}:', value.shape, value.dtype)
                else:
                    print(f'  Attribute {key}:', value)
            for key, value in group.items():
                if 'time' in key:
                    print(f'  Dataset {key}:', value.shape, value.dtype, f'Rate (kHz): {1/np.diff(value[:101]).mean():.1f}')
                else:
                    print(f'  Dataset {key}:', value.shape, value.dtype)

    def regenerate_figure(self, *args):
        self.shot = next(self.next_shot)
        self.fig.suptitle(f'Shot {self.shot}')
        self.status_label.value = 'State: loading shot...'
        with h5py.File(self.data_hdf5_file) as h5root:
            group = h5root[str(self.shot)]
            # self.print_group()
            self.mouse_lines = []
            self.vspans = [ [] for ax in self.axes ]
            for ax in self.axes:
                ax.clear()
            self.axes[0].plot(np.array(group['ip_time']), np.array(group['ip'])/1e6, label='Ip (MA)')
            self.axes[0].plot(np.array(group['pinj_time']), np.array(group['pinj'])/1e3/10, label='PINJ/10 (MW)')
            self.axes[0].plot(np.array(group['pinj_time']), np.array(group['pinj_15r'])/1e6, label='PINJ_15R (MW)')
            self.axes[0].plot(np.array(group['pinj_time']), np.array(group['pinj_15l'])/1e6, label='PINJ_15L (MW)')
            self.axes[0].set_ylabel('Ip and Pnbi')
            denv_time = np.array(group['denv3f_time'])
            denv_names = ['denv3f']
            denv_signals = np.array([group[denv_channel] for denv_channel in denv_names])
            self.axes[1].plot(denv_time, denv_signals.T, label=denv_names[0])
            self.axes[1].set_ylabel('Line avg ne')
            fs_time = np.array(group['FS_time'])
            fs_names = ['FS03','FS04','FS05']
            # fs_names = ['FS03','FS04']
            fs_signals = np.array([group[fs_channel] for fs_channel in fs_names])
            self.axes[2].plot(fs_time, fs_signals.T, label=fs_names,)
            self.axes[2].set_ylabel('Da (au)')
            bes_time = np.array(group['bes_time'])
            bes_signals = np.array(group['bes_signals'])
            self.axes[3].plot(
                bes_time, 
                bes_signals[1::-1,:].T,
                label=['BES 23', 'BES 21'],
            )
            self.axes[3].set_ylabel('BES (V)')
            with scipy.fft.set_workers(4):
                f, _, Sxx = scipy.signal.spectrogram(  # f in kHz
                    x=bes_signals[1,:],
                    fs=(bes_time.size-1) / (bes_time[-1]-bes_time[0]),  # kHz
                    window='hann',
                    nperseg=self.nfft,
                    noverlap=self.nfft/2,
                )
            Sxx = np.log10(Sxx+1e-9)
            self.axes[4].imshow(
                Sxx, 
                vmax=Sxx.max()-3,
                vmin=Sxx.max()-6,
                aspect='auto',
                origin='lower',
                extent=[bes_time[0], bes_time[-1], f[0], f[-1]],
            )
            self.axes[4].set_ylim(0, 100)
            self.axes[4].set_ylabel('Frequency (kHz)')
            self.axes[4].set_xlabel('Time (ms)')
            self.mouse_lines = [ax.axvline(x=np.nan, ls='--', c='m') for ax in self.axes]
            self.marker_lines = [ax.axvline(x=np.nan, c='m') for ax in self.axes]
            ip_end_index = np.flatnonzero(np.array(group['ip'])>100e3)[-1]
            t_end = group['ip_time'][ip_end_index] + 500
            self.axes[0].set_xlim([0, t_end])
            for ax in self.axes[:-1]:
                ax.legend(fontsize='small', loc='upper right', labelspacing=0.2)
            self.status()


In [24]:
plt.close('all')
widgets.Widget.close_all()

app = BES_ELM_Labeling_App()

widgets.AppLayout(
    left_sidebar=app.controls,
    center=app.fig.canvas,
    footer=app.status_label,
    pane_widths=['160px', 1, 0],
    pane_heights=[0, 1, '40px'],
)

AppLayout(children=(Label(value='State: click ELM cycle start', layout=Layout(grid_area='footer')), VBox(child…