In [None]:
# Imports
import time
import copy
import yaml
import numpy as np
import pandas as pd
from enum import Enum
import os

import panel as pn
import param

import taurus
import PyTango

from online_data.scan_monitor.scan_acquisition import callback as scanmon_callback
from online_data.scan_monitor.scan_acquisition import initialize_scan_monitor
from online_data.scan_monitor.scan_acquisition import unregister_all as scanmon_unregister_all

from online_data.scan_monitor.utils import getDeviceNamesByClass
from taurus.core.util.codecs import CodecFactory

from peakutils import baseline

import holoviews as hv
from holoviews import opts
from holoviews import streams
from holoviews.streams import Pipe
from bokeh.plotting import show
from holoviews import opts

hv.extension('bokeh')
pn.extension('texteditor')
#pn.extension()

In [None]:
scan_jobs = initialize_scan_monitor()

# Hardware Settings #

In [None]:
TIME_SCALING_FACTOR = 1e-3 # convert to microseconds

# Base Settings #

In [None]:
TOF_YLOG = False

## Globals ##

In [None]:
digitization_pts = np.arange(2000, dtype = float)*TIME_SCALING_FACTOR


mass_calibration_data = None # otherwise it is tuple(A, t0), m=A(t-t0)
mass_calibration_pairs = None # otherwise it is tuple(A, t0), m=A(t-t0)

settings_filename = 'digitizer_multitab_settings'

class States(Enum):
    ON = 1
    SCANNING = 2
    
scan_info = {}
scan_step = 0
    
State = States.ON

scan_data = []
scan_table = pd.DataFrame()
df_mov_centroids = pd.DataFrame()

full_range_roi_init_data = {
    'full_range':{
        'tof_roi_histo': np.zeros((len(digitization_pts)), dtype=float),
        'tof_roi': (0,-1),
        'roi_color': '#99ef78',
    }
}
    
scan_data.append(full_range_roi_init_data)  #always at least one record data

pipe_roi_scan = Pipe(data=pd.DataFrame())
pipe_tof = Pipe(data=[])

pipe_roi_scan.send(scan_table)
pipe_tof.send(np.zeros((len(digitization_pts)),) )


## Init data sources ##

In [None]:
mcs = taurus.Device(getDeviceNamesByClass('MacroServer')[0])

dec_format = mcs.Environment[0]
codec = CodecFactory().getCodec(dec_format)
active_mg=taurus.Device(codec.decode(mcs.Environment)[1]['new']["ActiveMntGrp"])
data_sources = [taurus.Device(alias).fullname for alias in active_mg.ElementList]

data_source_select_widget = pn.widgets.Select(name='Data source', options=data_sources)

## Utilities ##

In [None]:
def get_movable_label():
    global scan_info
    global State
    xlabel = 'scan step'
    if State == States.SCANNING:
        xlabel = scan_info['movable_names'][0]
    return xlabel

# mass calibration functions##########
# had to move it here, because of no forward declaration in notebook

def calculate_mass_from_tof(time_pts):
    global mass_calibration_data
    if mass_calibration_data != None:
        return mass_calibration_data[0]*pow((time_pts - mass_calibration_data[1]),2)
    return time_pts

def mass_calibrate(t1,m1,t2,m2):
    
    global mass_calibration_data, digitization_pts
    
    mass1_sqrt = pow(m1,0.5)
    mass2_sqrt = pow(m2,0.5)
    
    t0 = (mass2_sqrt*t1 - mass1_sqrt*t2)/(mass2_sqrt-mass1_sqrt)
    A = m2/pow((t2 - t0),2)
    
    mass_calibration_data = (A, t0)
    
    digitization_pts = np.arange(len(digitization_pts))*TIME_SCALING_FACTOR
    digitization_pts = calculate_mass_from_tof(digitization_pts)
    
    min_mass_index = np.argmin(digitization_pts)
    if min_mass_index > 0:
        digitization_pts[:min_mass_index] = -digitization_pts[:min_mass_index]
   

# save/load settings functions########################################

