# CMS ex situ GIWAXS 2023C3

# CMS GIWAXS raw data processing & exporting notebook
In this notebook you output xr.DataSets stored as .zarr stores containing all your raw,
remeshed (reciprocal space), and caked CMS GIWAXS data. Saving as a zarr automatically converts the array to a dask array

In [23]:
# ### Kernel updates if needed, remember to restart kernel after running this cell!:
# !pip install -e /nsls2/users/alevin/repos/PyHyperScattering  # to use pip to install via directory

## Imports

In [24]:
### Imports:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import xarray as xr
# import PyHyperScattering as phs
import pygix
import fabio
import gc
from tqdm.auto import tqdm  # progress bar loader!

from GIWAXS import Transform, single_images_to_dataset
from CMSGIWAXSLoader import CMSGIWAXSLoader
# print(f'Using PyHyperScattering Version: {phs.__version__}')

## Defining some objects

### Define & check paths

In [25]:
# pix_size = 0.000172
# 673 * pix_size

In [26]:
# 0.1097360 / 0.000172

In [27]:
# I like pathlib for its readability & checkability, it's also necessary for the loadSeries function later on
# Replace the paths with the ones relevant to your data, you can use the ".exists()" method to make sure you defined a path correctly
suitePath = pathlib.Path('/Users/andrew/Library/CloudStorage/OneDrive-UCB-O365/research/data_analysis/giwaxs_suite')
rawPath = suitePath.joinpath('raw_data/PS-Y6_ALS')
dataPath = rawPath.joinpath('all_data')
outPath = suitePath.joinpath('processed_data/PS-Y6_ALS')

# Select poni & mask filepaths
poniFile = rawPath.joinpath('10keV_2024-03-22_lo.poni')
poniFile_hi = rawPath.joinpath('10keV_2024-03-22_hi.poni')
maskFile = rawPath.joinpath('Pilatus2M_mask.edf')

# Create pg Transform objects with the above information:
# Can set the energy to overwrite default poni energy, this MUST be correct for your samples!
# pg_transformer = Transform(poniPath=poniFile, maskPath=maskFile, energy=10)
pg_transformer = Transform(poniPath=poniFile, maskPath=None, energy=10)

# Colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad('black')

In [28]:
def poni_centers(poniFile, pix_size=0.000172):
    """
    Returns poni center value and the corresponding pixel position. Default pixel size is 172 microns (Pilatus 1M)
    
    Inputs: poniFile as pathlib path object to the poni file
    Outputs: ((poni1, y_center), (poni2, x_center))
    """
    
    with poniFile.open('r') as f:
        lines = list(f.readlines())
    poni1_str = lines[6]
    poni2_str = lines[7]

    poni1 = float(poni1_str.split(' ')[1])
    poni2 = float(poni2_str.split(' ')[1])

    y_center = poni1 / pix_size
    x_center = poni2 / pix_size
        
    return ((poni1, y_center), (poni2, x_center))

display(poni_centers(poniFile))
display(poni_centers(poniFile_hi))

((0.21509344682470327, 1250.5432954924609),
 (0.12267179313842716, 713.2080996420184))

((0.2202295667253362, 1280.4044577054428),
 (0.12265961322468676, 713.1372861900393))

### Define metadata naming scheme & initialize loaders

In [29]:
[f.name for f in sorted(dataPath.glob('*'))]

