In [None]:
%matplotlib widget

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable

from databroker import DataBroker as db
from sixtools.rixs_wrapper import make_scan, calibrate, interp_robust
from sixtools.plotting_functions import plot_frame, plot_scan
from sixtools.GUI_components import markers
from rixs.plotting_functions import plot_pcolor

from ipywidgets import interact, interactive, fixed, HBox, VBox,Label, Layout, FloatText, IntSlider, Dropdown, SelectMultiple
from IPython.display import clear_output, display

In [None]:
%%capture
fig_dark = plt.figure(num=0, figsize=(10, 4), clear=True)
fig_frames = plt.figure(num=1, figsize=(10, 4), clear=True)
fig_scan = plt.figure(num=2, clear=True)
fig_cal = plt.figure(num=3, clear=True)
fig_map = plt.figure(num=4, clear=True)

## Define and execute processing

In [None]:
process_dicts = {'low_2theta': {'light_ROI': [slice(175, 1500), slice(4130, 4700)],
                                'curvature': np.array([0., 0., 0.]),
                                'bins': None,
                                'background': None},
                 'high_2theta': {'light_ROI': [slice(175, 1500), slice(1153, 1531)],
                                 'curvature': np.array([0., 0., 0.]),
                                 'bins': None,
                                 'background': None}
                }

light_ROIs = [d['light_ROI'] for _, d in process_dicts.items()]

# ids
scan_ids = list(range(22322, 22341+1))
dark_scan_ids = [22343, 22370]
scanned_motor = 'pgm_en'

# Process darks
dark_headers = db[dark_scan_ids]
dark_frame = sum(np.mean(ImageStack, axis=0) for h in dark_headers
                 for ImageStack in h.data('rixscam_image'))
dark_frame = dark_frame/len(dark_headers)

# Process data
frames = np.array([ImageStack for header in db[scan_ids] for ImageStack in header.data('rixscam_image')])
scan = make_scan(db[scan_ids], **process_dicts['high_2theta'])

# extraction of motor_values is clumsy now
before_after_values = db.get_table(db[scan_ids], stream_name="baseline", fields=[scanned_motor])[scanned_motor]
motor_values = (before_after_values[before_after_values.index == 1].values
                + before_after_values[before_after_values.index == 2].values)/2

event_labels = ["#{} {}={:.1f}".format(scan_id, scanned_motor, motor)
                for scan_id, motor in zip(scan_ids, motor_values)]

## View dark image

In [None]:
fig_dark = plt.figure(num=0, figsize=(10, 4), clear=True)
ax_dark = fig_dark.add_subplot(111)

art_dark, _, cb_dark = plot_frame(ax_dark, dark_frame, light_ROIs=light_ROIs)

ax_dark.set_title('Dark frame')

def update_dark(vmin, vmax):
    art_dark.set_clim(vmin, vmax)
    cb_dark.set_clim(vmin, vmax)
    fig_dark.canvas.draw()
    fig_dark.canvas.flush_events()
    
vmin_widget0 = FloatText(description='vmin')
vmax_widget0 = FloatText(description='vmax')

interactive(update_dark, vmin=vmin_widget0, vmax=vmax_widget0)
display(VBox([fig_dark.canvas, HBox([vmin_widget0, vmax_widget0])]))

## Examine frames

In [None]:
fig_frames = plt.figure(num=1, figsize=(10, 4), clear=True)
ax_frames = fig_frames.add_subplot(111)

art_frames, _, cb_frames = plot_frame(ax_frames, frames[0,0], light_ROIs=light_ROIs)

ax_frames.set_title("Frame {}_{}".format(event_labels[0], 0))


def update_frame(scan_id, frameid, vmin, vmax):
    event = next(event for event, sid in zip(frames, scan_ids) if sid == scan_id)
    print(event.shape)
    frame = event[frameid]
    art_frames.set_data(frame)
    event_label = next(label for label, sid in zip(event_labels, scan_ids) if sid == scan_id)
    ax_frames.set_title("Frame {}_{}".format(event_label, frameid))

    art_frames.set_clim(vmin, vmax)
    cb_frames.set_clim(vmin, vmax)
    fig_frames.canvas.draw()
    fig_frames.canvas.flush_events()

scanid_widget = Dropdown(options=scan_ids)
frameid_widget = IntSlider(min=0, max=frames.shape[1]-1)
vmin_widget1 = FloatText()
vmax_widget1 = FloatText()

grab = interactive(update_frame, scan_id=scanid_widget, frameid=frameid_widget, vmin=vmin_widget1, vmax=vmax_widget1)

display(VBox([fig_frames.canvas, HBox([scanid_widget, frameid_widget]), HBox([vmin_widget1, vmax_widget1])]))

## Plot spectra

In [None]:
fig_scan = plt.figure(num=2, clear=True)
ax_scan = fig_scan.add_subplot(111)

