# Ex situ GIWAXS processing

# GIWAXS raw data processing & exporting notebook
In this notebook you output xr.DataSets stored as .zarr stores containing all your raw,
remeshed (reciprocal space), and caked CMS GIWAXS data. Saving as a zarr automatically converts the array to a dask array

## Imports

In [None]:
### Imports:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import xarray as xr
import PyHyperScattering as phs
import fabio
import gc
from tqdm.auto import tqdm  # progress bar loader!

print(f"Don't worry about the above warnings/errors... using PyHyperScattering version: {phs.__version__}!")

## Defining some objects

### Define & check paths

In [None]:
sample_name = 'PM7_5CNCF'

# I like pathlib for its readability & checkability, it's also necessary for the loadSeries function later on
# Replace the paths with the ones relevant to your data, you can use the ".exists()" method to make sure you defined a path correctly
suitePath = pathlib.Path('/Users/andrew/Library/CloudStorage/OneDrive-UCB-O365/research/data_analysis/giwaxs_suite')
xenocsPath = suitePath.joinpath('raw_data/xenocs')
rawPath = xenocsPath.joinpath('2024_04_12')
dataPath = rawPath.joinpath(sample_name)
outPath = suitePath.joinpath('processed_data/xenocs')

# Select poni & mask filepaths
# poniFile = xenocsPath.joinpath('xenocs_100sdd.poni')
poniFile = xenocsPath.joinpath('xenocs_120sdd.poni')
# maskFile = rawPath.joinpath('Pilatus2M_mask.edf')

# Colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad('black')

In [None]:
pilatus1m = 0.000172
eiger2_1m = 0.000075

def poni_centers(poniFile, pix_size=eiger2_1m):
    """
    Returns poni center value and the corresponding pixel position. Default pixel size is 172 microns (Pilatus 1M)
    
    Inputs: poniFile as pathlib path object to the poni file
    Outputs: ((poni1, y_center), (poni2, x_center))
    """
    
    with poniFile.open('r') as f:
        lines = list(f.readlines())
    poni1_str = lines[6]
    poni2_str = lines[7]

    poni1 = float(poni1_str.split(' ')[1])
    poni2 = float(poni2_str.split(' ')[1])

    y_center = poni1 / pix_size
    x_center = poni2 / pix_size
        
    return ((poni1, y_center), (poni2, x_center))

display(poni_centers(poniFile))

In [None]:
[f.name for f in rawPath.glob('*')]

### Define metadata naming scheme & initialize loaders

In [None]:
# files = sorted(dataPath.glob('*vd*.edf'))
# files = sorted(dataPath.glob('*.edf'))
files = sorted(rawPath.glob('*.edf'))

[f.name for f in files]

In [None]:
# set ex situ metadata filename naming schemes:
# md_naming_scheme = ['material', 'solvent', 'misc', 'stitched', 'scan_ids']
# md_naming_scheme = ['material', 'anneal', 'solvent', 'misc', 'scan_id']


# Initalize CMSGIWAXSLoader objects with the above naming schemes
loader = phs.load.CMSGIWAXSLoader()

## Data processing
Break this section up however makes sense for your data

### Check raw data

In [None]:
# Single file

# file = sorted(dataPath.glob('*.edf'))[0]
file = files[0]

DA = loader.loadSingleImage(file)
DA

### Apply rotation corrections if necessary

In [None]:
# from scipy import ndimage

In [None]:
# # Interactively plot data of selected sample to identify point coordinates
# plt.close('all')
# DA = stitched_DAs[0].copy()

# cmin=DA.quantile(0.2)
# cmax=DA.quantile(0.99)
# ax = DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
# ax.axes.set_title(f'{DA.sample_id}: {DA.exposure_time}')    
# plt.show()

In [None]:
# pivot_poni = (713, 1251)
# def rotateImage(img, angle, pivot):
#     padX = [img.shape[1] - pivot[0], pivot[0]]
#     padY = [img.shape[0] - pivot[1], pivot[1]]
#     imgP = np.pad(img, [padY, padX], 'constant')
#     imgR = ndimage.rotate(imgP, angle, reshape=False)
#     return imgR[padY[0] : -padY[1], padX[0] : -padX[1]]

In [None]:
# # Apply rotation corrections based on lines drawn between points

# rot_corr_DAs = []
# for DA in tqdm(stitched_DAs, desc='Rotating...'):  
#     # Get line points from dictionary
#     p1, p2 = bottom_line_points[f'{DA.sample_id}_{DA.exposure_time}']