['sam1_th0.140_10kev_10s_hi_2m.edf',
 'sam1_th0.140_10kev_10s_lo_2m.edf',
 'sam1_th0.140_10kev_30s_hi_2m.edf',
 'sam1_th0.140_10kev_30s_lo_2m.edf',
 'sam2_th0.140_10kev_10s-2_hi_2m.edf',
 'sam2_th0.140_10kev_10s-2_lo_2m.edf',
 'sam2_th0.140_10kev_10s_hi_2m.edf',
 'sam2_th0.140_10kev_10s_lo_2m.edf',
 'sam3_th0.140_10kev_10s_hi_2m.edf',
 'sam3_th0.140_10kev_10s_lo_2m.edf',
 'sam3_th0.140_10kev_30s_hi_2m.edf',
 'sam3_th0.140_10kev_30s_lo_2m.edf',
 'sam4_th0.140_10kev_2s_hi_2m.edf',
 'sam4_th0.140_10kev_2s_lo_2m.edf',
 'sam4_th0.140_10kev_30s_hi_2m.edf',
 'sam4_th0.140_10kev_30s_lo_2m.edf',
 'sam5_th0.140_10kev_10s_hi_2m.edf',
 'sam5_th0.140_10kev_10s_lo_2m.edf',
 'sam5_th0.140_10kev_30s-2_hi_2m.edf',
 'sam5_th0.140_10kev_30s-2_lo_2m.edf',
 'sam5_th0.140_10kev_30s_hi_2m.edf',
 'sam5_th0.140_10kev_30s_lo_2m.edf',
 'sam6_th0.140_10kev_1s_hi_2m.edf',
 'sam6_th0.140_10kev_1s_lo_2m.edf',
 'sam6_th0.140_10kev_2s-2_hi_2m.edf',
 'sam6_th0.140_10kev_2s-2_lo_2m.edf',
 'sam6_th0.140_10kev_2s_hi_2m.ed

In [30]:
# set ex situ metadata filename naming schemes:
md_naming_scheme = ['sample_id', 'incident_angle', 'energy', 'exposure_time', 'detector_pos', 'detector']

# Initalize CMSGIWAXSLoader objects with the above naming schemes
loader = CMSGIWAXSLoader(md_naming_scheme=md_naming_scheme)

## Data processing
Break this section up however makes sense for your data

### Check raw data

In [31]:
# Check raw data, looks good!
files = sorted(dataPath.glob('*'))  # define list of files 

lo_DAs = []
hi_DAs = []
for file in tqdm(files):
    # Load
    DA = loader.loadSingleImage(file)
    
    # Record
    if DA.detector_pos == 'hi':
        lo_DAs.append(DA)
    elif DA.detector_pos == 'lo':
        hi_DAs.append(DA)

    # # Plot
    # cmin = DA.quantile(0.4)
    # cmax = DA.quantile(0.99)
    # ax = DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    # ax.axes.set_title(f'{DA.sample_id}: {DA.exposure_time}, {DA.detector_pos}')
    # plt.show()
    # plt.close('all')

  0%|          | 0/28 [00:00<?, ?it/s]

### Peform stitching if necessary

In [37]:
# Load mask & hi_lo shift:
hi_lo_shift = 30
mask_arr = fabio.open(maskFile).data.astype('bool')
print(mask_arr.shape)

(1679, 1475)


In [52]:
# Run stitching

# del hi_DA, lo_DA

stitched_DAs = []
for tup in zip(lo_DAs, hi_DAs):
    lo_DA = tup[0].copy()
    hi_DA = tup[1].copy()
    
    # Adjust hi DA to match lo, uncomment to check plot
    hi_DA['pix_y']= hi_DA.pix_y + hi_lo_shift  # this is the difference in pixel position for the high vs low position
    hi_DA = hi_DA.sel(pix_y=slice(30, 1678))
    # hi_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    # hi_DA.sel(pix_y=slice(1150,1300)).plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    # plt.show()
    
    # Set lo DA mask values to a crazy value to be replaced by the hi DA
    lo_DA.data[mask_arr] = -1
    lo_DA = lo_DA.sel(pix_y=slice(30, 1678))
    # lo_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    # lo_DA.sel(pix_y=slice(1150,1300)).plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    
    # Keep all values that are not the crazy value, replace the crazy ones with hi DA
    stitched_DA = lo_DA.where(~(lo_DA==-1), hi_DA)
    
    # Record stitched DAs if good:
    stitched_DAs.append(stitched_DA)

In [None]:
# Plot check stitched DAs

for stitched_DA in stitched_DAs:
    cmin = stitched_DA.quantile(0.4)
    cmax = stitched_DA.quantile(0.99)
    ax = stitched_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
    ax.axes.set_title(f'{stitched_DA.sample_id}: {stitched_DA.exposure_time}')    
    plt.show()
    plt.close('all') 

### Apply rotation corrections

### pygix-backed reduction

In [70]:
mask_arr[30:, :].shape

(1649, 1475)

In [90]:
# pg_transformer = Transform(poniFile, mask_arr[30:, :], energy=10)
pg_transformer = Transform(poniFile, np.zeros_like(mask_arr[30:, :]), energy=10)

In [91]:
# Select the first element of the sorted set outside of the for loop to initialize the xr.DataSet
DA = stitched_DAs[0]
recip_DA, caked_DA = pg_transformer.pg_convert(DA)

# Save coordinates for interpolating other dataarrays 
recip_coords = recip_DA.coords
caked_coords = caked_DA.coords

# Create a DataSet, each DataArray will be named according to it's scan id
raw_DS = DA.to_dataset(name=DA.sample_id)
recip_DS = recip_DA.to_dataset(name=DA.sample_id)
caked_DS = caked_DA.to_dataset(name=DA.sample_id)

# Populate the DataSet with 
for DA in tqdm(stitched_DAs[1:], desc=f'Transforming Raw Data'):
    recip_DA, caked_DA = pg_transformer.pg_convert(DA)

    recip_DA = recip_DA.interp(recip_coords)
    caked_DA = caked_DA.interp(caked_coords)    

    raw_DS[f'{DA.sample_id}'] = DA
    recip_DS[f'{DA.sample_id}'] = recip_DA    
    caked_DS[f'{DA.sample_id}'] = caked_DA

# # Save zarr stores if selected
# if savePath and savename:
#     print('Saving zarrs...')
#     savePath = pathlib.Path(savePath)
#     raw_DS.to_zarr(savePath.joinpath(f'raw_{savename}.zarr'), mode='w')
#     recip_DS.to_zarr(savePath.joinpath(f'recip_{savename}.zarr'), mode='w')
#     caked_DS.to_zarr(savePath.joinpath(f'caked_{savename}.zarr'), mode='w')
#     print('Saved!')
# else:
#     print('No save path or no filename specified, not saving zarrs... ')

Transforming Raw Data:   0%|          | 0/13 [00:00<?, ?it/s]

In [None]:
qxy_min = -0.65
qxy_max = 2
qz_min = 0
qz_max = 2.2

# Plot check recip DAs
for DA in recip_DS.data_vars.values():
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min,qz_max))
    
    cmin = sliced_DA.quantile(0.18)
    cmax = sliced_DA.quantile(0.99)
    ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), cmap=cmap)
    ax.axes.set_title(f'{sliced_DA.sample_id}: {sliced_DA.exposure_time}')   
    ax.axes.set(aspect='equal')
    plt.show()
    plt.close('all') 

