# Processing notebook 

## Imports

In [None]:
import PyHyperScattering as phs
import pathlib
import os
import gc
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from PIL import Image
from tqdm.auto import tqdm 

import fabio
from smi_analysis import SMI_beamline

print(phs.__version__)

## Define paths & functions

In [None]:
sample_name_dict = {
    14: 'PM6_CB',
    17: 'PM6_1CN-CB',
    18: 'PM6_5CN-CB',
    21: 'PM6_p5CN-CB',
    22: 'PM6-Y6_CB',
    23: 'PM6-Y6BO_CB',
    26: 'PM6_CF',
    29: 'PM6_1CN-CF',
    30: 'PM6_5CN-CF',
    33: 'PM6_p5CN-CF',
    34: 'PM6-Y6_CF',
    35: 'PM6-Y6BO_CF',
    1: 'BareSiN_01',
    3: 'BareSiN_03'
}

In [None]:
propPath = pathlib.Path('/nsls2/data/smi/proposals/2024-1/pass-314903')

rawPaths1 = propPath.joinpath('raw_01')
saxsPath1 = rawPaths1.joinpath('1M')
waxsPath1 = rawPaths1.joinpath('900KW')

rawPaths2 = propPath.joinpath('raw_02')
saxsPath2 = rawPaths2.joinpath('1M')
waxsPath2 = rawPaths2.joinpath('900KW')

rawPaths3 = propPath.joinpath('raw_03')
saxsPath3 = rawPaths3.joinpath('1M')
waxsPath3 = rawPaths3.joinpath('900KW')

rawPaths4 = propPath.joinpath('raw_04')
saxsPath4 = rawPaths4.joinpath('1M')
waxsPath4 = rawPaths4.joinpath('900KW')

# analysisPath = pathlib.Path('/nsls2/users/alevin/rsoxs_suite/sst1_notebooks/SMI_tender_scattering/analysis_02')
# reducedPath = analysisPath.joinpath('reduced_waxs')

In [None]:
# solnPaths = [saxsPath1, saxsPath3]
# filmPaths = [saxsPath2, saxsPath4]
# # for saxsPath in solnPaths:
# #     all_saxs = set(saxsPath.glob('*.tif'))
# #     test_saxs = set(saxsPath.glob('test*'))
# #     display(sorted(set([f.name[:f.name.find('_sdd1.8')] for f in all_saxs.difference(test_saxs)])))

### SMI loading function

In [None]:
#waxs
path = '/nsls2/data/smi/legacy/results/data/%s/%s/1M/'%('2024_1', '314483_Freychet_12')

sam, sam1 = [], []
for file in sorted(os.listdir(path)):
     if 'wa0' in file and 'ai1.60' in file and '2800.00eV' in file:
        idx = file.find('pos')
        if file[:idx] not in sam:
            sam = sam + [file[:idx]]

all_dat = [[]] * len(sam)
all_da = [[]] * len(sam)
    
for j, sa in enumerate(sam): 
    for file in sorted(os.listdir(path)):
        if sa in file and 'tif' in file and 'ai1.60' in file:
            all_dat[j] = all_dat[j] + [file]
            
for i, all_d in enumerate(all_dat[0]):
    img=fabio.open(os.path.join(path, all_d)).data
    if i==0:
        img_sum=np.zeros(np.shape(img))
    img_sum += img

idx_mask = np.where(img_sum>0)

In [None]:
def saxs_SMI_numpy_loading(path, filename_list):
    """
    Function adapted from Guillames SMI notebooks to process raw waxs tiffs:
    
    Returns filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents
    """
    
    #Saxs
    geometry = 'Transmission'
    energy = 2.450
    wav = 1E-10 * (12.398/energy)
    bs_kind = 'pindiode'
    alphai = np.deg2rad(0)

    #SAXS
    detector_waxs = 'Pilatus1m'
    sdd_waxs = 1800
    center_waxs = [354, 560]
    bs_pos_waxs = [[354, 535]]
    
    # flatPath = pathlib.Path('/nsls2/data/smi/legacy/results/analysis/2024_1/314483_Freychet_04')
    # flatfield = np.rot90(fabio.open(flatPath.joinpath('GF_Flatfield_Sedge_uhighg1600eV_10s_wa20_2477eV_pffBT4T_id481136_000000_WAXS.tif')).data, 1)

    
    filename_wa0_list = []
    recip_list = []
    recip_extents = []
    caked_list = []
    caked_extents = []
    for dat in tqdm(filename_list, desc='Processing tiffs'):
        
#             waxs_angle = [np.deg2rad(-0.06), np.deg2rad(19.7-0.06)]
#             # print(dat)

