# Processing SST1 RSoXS Data

## Pip install and restart kernel 

In [None]:
# Only needs to be run once per session, restart kernel after running

# %pip install pyhyperscattering==0.2.1  # to use pip published package
!pip install -e /nsls2/users/alevin/repos/PyHyperScattering  # to use pip to install via directory

In [None]:
!pip install --pre --upgrade tiled[all] databroker  # bottleneck # needed to fix tiled/databroker error in SST1RSoXSDB

## Imports

In [None]:
## Imports
import PyHyperScattering as phs
import pathlib
import sys
import json
import datetime
import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm.auto import tqdm
import dask.array as da
from dask.diagnostics import ProgressBar
# from tiled.client import from_profile, from_uri
# from matplotlib_inline.backend_inline import set_matplotlib_formats

sys.path.append('/nsls2/users/alevin/local_lib')
from andrew_rsoxs_fxns import *

## Some setup functions
# set_matplotlib_formats('svg')
print(f'Using PyHyperScattering Version: {phs.__version__}')

## Define masks directory path
userPath = pathlib.Path('/nsls2/users/alevin')
maskPath = userPath.joinpath('masks')
jsonPath = userPath.joinpath('local_lib')
# propPath = pathlib.Path('/nsls2/data/sst/proposals/2022-2/pass-309180')
propPath = pathlib.Path('/nsls2/data/sst/proposals/2023-2/pass-311130')
outPath = propPath.joinpath('processed_data')
zarrsPath = outPath.joinpath('zarrs')

## Set an RSoXS colormap for later
cmap = plt.cm.turbo.copy()
cmap.set_bad('black')

## Load from local file

In [None]:
local_loader = phs.load.SST1RSoXSLoader(corr_mode='none')

In [None]:
scan_id = '65817'
filepath = propPath.joinpath('TRMSN_NEXAFS_and_RSoXS', scan_id)
filepath

In [None]:
[f.name for f in filepath.iterdir()]

In [None]:
from PIL import Image
from PyHyperScattering.FileLoader import FileLoader
import os
import pathlib
import xarray as xr
import pandas as pd
import datetime
import warnings
import json
#from pyFAI import azimuthalIntegrator
import numpy as np