#     # Calculate the angle from points
#     dx = p2[0] - p1[0]
#     dy = p2[1] - p1[1]
#     angle_radians = np.arctan2(dy, dx)
#     angle_degrees = np.degrees(angle_radians)

#     # Rotate image & save into list
#     rot_corr_DA = xr.apply_ufunc(rotateImage, DA, angle_degrees, pivot_poni)
#     # rot_corr_DA = xr.apply_ufunc(ndimage.rotate, DA, angle_degrees, (1, 0), False, None, 3, 'constant')
#     # rot_corr_DA = xr.apply_ufunc(ndimage.rotate, DA, 0, (1, 0), False)
#     rot_corr_DA.attrs = DA.attrs
#     rot_corr_DAs.append(rot_corr_DA)
    
# # Plot check
# for rot_corr_DA in rot_corr_DAs:
#     cmin = rot_corr_DA.quantile(0.2)
#     cmax = rot_corr_DA.quantile(0.99)
#     rot_corr_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), origin='upper', cmap=cmap)
#     plt.show()
#     plt.close('all')

### pygix-backed reduction

In [None]:
arr = np.array([[1,2],[3,4]])
arr2 = arr.copy()
display(arr)
display((arr+arr2)/2)

In [None]:
file

In [None]:
fab_data.data.shape

In [None]:
fab_data = fabio.open(file)
fab_data.header

In [None]:
mask_arr = fabio.open(file).data.astype('bool')
print(mask_arr.shape)

In [None]:
# pg_transformer = Transform(poniFile, mask_arr[30:, :], energy=8.04)
pg_transformer = Transform(poniFile, np.zeros_like(mask_arr), energy=8.04)

In [None]:
# Transform single DA
# DA = rot_corr_DAs[0]
DA.attrs['incident_angle'] = 'th0.15'

recip_DA, caked_DA = pg_transformer.pg_convert(DA)

# # Create a DataSet, each DataArray will be named according to it's scan id
# raw_DS = DA.to_dataset(name=f'{DA.material}_{DA.anneal}_{DA.solvent}')
# recip_DS = recip_DA.to_dataset(name=f'{DA.material}_{DA.anneal}_{DA.solvent}')
# caked_DS = caked_DA.to_dataset(name=f'{DA.material}_{DA.anneal}_{DA.solvent}')

# Create a DataSet, each DataArray will be named according to it's scan id
raw_DS = DA.to_dataset(name=f'{DA.material}_{DA.solvent}')
recip_DS = recip_DA.to_dataset(name=f'{DA.material}_{DA.solvent}')
caked_DS = caked_DA.to_dataset(name=f'{DA.material}_{DA.solvent}')

In [None]:
# # Select the first element of the sorted set outside of the for loop to initialize the xr.DataSet
# # DA = stitched_DAs[0]
# DA = rot_corr_DAs[0]
# recip_DA, caked_DA = pg_transformer.pg_convert(DA)

# # Save coordinates for interpolating other dataarrays 
# recip_coords = recip_DA.coords
# caked_coords = caked_DA.coords

# # Create a DataSet, each DataArray will be named according to it's scan id
# raw_DS = DA.to_dataset(name=f'{DA.sample_id}_{DA.exposure_time}')
# recip_DS = recip_DA.to_dataset(name=f'{DA.sample_id}_{DA.exposure_time}')
# caked_DS = caked_DA.to_dataset(name=f'{DA.sample_id}_{DA.exposure_time}')

# # Populate the DataSet with 
# # for DA in tqdm(stitched_DAs[1:], desc=f'Transforming Raw Data'):
# for DA in tqdm(rot_corr_DAs[1:], desc=f'Transforming Raw Data'):
#     recip_DA, caked_DA = pg_transformer.pg_convert(DA)

#     recip_DA = recip_DA.interp(recip_coords)
#     caked_DA = caked_DA.interp(caked_coords)    

#     raw_DS[f'{DA.sample_id}_{DA.exposure_time}'] = DA
#     recip_DS[f'{DA.sample_id}_{DA.exposure_time}'] = recip_DA    
#     caked_DS[f'{DA.sample_id}_{DA.exposure_time}'] = caked_DA

In [None]:
%matplotlib widget

In [None]:
plt.close('all')

In [None]:
outPath

In [None]:
qxy_min = -0.65
qxy_max = 2
qz_min = 0
qz_max = 2.2