def load_settings(settings_filename):
    global scan_data, mass_calibration_pairs, digitization_pts
    
    if os.path.exists(f'{settings_filename}.yaml'):
        with open(f'{settings_filename}.yaml','r') as f:
            loaded_settings_data = yaml.load(f, Loader=yaml.Loader)
        if loaded_settings_data is None:
            return
        if 'roi' in loaded_settings_data.keys():
            for r_name, roi in loaded_settings_data['roi'].items():
                roi['roi_tof_histo'] = np.zeros((len(digitization_pts)-1), dtype=float)
                scan_data[-1][r_name] = roi

        if 'mass_calibration_pairs' in loaded_settings_data.keys():
            mass_calibration_pairs = loaded_settings_data['mass_calibration_pairs']
            mass_calibrate(mass_calibration_pairs['t1'], mass_calibration_pairs['m1'], 
                           mass_calibration_pairs['t2'], mass_calibration_pairs['m2'])
    


def save_settings(Event):
    global scan_data, mass_calibration_pairs, settings_filename

    roi_dict = {}
    for roi_name, roi in scan_data[-1].items():
        if roi_name == 'full_range':
            continue

        roi_copy = copy.copy(roi)
        del roi_copy['roi_tof_histo']
        roi_dict[roi_name] = roi_copy
        
    settings_dict = {'roi': roi_dict}
    if mass_calibration_pairs != None:
        settings_dict['mass_calibration_pairs'] = mass_calibration_pairs
        

    with open(f'{settings_filename}.yaml', 'w',) as f :
        yaml.dump(settings_dict,f,sort_keys=False)    
    
load_settings(settings_filename) # Load settings from yaml

In [None]:
#save_settings(settings_filename)

### Image&Spectrum panel  ###

In [None]:
#def on_roi_select_change(Event):
#    pass 

update_tof_plot_range = False

def reset_button_callback(Event):
    global scan_data, scan_table, digitization_pts
    
    newRecord = copy.deepcopy(scan_data[-1])
    for roi_name, roi_record in newRecord.items():
        roi_record['roi_tof_histo'] = np.zeros((len(digitization_pts)), dtype=float)
    scan_data = [newRecord]
    scan_table = pd.DataFrame()
    
roi_select = pn.widgets.MultiChoice(
    name='ToF/mass ROI', value=list(scan_data[-1].keys()),
    options=list(scan_data[-1].keys()),)

def update_tof_range(Event):
    global update_tof_plot_range
    update_tof_plot_range = True

reset_button = pn.widgets.Button(name='Reset data', button_type='primary')
reset_button.on_click(reset_button_callback)

save_settings_button = pn.widgets.Button(name='Save settings', button_type='primary')
save_settings_button.on_click(save_settings)

update_tof_plot_range_button = pn.widgets.Button(name='Update plot range', button_type='primary')
update_tof_plot_range_button.on_click(update_tof_range)

baseline_checkbox = pn.widgets.Checkbox(name='Baseline correction')
invert_signal_checkbox = pn.widgets.Checkbox(name='Invert amplitude')

#roi_select.param.watch(on_roi_select_change, 'value')

### mass calibration panel ###

In [None]:
df_widget_mov_centroids = pn.widgets.DataFrame(pd.DataFrame(), height=800, frozen_columns=1, autosize_mode='none')
df_widget_mov_centroids = pn.widgets.DataFrame(pd.DataFrame(), height=800, frozen_columns=1, autosize_mode='none')

time_point1 = pn.widgets.FloatInput(name='Time 1', value=0., start=0,)
time_point2 = pn.widgets.FloatInput(name='Time 2', value=1., start=0,)
mass_point1 = pn.widgets.FloatInput(name='Mass 1', value=0., start=0,)
mass_point2 = pn.widgets.FloatInput(name='Mass 2', value=1., start=0,)


def mass_calibrate_callback(Event):
    global mass_calibration_pairs, mass_calibration_data
    m1 = mass_point1.value
    m2 = mass_point2.value
    t1 = time_point1.value
    t2 = time_point2.value
    
    mass_calibration_pairs = {'t1':t1,'m1':m1, 't2': t2, 'm2':m2}
                              
    mass_calibrate(t1,m1,t2,m2)
    
    mass_calib_string.value = f'{mass_calibration_data[0]}  {mass_calibration_data[1]}'
    


def reset_calibration_callback(Event):
    global mass_calibration_data, digitization_pts
    mass_calibration_data = None
    digitization_pts = np.arange(len(digitization_pts))
    mass_calib_string.value = 'None'

mass_calibrate_button = pn.widgets.Button(name='Calibrate ToF spectra', button_type='primary')
mass_calibrate_button.on_click(mass_calibrate_callback)

reset_calibration_button = pn.widgets.Button(name='Reset mass calibration', button_type='primary')
reset_calibration_button.on_click(reset_calibration_callback)

