# CMS ex situ GIWAXS 2024C2

# CMS GIWAXS raw data processing & exporting notebook
In this notebook you output xr.DataSets stored as .zarr stores containing all your raw,
remeshed (reciprocal space), and caked CMS GIWAXS data. Saving as a zarr automatically converts the array to a dask array

## Imports

In [None]:
### Imports:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import xarray as xr
import PyHyperScattering as phs
import pygix
import gc
from tqdm.auto import tqdm  # progress bar loader!

print(f'Using PyHyperScattering Version: {phs.__version__}')

## Defining some objects

### Define & check paths

In [None]:
maskponiPath

In [None]:
# I like pathlib for its readability & checkability, it's also necessary for the loadSeries function later on
# Replace the paths with the ones relevant to your data, you can use the ".exists()" method to make sure you defined a path correctly
propPath = pathlib.Path('/nsls2/data/cms/proposals/2023-3/pass-311415')
# samplesPath = propPath.joinpath('AL_2024C2/waxs/stitched')
samplesPath = propPath.joinpath('AL_2024C2/stitched')

maskponiPath = propPath.joinpath('AL_processed_data/maskponi')  # place for pyhyper-drawn masks and poni files

# outPath = propPath.joinpath('AL_processed_data')

# Select poni & mask filepaths
poniFile = maskponiPath.joinpath('CeO2_2023-08-20_y667_x461.poni')
# maskFile = maskponiPath.joinpath('blank.json')
# maskFile = maskponiPath.joinpath('pilatus1m_vertical_gaps_only.json')
maskFile = maskponiPath.joinpath('pilatus_1m_stitched_vertical_gap_silicon_peak.edf')

# Colormap
cmap = plt.cm.turbo
cmap.set_bad('black')

In [None]:
def select_attrs(data_arrays_iterable, selected_attrs_dict):
    """
    Selects data arrays whose attributes match the specified values.

    Parameters:
    data_arrays_iterable: Iterable of xarray.DataArray objects.
    selected_attrs_dict: Dictionary where keys are attribute names and 
                         values are the attributes' desired values.

    Returns:
    List of xarray.DataArray objects that match the specified attributes.
    """    
    sublist = list(data_arrays_iterable)
    
    for attr_name, attr_values in selected_attrs_dict.items():
        sublist = [da.copy() for da in sublist if da.attrs[attr_name] in attr_values]
                
    return sublist

In [None]:
# rawPath = propPath.joinpath('AL_2024C2/raw')

# # Loop through each .tiff file in the directory
# for file_path in rawPath.glob("*"):
#     # Convert the file name to a string
#     new_name = str(file_path.name)
    
#     # Replace the desired substrings
#     new_name = new_name.replace("_x-0.000_", "_x0.000_")
#     new_name = new_name.replace("_th0.119_", "_th0.120_")
#     new_name = new_name.replace("_AL9_", "_AL09_")
#     new_name = new_name.replace("_leftover_", "_")
    
#     # Define the new file path with the updated name
#     new_file_path = file_path.with_name(new_name)
    
#     # Rename the file
#     file_path.rename(new_file_path)

# print("Filenames have been updated.")


In [None]:

# # # Define the directory and the string to be removed
# # directory = samplesPath
# # string_to_remove = '_leftover'

# # # Iterate over all files in the directory
# # for file_path in directory.iterdir():
# #     if file_path.is_file() and string_to_remove in file_path.name:
# #         # Create the new file name by removing the specific string
# #         new_name = file_path.name.replace(string_to_remove, '')
# #         new_file_path = file_path.with_name(new_name)
# #         # Rename the file
# #         file_path.rename(new_file_path)


# # Define the directory and the string to be removed
# directory = samplesPath
# string_to_remove = '_th0.119'

# # Iterate over all files in the directory
# for file_path in directory.iterdir():
#     if file_path.is_file() and string_to_remove in file_path.name:
#         # Create the new file name by removing the specific string
#         new_name = file_path.name.replace(string_to_remove, '_th0.120')
#         new_file_path = file_path.with_name(new_name)
#         # Rename the file
#         file_path.rename(new_file_path)