class SST1RSoXSLoader(FileLoader):
    '''
    Loader for TIFF files from NSLS-II SST1 RSoXS instrument

    '''
    file_ext = '(.*?)primary(.*?).tiff'
    md_loading_is_quick = True
    pix_size_1 = 0.06
    pix_size_2 = 0.06

    def __init__(self,corr_mode=None,user_corr_func=None,dark_pedestal=100,exposure_offset=0,constant_md={},):
        '''
        Args:
            corr_mode (str): origin to use for the intensity correction.  Can be 'expt','i0','expt+i0','user_func','old',or 'none'
            user_corr_func (callable): takes the header dictionary and returns the value of the correction.
            dark_pedestal (numeric): value to subtract(/add, if negative) to the whole image.  this should match the instrument setting for suitcased tiffs, typically 100.
            exposure_offset (numeric): value to add to the exposure time.  Measured at 2ms with the piezo shutter in Dec 2019 by Jacob Thelen, NIST
            constant_md (dict): values to insert into every metadata load. 
        '''

        if corr_mode == None:
            warnings.warn("Correction mode was not set, not performing *any* intensity corrections.  Are you sure this is "+
                          "right? Set corr_mode to 'none' to suppress this warning.",stacklevel=2)
            self.corr_mode = 'none'
        else:
            self.corr_mode = corr_mode


        self.constant_md = constant_md

        self.dark_pedestal = dark_pedestal
        self.user_corr_func = user_corr_func
        self.exposure_offset = exposure_offset
        # self.darks = {}
    # def loadFileSeries(self,basepath):
    #     try:
    #         flist = list(basepath.glob('*primary*.tiff'))
    #     except AttributeError:
    #         basepath = pathlib.Path(basepath)
    #         flist = list(basepath.glob('*primary*.tiff'))
    #     print(f'Found {str(len(flist))} files.')
    #
    #     out = xr.DataArray()
    #     for file in flist:
    #         single_img = self.loadSingleImage(file)
    #         out = xr.concat(out,single_img)
    #
    #     return out



    def loadSingleImage(self,filepath,coords=None, return_q=False,image_slice=None,use_cached_md=False,**kwargs):
        '''
        HELPER FUNCTION that loads a single image and returns an xarray with either pix_x / pix_y dimensions (if return_q == False) or qx / qy (if return_q == True)


        Args:
            filepath (Pathlib.path): path of the file to load
            coords (dict-like): coordinate values to inject into the metadata
            return_q (bool): return qx / qy coords.  If false, returns pixel coords.

        '''
        if len(kwargs.keys())>0:
            warnings.warn(f'Loader does not support features for kwargs: {kwargs.keys()}',stacklevel=2)
        
        if image_slice != None:
            raise NotImplementedError('Image slicing is not supported for SST1')
        if use_cached_md != False:
            raise NotImplementedError('Caching of metadata is not supported for SST1')
        img = Image.open(filepath)

        headerdict = self.loadMd(filepath)
        # two steps in this pre-processing stage:
        #     (1) get and apply the right scalar correction term to the image
        #     (2) find and subtract the right dark
        if coords != None:
            headerdict.update(coords)

        #step 1: correction term

        if self.corr_mode == 'expt':
            corr = headerdict['exposure'] #(headerdict['AI 3 Izero']*expt)
        elif self.corr_mode == 'i0':
            corr = headerdict['AI 3 Izero']
        elif self.corr_mode == 'expt+i0':
            corr = headerdict['exposure'] * headerdict['AI 3 Izero']
        elif self.corr_mode == 'user_func':
            corr = self.user_corr_func(headerdict)
        elif self.corr_mode == 'old':
            corr = headerdict['AI 6 BeamStop'] * 2.4e10/ headerdict['Beamline Energy'] / headerdict['AI 3 Izero']
            #this term is a mess...  @TODO check where it comes from
        else:
            corr = 1

        if(corr<0):
            warnings.warn(f'Correction value is negative: {corr} with headers {headerdict}.',stacklevel=2)
            corr = abs(corr)


        # # step 2: dark subtraction
        # this is already done in the suitcase, but we offer the option to add/subtract a pedestal.
        image_data = (np.array(img)-self.dark_pedestal)/corr
        if return_q:
            qpx = 2*np.pi*60e-6/(headerdict['sdd']/1000)/(headerdict['wavelength']*1e10)
            qx = (np.arange(1,img.size[0]+1)-headerdict['beamcenter_y'])*qpx
            qy = (np.arange(1,img.size[1]+1)-headerdict['beamcenter_x'])*qpx
            # now, match up the dims and coords
            return xr.DataArray(image_data,dims=['qy','qx'],coords={'qy':qy,'qx':qx},attrs=headerdict)
        else:
            # dim order changed by ktoth17 to reflect SST1RSoXSDB.py. See Issue #34 for more details. 
            return xr.DataArray(image_data,dims=['pix_y','pix_x'],attrs=headerdict)

    def read_json(self,jsonfile):
        json_dict = {}
        with open(jsonfile) as f:
            data = [0, json.load(f)]
            meas_time =datetime.datetime.fromtimestamp(data[1]['time'])
            json_dict['sample_name'] = data[1]['sample_name']
        if data[1]['RSoXS_Main_DET'] == 'SAXS':
            json_dict['rsoxs_config'] = 'saxs'
            # discrepency between what is in .json and actual
            if (meas_time > datetime.datetime(2020,12,1)) and (meas_time < datetime.datetime(2021,1,15)):
                json_dict['beamcenter_x'] = 489.86
                json_dict['beamcenter_y'] = 490.75
                json_dict['sdd'] = 521.8
            elif (meas_time > datetime.datetime(2020,11,16)) and (meas_time < datetime.datetime(2020,12,1)):
                json_dict['beamcenter_x'] = 371.52
                json_dict['beamcenter_y'] = 491.17
                json_dict['sdd'] = 512.12
            elif (meas_time > datetime.datetime(2022,5,1)) and (meas_time < datetime.datetime(2022,7,7)):
                # these params determined by Camille from Igor
                json_dict['beamcenter_x'] = 498 # not the best estimate; I didn't have great data
                json_dict['beamcenter_y'] = 498
                json_dict['sdd'] = 512.12 # GUESS; SOMEONE SHOULD CONFIRM WITH A BCP MAYBE??
            else:
                json_dict['beamcenter_x'] = data[1]['RSoXS_SAXS_BCX']
                json_dict['beamcenter_y'] = data[1]['RSoXS_SAXS_BCY']
                json_dict['sdd'] = data[1]['RSoXS_SAXS_SDD']

        elif (data[1]['RSoXS_Main_DET'] == 'WAXS') | (data[1]['RSoXS_Main_DET'] == 'waxs_det'):
            json_dict['rsoxs_config'] = 'waxs'
            if (meas_time > datetime.datetime(2020,11,16)) and (meas_time < datetime.datetime(2021,1,15)):
                json_dict['beamcenter_x'] = 400.46
                json_dict['beamcenter_y'] = 530.99
                json_dict['sdd'] = 38.745
            elif (meas_time > datetime.datetime(2022,5,1)) and (meas_time < datetime.datetime(2022,7,7)):
                # these params determined by Camille from Igor
                json_dict['beamcenter_x'] = 397.91
                json_dict['beamcenter_y'] = 549.76
                json_dict['sdd'] = 34.5 # GUESS; SOMEONE SHOULD CONFIRM WITH A BCP MAYBE??
            else:
                json_dict['beamcenter_x'] = data[1]['RSoXS_WAXS_BCX'] # 399 #
                json_dict['beamcenter_y'] = data[1]['RSoXS_WAXS_BCY'] # 526
                json_dict['sdd'] = data[1]['RSoXS_WAXS_SDD']

        else:
            json_dict['rsoxs_config'] = 'unknown'
            warnings.warn('RSoXS_Config is neither SAXS or WAXS. Check json file',stacklevel=2)

        if json_dict['sdd'] == None:
            warnings.warn('sdd is None, reverting to default values. Check json file',stacklevel=2)
            if json_dict['rsoxs_config'] == 'waxs':
                json_dict['sdd'] = 38.745
            elif json_dict['rsoxs_config'] == 'saxs':
                json_dict['sdd'] = 512.12
        if json_dict['beamcenter_x'] == None:
            warnings.warn('beamcenter_x/y is None, reverting to default values. Check json file',stacklevel=2)
            if json_dict['rsoxs_config'] == 'waxs':
                json_dict['beamcenter_x'] = 400.46
                json_dict['beamcenter_y'] = 530.99
            elif json_dict['rsoxs_config'] == 'saxs':
                json_dict['beamcenter_x'] = 371.52
                json_dict['beamcenter_y'] = 491.17
        return json_dict

    def read_baseline(self,baseline_csv):
        baseline_dict = {}
        df_baseline = pd.read_csv(baseline_csv)
        baseline_dict['sam_x'] = round(df_baseline['RSoXS Sample Outboard-Inboard'][0],4)
        baseline_dict['sam_y'] = round(df_baseline['RSoXS Sample Up-Down'][0],4)
        baseline_dict['sam_z'] = round(df_baseline['RSoXS Sample Downstream-Upstream'][0],4)
        baseline_dict['sam_th'] = round(df_baseline['RSoXS Sample Rotation'][0],4)

        return baseline_dict

    def read_shutter_toggle(self, shutter_csv):
        shutter_data = pd.read_csv(shutter_csv)
        # when shutter opens
        start_time = shutter_data['time'][shutter_data['RSoXS Shutter Toggle']==1]
        # when shutter closes
        end_time = shutter_data['time'][start_time.index + 1]
        # average over all images and round to nearest decimal
        shutter_exposure = np.round(np.mean(end_time.values - start_time.values),1)
        return shutter_exposure

    def read_primary(self,primary_csv,seq_num, cwd):
        primary_dict = {}
        df_primary = pd.read_csv(primary_csv)
        # if json_dict['rsoxs_config'] == 'waxs':
        try:
            primary_dict['exposure'] = df_primary['RSoXS Shutter Opening Time (ms)'][seq_num]
        except KeyError:
            shutter_fname = list(cwd.glob('*Shutter Toggle*'))
            primary_dict['exposure'] = self.read_shutter_toggle(shutter_fname[0])*1000 # keep in ms
            warnings.warn('No exposure time found in primary csv. Calculating from Shutter Toggle csv', stacklevel=2)
                
        # elif json_dict['rsoxs_config'] == 'saxs':
        #     try:
        #         primary_dict['exposure'] = df_primary['RSoXS Shutter Opening Time (ms)'][seq_num]
        #     except KeyError:
        #         primary_dict['exposure'] = 1
        #         warnings.warn('No exposure time found in primary csv. Calculating from Shutter Toggle csv', stacklevel=2)
        # else:
        #     warnings.warn('Check rsoxs_config in json file',stacklevel=2)

        primary_dict['energy'] = round(df_primary['en_energy_setpoint'][seq_num],4)
        primary_dict['polarization'] = df_primary['en_polarization_setpoint'][seq_num]

        return primary_dict


    def loadMd(self,filepath):
        # get sequence number of image for primary csv
        fname = os.path.basename(filepath)
        split_fname = fname.split('-')
        seq_num = int(split_fname[-1][:-5])
        scan_id = split_fname[0]

        # This allows for passing just the filename without the full path
        dirPath = os.path.dirname(filepath)

        if dirPath == '':
            cwd = pathlib.Path('.').absolute()
        else:
            cwd = pathlib.Path(dirPath)

        json_fname = list(cwd.glob('*.json'))
        json_dict = self.read_json(json_fname[0])

        baseline_fname = list(cwd.glob('*baseline.csv'))
        baseline_dict = self.read_baseline(baseline_fname[0])


        primary_path = pathlib.Path(os.path.dirname(cwd))
        primary_fname = list(primary_path.glob(f'{scan_id}*primary.csv'))
        primary_dict = self.read_primary(primary_fname[0],seq_num, cwd)

        # else:
        #     json_fname = list(pathlib.Path(dirPath).glob('*jsonl'))
        #     json_dict = self.read_json(json_fname[0])

        #     baseline_fname = list(pathlib.Path(dirPath).glob('*baseline.csv'))
        #     baseline_dict = self.read_baseline(baseline_fname[0])

        #     primary_path = os.path.dirname(dirPath)
        #     primary_fname = list(pathlib.Path(primary_path).glob(f'{scan_id}*primary.csv'))
        #     primary_dict = self.read_primary(primary_fname[0],json_dict,seq_num)

        headerdict = {**primary_dict,**baseline_dict,**json_dict}

        headerdict['wavelength'] = 1.239842e-6 / headerdict['energy']
        headerdict['seq_num'] = seq_num
        headerdict['sampleid'] = scan_id
        try:
            headerdict['dist'] = headerdict['sdd'] / 1000
        except TypeError:
            print(headerdict['sdd'])
        
        headerdict['pixel1'] = self.pix_size_1 / 1000
        headerdict['pixel2'] = self.pix_size_2 / 1000

        headerdict['poni1'] = headerdict['beamcenter_y'] * headerdict['pixel1']
        headerdict['poni2'] = headerdict['beamcenter_x'] * headerdict['pixel2']

        headerdict['rot1'] = 0
        headerdict['rot2'] = 0
        headerdict['rot3'] = 0

        headerdict.update(self.constant_md)
        return headerdict

    def peekAtMd(self,filepath):
        return self.loadMd(filepath)


