In [None]:
%matplotlib widget

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable

from databroker import DataBroker as db
from sixtools.rixs_wrapper import make_scan, calibrate, interp_robust
from sixtools.plotting_functions import plot_frame, plot_scan
from sixtools.GUI_components import markers
from rixs.process2d import image_to_photon_events, fit_curvature, get_curvature_offsets, estimate_elastic_pos
from rixs.plotting_functions import plot_pcolor

from ipywidgets import interact, interactive, fixed, HBox, VBox,Label, Layout, FloatText, IntSlider, Dropdown, SelectMultiple
from IPython.display import clear_output, display

In [None]:
%%capture
fig_frames, ax_frames = plt.subplots(num=0, figsize=(10, 4), clear=True)
fig_curvature, ax_curvature = plt.subplots(num=1, clear=True)

## Define and execute processing

In [None]:
process_dicts = {'low_2theta': {'light_ROI': [slice(175, 1500), slice(4130, 4700)],
                                'curvature': np.array([0., 0., 0.]),
                                'bins': None,
                                'background': None},
                 'high_2theta': {'light_ROI': [slice(175, 1500), slice(900, 1300)],
                                 'curvature': np.array([0., 0., 0.]),
                                 'bins': None,
                                 'background': None}
                }

light_ROIs = [d['light_ROI'] for _, d in process_dicts.items()]

# ids
scan_ids = list(range(34354, 34354+1))

# Process data
frames = np.array([ImageStack for header in db[scan_ids] for ImageStack in header.data('rixscam_image')])
event_labels = ["{}".format(i) for i, _ in enumerate(frames)]

## Examine frames

In [None]:
fig_frames, ax_frames = plt.subplots(num=0, figsize=(10, 4), clear=True)
art_frames, _, cb_frames = plot_frame(ax_frames, frames[0,0], light_ROIs=light_ROIs)

ax_frames.set_title("Frame {}_{}".format(event_labels[0], 0))


def update_frame(scan_id, frameid, vmin, vmax):
    event = next(event for event, sid in zip(frames, scan_ids) if sid == scan_id)
    print(event.shape)
    frame = event[frameid]
    art_frames.set_data(frame)
    event_label = next(label for label, sid in zip(event_labels, scan_ids) if sid == scan_id)
    ax_frames.set_title("Frame {}_{}".format(event_label, frameid))

    art_frames.set_clim(vmin, vmax)
    cb_frames.set_clim(vmin, vmax)
    fig_frames.canvas.draw()
    fig_frames.canvas.flush_events()

scanid_widget = Dropdown(options=scan_ids)
frameid_widget = IntSlider(min=0, max=frames.shape[1]-1)
vmin_widget1 = FloatText()
vmax_widget1 = FloatText()

grab = interactive(update_frame, scan_id=scanid_widget, frameid=frameid_widget, vmin=vmin_widget1, vmax=vmax_widget1)

display(VBox([fig_frames.canvas, HBox([scanid_widget, frameid_widget]), HBox([vmin_widget1, vmax_widget1])]))

## Fit curvature

In [None]:
image = frames.mean(axis=(0,1))

fig_curvature, ax_curvature = plt.subplots(num=1, clear=True)
ROI = [slice(175, 1500, None), slice(1000, 1200, None)]
BG = np.mean(frames[0,0][ROI][:600, :])

photon_events = image_to_photon_events(image[ROI]-BG)
x_centers, offsets = get_curvature_offsets(photon_events, bins=(2000, 13))
elastic_y_value = estimate_elastic_pos(photon_events)

curvature = fit_curvature(photon_events, np.array([0., 800.]), bins=(1000, 10))

art, cb_art = plot_pcolor(ax_curvature, photon_events)
ax_curvature.plot(x_centers, offsets+elastic_y_value, 'r*')
ax_curvature.set_xlabel('x')
ax_curvature.set_ylabel('y')
cb_art.set_label('I')

display(fig_curvature.canvas)

print("Not a great example as curvature is {}".format(curvature))