# CMS GIWAXS raw data processing & exporting notebook - time resolved GIWAXS series measurements
In this notebook you output xr.DataSets stored as .zarr stores containing all your raw,
remeshed (reciprocal space), and caked CMS GIWAXS data. Saving as a zarr automatically converts the array to a dask array!

In [None]:
# ### Kernel updates if needed, remember to restart kernel after running this cell!:
# !pip install -e /nsls2/users/alevin/repos/PyHyperScattering  # to use pip to install via directory

## Imports

In [None]:
### Imports:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import xarray as xr
import PyHyperScattering as phs
import pygix
import gc
from tqdm.auto import tqdm  # progress bar loader!

print(f'Using PyHyperScattering Version: {phs.__version__}')

# Set colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad('black')

## Defining some objects

### Define & check paths

In [None]:
# I like pathlib for its readability & checkability, it's also necessary for the loadSeries function later on
# Replace the paths with the ones relevant to your data, you can use the ".exists()" method to make sure you defined a path correctly
userPath = pathlib.Path('/nsls2/users/alevin')  # Your users path is great for small items that are personal to you (100 GB limit)
propPath = pathlib.Path('/nsls2/data/cms/proposals/2023-2/pass-311415')  # The proposals path is a good place to store large data (>1 TB space?)
dataPath = propPath.joinpath('KWhite5')
maskponiPath = userPath.joinpath('giwaxs_suite/beamline_data/maskponi')
outPath = propPath.joinpath('AL_processed_data')

# Select poni & mask filepaths
poniFile = maskponiPath.joinpath('LaB6_fixed_rot_x517.2.poni')
maskFile = maskponiPath.joinpath('LaB6.json')

# Creat pg Transform objects with the above information:
# qpara_transformer = phs.GIWAXS.Transform(poniPath=poniFile, maskPath=maskFile, inplane_config='q_para')  # I did not specify an energy, because the poni energy is correct
# qperp_transformer = phs.GIWAXS.Transform(poniPath=poniFile, maskPath=maskFile, inplane_config='q_perp')  # I did not specify an energy, because the poni energy is correct

In [None]:
# List the files inside the dataPath folder
sorted([f.name for f in dataPath.iterdir()])  # a way to list all filenames inside a path

# Misc sample notes:
# pm7_S1_95tol5cpme_14_100_18_85_75_009 is the first sample that I did with the different exposure series
# all previous samples are just 900 frames of 0.1s

In [None]:
sample = 'pybtz_CBCNp5_15_200_40_60_60_014'
samplePath = dataPath.joinpath(sample, 'maxs/raw')
sorted([f.name for f in samplePath.iterdir()])  # a way to list all filenames inside a path

In [None]:
# # Generate sets for samples with single scan id per series scan
# # Customize as per your data, in this example for my above selected blade coated sample, I have 3 sets I am interested in:
# # 1) The series set, here all with the same scan id and same exposure time
# # 2) qpara set, what I named my single image scans post-series measurement
# # 3) qperp set, what I named my single image scans after rotating 90 degrees in plane 

# # Choose series scan id(s)
# series_ids = ['1117471']

# # Create separate sets for single vs series measurements, customize per your data:
# qperp_set = set(samplePath.glob('*qperp*'))  # only my qperp samples have qperp in the name
# series_set = set(samplePath.glob(f'*{series_ids[0]}*'))  # all my series scans have the same id
# singles_set = set(samplePath.iterdir()).difference(series_set)  # the total single image scans is just the difference between all the scans and then series set
# qpara_set = singles_set.difference(qperp_set)  # qpara set is the singles set minus the qperp set

# # # Check content of sets
# # print('qperp images:')
# # display(sorted([f.name for f in qperp_set]))

# # print('\nqpara images:')
# # display(sorted([f.name for f in qpara_set]))

# # print('\nimage series:')
# # display(sorted([f.name for f in series_set]))

In [None]:
# Generate sets for samples with multiple scan ids per series scan
# Some of my series are broken into different scan ids because I changed the exposure time

# Choose series scan id(s)
series_ids = ['1118329', '1118330', '1118331']

# Create separate sets for single vs series measurements, customize per your data:
# I had 3 different scan ids in one series measurement, so I combine them all first 
# before substracting them from the total file list
exp0p1_set = set(samplePath.glob(f'*{series_ids[0]}*')) 
exp0p5_set = set(samplePath.glob(f'*{series_ids[1]}*'))
exp2p0_set = set(samplePath.glob(f'*{series_ids[2]}*'))
qperp_set = set(samplePath.glob('*qperp*'))