In [None]:
# local_loader = SST1RSoXSLoader(corr_mode='None')
da = local_loader.loadFileSeries(filepath, dims=['energy', 'polarization'])
da

In [None]:
da = da.unstack('system')
# da = da.where(da>1e-3)
da

In [None]:
# cmin = float(da.quantile(0.1))
# cmax = float(da.quantile(0.9))

# da.sel(polarization=0, energy=285, method='nearest').plot.imshow(norm=LogNorm(1e1, 1e4), cmap=cmap, interpolation='nearest')

energies = [270, 280, 282, 283, 284, 285, 286, 290]

fg = da.sel(polarization=90, method='nearest').sel(energy=energies, method='nearest').plot.imshow(figsize=(18, 6),
                col='energy', col_wrap=4, norm=LogNorm(1, 1e4), cmap=cmap, interpolation='nearest')
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

## Loading raw data from databroker

In [None]:
# %%time

# # Define catalogs:
# # c = from_profile("rsoxs", structure_clients='dask')
# c = from_uri('https://tiled.nsls2.bnl.gov/', structure_clients='numpy')['rsoxs']['raw']
# print(c)

In [None]:
# db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask
db_loader = phs.load.SST1RSoXSDB(corr_mode='None', use_chunked_loading=True, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask

In [None]:
## Search for and summarize runs:
runs_sum_df = db_loader.summarize_run(institution='CUBLDER', cycle='2023-2', sample_id='TRMSN', project='TRMSN', debugWarnings=False)
# runs_sum_df = db_loader.summarize_run(institution='CUBLDER',cycle='2022-2', sample='andrew*', plan='full*', debugWarnings=False)
# runs_sum_df = runs_sum_df.set_index('scan_id')  # optional, set index to scan id
print(runs_sum_df['plan'].unique())
display(runs_sum_df)

In [None]:
## Slice output dataframe for samples of interest
# runs_of_interest = runs_sum_df.loc[runs_sum_df['cycle']=='2022-2'] #.loc[runs_sum_df['sample_id']=='andrew7']
df = runs_sum_df
runs_of_interest = df[(df['plan']=='rsoxs_carbon') & (df['num_Images']==228)]
# runs_of_interest = df[(df['plan']=='rsoxs_nitrogen') & (df['num_Images']==114)]
# runs_of_interest = df[(df['plan']=='nexafs_carbon')]

display(runs_of_interest)

In [None]:
da_rows = []
# for scan_id in tqdm([65808, 65810]):
for scan_id in tqdm(runs_of_interest['scan_id']):
    da = db_loader.loadRun(scan_id, dims=['energy', 'polarization'])
    da = da.expand_dims({'scan_id': [da.sampleid]})
    da = da.assign_coords(sample_id=('scan_id', [da.start['sample_id']]),
                          sample_name=('scan_id', [da.sample_name]))
    da_rows.append(da)
    # print(scan_id)
    
DA = xr.concat(da_rows, 'scan_id')

In [None]:
DA

In [None]:
for k, v in DA.attrs.items():
    if isinstance(v, dask.array.core.Array):
        DA.attrs[k] = v.compute()
        print(f'{k:<20}  |  {type(v)}')
    elif isinstance(v, dict) or isinstance(v, datetime.datetime):
        DA.attrs[k] = str(v) 
        print(f'{k:<20}  |  {type(v)}')

In [None]:
DS = DA.to_dataset(name='rsoxs_carbon')

with ProgressBar():
    DS.to_zarr(zarrsPath.joinpath('raw_rsoxs_carbon.zarr'), mode='w')

In [None]:
DS_loaded = xr.open_zarr(zarrsPath.joinpath('raw_rsoxs_carbon.zarr'))
DA = DS_loaded['rsoxs_carbon']

# Compute any dask coordiantes
for coord_name, coord_data in DA.coords.items():
    if isinstance(coord_data.data, da.Array):
        DA.coords[coord_name] = coord_data.compute()
        
DA

In [None]:
str(DA.sel(scan_id=65810).sample_id.data)

In [None]:
%%time 

sliced_DA = DA.sel(scan_id=65808)

energies = [270, 280, 282, 283, 284, 285, 286, 290]
pol = 0

fg = sliced_DA.sel(polarization=pol, method='nearest').sel(energy=energies, method='nearest').sel(
            pix_x=slice(160, 780), pix_y=slice(240, 800)).plot.imshow(figsize=(18, 6),
                col='energy', col_wrap=4, norm=LogNorm(3e1, 1e3), cmap=cmap)
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
fg.fig.suptitle(f'{str(sliced_DA.sample_name.data)}, {str(sliced_DA.sample_id.data)}, pol = {pol}°', y=1.02)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

## Draw/check masks & beamcenters for transforming to q-space
### 1. Check raw images at a selected energy for all loaded scan configurations:

In [None]:
saxs_waxs_p00_p90_plot(raw_saxs, raw_waxs)

### 2. Draw masks

In [None]:
# ## SAXS:
# saxs_mask_img = raw_saxs.sel(pol=0, energy=275, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(saxs_mask_img)
# # draw.ui()

In [None]:
# ## Save saxs drawn mask
# draw.save(maskPath.joinpath(f'SAXS_{raw_saxs.sample_name}.json'))

In [None]:
# ## Repeat for WAXS mask:
# waxs_mask_img = raw_waxs.sel(pol=0, energy=275, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)
# # draw.ui()

In [None]:
# ## Save and load saxs drawn mask
# draw.save(maskPath.joinpath(f'WAXS_{raw_saxs.sample_name}.json'))

In [None]:
### Check masks on file
saxs_mask_img = raw_saxs.sel(pol=0, energy=275, method='nearest').compute()
waxs_mask_img = raw_waxs.sel(pol=0, energy=275, method='nearest').compute()
draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)