#             idx = dat[0].find('eV')
#             energy = 0.001*float(dat[0][idx-7:idx])
#             # print(energy)
#             wav = 1E-10 * (12.398/energy)

            #This part is to stitch the data
            SMI_waxs = SMI_beamline.SMI_geometry(geometry = geometry,
                                                 detector = detector_waxs,
                                                 sdd = sdd_waxs,
                                                 wav = wav,
                                                 alphai = 0,
                                                 center = center_waxs,
                                                 bs_pos = bs_pos_waxs,
                                                 det_angles = [0],
                                                 bs_kind = bs_kind)


            # print(dat)
            SMI_waxs.open_data(path, [dat], optional_mask='tender')
            SMI_waxs.masks[0][560:, 337:350]=True
            
            for da in [dat]:
                img=fabio.open(os.path.join(path, da)).data
                SMI_waxs.imgs[0]+=img

            SMI_waxs.masks[0][idx_mask]=True
            SMI_waxs.masks[0][570:, 337:342]=True

            SMI_waxs.masks[0][835:, 488:616]=True
            SMI_waxs.masks[0][372:414, 857:920]=True
            SMI_waxs.masks[0][370:410, 560:600]=True

            SMI_waxs.stitching_data(interp_factor=3, flag_scale=False)
            SMI_waxs.caking()
            
            filename_wa0_list.append(dat)
            recip_list.append(SMI_waxs.img_st)
            recip_extents.append([SMI_waxs.qp[0], SMI_waxs.qp[-1], SMI_waxs.qz[0], SMI_waxs.qz[-1]])
            
            caked_list.append(SMI_waxs.cake)
            caked_extents.append([SMI_waxs.q_cake[0], SMI_waxs.q_cake[-1], SMI_waxs.chi_cake[0], SMI_waxs.chi_cake[-1]])
            
    return filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents

## Load data

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

### Films

In [None]:
unique_sample_numbers = sorted(set([f.name.split('_')[2] for f in saxsPath4.glob('*')]))
unique_sample_rotations = sorted(set([f.name.split('_')[3] for f in saxsPath4.glob('*')]))

In [None]:
all_saxs = set(saxsPath4.glob('*.tif'))
# test_saxs = set(saxsPath4.glob('test*'))
sample_names = sorted(set([f.name[3:f.name.find('_sdd1.8')] for f in all_saxs]))
sample_names

In [None]:
# filename_sublists

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

In [None]:
for sample_number in tqdm(unique_sample_numbers[1:], desc='Samples'):
    for sample_rotation in tqdm(unique_sample_rotations, desc='Angles'):

        # Select files for a given sample and rotation
        all_paths = set(saxsPath4.glob(f'*Trmsn_{sample_number}*{sample_rotation}*'))
        dmg_paths = set(saxsPath4.glob(f'*Trmsn_{sample_number}*{sample_rotation}*damage*'))
        
        # For now only select first scans, not the damage test repeats
        filename_list = [f.name for f in sorted(all_paths.difference(dmg_paths))]

        # Run SMI loading code (this produces some fabio and divide by zero errors)
        names_list, recip_list, recip_extents, caked_list, caked_extents = saxs_SMI_numpy_loading(saxsPath4, filename_list)

        # Define naming scheme:
        waxs_naming_scheme = ['project', 'sample_type', 'sample_number', 'rotation_from_normal', 'set_sdd', 'energy', 'waxs_det_position',
                              'bpm', 'id', 'misc', 'detector']
        md_naming_scheme = waxs_naming_scheme.copy()


        # Construct xarrays with full values along detector dimensions and the energy dimension
        # They contain sample name and theta value as well, as single values to be concatenated in later steps
        recip_DA_rows = []
        caked_DA_rows = []
        zipped_lists = zip(names_list, recip_list, recip_extents, caked_list, caked_extents)
        for filename, recip_arr, recip_extent, caked_arr, caked_extent in zipped_lists:
            # print(filename)
            # print(recip_arr.shape)
            # print(recip_extent)
            # print(caked_arr.shape)
            # print(caked_extent)

            attr_dict = {}
            md_list = filename.split('_')
            for i, md_item in enumerate(md_naming_scheme):
                attr_dict[md_item] = md_list[i]

            recip_DA = xr.DataArray(data = recip_arr, 
                                    dims = ['pix_y', 'pix_x'],
                                    attrs = attr_dict)
            recip_DA = recip_DA.assign_coords({
                'pix_x': recip_DA.pix_x.data,
                'pix_y': recip_DA.pix_y.data,
                'q_x': ('pix_x', np.linspace(recip_extent[0], recip_extent[1], len(recip_DA.pix_x.data))),
                'q_y': ('pix_y', np.linspace(recip_extent[3], recip_extent[2], len(recip_DA.pix_y.data)))
            })
            recip_DA = recip_DA.expand_dims({
                'energy': [float(recip_DA.energy[:-2])],
                'sample_name': [sample_name_dict[float(recip_DA.sample_number)]],
                'theta': [90 - float(recip_DA.rotation_from_normal[3:-3])]
            })
            recip_DA_rows.append(recip_DA)

            caked_DA = xr.DataArray(data = caked_arr, 
                                    dims = ['index_y', 'index_x'],
                                    attrs = attr_dict)
            caked_DA = caked_DA.assign_coords({
                'index_x': caked_DA.index_x.data,
                'index_y': caked_DA.index_y.data,
                'q_r': ('index_x', np.linspace(caked_extent[0], caked_extent[1], len(caked_DA.index_x.data))),
                'chi': ('index_y', np.linspace(caked_extent[3], caked_extent[2], len(caked_DA.index_y.data)))
            }) 
            caked_DA = caked_DA.expand_dims({
                'energy': [float(caked_DA.energy[:-2])],
                'sample_name': [sample_name_dict[float(caked_DA.sample_number)]],
                'theta': [90 - float(caked_DA.rotation_from_normal[3:-3])]
            })
            caked_DA_rows.append(caked_DA)

        recip_DA = xr.concat(recip_DA_rows, 'energy')
        caked_DA = xr.concat(caked_DA_rows, 'energy')

        # Save sample zarr, load later to concatenate full zarr
        sampleZarrsPath = propPath.joinpath('processed_data/zarrs/saxs_core_films_trexs_sample_zarrs_v2')

        recip_samp_zarr_name = 'recip_'+recip_DA.sample_name.values[0]+'_'+str(int(recip_DA.theta.values[0]))+'deg.zarr'
        recip_DS = recip_DA.to_dataset(name='flatfield_corr')
        recip_DS.to_zarr(sampleZarrsPath.joinpath(recip_samp_zarr_name), mode='w')

        caked_samp_zarr_name = 'caked_'+caked_DA.sample_name.values[0]+'_'+str(int(caked_DA.theta.values[0]))+'deg.zarr'
        caked_DS = caked_DA.to_dataset(name='flatfield_corr')
        caked_DS.to_zarr(sampleZarrsPath.joinpath(caked_samp_zarr_name), mode='w')
        
        gc.collect()