In [None]:
667 * 0.000172

In [None]:
461 * 0.000172

In [None]:
def poni_centers(poniFile, pix_size=0.000172):
    """
    Returns poni center value and the corresponding pixel position. Default pixel size is 172 microns (Pilatus 1M)
    
    Inputs: poniFile as pathlib path object to the poni file
    Outputs: ((poni1, y_center), (poni2, x_center))
    """
    
    with poniFile.open('r') as f:
        lines = list(f.readlines())
    poni1_str = lines[6]
    poni2_str = lines[7]

    poni1 = float(poni1_str.split(' ')[1])
    poni2 = float(poni2_str.split(' ')[1])

    y_center = poni1 / pix_size
    x_center = poni2 / pix_size
        
    return ((poni1, y_center), (poni2, x_center))

poni_centers(poniFile)

### Define metadata naming scheme & initialize loaders

In [None]:
[f.name for f in sorted(samplesPath.glob('*'))]

In [None]:
samplesPath

In [None]:
# set ex situ metadata filename naming schemes:
md_naming_scheme = ['project', 'sample_ID', 'detector_pos', 'sample_pos', 
                    'incident_angle', 'exposure_time', 'scan_id', 'detector']

# Initalize CMSGIWAXSLoader objects with the above naming schemes
loader = phs.load.CMSGIWAXSLoader(md_naming_scheme=md_naming_scheme)
# loader = phs.load.CMSGIWAXSLoader()

## Data processing
Break this section up however makes sense for your data

### initialize integrators

In [None]:
recip_integrator = phs.integrate.PGGeneralIntegrator(geomethod = 'ponifile',
                                                     ponifile = poniFile,
                                                     output_space = 'recip',
                                                     maskmethod = 'edf',
                                                     maskpath = maskFile)
recip_integrator.energy = 13.5e3

caked_integrator = phs.integrate.PGGeneralIntegrator(geomethod = 'ponifile',
                                                     ponifile = poniFile,
                                                     output_space = 'caked',
                                                     maskmethod = 'edf',
                                                     maskpath = maskFile)
caked_integrator.energy = 13.5e3

### generate, check save: recip Dataset

In [None]:
# Use the single_images_to_dataset utility function to pygix transform all raw files in an indexable list
# Located in the IntegrationUtils script, CMSGIWAXS class:

# Initalize CMSGIWAXS util object
util = phs.util.IntegrationUtils.CMSGIWAXS(sorted(samplesPath.glob('CD*')), loader, recip_integrator)
raw_DS, recip_DS = util.single_images_to_dataset()  # run function 
display(recip_DS)

In [None]:
# Corrected with flipped bar
sn = {
    "AL01": "Y6",
    "AL02": "Y6:PVK 1:1",
    "AL03": "Y6:PVK 1:9",
    "AL04": "A1",
    "AL05": "A1:PVK 1:1",
    "AL06": "A1:PVK 1:9",
    "AL07": "A2",
    "AL08": "A2:PVK 1:1",
    "AL09": "A2:PVK 1:9",
    "AL10": "A3",
    "AL11": "A3:PVK 1:1",
    "AL12": "A3:PVK 1:9",
    "AL13": "Y6 CF:CB 4:1",
    "AL14": "Y6 CF:CB 2:3",
    "AL15": "Y6 CF:CB 2:3 + 0.5% CN",
    "AL16": "PM6 CF:CB 4:1",
    "AL17": "PM6 CF:CB 2:3",
    "AL18": "PM6 CF:CB 2:3 + 0.5% CN",
    "AL19": "PM6:Y6 CF:CB 4:1",
    "AL20": "PM6:Y6 CF:CB 2:3",
    "AL21": "PM6:Y6 CF:CB 2:3 + 0.5% CN",
    "AL22": "PM6 CB",
    "AL23": "Y6 CB",
    "AL24": "Y6BO CB",
    "AL25": "PM6 CB + 1% CN",
    "AL26": "PM6 CB + 5% CN",
    "AL27": "Y6 CB + 0.5% CN",
    
    "AL28": "PM6 CF + 1% CN",
    "AL29": "Y6BO CF",
    "AL30": "Y6 CF",
    "AL31": "PM6 CF",
    "AL32": "PM6:Y6BO CB + 0.5% CN",
    "AL33": "PM6:Y6 CB + 0.5% CN",
    "AL34": "PM6:Y6BO CB",
    "AL35": "PM6:Y6 CB",
    "AL36": "PM6 CB + 0.5% CN",
    "AL37": "Y6BO CB + 0.5% CN",
    
    "AL38": "PM6 CF + 5% CN",
    "AL39": "Y6 CF + 0.5% CN",
    "AL40": "Y6BO CF + 0.5% CN",
    "AL41": "PM6 CF + 0.5% CN",
    "AL42": "PM6:Y6 CF",
    "AL43": "PM6:Y6BO CF",
    "AL44": "PM6:Y6 CF + 0.5% CN",
    "AL45": "PM6:Y6BO CF + 0.5% CN"
}