### Load masks
saxs_mask, waxs_mask = plot_mask_files(draw, maskPath, raw_waxs.sample_name, saxs_img=saxs_mask_img, waxs_img=waxs_mask_img)
# plot_one_mask_file(draw, maskPath, raw_waxs.sample_name, img=raw_waxs.sel(pol=0, energy=275, method='nearest'))

### 3. Check and save beamcenters before converting to q-space

In [None]:
## SAXS
SAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_saxs.sel(pol=0))
SAXSinteg.mask = saxs_mask
SAXSinteg.ni_beamcenter_x = bcxy_2022_2['saxs_bcx']
SAXSinteg.ni_beamcenter_y = bcxy_2022_2['saxs_bcy']
raw_saxs.attrs['beamcenter_x'] = bcxy_2022_2['saxs_bcx']
raw_saxs.attrs['beamcenter_y'] = bcxy_2022_2['saxs_bcy']
raw_saxs.attrs['poni1'] = SAXSinteg.poni1
raw_saxs.attrs['poni2'] = SAXSinteg.poni2
print('SAXS Beamcenter: \n'
      f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')

## Plot check
phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.4)
plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
plt.gcf().set(dpi=120)
plt.show()

## WAXS
WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_waxs.sel(pol=0))
WAXSinteg.mask = waxs_mask
WAXSinteg.ni_beamcenter_x = bcxy_2022_2['waxs_bcx']
WAXSinteg.ni_beamcenter_y = bcxy_2022_2['waxs_bcy']
raw_waxs.attrs['beamcenter_x'] = bcxy_2022_2['waxs_bcx']
raw_waxs.attrs['beamcenter_y'] = bcxy_2022_2['waxs_bcy']
raw_waxs.attrs['poni1'] = WAXSinteg.poni1
raw_waxs.attrs['poni2'] = WAXSinteg.poni2
print('WAXS Beamcenter: \n'
      f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')

## Plot check
phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=7e3, alpha=0.4)
plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
plt.gcf().set(dpi=120)
plt.show()