In [None]:
recip_DA

### Solutions

In [None]:
all_saxs = set(saxsPath1.glob('*.tif'))
test_saxs = set(saxsPath1.glob('test*'))
sample_names = sorted(set([f.name[3:f.name.find('_sdd1.8')] for f in all_saxs.difference(test_saxs)]))
sample_names

In [None]:
saxs_63 = []

print('Before:')
for sample_name in sample_names[:-1]:
    files = sorted(saxsPath1.glob(f'CM_{sample_name}_sdd*'))
    files_number = len(files)
    print(files_number)
    if files_number == 63:
        saxs_63.append(files)
    else:
        file_energies = []
        kept_files = []

        for file in files:
            file_energy = file.name[:file.name.find('eV')].split('_')[-1]
            if file_energy in file_energies:
                pass
            else:
                file_energies.append(file_energy)
                kept_files.append(file)

        saxs_63.append(kept_files)  

print('After (in saxs_63):')
for i in range(len(saxs_63)):
    print(len(saxs_63[i]))

In [None]:
# Update sample names

sample_names = []
for folder in saxs_63:
    sample_name = folder[0].name[3:folder[0].name.find('_sdd1.8')]
    sample_names.append(sample_name)

display(sample_names)

In [None]:
# Make file sets & define unique sample names
for i, sample_name in enumerate(tqdm(sample_names, desc='Samples')):
    # Select files for a given sample and rotation
    # all_paths = set(saxsPath.glob(f'*_{sample_name}_*'))
    all_paths = saxs_63[i]
    
    # Select first scans
    filename_list = [f.name for f in sorted(all_paths)]

    # # Make sublists to stitch two waxs positions together
    # group_size = 2
    # filename_sublists = [filename_list[i:i + group_size] for i in range(0, len(filename_list), group_size)]

    # Run SMI loading code (this produces some fabio and divide by zero errors)
    names_list, recip_list, recip_extents, caked_list, caked_extents = saxs_SMI_numpy_loading(
        saxsPath1, filename_list)

    # # Define naming scheme:
    # saxs_naming_scheme = ['project', 'sample_type', 'sample_number', 'rotation_from_normal', 'set_sdd', 'energy', 'waxs_det_position',
    #                       'bpm', 'id', 'misc', 'detector']
    # md_naming_scheme = saxs_naming_scheme.copy()


    # Construct xarrays with full values along detector dimensions and the energy dimension
    # They contain sample name and theta value as well, as single values to be concatenated in later steps
    recip_DA_rows = []
    caked_DA_rows = []
    zipped_lists = zip(names_list, recip_list, recip_extents, caked_list, caked_extents)
    for filename, recip_arr, recip_extent, caked_arr, caked_extent in zipped_lists:
        # print(filename)
        # print(recip_arr.shape)
        # print(recip_extent)
        # print(caked_arr.shape)
        # print(caked_extent)

        attr_dict = {}
        attr_dict['filename'] = filename
        sample_name = filename[3:filename.find('_sdd')]
        energy = float(filename[:filename.find('eV')].split('_')[-1])
        # md_list = filename.split('_')
        # for i, md_item in enumerate(md_naming_scheme):
        #     attr_dict[md_item] = md_list[i]

        recip_DA = xr.DataArray(data = recip_arr, 
                                dims = ['pix_y', 'pix_x'],
                                attrs = attr_dict)
        recip_DA = recip_DA.assign_coords({
            'pix_x': recip_DA.pix_x.data,
            'pix_y': recip_DA.pix_y.data,
            'q_x': ('pix_x', np.linspace(recip_extent[0], recip_extent[1], len(recip_DA.pix_x.data))),
            'q_y': ('pix_y', np.linspace(recip_extent[3], recip_extent[2], len(recip_DA.pix_y.data)))
        })
        recip_DA = recip_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        recip_DA_rows.append(recip_DA)

        caked_DA = xr.DataArray(data = caked_arr, 
                                dims = ['index_y', 'index_x'],
                                attrs = attr_dict)
        caked_DA = caked_DA.assign_coords({
            'index_x': caked_DA.index_x.data,
            'index_y': caked_DA.index_y.data,
            'q_r': ('index_x', np.linspace(caked_extent[0], caked_extent[1], len(caked_DA.index_x.data))),
            'chi': ('index_y', np.linspace(caked_extent[3], caked_extent[2], len(caked_DA.index_y.data)))
        }) 
        caked_DA = caked_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        caked_DA_rows.append(caked_DA)

    recip_DA = xr.concat(recip_DA_rows, 'energy')
    caked_DA = xr.concat(caked_DA_rows, 'energy')

    # Save sample zarr, load later to concatenate full zarr
    sampleZarrsPath = propPath.joinpath('processed_data/saxs_solution_trexs_sample_zarrs_v2')

    recip_samp_zarr_name = 'recip_'+sample_name+'.zarr'
    recip_DS = recip_DA.to_dataset(name='flatfield_corr')
    recip_DS.to_zarr(sampleZarrsPath.joinpath(recip_samp_zarr_name), mode='w')

    caked_samp_zarr_name = 'caked_'+sample_name+'.zarr'
    caked_DS = caked_DA.to_dataset(name='flatfield_corr')
    caked_DS.to_zarr(sampleZarrsPath.joinpath(caked_samp_zarr_name), mode='w')