In [None]:
# Example of a quick plot check if desired here:
for DA in tqdm(list(recip_DS.data_vars.values())):   
# for DA in tqdm(selected_DAs):
    cmin = DA.quantile(0.05)
    cmax = DA.quantile(0.9992)
    
    ax = DA.sel(q_xy=slice(-1.1, 2.1), q_z=slice(-0.05, 2.4)).plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), figsize=(8,4))
    ax.axes.set(aspect='equal', title=f'{DA.sample_ID}: {sn[DA.sample_ID]},\n {DA.incident_angle}, {DA.sample_pos}, id: {DA.scan_id}')
    plt.show()
    plt.close('all')

In [None]:
# Use the single_images_to_dataset utility function to pygix transform all raw files in an indexable list
# Located in the IntegrationUtils script, CMSGIWAXS class:

# Initalize CMSGIWAXS util object
util = phs.util.IntegrationUtils.CMSGIWAXS(sorted(samplesPath.glob('CD*')), loader, caked_integrator)
raw_DS, caked_DS = util.single_images_to_dataset()  # run function 
display(caked_DS)

In [None]:
# Example of a quick plot check if desired here:


for DA in tqdm(list(caked_DS.data_vars.values())):   
# for DA in tqdm(selected_DAs):
    cmin = DA.sel(qr=slice(0.25,None)).quantile(0.05)
    cmax = DA.sel(qr=slice(0.25,None)).quantile(0.9992)
    
    ax = DA.sel(qr=slice(0, 2.1), chi=slice(None, None)).plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), figsize=(8,4))
    ax.axes.set(title=f'{DA.sample_ID}: {sn[DA.sample_ID]},\n {DA.incident_angle}, {DA.sample_pos}, id: {DA.scan_id}')
    plt.show()
    plt.close('all')

In [None]:
### Apply a sin chi correction
sin_chi_DA = np.sin(np.radians(np.abs(caked_DS.chi)))

corr_DS = caked_DS.copy()
# corr_DS = corr_DS * sin_chi_DA  # This works mathematically, but does not preserve attributes
for var in corr_DS.data_vars:
    corrected = corr_DS[var] * sin_chi_DA
    corr_DS[var].values = corrected.values
    
corr_DS

In [None]:
%matplotlib widget
plt.close('all')

# Polar plots, for sin(chi) intensities
# Set chi range: Full range
chi_min = -90
chi_max = 90
q_min = 0.05
q_max = 2.04

# selected_attrs_dict = {'sample_ID': ['AL22', 'AL36', 'AL25', 'AL26', 
#                                      'AL31', 'AL41', 'AL28', 'AL38']}
selected_attrs_dict = {'sample_ID': ['AL22', 'AL31'], 'incident_angle': ['th0.110']}

selected_DAs = select_attrs(corr_DS.data_vars.values(), selected_attrs_dict)    

