# CMS ex situ GIWAXS 2023C3

# Wenhan Samples CMS GIWAXS raw data processing & exporting notebook
In this notebook you output xr.DataSets stored as .zarr stores containing all your raw,
remeshed (reciprocal space), and caked CMS GIWAXS data. Saving as a zarr automatically converts the array to a dask array

In [None]:
### Kernel updates if needed, remember to restart kernel after running this cell!:
!pip install -e /nsls2/users/alevin/repos/PyHyperScattering  # to use pip to install via directory

## Imports

In [None]:
### Imports:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import xarray as xr
import PyHyperScattering as phs
import pygix
import gc
from tqdm.auto import tqdm  # progress bar loader!

print(f'Using PyHyperScattering Version: {phs.__version__}')

## Defining some objects

### Define & check paths

In [None]:
# pix_size = 0.000172
# 668 * pix_size

In [None]:
# I like pathlib for its readability & checkability, it's also necessary for the loadSeries function later on
# Replace the paths with the ones relevant to your data, you can use the ".exists()" method to make sure you defined a path correctly
propPath = pathlib.Path('/nsls2/data/cms/proposals/2023-3/pass-311415')
dataPath = propPath.joinpath('AL_processed_data/WOu/waxs')
rawPath = dataPath.joinpath('raw')
samplesPath = dataPath.joinpath('stitched')

notebookPath = pathlib.Path.cwd()
poniFile = notebookPath.joinpath('WO_bcx_bcy.poni')
maskFile = notebookPath.joinpath('blank.json')

# outPath = propPath.joinpath('AL_processed_data')

# Select poni & mask filepaths
# poniFile = maskponiPath.joinpath('CeO2_2023-12-03_y673_x464p3.poni')
# maskFile = maskponiPath.joinpath('pilatus1m_vertical_gaps_only.json')

# # Create pg Transform objects with the above information:
# # Can set the energy to overwrite default poni energy, this MUST be correct for your samples!
# pg_transformer = phs.GIWAXS.Transform(poniPath=poniFile, maskPath=None, energy=12.7)

# Colormap
cmap = plt.cm.turbo
cmap.set_bad('black')

In [None]:
def poni_centers(poniFile, pix_size=0.000172):
    """
    Returns poni center value and the corresponding pixel position. Default pixel size is 172 microns (Pilatus 1M)
    
    Inputs: poniFile as pathlib path object to the poni file
    Outputs: ((poni1, y_center), (poni2, x_center))
    """
    
    with poniFile.open('r') as f:
        lines = list(f.readlines())
    poni1_str = lines[6]
    poni2_str = lines[7]

    poni1 = float(poni1_str.split(' ')[1])
    poni2 = float(poni2_str.split(' ')[1])

    y_center = poni1 / pix_size
    x_center = poni2 / pix_size
        
    return ((poni1, y_center), (poni2, x_center))

poni_centers(poniFile)

### Define metadata naming scheme & initialize loaders

In [None]:
[f.name for f in sorted(samplesPath.glob('*'))]

In [None]:
WO_set = sorted(samplesPath.glob('*'))

In [None]:
WO_set = [f for f in WO_set if len(f.name.split('_'))==9]
[f.name for f in WO_set]

In [None]:
# set ex situ metadata filename naming schemes:
WO_md_naming_scheme = ['project', 'sampleid','detector_pos', 'sample_pos', 
                       'incident_angle', 'exposure_time', 'scan_id', 'detector', 'image_type']

# Initalize CMSGIWAXSLoader objects with the above naming schemes
WO_loader = phs.load.CMSGIWAXSLoader(md_naming_scheme=WO_md_naming_scheme)

## Data processing
Break this section up however makes sense for your data

In [None]:
# loader = phs.load.CMSGIWAXSLoader()
# # calibPath = pathlib.Path('/nsls2/data/cms/proposals/2023-2/pass-311415/KWhite5/maxs/raw/LaB6_5.6m_12.7keV_4250.1s_x0.001_th0.120_10.00s_1118442_maxs.tiff')
# calib_DA = loader.loadSingleImage(calibPath)  # Loads the file specified at calibPath into an xr.DataArray object

In [None]:
# # Load a mask as np.array (can use any method)
# draw = phs.IntegrationUtils.DrawMask(calib_DA)
# draw.load(maskFile)

# mask = draw.mask  # Loads mask as numpy array

# # Show np.array mask:
# plt.imshow(mask)
# plt.colorbar()

# Initialize a transformer:
transformer = phs.GIWAXS.Transform(poniPath=poniFile, maskPath=maskFile)

In [None]:
WO_raw_DS, WO_recip_DS, WO_caked_DS = phs.GIWAXS.single_images_to_dataset(WO_set, WO_loader, transformer)  

In [None]:
WO_raw_DS

In [None]:
display(WO_recip_DS)

In [None]:
display(WO_caked_DS)

In [None]:
def select_attrs(data_arrays_iterable, selected_attrs_dict):
    """
    Selects data arrays whose attributes match the specified values.

    Parameters:
    data_arrays_iterable: Iterable of xarray.DataArray objects.
    selected_attrs_dict: Dictionary where keys are attribute names and 
                         values are the attributes' desired values.

    Returns:
    List of xarray.DataArray objects that match the specified attributes.
    """    
    sublist = list(data_arrays_iterable)
    
    for attr_name, attr_values in selected_attrs_dict.items():
        sublist = [da for da in sublist if da.attrs[attr_name] in attr_values]
                
    return sublist