## Draw/check data & beamcenters & data

In [None]:
# Define paths
propPath = pathlib.Path('/nsls2/data/smi/proposals/2024-1/pass-314903')
outPath = propPath.joinpath('processed_data/trexs_plots')
sampleZarrsPath = propPath.joinpath('processed_data/zarrs/saxs_core_films_trexs_sample_zarrs')

# rawPaths = propPath.joinpath('raw_04')
# waxsPath = rawPaths.joinpath('900KW')

In [None]:
sampleZarrsPath.exists()

In [None]:
unique_sample_names = sorted(set(['_'.join(f.name.split('_')[1:3]) for f in sampleZarrsPath.glob('*')]))
unique_sample_names

In [None]:
recip_DS_rows = []
# caked_DS_rows = []
for sample_name in tqdm(unique_sample_names):
    sample_zarrs = sorted(sampleZarrsPath.glob(f'*{sample_name}*'))
    # display(sorted([f.name for f in sample_zarrs]))
    
    samp_recip_DS_rows = []
    # samp_caked_DS_rows = []
    for sample_zarr in sample_zarrs:
        if 'recip_' in sample_zarr.name:
            recip_DS = xr.open_zarr(sample_zarr)
            samp_recip_DS_rows.append(recip_DS)
        # elif 'caked_' in sample_zarr.name:
        #     caked_DS = xr.open_zarr(sample_zarr)
        #     samp_caked_DS_rows.append(caked_DS)
            
    recip_DS = xr.concat(samp_recip_DS_rows, 'theta')
    recip_DS_rows.append(recip_DS)
    
    # caked_DS = xr.concat(samp_caked_DS_rows, 'theta')
    # caked_DS_rows.append(caked_DS)
    
recip_DS = xr.concat(recip_DS_rows, 'sample_name')
# caked_DS = xr.concat(caked_DS_rows, 'sample_name')

In [None]:
recip_DS = recip_DS.chunk({'sample_name':1, 'pix_y': 3129, 'pix_x': 2943, 'energy':63,})
# caked_DS = caked_DS.chunk({'sample_name':1, 'index_y':500,'index_x':500,'energy':63})

In [None]:
# recip_DS = recip_DS.swap_dims({'pix_y':'q_y', 'pix_x':'q_x'})
# caked_DS = caked_DS.swap_dims({'index_y':'chi', 'index_x':'q_r'})
recip_DS

In [None]:
def make_para_perp_DAs(DS, sample_name, theta=90, chi_width=90):
    # select dataarray to plot
    DA = DS.sel(sample_name=sample_name)['flatfield_corr']
    sel_DA = DA.sel(theta=theta)

    # calculate ISI dataarrays
    para_DA = sel_DA.rsoxs.slice_chi(0, chi_width=(chi_width/2))
    perp_DA = sel_DA.rsoxs.slice_chi(-90, chi_width=(chi_width/2))
        
    return para_DA, perp_DA

In [None]:
# Set colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad(cmap.get_under())

# # Choose a sample dataarray:
# bare_sin_DA = DS.sel(sample_name='BareSiN_1mm')
# print(DS.sample_name.values)
# sample_name = 'Y6_CB_2500'

### 1. Check raw images at a selected energy for all loaded scan configurations:

In [None]:
# DS = recip_DS

# sample_name = 'PM6_5CN-CB'
# sample_DA = DS['raw_intensity'].sel(sample_name=sample_name)
# # sample_DA = loaded_DS['raw_intensity'].sel(sample_name=sample_name)