artists_scan = []
for event, event_label in zip(scan, event_labels):
    for i, S in enumerate(event):
        art = ax_scan.plot(S[:, 0], S[:, 1], marker=next(markers),
                           markersize=4,
                           label="{}_{}".format(event_label, i))
        artists_scan.append(art)

ax_scan.set_xlabel('pixels')
ax_scan.set_ylabel('I')
ax_scan.set_title('Raw spectra')
ax_scan.legend(fontsize=7)

def update_scan(choose_labels):
    for art in ax_scan.lines + ax_scan.collections:
        art.remove()
    for event, event_label, art in zip(scan, event_labels, artists_scan):
        for i, S in enumerate(event):
            if event_label in choose_labels:
                ax_scan.plot(S[:, 0], S[:, 1], marker=art[0].get_marker(),
                                   markersize=4, color=art[0].get_color(),
                                   label="{}_{}".format(event_label, i))
    ax_scan.legend(fontsize=7)
    fig_scan.canvas.draw()
    

choose_label_widget0 = SelectMultiple(options=event_labels, value=event_labels, description=' ', rows=15)
interactive(update_scan, choose_labels=choose_label_widget0)

display(HBox([VBox([Label('Choose spectra'), choose_label_widget0], layout=Layout(align_items='center')),
              fig_scan.canvas], layout=Layout(align_items='center')))

## Calibrate spectra

In [None]:
elastics = np.array([[S[np.argmax(S[:,1]),0] for S in event]
                     for event in scan])

table = db.get_table(db[scan_ids], stream_name='baseline')
ring = table['ring_curr']
I0s = (ring[ring.index == 1].values + ring[ring.index == 2].values)/2

cal_scan = calibrate(scan, elastics=elastics, energy_per_pixel=-0.018, I0s=I0s[:,np.newaxis])

## Plot calibrated spectra

In [None]:
fig_cal = plt.figure(num=3, clear=True)
ax_cal = fig_cal.add_subplot(111)

artists_cal = []
for event, event_label in zip(cal_scan, event_labels):
    for i, S in enumerate(event):
        art = ax_cal.plot(S[:, 0], S[:, 1], marker=next(markers),
                           markersize=4,
                           label="{}_{}".format(event_label, i))
        artists_cal.append(art)

ax_cal.set_xlabel('pixels')
ax_cal.set_ylabel('I')
ax_cal.set_title('Calibrated spectra')
ax_cal.legend(fontsize=7)

def update_cal(choose_labels):
    for art in ax_cal.lines + ax_cal.collections:
        art.remove()
    for event, event_label, art in zip(cal_scan, event_labels, artists_cal):
        for i, S in enumerate(event):
            if event_label in choose_labels:
                ax_cal.plot(S[:, 0], S[:, 1], marker=art[0].get_marker(),
                                   markersize=4, color=art[0].get_color(),
                                   label="{}_{}".format(event_label, i))
    ax_cal.legend(fontsize=7)
    fig_cal.canvas.draw()
    

choose_label_widget1 = SelectMultiple(options=event_labels, value=event_labels, description=' ', rows=15)
interactive(update_cal, choose_labels=choose_label_widget1)

display(HBox([VBox([Label('Choose spectra'), choose_label_widget1], layout=Layout(align_items='center')),
              fig_cal.canvas], layout=Layout(align_items='center')))

## Construct map

In [None]:
energies = motor_values
energy_loss = np.linspace(-0.5, 10, 1000)

RIXSmap = np.array([interp_robust(energy_loss, S[:,0], S[:,1])
                    for S in cal_scan.mean(1)])

## Plot map

In [None]:
fig_map = plt.figure(num=4, clear=True)
ax_map = fig_map.add_subplot(111)
divider = make_axes_locatable(ax_map)
cax_map = divider.append_axes("right", size="2%", pad=0.1)

art_map = ax_map.pcolor(energy_loss, energies, RIXSmap, vmin=np.nanpercentile(RIXSmap, 1),
                    vmax=np.nanpercentile(RIXSmap, 99))

cb_map = plt.colorbar(art_map, cax=cax_map)
cb_map.set_label('I')
ax_map.set_xlabel('Energy loss')
ax_map.set_ylabel('Incident energy (eV)')

def update_map(vmin, vmax):
    art_map.set_clim(vmin, vmax)
    cb_map.set_clim(vmin, vmax)
    fig_map.canvas.draw()
    fig_map.canvas.flush_events()

vin_widget3 = FloatText(value=np.nanpercentile(RIXSmap, 1))
vmax_widget3 = FloatText(value=np.nanpercentile(RIXSmap, 99))

grab = interactive(update_map, vmin=vin_widget3, vmax=vmax_widget3)

display(fig_map.canvas, vin_widget3, vmax_widget3)