#### Yoneda check:

In [None]:
yoneda_angles

In [None]:
qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad(angles)))
qz_inv_angstroms = qz_inv_meters / 1e10

In [None]:
def qz(wavelength, alpha_crit, alpha_incidents):
    qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad((alpha_incidents + alpha_crit)/2)))
    # qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad(alpha_crit)) + np.sin(np.deg2rad(alpha_incidents)))
    qz_inv_angstroms = qz_inv_meters / 1e10
    return qz_inv_angstroms


wavelength = 9.762535309700809e-11  # 12.7 keV

alpha_crit = 0.11  # organic film critical angle
alpha_incidents = np.array([0.08, 0.1, 0.12, 0.15])

yoneda_angles = alpha_incidents + alpha_crit

qz(wavelength, alpha_crit, alpha_incidents)

In [None]:
def select_attrs(data_arrays_iterable, selected_attrs_dict):
    """
    Selects data arrays whose attributes match the specified values.

    Parameters:
    data_arrays_iterable: Iterable of xarray.DataArray objects.
    selected_attrs_dict: Dictionary where keys are attribute names and 
                         values are the attributes' desired values.

    Returns:
    List of xarray.DataArray objects that match the specified attributes.
    """    
    sublist = list(data_arrays_iterable)
    
    for attr_name, attr_values in selected_attrs_dict.items():
        sublist = [da for da in sublist if da.attrs[attr_name] in attr_values]
                
    return sublist

In [None]:
# 2D reciprocal space cartesian plots
qxy_min = -1.1
qxy_max = 2.1
qz_min = -0.01
qz_max = 2.2

selected_attrs_dict = {'material': ['PM6'], 'solvent': ['CBCN']}
# selected_attrs_dict = {}