# fg = sample_DA.sel(polarization=pol, method='nearest').sel(energy=energies, method='nearest').sel(
#             pix_x=slice(160, 780), pix_y=slice(240, 800)).plot.imshow(figsize=(18, 6),
#                 col='energy', col_wrap=4, norm=LogNorm(3e1, 1e4), cmap=cmap, x='qx', y='qy')
# fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
# fg.fig.suptitle(f'{str(sample_DA.sample_name.data)},  Polarization = {pol}°', y=1.02)
# for axes in fg.axs.flatten():
#     axes.set(aspect='equal')

# plt.show()

In [None]:
# sample_name = 'PM6-Y6BO_CF'
# corr_sample_DA = DS['corr_intensity'].sel(sample_name=sample_name)

# # energies = [270, 280, 282, 283, 284, 285, 286, 290]
# energies = np.round(np.linspace(280, 290, 8), 1)  # carbon
# # energies = np.round(np.linspace(380, 440, 8), 1)  # nitrogen

# pol = 0

# fg = corr_sample_DA.sel(polarization=pol, method='nearest').sel(energy=energies, method='nearest').sel(
#             pix_x=slice(160, 780), pix_y=slice(240, 800)).plot.imshow(figsize=(18, 6), x='qx', y='qy',
#                 col='energy', col_wrap=4, norm=LogNorm(5e8, 1e11), cmap=cmap)
# fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
# fg.fig.suptitle(f'{str(sample_name)},  Polarization = {pol}°', y=1.02)
# for axes in fg.axs.flatten():
#     axes.set(aspect='equal')

# plt.show()

### 2. Draw masks

In [None]:
# Define example image for mask & initialize phs DrawMask object:
sample_name = 'BareSiN_01'
sample_DA = recip_DS['flatfield_corr'].sel(sample_name=sample_name)

In [None]:
sample_DA

In [None]:
## WAXS mask:
waxs_mask_img = sample_DA.sel(theta=90, energy=2550, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(waxs_mask_img, clim=(30, 1e3))
draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)

draw.ui()

In [None]:
outPath

In [None]:
## Save and load saxs drawn mask
draw.save(outPath.joinpath('SMI_saxs_mask_v1_test.json'))

### 3. Check beamcenters

In [None]:
# Define example image for mask & initialize phs DrawMask object:
sample_name = 'PM6-Y6BO_CF'
sample_DA = DS['raw_intensity'].sel(sample_name=sample_name)

In [None]:
# sample_DA.attrs['beamcenter_x'] = 450
# sample_DA.attrs['beamcenter_y'] = 510

energy = 250
# energy = 400
# energy = 532

waxs_mask_img = sample_DA.sel(polarization=0, energy=energy, method='nearest').compute()
draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)

# Load masks:
draw.load(maskPath.joinpath('2023C3_full_length_masks.json'))
waxs_mask = draw.mask

# Check masks:
ax = waxs_mask_img.plot.imshow(norm=LogNorm(3e1, 1e4), cmap=cmap)
ax.axes.imshow(waxs_mask, alpha=0.5, origin='lower')
# ax.axes.imshow(WAXSinteg.mask, alpha=0.5, origin='lower')
plt.show()

In [None]:
# Initalize PFEnergySeriesIntegrator object & check beamcenter & masks
# WAXS
WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr=sample_DA.sel(polarization=0))
WAXSinteg.mask = waxs_mask
WAXSinteg.ni_beamcenter_x = waxs_mask_img.beamcenter_x
WAXSinteg.ni_beamcenter_y = waxs_mask_img.beamcenter_y
print('WAXS Beamcenter: \n'
      f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')

# Plot check
phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=1e4, alpha=0.5)
plt.xlim(WAXSinteg.ni_beamcenter_x-250, WAXSinteg.ni_beamcenter_x+250)
plt.ylim(WAXSinteg.ni_beamcenter_y-250, WAXSinteg.ni_beamcenter_y+250)
plt.gcf().set(dpi=120)
plt.show()

In [None]:
# ## Tweaking if needed:

# ## WAXS Tweaking & Plot Check
# waxs_new_bcx = 396.3
# waxs_new_bcy = 553
# WAXSinteg.ni_beamcenter_x = waxs_new_bcx
# WAXSinteg.ni_beamcenter_y = waxs_new_bcy
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# raw_waxs.attrs['poni1'] = WAXSinteg.poni1
# raw_waxs.attrs['poni2'] = WAXSinteg.poni2

# print('WAXS Beamcenter Tweaking: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6, guide1=40)
# plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# plt.gcf().set(dpi=120)
# plt.show()


## Using Pete D.'s (very slightly modified) beamcentering script:
# phs.BeamCentering.CenteringAccessor.refine_geometry

# ## WAXS
# # res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06, chi_min=-10, chi_max=70)
# res_waxs = sample_DA.sel(polarization=0).util.refine_geometry(energy=270, q_min=0.02, q_max=0.06)
# sample_DA.attrs['poni1'] = res_waxs.x[0]
# sample_DA.attrs['poni2'] = res_waxs.x[1]
# WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = sample_DA.sel(polarization=0))
# WAXSinteg.mask = waxs_mask