In [None]:
# ## Tweaking if needed:
# ## SAXS Tweaking & Plot Check
# saxs_new_bcx = 488
# saxs_new_bcy = 515
# SAXSinteg.ni_beamcenter_x = saxs_new_bcx
# SAXSinteg.ni_beamcenter_y = saxs_new_bcy
# raw_saxs.attrs['beamcenter_x'] = saxs_new_bcx
# raw_saxs.attrs['beamcenter_y'] = saxs_new_bcy
# raw_saxs.attrs['poni1'] = SAXSinteg.poni1
# raw_saxs.attrs['poni2'] = SAXSinteg.poni2

# print('SAXS Beamcenter Tweaking: \n'
#       f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')

# phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.6)
# plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
# plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

# ## WAXS Tweaking & Plot Check
# waxs_new_bcx = 396.3
# waxs_new_bcy = 553
# WAXSinteg.ni_beamcenter_x = waxs_new_bcx
# WAXSinteg.ni_beamcenter_y = waxs_new_bcy
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['poni1'] = WAXSinteg.poni1
# raw_waxs.attrs['poni2'] = WAXSinteg.poni2

# print('WAXS Beamcenter Tweaking: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6, guide1=40)
# plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()


# ## Using Pete D.'s (very slightly modified) beamcentering script:
# # phs.BeamCentering.CenteringAccessor.refine_geometry