mass_calib_string = pn.widgets.StaticText(name='Calibration const.', value=mass_calibration_data)



### ROI management panel ###

In [None]:
m1_point = pn.widgets.FloatInput(name='m1', value=0., start=0)
m2_point = pn.widgets.FloatInput(name='m2', value=0., start=0)


text_input_roi_name = pn.widgets.TextInput(name='ROI name', placeholder='Enter the name of ROI here...')

colorpicker = pn.widgets.ColorPicker(name='ROI color', value='#99ef78')

def add_roi_callback(Event):
    
    if text_input_roi_name.value == '':
        return
    
    new_roi= {
            'tof_roi': (m1_point.value, m2_point.value),
            'roi_tof_histo': np.zeros((len(digitization_pts)-1), dtype=float),
            'roi_color': colorpicker.value,
    }
    
    scan_data[-1][text_input_roi_name.value] = new_roi
    
    if not (text_input_roi_name.value in roi_select.options):
        options_list = roi_select.options.copy()
        options_list.append(text_input_roi_name.value)
        roi_select.options = []
        roi_select.options = options_list
        roi_select.value=[text_input_roi_name.value]
    

def remove_roi_by_name_callback(Event):
    if text_input_roi_name.value == 'full_range':
        return 
    
    del scan_data[-1][text_input_roi_name.value]
    
    value_list = roi_select.value.copy()
    if text_input_roi_name.value in value_list:
        value_list.remove(text_input_roi_name.value)
        roi_select.value = []
        #roi_select.values = []
    
    if text_input_roi_name.value in roi_select.options:
        options_list = roi_select.options.copy()
        options_list.remove(text_input_roi_name.value)
        roi_select.options = []
        roi_select.options = options_list
        #roi_select.values = []
    
    if roi_select.value == [] and value_list != []:
        roi_select.value = value_list
  


add_roi_button = pn.widgets.Button(name='Add ROI', button_type='primary')
add_roi_button.on_click(add_roi_callback)

remove_roi_by_name_button = pn.widgets.Button(name='Remove ROI by name', button_type='primary')
remove_roi_by_name_button.on_click(remove_roi_by_name_callback)




In [None]:
def process_scan_roi_data(in_table):    
    global df_widget_mov_centroids
    global df_mov_centroids
    global scan_info
    
    if df_mov_centroids.empty:
        df_mov_centroids = pd.DataFrame(index=scan_info['movable_names'], columns=scan_info['roi_names'])
    
    for roi_name in scan_info['roi_names']:
        integral_val = np.nansum(scan_table[roi_name])
        for movable_name in scan_info['movable_names']:
            if integral_val > 0:
                df_mov_centroids[roi_name][movable_name] = np.nansum(scan_table[roi_name]*scan_table[movable_name])/integral_val
            else:
                df_mov_centroids[roi_name][movable_name] = float('nan')
            
    df_widget_mov_centroids.value = df_mov_centroids
 