selected_DAs = select_attrs(fixed_recip_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min, qz_max))
    
    real_min = float(sliced_DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min

    cmax = float(sliced_DA.compute().quantile(0.997))   
    
    # Plot
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, cmax), interpolation='antialiased', figsize=(5.5,3.3))
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
    # ax.axes.set(aspect='equal', title=f'Cartesian Plot: {DA.material} {DA.solvent} {DA.rpm}, {float(DA.incident_angle[2:])}° Incidence',
    #             xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.axes.set(aspect='equal', title=f'Cartesian Plot: {DA.material} {DA.solvent}, {float(DA.incident_angle[2:])}° Incidence',
                xlabel='q$_{xy}$ [Å$^{-1}$]', ylabel='q$_z$ [Å$^{-1}$]')
    ax.figure.set(tight_layout=True, dpi=130)
    
    # ax.figure.savefig(savePath.joinpath(f'{DA.material}-{DA.solvent}-{DA.rpm}_qxy{qxy_min}to{qxy_max}_qz{qz_min}to{qz_max}_{DA.incident_angle}.png'), dpi=150)
    # ax.figure.savefig(savePath.joinpath(f'{DA.material}-{DA.solvent}_qxy{qxy_min}to{qxy_max}_qz{qz_min}to{qz_max}_{DA.incident_angle}.png'), dpi=150)

    plt.show()
    plt.close('all')

In [None]:
# Yoneda peak linecut check
qxy_min = 0.22
qxy_max = 2
qz_min = -0.02
qz_max = 0.06

selected_DAs = select_attrs(fixed_recip_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min, qz_max))
    qz_integrated_DA = sliced_DA.sum('q_xy')
    
    # Plot
    qz_integrated_DA.plot.line(label=DA.incident_angle)
    
plt.legend()
plt.grid(visible=True, which='major', axis='x')
plt.show()

In [None]:
chi_min = 60
chi_max = None

selected_DAs = select_attrs(fixed_caked_DS.data_vars.values(), selected_attrs_dict)
for DA in tqdm(selected_DAs):
    # Slice dataarray to select plotting region 
    sliced_DA = DA.sel(chi=slice(chi_min,chi_max))
    
    # real_min = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 0.4)).compute().quantile(1e-3))
    real_min = float(DA.compute().quantile(0.05))
    cmin = 1 if real_min < 1 else real_min
    
    # cmax = float(DA.sel(q_xy=slice(-0.5, -0.1), q_z=slice(0.1, 2)).compute().quantile(1))   
    cmax = float(DA.compute().quantile(0.999))  
    
    # Plot sliced dataarray
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=plt.Normalize(cmin, 10), figsize=(5,4), interpolation='antialiased')  # plot, optional parameter interpolation='antialiased' for image smoothing
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)  # set colorbar label & parameters 
    ax.axes.set(title=f'Polar Plot: {DA.material} {DA.solvent}, {float(DA.incident_angle[2:])}° Incidence',
                xlabel='q$_r$ [Å$^{-1}$]', ylabel='$\chi$ [°]')  # set title, axis labels, misc
    ax.figure.set(tight_layout=True, dpi=130)  # Adjust figure dpi & plotting style
    
    plt.show()  # Comment to mute plotting output
    
    # Uncomment below line and set savepath/savename for saving plots, I usually like to check 
    # ax.figure.savefig(outPath.joinpath('PM6-Y6set_waxs', f'polar-2D_{DA.sample_id}_{chi_min}to{chi_max}chi_{DA.incident_angle}.png'), dpi=150)
    plt.close('all')

In [None]:
fixed_recip_DS.to_zarr(savePath.joinpath('fix_recip_stitched.zarr'), mode='w')

In [None]:
fixed_raw_DS.to_zarr(savePath.joinpath('fix_raw_stitched.zarr'), mode='w')

In [None]:
fixed_caked_DS.to_zarr(savePath.joinpath('fix_caked_stitched.zarr'), mode='w')

In [None]:
variable_raw_DS, variable_recip_DS, variable_caked_DS = phs.GIWAXS.single_images_to_dataset(variable_rpm_set, variable_rpm_loader, transformer)

In [None]:
variable_recip_DS.to_zarr(savePath.joinpath('var_recip_stitched.zarr'), mode='w')

In [None]:
variable_raw_DS.to_zarr(savePath.joinpath('var_raw_stitched.zarr'), mode='w')

In [None]:
variable_caked_DS.to_zarr(savePath.joinpath('var_caked_stitched.zarr'), mode='w')