# ## SAXS
# res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.002, q_max=0.006)
# # res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.002, q_max=0.006, chi_min=-180, chi_max=60)
# # res_saxs = raw_saxs.sel(pol=0).util.refine_geometry(energy=280, q_min=0.002, q_max=0.008, mask=saxs_mask)
# raw_saxs.attrs['poni1'] = res_saxs.x[0]
# raw_saxs.attrs['poni2'] = res_saxs.x[1]
# SAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_saxs.sel(pol=0))
# SAXSinteg.mask = saxs_mask

# ## SAXS Plot check
# print('SAXS Beamcenter Post-optimization: \n'
#       f'poni1: {SAXSinteg.poni1}, poni2: {SAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {SAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {SAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(SAXSinteg, saxs_mask_img, img_max=1e3, alpha=0.6)
# plt.xlim(SAXSinteg.ni_beamcenter_x-200, SAXSinteg.ni_beamcenter_x+200)
# plt.ylim(SAXSinteg.ni_beamcenter_y-200, SAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

# ## WAXS
# # res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06, chi_min=-10, chi_max=70)
# res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06)
# raw_waxs.attrs['poni1'] = res_waxs.x[0]
# raw_waxs.attrs['poni2'] = res_waxs.x[1]
# WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = raw_waxs.sel(pol=0))
# WAXSinteg.mask = waxs_mask