for DA in tqdm(selected_DAs):
    # Slice dataarray to select plotting region 
    sliced_DA = DA.sel(chi=slice(chi_min,chi_max), qr=slice(q_min, q_max))
    
    # Set color limits
    real_min = float(sliced_DA.compute().quantile(0.01))
    cmin = 1 if real_min < 1 else real_min

    cmax = float(sliced_DA.compute().quantile(0.995))       
    
    # Plot sliced dataarray
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), figsize=(5,4))  # plot, optional parameter interpolation='antialiased' for image smoothing
    ax.axes.set(title=f'Polar Plot: {sn[DA.sample_ID]}, {float(DA.incident_angle[2:])}° Incidence, sin($\chi$) Corrected')
    ax.colorbar.set_label('Intensity * sin($\chi$) [arb. units]', rotation=270, labelpad=15)  # set colorbar label & parameters 
    ax.axes.set(xlabel='q$_r$ [Å$^{-1}$]', ylabel='$\chi$ [°]')  # set title, axis labels, misc
    ax.figure.set(tight_layout=True, dpi=130)  # Adjust figure dpi & plotting style
    
    plt.show()  # Comment to mute plotting output
    # plt.close('all')

In [None]:
outPath = propPath.joinpath('AL_2024C2/processed_data')
outPath.exists()

In [None]:
# Saving dataset with xarray's to_zarr() method:
# General structure below:

# Set where to save file and what to name it
savePath = outPath.joinpath('zarrs')
savePath.mkdir(exist_ok=True)
savename = 'recip_DS.zarr'

# Save it
recip_DS.to_zarr(savePath.joinpath(savename), mode='w')

In [None]:
# Saving dataset with xarray's to_zarr() method:
# General structure below:

# Set where to save file and what to name it
savePath = outPath.joinpath('zarrs')
savePath.mkdir(exist_ok=True)
savename = 'caked_DS.zarr'

# Save it
caked_DS.to_zarr(savePath.joinpath(savename), mode='w')

### generate, check save: caked Dataset

#### Yoneda check:

In [None]:
yoneda_angles

In [None]:
qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad(angles)))
qz_inv_angstroms = qz_inv_meters / 1e10

In [None]:
recip_DS

In [None]:
def qz(wavelength, alpha_crit, alpha_incidents):
    qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad((alpha_incidents + alpha_crit)/2)))
    # qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad(alpha_crit)) + np.sin(np.deg2rad(alpha_incidents)))
    qz_inv_angstroms = qz_inv_meters / 1e10
    return qz_inv_angstroms


# wavelength = 9.762535309700809e-11  # 12.7 keV
wavelength = 9.184014698755575e-11  # 13.5 keV

alpha_crit = 0.11  # organic film critical angle
alpha_incidents = np.array([0.08, 0.11, 0.12, 0.15])

yoneda_angles = alpha_incidents + alpha_crit

qz(wavelength, alpha_crit, alpha_incidents)

In [None]:
def select_attrs(data_arrays_iterable, selected_attrs_dict):
    """
    Selects data arrays whose attributes match the specified values.

    Parameters:
    data_arrays_iterable: Iterable of xarray.DataArray objects.
    selected_attrs_dict: Dictionary where keys are attribute names and 
                         values are the attributes' desired values.

    Returns:
    List of xarray.DataArray objects that match the specified attributes.
    """    
    sublist = list(data_arrays_iterable)
    
    for attr_name, attr_values in selected_attrs_dict.items():
        sublist = [da for da in sublist if da.attrs[attr_name] in attr_values]
                
    return sublist

In [None]:
%matplotlib inline

In [None]:
plt.close('all')

# 2D reciprocal space cartesian plots
qxy_min = -1.1
qxy_max = 2.1
qz_min = -0.2
qz_max = 2.2

selected_attrs_dict = {'sample_ID': ['AL26']}
# selected_attrs_dict = {}