In [None]:
# Example of a quick plot check if desired here:
# for DA in tqdm(fixed_recip_DS.data_vars.values()):   

selected_attrs_dict = {}
selected_DAs = select_attrs(WO_raw_DS.data_vars.values(), selected_attrs_dict)

for DA in tqdm(selected_DAs):   
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA
    # sliced_DA = DA.sel(q_xy=slice(-0.5, -0.25), q_z=slice(1.5, 1.75))

    # real_min = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 0.4)).compute().quantile(1e-3))
    real_min = float(DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min
    
    # cmax = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 2)).compute().quantile(1))   
    cmax = float(DA.compute().quantile(0.999))   
    
    # Same plotting procedure as above
    # ax = sliced_DA.plot.imshow(cmap=cmap, norm=LogNorm(cmin, cmax), interpolation='antialiased', figsize=(5.5,3.3))
    ax = sliced_DA.plot.imshow(origin='upper', cmap=cmap, norm=plt.Normalize(cmin, cmax), figsize=(5.5,3.3))

    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
    ax.axes.set(title=f'{DA.sampleid}, incident angle: {DA.incident_angle}, scan id: {DA.scan_id}',
                aspect='equal', xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.figure.set(tight_layout=True, dpi=130)
    
    # ax.figure.savefig(outPath.joinpath('recip_plots/stitched_v2', f'{DA.material}_{DA.solvent}_{DA.incident_angle}.png'), dpi=120)
    plt.show()
    plt.close('all')

In [None]:
# Example of a quick plot check if desired here:
# for DA in tqdm(fixed_recip_DS.data_vars.values()):   

selected_attrs_dict = {}
selected_DAs = select_attrs(WO_recip_DS.data_vars.values(), selected_attrs_dict)

for DA in tqdm(selected_DAs):   
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA.sel(q_xy=slice(-1, 3), q_z=slice(0, None))
    # sliced_DA = DA.sel(q_xy=slice(-0.5, -0.25), q_z=slice(1.5, 1.75))

    # real_min = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 0.4)).compute().quantile(1e-3))
    real_min = float(DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min
    
    # cmax = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 2)).compute().quantile(1))   
    cmax = float(DA.compute().quantile(0.999))   
    
    # Same plotting procedure as above
    # ax = sliced_DA.plot.imshow(cmap=cmap, norm=LogNorm(cmin, cmax), interpolation='antialiased', figsize=(5.5,3.3))
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), interpolation='antialiased', figsize=(5.5,3.3))

    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
    ax.axes.set(title=f'{DA.sampleid}, incident angle: {DA.incident_angle}, scan id: {DA.scan_id}',
                aspect='equal', xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.figure.set(tight_layout=True, dpi=130)
    
    # ax.figure.savefig(outPath.joinpath('recip_plots/stitched_v2', f'{DA.material}_{DA.solvent}_{DA.incident_angle}.png'), dpi=120)
    plt.show()
    plt.close('all')

In [None]:
chi_min = -90
chi_max = 90

selected_attrs_dict = {}
selected_DAs = select_attrs(WO_caked_DS.data_vars.values(), selected_attrs_dict)

for DA in tqdm(selected_DAs):
    # Slice dataarray to select plotting region 
    sliced_DA = DA.sel(chi=slice(chi_min,chi_max))
    
    # real_min = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 0.4)).compute().quantile(1e-3))
    real_min = float(DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min
    
    # cmax = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 2)).compute().quantile(1))   
    cmax = float(DA.compute().quantile(0.999))  
    
    # Plot sliced dataarray
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), figsize=(5,4), interpolation='antialiased')  # plot, optional parameter interpolation='antialiased' for image smoothing
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)  # set colorbar label & parameters 
    ax.axes.set(title=f'Polar Plot: {DA.sampleid}, {float(DA.incident_angle[2:])}° Incidence',
                xlabel='q$_r$ [Å$^{-1}$]', ylabel='$\chi$ [°]')  # set title, axis labels, misc
    ax.figure.set(tight_layout=True, dpi=130)  # Adjust figure dpi & plotting style
    
    plt.show()  # Comment to mute plotting output
    
    # Uncomment below line and set savepath/savename for saving plots, I usually like to check 
    # ax.figure.savefig(outPath.joinpath('PM6-Y6set_waxs', f'polar-2D_{DA.sample_id}_{chi_min}to{chi_max}chi_{DA.incident_angle}.png'), dpi=150)
    plt.close('all')

In [None]:
whos PosixPath

In [None]:
savePath = dataPath.parent.joinpath('zarrs')

In [None]:
WO_recip_DS.to_zarr(savePath.joinpath('WO_recip_stitched.zarr'), mode='w')

In [None]:
WO_raw_DS.to_zarr(savePath.joinpath('WO_raw_stitched.zarr'), mode='w')

In [None]:
WO_caked_DS.to_zarr(savePath.joinpath('WO_caked_stitched.zarr'), mode='w')