series_set = exp0p1_set.union(exp0p5_set, exp2p0_set)
singles_set = set(samplePath.iterdir()).difference(series_set)
qpara_set = singles_set.difference(qperp_set)

# Check content of sets
print('qperp images:')
display(sorted([f.name for f in qperp_set]))

print('\nqpara images:')
display(sorted([f.name for f in qpara_set]))

print('\nimage series:')
display(sorted([f.name for f in series_set]))

### Define metadata naming schemes & initialize loaders

In [None]:
# My example metadata filename naming schemes:
# Make sure the length of this list lines up with your filenames split by underscore (or however you split them)!

# Metadata naming schemes for the pybtz samples
# For nonrotated, qpara images:
qpara_md_naming_scheme = ['material', 'solvent', 'concentration', 'gap_height', 'blade_speed',
                    'solution_temperature', 'stage_temperature', 'sample_number', 'time_start',
                    'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 'detector']

# For rotated, qperp images:
qperp_md_naming_scheme = ['material', 'solvent', 'concentration', 'gap_height', 'blade_speed',
                    'solution_temperature', 'stage_temperature', 'sample_number', 'in-plane_orientation',
                    'time_start', 'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 'detector']

# For in situ series images:
in_situ_md_naming_scheme = ['material', 'solvent', 'concentration', 'gap_height', 'blade_speed',
                    'solution_temperature', 'stage_temperature', 'sample_number', 'time_start',
                    'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 
                    'series_number', 'detector']

# # Metadata naming schemes for the pm6 & pm7 S1 samples
# # For nonrotated, qpara images:
# qpara_md_naming_scheme = ['material', 'material description', 'solvent', 'concentration', 'gap_height', 'blade_speed',
#                     'solution_temperature', 'stage_temperature', 'sample_number', 'time_start',
#                     'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 'detector']

# # For rotated, qperp images:
# qperp_md_naming_scheme = ['material', 'material description', 'solvent', 'concentration', 'gap_height', 'blade_speed',
#                     'solution_temperature', 'stage_temperature', 'sample_number', 'in-plane_orientation',
#                     'time_start', 'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 'detector']

# # For in situ series images:
# in_situ_md_naming_scheme = ['material', 'material description', 'solvent', 'concentration', 'gap_height', 'blade_speed',
#                     'solution_temperature', 'stage_temperature', 'sample_number', 'time_start',
#                     'x_position_offset', 'incident_angle', 'exposure_time', 'scan_id', 
#                     'series_number', 'detector']

# Initalize CMSGIWAXSLoader objects with the above naming schemes
qpara_loader = phs.load.CMSGIWAXSLoader(md_naming_scheme=qpara_md_naming_scheme)
qperp_loader = phs.load.CMSGIWAXSLoader(md_naming_scheme=qperp_md_naming_scheme)
series_loader = phs.load.CMSGIWAXSLoader(md_naming_scheme=in_situ_md_naming_scheme)

## Data processing

### Single image scans outside of series measurement
Using same single_images_to_dataset function as in the single image processing example notebook
Break up sets below according to your data

#### qperp set

In [None]:
# Use the single_images_to_dataset function to pygix transform all raw files in an indexable list
# files must be a pathlib Path and be a raw .tiff detector image
# generate raw, recip (cartesian), and caked (polar) datasets containing dataarrays of all samples
# optionally save as zarr stores with optional extra function parameters 

# Set a savename
material = sample.split('_')[0]
sample_num = sample.split('_')[-1]
savename = f'{material}_qperp_{sample_num}'

# Run function
raw_DS, recip_DS, caked_DS = phs.GIWAXS.single_images_to_dataset(sorted(qperp_set), qperp_loader, qperp_transformer,
                                                                 savePath = outPath.joinpath('qperp_zarrs'),
                                                                 savename = savename)

In [None]:
# # Example of a quick plot check if desired here:
# for DA in tqdm(recip_DS.data_vars.values()):   
#     ax = DA.sel(q_perp=slice(-1.1, 2.1), q_z=slice(-0.05, 2.4)).plot.imshow(cmap=cmap, norm=LogNorm(1e1, 1e4), figsize=(8,4))
#     ax.axes.set(aspect='equal', title=f'{DA.material}, incident angle: {DA.incident_angle}, scan id: {DA.scan_id}')
#     plt.show()
#     plt.close('all')

#### qpara set

In [None]:
# qpara_recip_integrator = phs.integrate.PGGeneralIntegrator(maskmethod = 'pyhyper',
#                                                            maskpath = maskFile,
#                                                            geomethod = 'ponifile',
#                                                            ponifile = poniFile,
#                                                            incident_angle = 0.12,
#                                                            output_space = 'recip')

raw_DA = qpara_loader.loadSingleImage(sorted(qpara_set)[0])
qpara_recip_integrator = phs.integrate.PGGeneralIntegrator(geomethod = 'ponifile',
                                                           ponifile = poniFile,
                                                           output_space = 'recip',
                                                           template_xr = raw_DA)
qpara_caked_integrator = phs.integrate.PGGeneralIntegrator(geomethod = 'ponifile',
                                                           ponifile = poniFile,
                                                           output_space = 'caked',
                                                           template_xr = raw_DA)

In [None]:
raw_DA = qpara_loader.loadSingleImage(sorted(qpara_set)[-1])


qpara_recip_integrator.incident_angle = float(raw_DA.incident_angle[2:])
recip_DA = qpara_recip_integrator.integrateSingleImage(raw_DA)
recip_DA

In [None]:
raw_DA = qpara_loader.loadSingleImage(sorted(qpara_set)[-1])


qpara_caked_integrator.incident_angle = float(raw_DA.incident_angle[2:])
caked_DA = qpara_caked_integrator.integrateSingleImage(raw_DA)
caked_DA

In [None]:
# # Use the single_images_to_dataset function to pygix transform all raw files in an indexable list

# # Set a savename
# material = sample.split('_')[0]
# sample_num = sample.split('_')[-1]
# savename = f'{material}_qpara_{sample_num}'

# # Run function
# raw_DS, recip_DS, caked_DS = phs.GIWAXS.single_images_to_dataset(sorted(qpara_set), qpara_loader, qpara_transformer,
#                                                                  savePath = outPath.joinpath('qpara_zarrs'),
#                                                                  savename = savename)

In [None]:
# # Example of a quick plot check if desired here:
# for DA in tqdm(list(recip_DS.data_vars.values())[0:10]):   
#     ax = DA.sel(q_para=slice(-1.1, 2.1), q_z=slice(-0.05, 2.4)).plot.imshow(cmap=cmap, norm=LogNorm(1e1, 1e4), figsize=(8,4))
#     ax.axes.set(aspect='equal', title=f'{DA.material}, incident angle: {DA.incident_angle}, scan id: {DA.scan_id}')
#     plt.show()
#     plt.close('all')

### Series measurement processing

#### Save each series as its own DataSet
For some samples the series is broken up into different scan ids and exposure times, so I opted to just save each dataarray as its own zarr dataset. Later in the plotting notebook, I load the dataarrays and then will normalize by exposure time and concatenate along the time dimension. This should all be refined into a function like "series_to_datasets", but it isn't super urgen IMO as the code is pretty straightforward to do and more flexible this way:

In [None]:
# from pyFAI.io.ponifile import PoniFile
# PoniFile(data=str(poniFile))

In [None]:
s_DA = series_loader.loadSeries(sorted(exp0p1_set))
s_DA

In [None]:
# # Choose series scan id(s)
# series_ids = ['1118329', '1118330', '1118331']

# # Create separate sets for single vs series measurements, customize per your data:
# # I had 3 different scan ids in one series measurement, so I combine them all first 
# # before substracting them from the total file list
# exp0p1_set = set(samplePath.glob(f'*{series_ids[0]}*')) 

In [None]:
updated_DA = series_loader.loadFileSeries(basepath=samplePath, dims=['series_number'], file_filter='0.10s_1118329')
# updated_DA

time_start = 0
exp_time = np.round(float(updated_DA.exposure_time[:-1]), 1)

fs_DA = updated_DA.unstack('system')
fs_DA = fs_DA.assign_coords({'time': ('series_number', fs_DA.series_number.data.astype('float')*exp_time+exp_time+time_start)})
fs_DA = fs_DA.swap_dims({'series_number': 'time'})
fs_DA

In [None]:
# # Facet plot of selected times
# # cmin = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.01))
# # cmax = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.99))

# cmin=1
# cmax=10
# times = np.linspace(0.1, 10, 8)

# sliced_DA = s_DA.sel(time=times, method='nearest')
# fg = sliced_DA.plot.imshow(figsize=(18, 6), col='time', col_wrap=4, norm=plt.Normalize(cmin, cmax), cmap=cmap, origin='upper')
# fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
# for axes in fg.axs.flatten():
#     axes.set(aspect='equal')

# plt.show()

In [None]:
# # Facet plot of selected times
# # cmin = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.01))
# # cmax = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.99))

# cmin=1
# cmax=10
# times = np.linspace(0.1, 10, 8)

# sliced_DA = fs_DA.sel(time=times, method='nearest')
# fg = sliced_DA.plot.imshow(figsize=(18, 6), col='time', col_wrap=4, norm=plt.Normalize(cmin, cmax), cmap=cmap, origin='upper')
# fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
# for axes in fg.axs.flatten():
#     axes.set(aspect='equal')

# plt.show()

In [None]:
# for time in fs_DA.time.values[:2]:
#     sliced_DA = fs_DA.sel(time=[time], method='nearest')
#     display(sliced_DA)
#     recip_DA = qpara_recip_integrator.integrateSingleImage(sliced_DA)
#     display(recip_DA)

In [None]:
recip_DA = qpara_recip_integrator.integrateImageStack(fs_DA)
recip_DA

In [None]:
caked_DA = qpara_caked_integrator.integrateImageStack(fs_DA)
caked_DA

In [None]:
# Facet plot of selected times
# cmin = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.01))
# cmax = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.99))

cmin=1
cmax=10
times = np.linspace(0.1, 10, 8)

sliced_DA = recip_DA.sel(time=times, method='nearest')
fg = sliced_DA.plot.imshow(figsize=(18, 6), col='time', col_wrap=4, norm=plt.Normalize(cmin, cmax), cmap=cmap)
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

In [None]:
# Facet plot of selected times
# cmin = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.01))
# cmax = float(fs_DA.sel(time=10, method='nearest').compute().quantile(0.99))

cmin=1
cmax=10
times = np.linspace(0.1, 10, 8)

sliced_DA = caked_DA.sel(time=times, method='nearest')
fg = sliced_DA.plot.imshow(figsize=(18, 6), col='time', col_wrap=4, norm=plt.Normalize(cmin, cmax), cmap=cmap)
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
# for axes in fg.axs.flatten():
#     axes.set(aspect='equal')

plt.show()

In [None]:
# # For a series set of just one scan ID

# DA = series_loader.loadSeries(sorted(series_set))
# recip_DA, caked_DA = qpara_transformer.pg_convert_series(DA)

# # Transform DataArrays into DataSets & save as .zarr stores
# raw_DS = DA.to_dataset(name='DA')
# recip_DS = recip_DA.to_dataset(name='DA')
# caked_DS = caked_DA.to_dataset(name='DA')

# # Specify a suffix for saving the raw, recip, and caked DataSets. 
# suffix = f'{DA.scan_id}_{DA.material}_0to90s_qpara_{DA.sample_number}'

# raw_DS.to_zarr(savePath.joinpath('series_zarrs', f'raw_{suffix}.zarr'), mode='w')
# recip_DS.to_zarr(savePath.joinpath('series_zarrs', f'recip_{suffix}.zarr'), mode='w')
# caked_DS.to_zarr(savePath.joinpath('series_zarrs', f'caked_{suffix}.zarr'), mode='w')

In [None]:
# # For a series set for multiple scan IDs / time starts

# time_starts = [0, 10, 90]
# time_ranges = ['0to10', '10to90', '90to180']
# for i, series in enumerate((exp0p1_set, exp0p5_set, exp2p0_set)):
#     # Select the first element of the sorted set outside of the for loop to initialize the xr.DataSet
#     DA = series_loader.loadSeries(sorted(series), time_start= time_starts[i])
#     recip_DA, caked_DA = qpara_transformer.pg_convert_series(DA)

#     # Transform DataArrays into DataSets & save as .zarr stores
#     raw_DS = DA.to_dataset(name='DA')
#     recip_DS = recip_DA.to_dataset(name='DA')
#     caked_DS = caked_DA.to_dataset(name='DA')

#     # Specify a suffix for saving the raw, recip, and caked DataSets. 
#     suffix = f'{DA.scan_id}_{DA.material}_{time_ranges[i]}s_qpara_{DA.sample_number}'

#     raw_DS.to_zarr(savePath.joinpath('series_zarrs', f'raw_{suffix}.zarr'), mode='w')
#     recip_DS.to_zarr(savePath.joinpath('series_zarrs', f'recip_{suffix}.zarr'), mode='w')
#     caked_DS.to_zarr(savePath.joinpath('series_zarrs', f'caked_{suffix}.zarr'), mode='w')