selected_DAs = select_attrs(recip_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min, qz_max))
    
    real_min = float(sliced_DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min

    cmax = float(sliced_DA.compute().quantile(0.997))   
    
    # Plot
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), interpolation='antialiased', figsize=(5.5,3.3))
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
    # ax.axes.set(aspect='equal', title=f'Cartesian Plot: {DA.material} {DA.solvent} {DA.rpm}, {float(DA.incident_angle[2:])}° Incidence',
    #             xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.axes.set(aspect='equal', title=f'Cartesian Plot: {sn[DA.sample_ID]}, {float(DA.incident_angle[2:])}° Incidence',
                xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.figure.set(tight_layout=True, dpi=130)
    
    # ax.figure.savefig(savePath.joinpath(f'{DA.material}-{DA.solvent}-{DA.rpm}_qxy{qxy_min}to{qxy_max}_qz{qz_min}to{qz_max}_{DA.incident_angle}.png'), dpi=150)
    # ax.figure.savefig(savePath.joinpath(f'{DA.material}-{DA.solvent}_qxy{qxy_min}to{qxy_max}_qz{qz_min}to{qz_max}_{DA.incident_angle}.png'), dpi=150)

    plt.show()
    # plt.close('all')

In [None]:
# Yoneda peak linecut check
qxy_min = 0.22
qxy_max = 2
qz_min = -0.02
qz_max = 0.06

selected_DAs = select_attrs(fixed_recip_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min, qz_max))
    qz_integrated_DA = sliced_DA.sum('q_xy')
    
    # Plot
    qz_integrated_DA.plot.line(label=DA.incident_angle)
    
plt.legend()
plt.grid(visible=True, which='major', axis='x')
plt.show()

In [None]:
chi_min = 60
chi_max = None

selected_DAs = select_attrs(fixed_caked_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice dataarray to select plotting region 
    sliced_DA = DA.sel(chi=slice(chi_min,chi_max))
    
    # real_min = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 0.4)).compute().quantile(1e-3))
    real_min = float(DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min
    
    # cmax = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 2)).compute().quantile(1))   
    cmax = float(DA.compute().quantile(0.999))  
    
    # Plot sliced dataarray
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, 10), figsize=(5,4), interpolation='antialiased')  # plot, optional parameter interpolation='antialiased' for image smoothing
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)  # set colorbar label & parameters 
    ax.axes.set(title=f'Polar Plot: {DA.material} {DA.solvent}, {float(DA.incident_angle[2:])}° Incidence',
                xlabel='q$_r$ [Å$^{-1}$]', ylabel='$\chi$ [°]')  # set title, axis labels, misc
    ax.figure.set(tight_layout=True, dpi=130)  # Adjust figure dpi & plotting style
    
    plt.show()  # Comment to mute plotting output
    
    # Uncomment below line and set savepath/savename for saving plots, I usually like to check 
    # ax.figure.savefig(outPath.joinpath('PM6-Y6set_waxs', f'polar-2D_{DA.sample_id}_{chi_min}to{chi_max}chi_{DA.incident_angle}.png'), dpi=150)
    plt.close('all')

In [None]:
fixed_recip_DS.to_zarr(savePath.joinpath('fix_recip_stitched.zarr'), mode='w')

In [None]:
fixed_raw_DS.to_zarr(savePath.joinpath('fix_raw_stitched.zarr'), mode='w')

In [None]:
fixed_caked_DS.to_zarr(savePath.joinpath('fix_caked_stitched.zarr'), mode='w')

In [None]:
variable_raw_DS, variable_recip_DS, variable_caked_DS = phs.GIWAXS.single_images_to_dataset(variable_rpm_set, variable_rpm_loader, transformer)

In [None]:
variable_recip_DS.to_zarr(savePath.joinpath('var_recip_stitched.zarr'), mode='w')

In [None]:
variable_raw_DS.to_zarr(savePath.joinpath('var_raw_stitched.zarr'), mode='w')

In [None]:
variable_caked_DS.to_zarr(savePath.joinpath('var_caked_stitched.zarr'), mode='w')