# ## WAXS Plot check
# print('WAXS Beamcenter Post-optimization: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=1e5, alpha=0.4)
# plt.xlim(WAXSinteg.ni_beamcenter_x-100, WAXSinteg.ni_beamcenter_x+100)
# plt.ylim(WAXSinteg.ni_beamcenter_y-100, WAXSinteg.ni_beamcenter_y+100)
# plt.gcf().set(dpi=120)
# plt.show()

In [None]:
# ### Write beamcenters to saved .json file if content with them:

# beamcenters_dict = {
#     f'WAXS_2023C2': {'bcx':sample_DA.beamcenter_x, 'bcy':sample_DA.beamcenter_y}
# }

# # Check if the file exists, if not, create an empty JSON file
# jsonFile = jsonPath.joinpath('beamcenters_dict.json')
# if not jsonFile.exists():
#     with jsonFile.open('w') as f:
#         json.dump({}, f)

# # Now, read the existing or empty JSON file
# with jsonFile.open('r') as f:
#     dic = json.load(f)

# dic.update(beamcenters_dict)

# # Write the updated dictionary back to the JSON file
# with jsonFile.open('w') as f:
#     json.dump(dic, f)

In [None]:
# # Make mask DataArray:
# mask_DA = xr.DataArray(data=waxs_mask, dims=['pix_y', 'pix_x'])

# # Create Dataset of rsoxs_carbon and add the mask as a data variable
# DS = DA.to_dataset()
# DS['mask'] = mask_DA
# display(DS)

## Convert to chi-q space & save zarrs

In [None]:
# Integrate whole cartesian dataset!
polar_DS_sample_rows = []
for sample_name in tqdm(DS.sample_name.data):
# for sample_name in tqdm(['BareSiN', 'A3_3000_dSiN_01', 'BareAlO', 'PM6_3000_dSiN', 'PM6-Y7_3000_dSiN']):
    polar_DS = xr.Dataset()
    for intensity in ['corr_intensity']:
        polar_DA_polarization_rows = []
        for pol in [0, 90]:
            cart_DA = DS[intensity].sel(polarization=pol, sample_name=sample_name)
            polar_DA = WAXSinteg.integrateImageStack_dask(cart_DA)
            # polar_DA = WAXSinteg.integrateImageStack(cart_DA)
            polar_DA = polar_DA.expand_dims({'polarization': [pol]})
            polar_DA_polarization_rows.append(polar_DA)
        
        polar_DS[intensity] = xr.concat(polar_DA_polarization_rows, dim='polarization')

    polar_DS = polar_DS.expand_dims({'sample_name':[sample_name]})
    polar_DS_sample_rows.append(polar_DS)
    polar_DS.attrs['name'] = DS.name
    
    # polar_DS.to_netcdf(zarrsPath.joinpath('polar_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')
    
polar_DS = xr.concat(polar_DS_sample_rows, dim='sample_name')

In [None]:
def make_para_perp_DAs(DS, sample_name, intensity_type, pol, qlims, chi_width):
    # select dataarray to plot
    DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity']
    sliced_DA = DA.sel(polarization=pol, q=slice(qlims[0],qlims[1]))

    # calculate ISI dataarrays
    if pol==0:
        para_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        perp_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))
    elif pol==90:
        perp_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        para_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))   
        
    return para_DA, perp_DA


# # make selection
# sample_name = 'BareSiN'
# edge = 'carbon'
# intensity_type = 'corr'
# pol = 0
# qlims = (0.01, 0.08)
# chi_width = 30

# para_DA, perp_DA = make_para_perp_DAs(polar_DS, sample_name, intensity_type, pol, qlims, chi_width)  

# # slice ISI data
# para_ISI = para_DA.interpolate_na(dim='q').mean('chi').sum('q')
# perp_ISI = perp_DA.interpolate_na(dim='q').mean('chi').sum('q')

# # plot
# fig, ax = plt.subplots()
# para_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='para', yscale='log')
# perp_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='perp', yscale='log')
# fig.suptitle('Integrated Scattering Intensity (ISI)', fontsize=14)
# ax.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', xlabel='Photon Energy [eV]', ylabel='Double-Norm-Corrected Intensity [arb. units]')
# ax.legend()
# plt.show()

In [None]:
# # polar_sample_DS = polar_DS_sample_rows[0]
# for polar_sample_DS in tqdm(polar_DS_sample_rows):
#     # display(polar_sample_DS)
#     sample_name = polar_sample_DS.sample_name.values[0]
#     print(sample_name)
#     polar_sample_DS.to_netcdf(zarrsPath.joinpath('polar_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')

In [None]:
polar_DS

In [None]:
polar_DS = polar_DS.chunk({'sample_name':1, 'energy':56, 'polarization':2})

polar_DS['samp_diode'] = DS['samp_diode']
polar_DS['smoothed_diode'] = DS['smoothed_diode']


polar_DS

In [None]:
sample_names = polar_DS.sample_name.values

polar_DS.sel(sample_name=[sample_names[0]]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}_rechunked-v2.zarr'), mode='w')

for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    polar_DS.sel(sample_name=[sample_name]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}_rechunked-v2.zarr'), mode='a', append_dim='sample_name')

In [None]:
netcdf_paths = str(zarrsPath.joinpath('polar_rsoxs_carbon_ncs')) + '/*.nc'
netcdf_paths