# Plot check recip DAs
for DA in recip_DS.data_vars.values():
    sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min,qz_max))
    
    cmin = sliced_DA.quantile(0.1)
    cmax = sliced_DA.quantile(0.993)
    ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin,cmax), cmap=cmap, figsize=(5,3.5))
    ax.axes.set_title(f'{sliced_DA.material} {sliced_DA.solvent}')   
    ax.axes.set(aspect='equal', xlabel='$Q_{xy}$ $[\AA^{-1}]$', ylabel='$Q_{z}$ $[\AA^{-1}]$')
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
    
    # Save
    savePath = outPath.joinpath('recip_images')
    savePath.mkdir(exist_ok=True)
    ax.figure.savefig(savePath.joinpath(f'{sliced_DA.material}_{sliced_DA.solvent}'), dpi=120)
    
    plt.show()
    plt.close('all') 

In [None]:
sliced_recip_DA = recip_DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min,qz_max))

display(sliced_recip_DA.data.shape)
display(sliced_recip_DA.q_xy.data.shape)
display(sliced_recip_DA.q_z.data.shape)

In [None]:
npysPath = outPath.joinpath('npys')
savePath = npysPath.joinpath(f'{sample_name}')
savePath.mkdir(exist_ok=True)

np.save(savePath.joinpath(f'data_{sample_name}.npy'), sliced_recip_DA.data)
np.save(savePath.joinpath(f'qxy_{sample_name}.npy'), sliced_recip_DA.q_xy.data)
np.save(savePath.joinpath(f'qz_{sample_name}.npy'), sliced_recip_DA.q_z.data)

In [None]:
plt.close('all')

In [None]:
data = np.load(savePath.joinpath(f'data_{sample_name}.npy'))
plt.imshow(data, norm=plt.Normalize(cmin,cmax), origin='lower')
plt.show()

### Yoneda check:

In [None]:
def qz(wavelength, alpha_crit, alpha_incidents):
    qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad((alpha_incidents + alpha_crit)/2)))
    # qz_inv_meters = ((4 * np.pi) / (wavelength)) * (np.sin(np.deg2rad(alpha_crit)) + np.sin(np.deg2rad(alpha_incidents)))
    qz_inv_angstroms = qz_inv_meters / 1e10
    return qz_inv_angstroms


# wavelength = 9.762535309700809e-11  # 12.7 keV
wavelength = 1.2398419843320025e-10  # 10 keV

alpha_crit = 0.11  # organic film critical angle
alpha_incidents = np.array([0.14])

yoneda_angles = alpha_incidents + alpha_crit

qz(wavelength, alpha_crit, alpha_incidents)

In [None]:
# %matplotlib widget

In [None]:
# ax = recip_DS['sam6_2s'].plot.imshow(norm=plt.Normalize(cmin,cmax), cmap=cmap)
# ax.axes.set_title(f'{sliced_DA.sample_id}: {sliced_DA.exposure_time}')   
# ax.axes.set(aspect='equal')
# plt.show()

In [None]:
# def select_attrs(data_arrays_iterable, selected_attrs_dict):
#     """
#     Selects data arrays whose attributes match the specified values.

#     Parameters:
#     data_arrays_iterable: Iterable of xarray.DataArray objects.
#     selected_attrs_dict: Dictionary where keys are attribute names and 
#                          values are the attributes' desired values.

#     Returns:
#     List of xarray.DataArray objects that match the specified attributes.
#     """    
#     sublist = list(data_arrays_iterable)
    
#     for attr_name, attr_values in selected_attrs_dict.items():
#         sublist = [da for da in sublist if da.attrs[attr_name] in attr_values]
                
#     return sublist

In [None]:
# # Yoneda peak linecut check
# qxy_min = 0.22
# qxy_max = 2
# qz_min = -0.02
# qz_max = 0.06

# selected_DAs = select_attrs(fixed_recip_DS.data_vars.values(), selected_attrs_dict)
# for DA in tqdm(selected_DAs):
#     # Slice data for selected q ranges (will need to rename q_xy if dimensions are differently named)
#     sliced_DA = DA.sel(q_xy=slice(qxy_min, qxy_max), q_z=slice(qz_min, qz_max))
#     qz_integrated_DA = sliced_DA.sum('q_xy')
    
#     # Plot
#     qz_integrated_DA.plot.line(label=DA.incident_angle)
    
# plt.legend()
# plt.grid(visible=True, which='major', axis='x')
# plt.show()

### Save zarrs

In [None]:
savePath = outPath.joinpath('zarrs')

In [None]:
recip_DS.to_zarr(savePath.joinpath('recip_stitched.zarr'), mode='w')

In [None]:
raw_DS.to_zarr(savePath.joinpath('raw_stitched.zarr'), mode='w')

In [None]:
caked_DS.to_zarr(savePath.joinpath('caked_stitched.zarr'), mode='w')