In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable

from databroker import DataBroker as db
from sixtools.rixs_wrapper import make_scan, calibrate, interp_robust
from sixtools.plotting_functions import plot_frame, plot_scan
from rixs.plotting_functions import plot_pcolor

from lmfit.models import GaussianModel, ConstantModel, LinearModel

import ipywidgets
from ipywidgets import interact, interactive, fixed
from IPython.display import clear_output, display

%matplotlib widget

## Define plot figures

In [None]:
%matplotlib widget

plot_output_frames = ipywidgets.Output()
with plot_output_frames:
    fig_frames, ax_frames = plt.subplots(num=1, figsize=(10, 4))
    divider = make_axes_locatable(ax_frames)
    cax_frames = divider.append_axes("right", size="2%", pad=0.1)
    cax_frames.set_xticks([])
    cax_frames.set_yticks([])

plot_output_fits = ipywidgets.Output()
with plot_output_fits:
    fig_fits, axs_fits = plt.subplots(3, 2, num=2, figsize=(10, 8), gridspec_kw=dict(height_ratios=[1,1,0.2]))
    
plot_output_ecal = ipywidgets.Output()
with plot_output_ecal:
    fig_ecal, ax_ecal = plt.subplots(num=3)

## Define and execture data processing

In [None]:
process_dicts = {'low_2theta': {'light_ROI': [slice(175, 1500), slice(3650, 4200)],
                                'curvature': np.array([0., 0., 0.]),
                                'bins': None,
                                'background': None},
                 'high_2theta': {'light_ROI': [slice(175, 1500), slice(1020, 1450)],
                                 'curvature': np.array([0., 0., 0.]),
                                 'bins': None,
                                 'background': None}
                }

light_ROIs = [d['light_ROI'] for _, d in process_dicts.items()]

scan_ids = list(range(21110, 21118))
scanned_motor = 'pgm_en'

frames = np.array([ImageStack for header in db[scan_ids] for ImageStack in header.data('rixscam_image')])

scan_high_2theta = make_scan(db[scan_ids], **process_dicts['high_2theta'])
scan_low_2theta = make_scan(db[scan_ids], **process_dicts['low_2theta'])

before_after_values = db.get_table(db[scan_ids], stream_name="baseline", fields=[scanned_motor])[scanned_motor]
motor_values = (before_after_values[before_after_values.index == 1].values
                + before_after_values[before_after_values.index == 2].values)/2

fit_model = GaussianModel() + ConstantModel()

def fitit(S):
    c = np.mean(S[:,1])
    params = GaussianModel().guess(S[:,1]-c, x=S[:,0])
    params.add('c', value=c)
    return fit_model.fit(S[:,1], x=S[:,0], params=params)

fits_high_2theta = [[fitit(S) for S in event] for event in scan_high_2theta]
fits_low_2theta = [[fitit(S) for S in event] for event in scan_low_2theta]

## Examine frames

In [None]:
def get_plot_frame(ax, scan_id, frame_index, light_ROIs=[], cax=None, **kwargs):
    ax.clear()
    for a in ax.get_children():
        try:
            a.clear()
        except AttributeError:
            pass
    try:
        cax.clear()
    except AttributeError:
        pass
    frame = frames[np.array(scan_ids) == scan_id, frame_index][0]
    art, cax, cb = plot_frame(ax, frame, light_ROIs=light_ROIs, cax=cax, **kwargs)
    ax.set_title('Scan no {}'.format(scan_id))


scanid_widget = ipywidgets.Dropdown(options=scan_ids)
frame_index_widget = ipywidgets.IntSlider(min=0, value=0, max=frames.shape[1]-1)
vin_widget = ipywidgets.FloatText(value=200.)
vmax_widget = ipywidgets.FloatText(value=260.)

grab = interactive(get_plot_frame, ax=fixed(ax_frames), scan_id=scanid_widget, frame_index=frame_index_widget, light_ROIs=fixed(light_ROIs),
         cax=fixed(cax_frames),
         vmin=vin_widget, vmax=vmax_widget)


display(plot_output_frames, scanid_widget, frame_index_widget, vin_widget, vmax_widget)

## Plot spectra

In [None]:
_ = [ax.clear() for ax in axs_fits.ravel()]