In [None]:
# retriving scan information from sardana
@scanmon_callback
def on_scandata_received(in_data): 
    #print(in_data)
    global State
    global scan_info
    global scan_table
    global df_mov_centroids
    global scan_step
    global scan_data
    global digitization_pts
    
    if in_data['type'] == 'data_desc':
        #print('total_scan_intervals: ', in_data['data']['total_scan_intervals'])
        scan_info['total_scan_intervals'] = in_data['data']['total_scan_intervals']
        scan_info['serial_number'] = in_data['data']['serialno']
        scan_info['ref_moveables'] = in_data['data']['ref_moveables']
        # enumerate all movables
        
        movable_names = []
       
        for column_descr in in_data['data']['column_desc']:
            if 'instrument' in column_descr.keys() and not 'output' in column_descr.keys():
                #print(column_descr['name'])
                movable_names.append(column_descr['name'])
                
        scan_info['movable_names'] = movable_names
                
        roi_names = [roi_name for roi_name in scan_data[-1].keys() if roi_name != 'full_range']
        scan_info['roi_names'] = roi_names
        table_columns = ['point_nb'] + movable_names + roi_names
        #print(table_columns)
                
        scan_table = pd.DataFrame(index=range(0,scan_info['total_scan_intervals']+1), columns=table_columns)

        scan_step = 0
        
        df_mov_centroids = pd.DataFrame()
        
        State = States.SCANNING
        
        newRecord = copy.copy(scan_data[-1])
        for roi_name, roi_record in newRecord.items():
                roi_record['roi_tof_histo'] = np.zeros((len(digitization_pts)), dtype=float)
        scan_datan = [newRecord]
        
        return
    
    if in_data['type'] == 'record_data':
        #print('in_data: ', in_data)
        point_nb = in_data['data']['point_nb']
        scan_table['point_nb'][point_nb] = point_nb
        for mov_name in scan_info['movable_names']:
            scan_table[mov_name][point_nb] = in_data['data'][mov_name]
            
            
        records = in_data['records']
        last_record = records[list(records.keys())[-1]]
        if not data_source_select_widget.value in last_record.data.keys():
            return
        
        last_1d_data = last_record.data[data_source_select_widget.value]
        
        if invert_signal_checkbox.value == True:
            last_1d_data = -last_1d_data
        
        if baseline_checkbox.value == True:
            last_1d_data -= baseline(last_1d_data)
        
        digitization_pts = np.arange(len(last_1d_data), dtype = float)*TIME_SCALING_FACTOR
        digitization_pts = calculate_mass_from_tof(digitization_pts)
        
        newRecord = copy.copy(scan_data[-1])
        for roi_name, roi_record in newRecord.items():
                roi_record['roi_tof_histo'] = np.zeros((len(digitization_pts)), dtype=float)
        newRecord['full_range']['roi_tof_histo'] = last_1d_data
        scan_data.append(newRecord)
        
        
        for roi_key, roi_val in scan_data[-1].items():
            if roi_key == 'full_range':
                continue
            roi_indxs = np.logical_and(digitization_pts>roi_val['tof_roi'][0],\
                           digitization_pts<roi_val['tof_roi'][1])
            roi_val = np.sum(last_1d_data[roi_indxs])
            scan_table[roi_key][point_nb] = roi_val
        
        pipe_tof.send(last_1d_data)
            
        process_scan_roi_data(scan_table)                    
        pipe_roi_scan.send(scan_table)            
        return
    
    if in_data['type'] == 'record_end':
        if not scan_table.iloc[-1].isnull().any():
            State = States.ON            
        return
            
    # only sinlge record sent from macro ctn
    if in_data['type'] == 'record_single':
        last_record = in_data['records']
        if not data_source_select_widget.value in last_record.data.keys():
            return
        
        last_1d_data = last_record.data[data_source_select_widget.value]
        
        if invert_signal_checkbox.value == True:
            last_1d_data = -last_1d_data
        
        if baseline_checkbox.value == True:
            last_1d_data -= baseline(last_1d_data)
        
        digitization_pts = np.arange(len(last_1d_data), dtype = float)*TIME_SCALING_FACTOR
        digitization_pts = calculate_mass_from_tof(digitization_pts)        
        
        pipe_tof.send(last_1d_data)


        
        

## Panel GUI, server start ##

In [None]:
def make_roi_scan_plot(data):
    global roi_select  
    global State
    global scan_info
    
    roi_names = [i for i in roi_select.value if i != 'full_range']

    if roi_names != [] and data.empty == False:
        
        if State == States.SCANNING:
            x_vals = np.array(data[scan_info['movable_names'][0]])
        else:
            x_vals = np.array(data['point_nb'])

        xmax = np.nanmax(x_vals)
        xmin = np.nanmin(x_vals)
        ymax = np.nanmax(data[roi_names].max())
        ymin = np.nanmin(data[roi_names].min())

        if xmax==xmin:
            xlimits = (xmax-1,xmax+1,)
        else:
            xrange = xmax-xmin
            xlimits = (xmin - xrange*0.1,xmax+xrange*0.1,)

        if ymax==ymin:
            ylimits = (ymax-1,ymax+1,)
        else:
            yrange = ymax-ymin
            ylimits = (ymin - yrange*0.1,ymax+yrange*0.1,)

        colors = {}
        for roi_name, roi in scan_data[-1].items():
            colors[roi_name]=roi['roi_color']
        
        xlabel = get_movable_label()
        
        if 'serial_number' in scan_info.keys() and State == States.SCANNING:
            figure_title = f"Scan: {scan_info['serial_number']}"
        else:
            figure_title = ''
       
        curves=[hv.Curve((x_vals, data[roi_name])).opts(xlabel=xlabel, \
                            ylabel='Amplitude, [au]', width=600, height=450, show_grid=True, tools=['hover'],\
                            color=colors[roi_name], xlim=xlimits, ylim=ylimits, axiswise=True, \
                            framewise=True,)*hv.Scatter((x_vals,\
                            data[roi_name])).opts(xlabel=xlabel,\
                            ylabel='Amplitude, [au]', width=600, height=450, show_grid=True,\
                            tools=['hover'], color=colors[roi_name], xlim=xlimits, ylim=ylimits, axiswise=True,\
                            framewise=True, size=10) for roi_name in roi_select.value if roi_name != 'full_range' ]

               

        return hv.Overlay(curves).opts( title=figure_title, axiswise=True, framewise=True,) #.redim.range(Sample=(0,10))
    else:
        hv_curve = hv.Curve([]).opts(xlabel='scan num.', ylabel='Amplitude, [au]', width=1200, height=800, show_grid=True, tools=['hover'],  axiswise=True, framewise=True,)
        return hv.Overlay([hv_curve , hv_curve ],).opts(axiswise=True, framewise=True,)

    
    

