# Processing SST1 RSoXS Data

## Pip install and restart kernel 

In [None]:
# Only needs to be run once per session, restart kernel after running

# %pip install pyhyperscattering  # to use pip published package
!pip install -e /nsls2/users/alevin/repos/pyhyper_toneygroup_fork/PyHyperScattering  # to use pip to install via directory
!pip install zarr  # fixed an error with xr.DataSet.to_zarr() method, though not sure if this is really needed
# !pip install xarray==2023.4.0
# !pip install -e /nsls2/users/alevin/repos/xarray  # install a more recent xarray, really just wanted to xr.DataArray.to_zarr method

## Imports

In [None]:
## The autoreload IPython magic command reloads all modules before code is ran
%load_ext autoreload

In [None]:
## Imports
import PyHyperScattering as phs
import pathlib
import sys
import json
import datetime
import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib_inline.backend_inline import set_matplotlib_formats

sys.path.append('/nsls2/users/alevin/local_lib')
from andrew_rsoxs_fxns import *

## Some setup functions
set_matplotlib_formats('svg')
# c = from_profile('rsoxs')
print(f'Using PyHyperScattering Version: {phs.__version__}')
rsoxsload = phs.load.SST1RSoXSDB(corr_mode='None', use_chunked_loading=True)  # initialize rsoxs databroker loader w/ Dask

## Define masks directory path
userPath = pathlib.Path('/nsls2/users/alevin')
notebookPath = pathlib.Path.cwd()
maskPath = userPath.joinpath('masks')
jsonPath = userPath.joinpath('local_lib')
propPath = pathlib.Path('/nsls2/data/sst/proposals/2022-2/pass-309180')
zarrPath = propPath.joinpath('zarr_datasets')
exportPath = propPath.joinpath('processed_data')

## Set an RSoXS colormap for later
cm = plt.cm.terrain.copy()
cm.set_bad('purple')

## Loading raw data from databroker

In [None]:
## Search for and summarize runs:
runs_sum_df = rsoxsload.summarize_run(institution='CUBLDER', plan='full_carbon_scan_nd')
runs_sum_df = runs_sum_df.set_index('scan_id')  # optional, set index to scan id
display(runs_sum_df)

In [None]:
## Slice output dataframe for samples of interest
runs_of_interest = runs_sum_df.loc[runs_sum_df['cycle']=='2022-2'].loc[runs_sum_df['sample_id']=='andrew18']
scans = sorted(runs_of_interest.index)
display(runs_of_interest)

In [None]:
### Run this for samples without error
raw_saxs = load_stacked_pol(rsoxsload, scans[0], scans[1])
raw_waxs = load_stacked_pol(rsoxsload, scans[2], scans[3])

# ### Explicity select scan_ids:
# raw_saxs = load_stacked_pol(rsoxsload, 43157, 43158)
# raw_waxs = load_stacked_pol(rsoxsload, 43213, 43214)

raw_saxs.attrs['blend_name'] = sample_guide[raw_saxs.sample_name]
raw_waxs.attrs['blend_name'] = sample_guide[raw_waxs.sample_name]
display(raw_saxs, raw_waxs)

## Draw/check masks & beamcenters for transforming to q-space
### 1. Check raw images at a selected energy for all loaded scan configurations:

In [None]:
saxs_waxs_p00_p90_plot(raw_saxs, raw_waxs)

### 2. Draw masks

In [None]:
# ## SAXS:
# saxs_mask_img = raw_saxs.sel(pol=0, energy=275, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(saxs_mask_img)
# # draw.ui()

In [None]:
# ## Save saxs drawn mask
# draw.save(maskPath.joinpath(f'SAXS_{raw_saxs.sample_name}.json'))

In [None]:
# ## Repeat for WAXS mask:
# waxs_mask_img = raw_waxs.sel(pol=0, energy=275, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)
# # draw.ui()

In [None]:
# ## Save and load saxs drawn mask
# draw.save(maskPath.joinpath(f'WAXS_{raw_saxs.sample_name}.json'))

In [None]:
### Check masks on file
saxs_mask_img = raw_saxs.sel(pol=0, energy=275, method='nearest').compute()
waxs_mask_img = raw_waxs.sel(pol=0, energy=275, method='nearest').compute()
draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)


### Load masks
saxs_mask, waxs_mask = plot_mask_files(draw, maskPath, raw_waxs.sample_name, saxs_img=saxs_mask_img, waxs_img=waxs_mask_img)
# plot_one_mask_file(draw, maskPath, raw_waxs.sample_name, img=raw_waxs.sel(pol=0, energy=275, method='nearest'))