# ## WAXS Plot check
# print('WAXS Beamcenter Post-optimization: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6)
# plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()

In [None]:
### Write beamcenters to saved .json file if content with them:

beamcenters_dict = {
    f'SAXS_{raw_saxs.sample_name}': {'bcx':raw_saxs.beamcenter_x, 'bcy':raw_saxs.beamcenter_y},
    f'WAXS_{raw_waxs.sample_name}': {'bcx':raw_waxs.beamcenter_x, 'bcy':raw_waxs.beamcenter_y}
}

with open(jsonPath.joinpath('beamcenters_dict.json'), 'r') as f:
    dic = json.load(f)

dic.update(beamcenters_dict)

with open(jsonPath.joinpath('beamcenters_dict.json'), 'w') as f:
    json.dump(dic, f)

## Export data
These files are large and therefore should not be saved into the user folder, but rather the proposal folder:


### 1. Apply qx,qy labels, save .zarr stores

In [None]:
### Now that we know our beamcenters are accurate, we can apply correct q axis labels
raw_waxs = apply_q_labels(raw_waxs)
raw_saxs = apply_q_labels(raw_saxs)

### Load energy lists for facet plots
energies = raw_waxs.energy.data
resonant_energies = energies[16:96]

gif_energies = np.array([])
gif_energies = np.append(energies[0:16:3], energies[-31::2])
gif_energies = np.sort(np.append(gif_energies, resonant_energies))