In [None]:
polar_DS = xr.open_mfdataset(netcdf_paths)
polar_DS

In [None]:
polar_DS

In [None]:
polar_DS.sample_name.values

In [None]:
# make selection
sample_name = 'PM6_CB_3000'
edge = 'carbon'
intensity_type = 'corr'
pol = 0
qlims = (0.01, 0.08)
chi_width = 90

para_DA, perp_DA = make_para_perp_DAs(polar_DS, sample_name, intensity_type, pol, qlims, chi_width)   

# Select AR data
ar_DA = (para_DA.mean('chi') - perp_DA.mean('chi')) / (para_DA.mean('chi') + perp_DA.mean('chi'))

# Plot
ax = ar_DA.sel(energy=slice(282,292)).plot(figsize=(8,5), norm=plt.Normalize(-0.6,0.6))
ax.figure.suptitle('Anisotropy Ratio (AR) Map', fontsize=14, x=0.43)
ax.axes.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', ylabel='Photon Energy [eV]', xlabel='q [$Å^{-1}$]')
ax.colorbar.set_label('AR [arb. units]', rotation=270, labelpad=12)
plt.show()

In [None]:
# display(polar_DS.sample_name.values)
# sample_name = 'PM6-Y6_3000_dSiN'

In [None]:
energy = 285
# energy = 400
# energy = 530
for sample_name in polar_DS.sample_name.values:
    polar_DS['corr_intensity'].sel(sample_name=sample_name, polarization=90, q=slice(0,0.1)).sel(
        energy=energy, method='nearest').plot.imshow(norm=LogNorm(2e9, 1e11), cmap=cmap)
    plt.show()

In [None]:
# 1. Get energy values
energy_values = polar_DS.energy.values
energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]
energy_slices

In [None]:
polar_DS

In [None]:
stacked_polar_DS = polar_DS.stack(system=('sample_name', 'polarization')).reset_index('system')
stacked_polar_DS

In [None]:
# Save the first part of the dataset to initialize the Zarr store
first_system = stacked_polar_DS.isel(system=slice(0, 1))
first_system.to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='w')

# Iterate over the rest of the systems and append to the Zarr store
for i in tqdm(range(1, len(stacked_polar_DS.system)), desc='Samples...'):
    subset = stacked_polar_DS.isel(system=slice(i, i+1))
    subset.to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='system')

In [None]:
sample_names = polar_DS.sample_name.values

polar_DS.sel(sample_name=[sample_names[0]], polarization=[0]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='w')

for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    # polar_DS.sel(sample_name=[sample_name], polarization=[0]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='sample_name')
    polar_DS.sel(sample_name=[sample_name], polarization=[90]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='polarization')

In [None]:
# # import zarr

# # sample_names = polar_DS.sample_name.values
# # energy_values = polar_DS.energy.values
# # energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]

# # root_store = zarr.open_group(zarrsPath.joinpath(f'polar_{polar_DS.name}_regions.zarr').as_posix(), mode='a')

# # for sample in sample_names:
# #     # Make sure there's a group for this sample in the Zarr store
# #     if sample not in root_store.group_keys():
# #         root_store.create_group(sample)
    
# #     subset_by_sample = polar_DS.sel(sample_name=sample)
    
# #     for idx, energy_slice in enumerate(energy_slices):
# #         final_subset = subset_by_sample.sel(energy=energy_slice)
        
# #         # Save to the specific group and energy slice within the Zarr store
# #         final_subset.to_zarr(root_store[sample], mode='a', append_dim='energy', consolidated=True)

# sample_names = polar_DS.sample_name.values
# energy_values = polar_DS.energy.values
# energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]

# main_zarr_path = zarrsPath.joinpath(f'polar_{polar_DS.name}_regions.zarr')

# for sample in tqdm(sample_names, desc='Samples...'):
#     subset_by_sample = polar_DS.sel(sample_name=sample)
    
#     # Define path for the sample within the main Zarr store
#     sample_path = main_zarr_path.joinpath(sample)
    
#     for idx, energy_slice in enumerate(energy_slices):
#         final_subset = subset_by_sample.sel(energy=energy_slice)
        
#         # Save to the specific sample path and energy slice within the Zarr store
#         if idx==0:
#             final_subset.to_zarr(sample_path, mode='w', consolidated=True)
#         else:
#             final_subset.to_zarr(sample_path, mode='a', append_dim='energy', consolidated=True)



In [None]:
def make_para_perp_DAs(datasets, sample_name, edge, intensity_type, pol, qlims, chi_width):
    # select dataarray to plot
    DS = datasets[f'polar_{edge}']
    DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity']
    sliced_DA = DA.sel(polarization=pol, q=slice(qlims[0],qlims[1]))

    # calculate ISI dataarrays
    if pol==0:
        para_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        perp_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))
    elif pol==90:
        perp_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        para_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))   
        
    return para_DA, perp_DA

# load dictionary of rsoxs datasets
rsoxs_datasets = {}
key = 'polar_regions'
key_start = key.split('_')[0]
key_end = key.split('_')[1]        
zarrPath = list(zarrsPath.glob(f'{key_start}*{key_end}.zarr'))[0]
rsoxs_datasets[key] = xr.open_zarr(zarrPath)