### 3. Check and save beamcenters before converting to q-space

In [None]:
## SAXS
SAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_saxs.sel(pol=0))
SAXSinteg.mask = saxs_mask
SAXSinteg.ni_beamcenter_x = bcxy_2022_2['saxs_bcx']
SAXSinteg.ni_beamcenter_y = bcxy_2022_2['saxs_bcy']
raw_saxs.attrs['beamcenter_x'] = bcxy_2022_2['saxs_bcx']
raw_saxs.attrs['beamcenter_y'] = bcxy_2022_2['saxs_bcy']
raw_saxs.attrs['poni1'] = SAXSinteg.poni1
raw_saxs.attrs['poni2'] = SAXSinteg.poni2
print('SAXS Beamcenter: \n'
      f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')

## Plot check
phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.4)
plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
plt.gcf().set(dpi=120)
plt.show()

## WAXS
WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_waxs.sel(pol=0))
WAXSinteg.mask = waxs_mask
WAXSinteg.ni_beamcenter_x = bcxy_2022_2['waxs_bcx']
WAXSinteg.ni_beamcenter_y = bcxy_2022_2['waxs_bcy']
raw_waxs.attrs['beamcenter_x'] = bcxy_2022_2['waxs_bcx']
raw_waxs.attrs['beamcenter_y'] = bcxy_2022_2['waxs_bcy']
raw_waxs.attrs['poni1'] = WAXSinteg.poni1
raw_waxs.attrs['poni2'] = WAXSinteg.poni2
print('WAXS Beamcenter: \n'
      f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')

## Plot check
phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=7e3, alpha=0.4)
plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
plt.gcf().set(dpi=120)
plt.show()

In [None]:
# ## Tweaking if needed:
# ## SAXS Tweaking & Plot Check
# saxs_new_bcx = 488
# saxs_new_bcy = 515
# SAXSinteg.ni_beamcenter_x = saxs_new_bcx
# SAXSinteg.ni_beamcenter_y = saxs_new_bcy
# raw_saxs.attrs['beamcenter_x'] = saxs_new_bcx
# raw_saxs.attrs['beamcenter_y'] = saxs_new_bcy
# raw_saxs.attrs['poni1'] = SAXSinteg.poni1
# raw_saxs.attrs['poni2'] = SAXSinteg.poni2

# print('SAXS Beamcenter Tweaking: \n'
#       f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')

# phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.6)
# plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
# plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

# ## WAXS Tweaking & Plot Check
# waxs_new_bcx = 396.3
# waxs_new_bcy = 553
# WAXSinteg.ni_beamcenter_x = waxs_new_bcx
# WAXSinteg.ni_beamcenter_y = waxs_new_bcy
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['poni1'] = WAXSinteg.poni1
# raw_waxs.attrs['poni2'] = WAXSinteg.poni2

# print('WAXS Beamcenter Tweaking: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6, guide1=40)
# plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()


# ## Using Pete D.'s (very slightly modified) beamcentering script:
# # phs.BeamCentering.CenteringAccessor.refine_geometry

# ## SAXS
# res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.002, q_max=0.006)
# # res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.002, q_max=0.006, chi_min=-180, chi_max=60)
# # res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=280, q_min=0.002, q_max=0.008, mask=saxs_mask)
# raw_saxs.attrs['poni1'] = res_saxs.x[0]
# raw_saxs.attrs['poni2'] = res_saxs.x[1]
# SAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_saxs.sel(pol=0))
# SAXSinteg.mask = saxs_mask

# ## SAXS Plot check
# print('SAXS Beamcenter Post-optimization: \n'
#       f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.6)
# plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
# plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

# ## WAXS
# # res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06, chi_min=-10, chi_max=70)
# res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06)
# raw_waxs.attrs['poni1'] = res_waxs.x[0]
# raw_waxs.attrs['poni2'] = res_waxs.x[1]
# WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_waxs.sel(pol=0))
# WAXSinteg.mask = waxs_mask

# ## WAXS Plot check
# print('WAXS Beamcenter Post-optimization: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6)
# plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

In [None]:
### Write beamcenters to saved .json file if content with them:

beamcenters_dict = {
    f'SAXS_{raw_saxs.sample_name}': {'bcx':raw_saxs.beamcenter_x, 'bcy':raw_saxs.beamcenter_y},
    f'WAXS_{raw_waxs.sample_name}': {'bcx':raw_waxs.beamcenter_x, 'bcy':raw_waxs.beamcenter_y}
}

with open(jsonPath.joinpath('beamcenters_dict.json'), 'r') as f:
    dic = json.load(f)

dic.update(beamcenters_dict)

with open(jsonPath.joinpath('beamcenters_dict.json'), 'w') as f:
    json.dump(dic, f)

## Export data
These files are large and therefore should not be saved into the user folder, but rather the proposal folder:


### 1. Apply qx,qy labels, save .zarr stores

In [None]:
### Now that we know our beamcenters are accurate, we can apply correct q axis labels
raw_waxs = apply_q_labels(raw_waxs)
raw_saxs = apply_q_labels(raw_saxs)

### Load energy lists for facet plots
energies = raw_waxs.energy.data
resonant_energies = energies[16:96]

gif_energies = np.array([])
gif_energies = np.append(energies[0:16:3], energies[-31::2])
gif_energies = np.sort(np.append(gif_energies, resonant_energies))

### Set variables for naming purposes
sample_name = sample_guide[raw_waxs.sample_name]
detector = detector_guide[raw_waxs.detector]

sampPath = exportPath.joinpath(f'{detector}_{sample_name}')
sampPath.mkdir(parents=True, exist_ok=True)

In [None]:
### Save zarr store/directory 

save_zarr(raw_saxs, raw_waxs, zarrPath)

In [None]:
# ### Generate WAXS facet plots
# sample_name = sample_guide[raw_waxs.sample_name]
# scan_id = raw_waxs.sampleid
# detector = detector_guide[raw_waxs.detector]

# scanPath = facetPath.joinpath(f'{scan_id}_{sample_name}_{detector}')
# scanPath.mkdir(parents=True, exist_ok=True)

# for pol in (0, 90):
#     for num in range(10):
#         grid = raw_waxs.sel(pol=pol, energy=energy_list[8*num:8*num+8], method='nearest').plot.imshow(x='qx', y='qy',
#                     norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
#         grid.set_xlabels('qx [1/Å]')
#         grid.set_ylabels('qy [1/Å]') 

#         # Create/select folder for scan to save plots:
#         imgsPath = scanPath.joinpath(f'_qxqy_frames_{detector}_{int(pol):0>2}deg')
#         imgsPath.mkdir(parents=True, exist_ok=True)

#         plt.savefig(imgsPath.joinpath(f'{sample_name}_{detector}_{int(pol):0>2}_f{num}.svg'))

### 2. Convert to chi-q space & save .zarr stores

In [None]:
integ_saxs = integrate_stacked_pol(SAXSinteg, raw_saxs)
integ_waxs = integrate_stacked_pol(WAXSinteg, raw_waxs)
display(integ_saxs, integ_waxs)

In [None]:
### Save zarr store/directory 
save_zarr(integ_saxs, integ_waxs, zarrPath, prefix='integ_qchi')

In [None]:
sorted(zarrPath.glob('raw*w18*SAXS*'))

In [None]:
### How you would load data:
loaded_raw_saxs = xr.open_zarr(sorted(zarrPath.glob('raw*w11*SAXS*'))[0]).saxs
loaded_raw_waxs = xr.open_zarr(sorted(zarrPath.glob('raw*w11*WAXS*'))[0]).waxs
loaded_integ_saxs = xr.open_zarr(sorted(zarrPath.glob('integ*w11*SAXS*'))[0]).saxs
loaded_integ_waxs = xr.open_zarr(sorted(zarrPath.glob('integ*w11*WAXS*'))[0]).waxs

In [None]:
### Load energy lists for facet plots
energies = raw_waxs.energy.data
resonant_energies = energies[16:96]

gif_energies = np.array([])
gif_energies = np.append(energies[0:16:3], energies[-31::2])
gif_energies = np.sort(np.append(gif_energies, resonant_energies))

pol=0

# #### View facet plot to verify data:
# loaded_integ_waxs.sel(pol=pol, energy=gif_energies[:-6:6], method='nearest').plot.imshow(xscale='log', xlim=(1e-2, 2e-1),
#                         norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
# loaded_integ_saxs.sel(pol=pol, energy=gif_energies[:-6:6], method='nearest').plot.imshow(xscale='log', xlim=(1e-3, 1e-2),
#                         norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