### Set variables for naming purposes
sample_name = sample_guide[raw_waxs.sample_name]
detector = detector_guide[raw_waxs.detector]

# sampPath = exportPath.joinpath(f'{detector}_{sample_name}')
# sampPath.mkdir(parents=True, exist_ok=True)

In [None]:
# sorted([f.name for f in zarrPath.iterdir()])

In [None]:
### Save zarr store/directory 

save_zarr((raw_saxs, raw_waxs), zarrPath)

In [None]:
# ### Generate WAXS facet plots
# sample_name = sample_guide[raw_waxs.sample_name]
# scan_id = raw_waxs.sampleid
# detector = detector_guide[raw_waxs.detector]

# scanPath = facetPath.joinpath(f'{scan_id}_{sample_name}_{detector}')
# scanPath.mkdir(parents=True, exist_ok=True)

# for pol in (0, 90):
#     for num in range(10):
#         grid = raw_waxs.sel(pol=pol, energy=energy_list[8*num:8*num+8], method='nearest').plot.imshow(x='qx', y='qy',
#                     norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
#         grid.set_xlabels('qx [1/Å]')
#         grid.set_ylabels('qy [1/Å]') 

#         # Create/select folder for scan to save plots:
#         imgsPath = scanPath.joinpath(f'_qxqy_frames_{detector}_{int(pol):0>2}deg')
#         imgsPath.mkdir(parents=True, exist_ok=True)

#         plt.savefig(imgsPath.joinpath(f'{sample_name}_{detector}_{int(pol):0>2}_f{num}.svg'))

### 2. Convert to chi-q space & save .zarr stores

In [None]:
integ_saxs = integrate_stacked_pol(SAXSinteg, raw_saxs)
integ_waxs = integrate_stacked_pol(WAXSinteg, raw_waxs)
display(integ_saxs, integ_waxs)

In [None]:
### Save zarr store/directory 
save_zarr((integ_saxs, integ_waxs), zarrPath, prefix='integ_qchi')

In [None]:
sorted(zarrPath.glob('raw*w18*SAXS*'))

In [None]:
### How you would load data:
loaded_raw_saxs = xr.open_zarr(sorted(zarrPath.glob('raw*w11*SAXS*'))[0]).saxs
loaded_raw_waxs = xr.open_zarr(sorted(zarrPath.glob('raw*w11*WAXS*'))[0]).waxs
loaded_integ_saxs = xr.open_zarr(sorted(zarrPath.glob('integ*w11*SAXS*'))[0]).saxs
loaded_integ_waxs = xr.open_zarr(sorted(zarrPath.glob('integ*w11*WAXS*'))[0]).waxs

In [None]:
### Load energy lists for facet plots
energies = raw_waxs.energy.data
resonant_energies = energies[16:96]

gif_energies = np.array([])
gif_energies = np.append(energies[0:16:3], energies[-31::2])
gif_energies = np.sort(np.append(gif_energies, resonant_energies))

pol=0

# #### View facet plot to verify data:
# loaded_integ_waxs.sel(pol=pol, energy=gif_energies[:-6:6], method='nearest').plot.imshow(xscale='log', xlim=(1e-2, 2e-1),
#                         norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
# loaded_integ_saxs.sel(pol=pol, energy=gif_energies[:-6:6], method='nearest').plot.imshow(xscale='log', xlim=(1e-3, 1e-2),
#                         norm=LogNorm(1e1, 5e3), cmap=cm, interpolation='antialiased', col='energy', col_wrap=4)