# high 2theta spectra
ax = axs_fits[0,0]
for event, fit_event, motor_value, scan_id in zip(scan_high_2theta, fits_high_2theta, motor_values, scan_ids):
    for i, (S, fit) in enumerate(zip(event, fit_event)):
        art = ax.plot(S[:,0], S[:,1], '.', label='#{}_{} {:.1f}'.format(scan_id, i, motor_value))
        ax.plot(S[:,0], fit.best_fit, '-', color=art[0].get_color())

ax.set_xlabel('pixels')
ax.set_ylabel('I')
ax.set_title('High 2theta')

# low 2theta spectra
ax = axs_fits[0,1]
for event, fit_event, motor_value, scan_id in zip(scan_low_2theta, fits_low_2theta, motor_values, scan_ids):
    for i, (S, fit) in enumerate(zip(event, fit_event)):
        art = ax.plot(S[:,0], S[:,1], '.', label='#{}_{} {:.1f}'.format(scan_id, i, motor_value))
        ax.plot(S[:,0], fit.best_fit, '-', color=art[0].get_color())

ax.set_xlabel('pixels')
ax.set_ylabel('I')
ax.set_title('Low 2theta')

# create legend
axs_fits[-1,0].legend(*axs_fits[0,0].get_legend_handles_labels(), ncol=5, loc=(0,0), fontsize=8)
axs_fits[-1,-1].axis('off')
axs_fits[-1,0].axis('off')

# Fit values high 2theta
ax = axs_fits[1,0]
axr = ax.twinx()
for event, fit_event, motor_value, scan_id in zip(scan_high_2theta, fits_high_2theta, motor_values, scan_ids):
    for i, (S, fit) in enumerate(zip(event, fit_event)):
        art = ax.plot(motor_value, fit.best_values['center'], 'bo')
        axr.plot(motor_value, fit.best_values['sigma']*2*np.sqrt(2*np.log(2)), 'rs')

ax.set_xlabel(scanned_motor)
ax.set_ylabel('center', color='b')
_ = [tick.set_color('b') for tick in ax.get_ymajorticklabels()]
axr.set_ylabel('fwhm', color='r')
_ = [tick.set_color('r') for tick in axr.get_ymajorticklabels()]
ax.set_title('High 2theta')

# Fit values low 2theta
ax = axs_fits[1,1]
axr = ax.twinx()
for event, fit_event, motor_value, scan_id in zip(scan_low_2theta, fits_low_2theta, motor_values, scan_ids):
    for i, (S, fit) in enumerate(zip(event, fit_event)):
        art = ax.plot(motor_value, fit.best_values['center'], 'bo')
        axr.plot(motor_value, fit.best_values['sigma']*2*np.sqrt(2*np.log(2)), 'rs')

ax.set_xlabel(scanned_motor)
ax.set_ylabel('center', color='b')
_ = [tick.set_color('b') for tick in ax.get_ymajorticklabels()]
axr.set_ylabel('fwhm', color='r')
_ = [tick.set_color('r') for tick in axr.get_ymajorticklabels()]
ax.set_title('Low 2theta')

fig_fits.subplots_adjust(wspace=0.4, hspace=0.5)

display(plot_output_fits)

In [None]:
# Make a mean over the frames at each event
ax_ecal.clear()

ecal_model = LinearModel()

centers_high_2theta = np.array([[fit.best_values['center'] for fit in event] for event in fits_high_2theta]).mean(1)
centers_low_2theta = np.array([[fit.best_values['center'] for fit in event] for event in fits_low_2theta]).mean(1)

result_high_2theta = ecal_model.fit(centers_high_2theta, x=motor_values)
result_low_2theta = ecal_model.fit(centers_low_2theta, x=motor_values)

art, *_ = ax_ecal.plot(motor_values, centers_high_2theta, 'o', label='high 2theta')
ax_ecal.plot(motor_values, result_high_2theta.best_fit, '-', color=art.get_color())

art, *_ = ax_ecal.plot(motor_values, centers_low_2theta, 'o', label='low 2theta')
ax_ecal.plot(motor_values, result_low_2theta.best_fit, '-', color=art.get_color())

ax_ecal.set_xlabel(scanned_motor)
ax_ecal.set_ylabel('Cental pixel')
ax_ecal.set_title('Calibration')
ax_ecal.legend()

print("high 2theta gradient={:.3f} meV/pix\nlow 2theta gradient={:.3f} meV/pix".format(1000./result_low_2theta.best_values['slope'],
                                                                       1000./result_high_2theta.best_values['slope']))

display(plot_output_ecal)