def make_tof_histo_plot(data):
    global roi_select
    global update_tof_plot_range
    global State
    global scan_info
    global scan_table
    
    #    if State == States.SCANNING:
    #        x_vals = np.array(data[scan_info['movable_names'][0]])
    #    else:
    #        x_vals = np.array(data['point_nb'])
    
    
    if 'serial_number' in scan_info.keys() and State == States.SCANNING:
        scan_step = max(scan_table['point_nb'])
        mov_val = np.array(scan_table[get_movable_label()])[scan_step]
        figure_title = f"Scan: {scan_info['serial_number']},  {get_movable_label()} at: {mov_val}"
    else:
        figure_title = ''
    
    
    last_roi_data = scan_data[-1]
 
    selected_rois = [(last_roi_data[i]['tof_roi'], last_roi_data[i]['roi_color']) for i in roi_select.value if i != 'full_range']
    
    if mass_calibration_data == None:
        xlabel='ToF (µs)'
    else:
        xlabel='Mass (amu)'
        
    ylabel='Normalized amplitude'
    
    framewise = False
    if update_tof_plot_range == True:
        framewise = True
        update_tof_plot_range = False
    
    tof_hist_plot_log = hv.Curve((digitization_pts, data)).opts(xlabel=xlabel, ylabel=ylabel,
                                                                axiswise=True, framewise=framewise,
                                                                height=800, width=1200, show_grid=True,
                                                                tools=['hover'], ylim=(min(data),max(data)),
                                                                xlim=(digitization_pts[0],
                                                                digitization_pts[-1]), logy=TOF_YLOG)
    
    
    #rng = hv.streams.RangeY(source=p)
    
    mass_rois = [hv.Rectangles((i[0], min(data), i[1], max(data))).opts(alpha=0.3, color=c) for i, c in selected_rois]
    #mass_rois = [hv.Rectangles((i[0], rng[0], i[1], rng[1])).opts(alpha=0.3, color=c) for i, c in selected_rois]
    #return tof_hist_plot_log * hv.Overlay(mass_rois)
    return hv.Overlay([tof_hist_plot_log] + mass_rois).opts(title=figure_title)


main_panel = pn.Column(data_source_select_widget, reset_button, save_settings_button)

mass_calibrate_panel = pn.Row(pn.Column(time_point1, mass_point1, time_point2, mass_point2,\
                                mass_calibrate_button, reset_calibration_button, mass_calib_string),)

roi_panel = pn.Column(m1_point, m2_point, text_input_roi_name, colorpicker, add_roi_button,\
                            remove_roi_by_name_button)

settings_tabs = pn.Tabs(('Main', main_panel),
        ('Mass calibration', mass_calibrate_panel),
        ('ROI settings',roi_panel))

viz_panel = pn.Tabs(('Plots', pn.Column(pn.Row(roi_select,update_tof_plot_range_button, baseline_checkbox, invert_signal_checkbox),
                                        hv.DynamicMap(make_tof_histo_plot,  streams=[pipe_tof]))),
                     ('Scan', pn.Column(hv.DynamicMap(make_roi_scan_plot, streams=[pipe_roi_scan]))),
                     ("Movable' centroids table", df_widget_mov_centroids))

#viz_panel = pn.Column( roi_select, hv.DynamicMap(make_tof_histo_plot,  streams=[pipe_tof]),\
#                      hv.DynamicMap(make_image_plot, streams=[pipe_image]), )

page = pn.Row(settings_tabs, viz_panel)

pn.serve(page)


#         , port=SERVE_PORT)

In [None]:
scan_table