In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable

from databroker import DataBroker as db
from sixtools.rixs_wrapper import make_scan, calibrate, interp_robust
from sixtools.plotting_functions import plot_frame, plot_scan
from rixs.plotting_functions import plot_pcolor

import ipywidgets
from ipywidgets import interact, interactive, fixed
from IPython.display import clear_output, display

%matplotlib widget

## Define plot figures

In [None]:
%matplotlib widget

def update_plot(num, plot_output, image=False, **kwargs):
    with plot_output:
        fig, ax = plt.subplots(num=num, **kwargs)
        if image is False:
            return fig, ax
        elif image is True:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="2%", pad=0.1)
            cax.set_xticks([])
            cax.set_yticks([])
            return fig, ax, cax

plot_output_dark = ipywidgets.Output()
fig_dark, ax_dark, cax_dark = update_plot(0, plot_output_dark, image=True, figsize=(10, 4))

plot_output_frames = ipywidgets.Output()
fig_frames, ax_frames, cax_frames = update_plot(1, plot_output_frames, image=True, figsize=(10, 4))

plot_output_scan = ipywidgets.Output()
fig_scan, ax_scan = update_plot(2, plot_output_scan)

plot_output_cal_scan = ipywidgets.Output()
fig_cal_scan, ax_cal_scan = update_plot(3, plot_output_cal_scan)

plot_output_map = ipywidgets.Output()
fig_map, ax_map, cax_map = update_plot(4, plot_output_map, image=True)

## Define processing

In [None]:
process_dicts = {'low_2theta': {'light_ROI': [slice(175, 1500), slice(4130, 4700)],
                                'curvature': np.array([0., 0., 0.]),
                                'bins': None,
                                'background': None},
                 'high_2theta': {'light_ROI': [slice(175, 1500), slice(1153, 1531)],
                                 'curvature': np.array([0., 0., 0.]),
                                 'bins': None,
                                 'background': None}
                }

light_ROIs = [d['light_ROI'] for _, d in process_dicts.items()]

raise Exception(("Ask Mark for the scan_ids and dark_scan_ids."
                 " As we have minimal publicly sharable data at the moment."))

## Make dark images

In [None]:

dark_headers = db[dark_scan_ids]

dark_frame = sum(np.mean(ImageStack, axis=0) for h in dark_headers
                 for ImageStack in h.data('rixscam_image'))

dark_frame = dark_frame/len(dark_headers)

_ = plot_frame(ax_dark, dark_frame, light_ROIs)
ax_dark.set_title('Dark Frame')

process_dicts['high_2theta']['background'] = dark_frame
process_dicts['low_2theta']['background'] = dark_frame

display(plot_output_dark)

## Examine frames

In [None]:
def get_plot_frame(ax, scan_id, light_ROIs=[], cax=None, **kwargs):
    ax.clear()
    for a in ax.get_children():
        try:
            a.clear()
        except AttributeError:
            pass
    header = db[scan_id]
    frame = np.mean(next(header.data('rixscam_image')), axis=0)
    art, cax, cb = plot_frame(ax, frame, light_ROIs=light_ROIs, cax=cax, **kwargs)
    ax.set_title('Scan no {}'.format(scan_id))


scanid_widget = ipywidgets.Dropdown(options=scan_ids)
vin_widget = ipywidgets.FloatText(value=200.)
vmax_widget = ipywidgets.FloatText(value=260.)

grab = interactive(get_plot_frame, ax=fixed(ax_frames), scan_id=scanid_widget, light_ROIs=fixed(light_ROIs),
         cax=fixed(cax_frames),
         vmin=vin_widget, vmax=vmax_widget)


display(plot_output_frames, scanid_widget, vin_widget, vmax_widget)

## Make spectra

In [None]:
scan = make_scan(db[scan_ids], **process_dicts['high_2theta'])

## Retrieve energies
Note that this was not a formal scan, so this is somewhat inelegant now

In [None]:
pgm_en = db.get_table(db[scan_ids], stream_name="baseline", fields=['pgm_en'])['pgm_en']
energies = (pgm_en[pgm_en.index == 1].values + pgm_en[pgm_en.index == 2].values)/2

## Take mean over eqivalent energies

The eventual data structure will mean this can be done much more elegantly

In [None]:
energies = np.round(energies, decimals=1)
unique_energies = np.unique(energies)

scan_mean_equiv_E = np.array([np.mean(scan[np.abs(energies - E) < 0.05, :, :, :], axis=0)
                              for E in unique_energies])

scan_ids_equiv_E = [np.array(scan_ids)[np.abs(energies - E) < 0.05] for E in unique_energies]

## Plot spectra

In [None]:
event_labels = ["{:.1f}".format(E) for E in unique_energies]

ax_scan.cla()
plot_scan(ax_scan, scan_mean_equiv_E, event_labels=event_labels,
          marker='.', markersize=0.5, xlabel='Pixel',
          linestyle='-', legend_kw=dict(ncol=2))

display(plot_output_scan)

## Calibrate spectra

In [None]:
elastics = np.array([[S[np.argmax(S[:,1]),0] for S in event]
                     for event in scan_mean_equiv_E])

cal_scan = calibrate(scan_mean_equiv_E, elastics, -0.01)

ax_cal_scan.cla()
plot_scan(ax_cal_scan, cal_scan, event_labels=event_labels,
          marker='.', markersize=0.5,
          linestyle='-', legend_kw=dict(ncol=2))

display(plot_output_cal_scan)

In [None]:
energy_loss = np.linspace(-0.5, 10, 1000)

RIXSmap = np.array([interp_robust(energy_loss, S[:,0], S[:,1])
                    for S in cal_scan.mean(1)])

def plot_map(ax, cax, vmin=0, vmax=2000):
    art = ax.pcolor(energy_loss, unique_energies, RIXSmap, vmin=vmin, vmax=vmax)
    cb = plt.colorbar(art, cax=cax)
    cb.set_label('I')
    ax.set_xlabel('Energy loss')
    ax.set_ylabel('Incident energy (eV)')

vin_widget2 = ipywidgets.FloatText(value=np.nanpercentile(RIXSmap, 1))
vmax_widget2 = ipywidgets.FloatText(value=np.nanpercentile(RIXSmap, 99))

grab = interactive(plot_map, ax=fixed(ax_map), cax=fixed(cax_map), vmin=vin_widget2, vmax=vmax_widget2)

display(plot_output_map, vin_widget2, vmax_widget2)