# Compute any dask coordiantes
for coord_name, coord_data in rsoxs_datasets[key].coords.items():
    if isinstance(coord_data.data, da.Array):
        rsoxs_datasets[key].coords[coord_name] = coord_data.compute()
            
rsoxs_datasets[key]

In [None]:
# make selection
edge = 'regions'
intensity_type = 'corr'
qlims = (0.01, 0.08)
chi_width = 30

for sample_name in tqdm(rsoxs_datasets[f'polar_{edge}'].sample_name.data):
    for pol in [0, 90]:
        ### Select para & perp DataArrays
        para_DA, perp_DA = make_para_perp_DAs(rsoxs_datasets, sample_name, edge, intensity_type, pol, qlims, chi_width)  
        
#         ### ISI:
#         # Slice ISI data
#         para_ISI = para_DA.interpolate_na(dim='q').mean('chi').sum('q')
#         perp_ISI = perp_DA.interpolate_na(dim='q').mean('chi').sum('q')

#         # Plot
#         fig, ax = plt.subplots()
#         para_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='para', yscale='log')
#         perp_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='perp', yscale='log')
#         fig.suptitle('Integrated Scattering Intensity (ISI)', fontsize=14)
#         ax.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', xlabel='Photon Energy [eV]', ylabel='Double-Norm-Corrected Intensity [arb. units]')
#         ax.legend()
#         fig.savefig(plotsPath.joinpath('isi', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
#         plt.close('all')
        
#         ### Linecut Maps:
#         fig, axs = plt.subplots(1, 2, figsize=(11,5))

#         para_DA.mean('chi').sel(energy=slice(282,290)).plot(ax=axs[0], cmap=cmap, norm=LogNorm(1e9, 1e11), add_colorbar=False)
#         perp_DA.mean('chi').sel(energy=slice(282,290)).plot(ax=axs[1], cmap=cmap, norm=LogNorm(1e9, 1e11), add_colorbar=False)

#         sm = plt.cm.ScalarMappable(cmap=cmap, norm=LogNorm(2e10, 1e12)) # Create a ScalarMappable object with the colormap and normalization & add the colorbar to the figure
#         cax = axs[1].inset_axes([1.03, 0, 0.05, 1])
#         cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
#         cbar.set_label(label='Intensity [arb. units]', labelpad=12)
#         fig.suptitle(f'Linecut Maps: {sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', fontsize=14)
#         fig.set(tight_layout=True)
#         axs[0].set(title='Parallel to $E_p$', ylabel='Photon energy [eV]', xlabel='q [$Å^{-1}$]')
#         axs[1].set(title='Perpendicular to $E_p$ ', ylabel=None, xlabel='q [$Å^{-1}$]')
#         fig.savefig(plotsPath.joinpath('linecut_maps', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
#         plt.close('all')

        ### AR Maps:
        # Select AR data
        ar_DA = (para_DA.mean('chi') - perp_DA.mean('chi')) / (para_DA.mean('chi') + perp_DA.mean('chi'))

        # Plot
        ax = ar_DA.sel(energy=slice(282,292)).plot(figsize=(8,5), norm=plt.Normalize(-0.6, 0.6))
        ax.figure.suptitle('Anisotropy Ratio (AR) Map', fontsize=14, x=0.43)
        ax.axes.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', ylabel='Photon Energy [eV]', xlabel='q [$Å^{-1}$]')
        ax.colorbar.set_label('AR [arb. units]', rotation=270, labelpad=12)
        # ax.figure.savefig(plotsPath.joinpath('ar_maps', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
        plt.show()
        plt.close('all')

## Quick DataArray contents checking

In [None]:
energy = 2477

sliced_DA = recip_DA.squeeze().sel(energy=energy, method='nearest')

cmin = sliced_DA.compute().quantile(0.15)
cmax = sliced_DA.compute().quantile(0.995)
ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=plt.cm.turbo, x='q_x', y='q_y')
# ax.axes.set(title=f'{sample_name}: Energy = {energy}')
plt.show()
plt.close('all')

In [None]:
swap_caked_DA = caked_DA.swap_dims({'index_x':'q_r', 'index_y':'chi'})

In [None]:
energy = 2477

sliced_DA = swap_caked_DA.squeeze().sel(energy=energy, method='nearest')

cmin = sliced_DA.compute().quantile(0.15)
cmax = sliced_DA.compute().quantile(0.995)
ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=plt.cm.turbo, xscale='log')
# ax.axes.set(title=f'{sample_name}: Energy = {energy}')
plt.show()
plt.close('all')

## Misc cells:

In [None]:
# trmsn_35_tot = sorted(reducedPath.glob('*tot*Trmsn_35*.txt'))
# trmsn_35_ver = sorted(reducedPath.glob('*ver*Trmsn_35*.txt'))
# trmsn_35_hor = sorted(reducedPath.glob('*hor*Trmsn_35*.txt'))

# len([f.name for f in trmsn_35_tot])

In [None]:
# for file in trmsn_35_tot:
#     pr = np.loadtxt(file)
#     plt.plot(pr[:, 1]-0.9*np.mean(pr[1100:1250, 1]))
#     plt.show()
#     plt.close()