# Processing notebook 

## Imports

In [None]:
import pathlib
import os
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from PIL import Image
from tqdm.auto import tqdm 

import fabio
from smi_analysis import SMI_beamline

## Define paths & functions

In [None]:
sample_name_dict = {
    14: 'PM6_CB',
    17: 'PM6_1CN-CB',
    18: 'PM6_5CN-CB',
    21: 'PM6_p5CN-CB',
    22: 'PM6-Y6_CB',
    23: 'PM6-Y6BO_CB',
    26: 'PM6_CF',
    29: 'PM6_1CN-CF',
    30: 'PM6_5CN-CF',
    33: 'PM6_p5CN-CF',
    34: 'PM6-Y6_CF',
    35: 'PM6-Y6BO_CF',
    1: 'BareSiN_01',
    3: 'BareSiN_03'
}

In [None]:
propPath = pathlib.Path('/nsls2/data/smi/proposals/2024-1/pass-314903')

rawPaths1 = propPath.joinpath('raw_01')
saxsPath1 = rawPaths1.joinpath('1M')
waxsPath1 = rawPaths1.joinpath('900KW')

rawPaths2 = propPath.joinpath('raw_02')
saxsPath2 = rawPaths2.joinpath('1M')
waxsPath2 = rawPaths2.joinpath('900KW')

rawPaths3 = propPath.joinpath('raw_03')
saxsPath3 = rawPaths3.joinpath('1M')
waxsPath3 = rawPaths3.joinpath('900KW')

rawPaths4 = propPath.joinpath('raw_04')
saxsPath4 = rawPaths4.joinpath('1M')
waxsPath4 = rawPaths4.joinpath('900KW')

# analysisPath = pathlib.Path('/nsls2/users/alevin/rsoxs_suite/sst1_notebooks/SMI_tender_scattering/analysis_02')
# reducedPath = analysisPath.joinpath('reduced_waxs')

In [None]:
# solnPaths = [saxsPath1, saxsPath3]
# filmPaths = [saxsPath2, saxsPath4]
# # for saxsPath in solnPaths:
# #     all_saxs = set(saxsPath.glob('*.tif'))
# #     test_saxs = set(saxsPath.glob('test*'))
# #     display(sorted(set([f.name[:f.name.find('_sdd1.8')] for f in all_saxs.difference(test_saxs)])))

### SMI loading function

In [None]:
def saxs_SMI_numpy_loading(path, filename_list):
    """
    Function adapted from Guillames SMI notebooks to process raw waxs tiffs:
    
    Returns filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents
    """
    
    #Saxs
    geometry = 'Transmission'
    energy = 2.450
    wav = 1E-10 * (12.398/energy)
    bs_kind = 'pindiode'
    alphai = np.deg2rad(0)

    #SAXS
    detector_waxs = 'Pilatus1m'
    sdd_waxs = 1800
    center_waxs = [354, 560]
    bs_pos_waxs = [[354, 535]]
    
    # flatPath = pathlib.Path('/nsls2/data/smi/legacy/results/analysis/2024_1/314483_Freychet_04')
    # flatfield = np.rot90(fabio.open(flatPath.joinpath('GF_Flatfield_Sedge_uhighg1600eV_10s_wa20_2477eV_pffBT4T_id481136_000000_WAXS.tif')).data, 1)

    
    filename_wa0_list = []
    recip_list = []
    recip_extents = []
    caked_list = []
    caked_extents = []
    for dat in tqdm(filename_list, desc='Processing tiffs'):
        
#             waxs_angle = [np.deg2rad(-0.06), np.deg2rad(19.7-0.06)]
#             # print(dat)

#             idx = dat[0].find('eV')
#             energy = 0.001*float(dat[0][idx-7:idx])
#             # print(energy)
#             wav = 1E-10 * (12.398/energy)

            #This part is to stitch the data
            SMI_waxs = SMI_beamline.SMI_geometry(geometry = geometry,
                                                 detector = detector_waxs,
                                                 sdd = sdd_waxs,
                                                 wav = wav,
                                                 alphai = 0,
                                                 center = center_waxs,
                                                 bs_pos = bs_pos_waxs,
                                                 det_angles = [0],
                                                 bs_kind = bs_kind)


            # print(dat)
            SMI_waxs.open_data(path, [dat], optional_mask='tender')
            SMI_waxs.masks[0][560:, 337:350]=True


            SMI_waxs.stitching_data(interp_factor=3, flag_scale=False)
            SMI_waxs.caking()
            
            filename_wa0_list.append(dat)
            recip_list.append(SMI_waxs.img_st)
            recip_extents.append([SMI_waxs.qp[0], SMI_waxs.qp[-1], SMI_waxs.qz[0], SMI_waxs.qz[-1]])
            
            caked_list.append(SMI_waxs.cake)
            caked_extents.append([SMI_waxs.q_cake[0], SMI_waxs.q_cake[-1], SMI_waxs.chi_cake[0], SMI_waxs.chi_cake[-1]])
            
    return filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents

## Load data

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

### Films

In [None]:
unique_sample_numbers = sorted(set([f.name.split('_')[2] for f in saxsPath4.glob('*')]))
unique_sample_rotations = sorted(set([f.name.split('_')[3] for f in saxsPath4.glob('*')]))

In [None]:
all_saxs = set(saxsPath4.glob('*.tif'))
# test_saxs = set(saxsPath4.glob('test*'))
sample_names = sorted(set([f.name[3:f.name.find('_sdd1.8')] for f in all_saxs]))
sample_names

In [None]:
# filename_sublists

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

In [None]:
unique_sample_numbers[1:]

In [None]:
dmg_paths

In [None]:
for sample_number in tqdm(unique_sample_numbers[1:], desc='Samples'):
    for sample_rotation in tqdm(unique_sample_rotations, desc='Angles'):

        # Select files for a given sample and rotation
        all_paths = set(saxsPath4.glob(f'*Trmsn_{sample_number}*{sample_rotation}*'))
        dmg_paths = set(saxsPath4.glob(f'*Trmsn_{sample_number}*{sample_rotation}*damage*'))
        
        # For now only select first scans, not the damage test repeats
        filename_list = [f.name for f in sorted(all_paths.difference(dmg_paths))]

        # Run SMI loading code (this produces some fabio and divide by zero errors)
        names_list, recip_list, recip_extents, caked_list, caked_extents = saxs_SMI_numpy_loading(saxsPath4, filename_list)

        # Define naming scheme:
        waxs_naming_scheme = ['project', 'sample_type', 'sample_number', 'rotation_from_normal', 'set_sdd', 'energy', 'waxs_det_position',
                              'bpm', 'id', 'misc', 'detector']
        md_naming_scheme = waxs_naming_scheme.copy()


        # Construct xarrays with full values along detector dimensions and the energy dimension
        # They contain sample name and theta value as well, as single values to be concatenated in later steps
        recip_DA_rows = []
        caked_DA_rows = []
        zipped_lists = zip(names_list, recip_list, recip_extents, caked_list, caked_extents)
        for filename, recip_arr, recip_extent, caked_arr, caked_extent in zipped_lists:
            # print(filename)
            # print(recip_arr.shape)
            # print(recip_extent)
            # print(caked_arr.shape)
            # print(caked_extent)

            attr_dict = {}
            md_list = filename.split('_')
            for i, md_item in enumerate(md_naming_scheme):
                attr_dict[md_item] = md_list[i]

            recip_DA = xr.DataArray(data = recip_arr, 
                                    dims = ['pix_y', 'pix_x'],
                                    attrs = attr_dict)
            recip_DA = recip_DA.assign_coords({
                'pix_x': recip_DA.pix_x.data,
                'pix_y': recip_DA.pix_y.data,
                'q_x': ('pix_x', np.linspace(recip_extent[0], recip_extent[1], len(recip_DA.pix_x.data))),
                'q_y': ('pix_y', np.linspace(recip_extent[3], recip_extent[2], len(recip_DA.pix_y.data)))
            })
            recip_DA = recip_DA.expand_dims({
                'energy': [float(recip_DA.energy[:-2])],
                'sample_name': [sample_name_dict[float(recip_DA.sample_number)]],
                'theta': [90 - float(recip_DA.rotation_from_normal[3:-3])]
            })
            recip_DA_rows.append(recip_DA)

            caked_DA = xr.DataArray(data = caked_arr, 
                                    dims = ['index_y', 'index_x'],
                                    attrs = attr_dict)
            caked_DA = caked_DA.assign_coords({
                'index_x': caked_DA.index_x.data,
                'index_y': caked_DA.index_y.data,
                'q_r': ('index_x', np.linspace(caked_extent[0], caked_extent[1], len(caked_DA.index_x.data))),
                'chi': ('index_y', np.linspace(caked_extent[3], caked_extent[2], len(caked_DA.index_y.data)))
            }) 
            caked_DA = caked_DA.expand_dims({
                'energy': [float(caked_DA.energy[:-2])],
                'sample_name': [sample_name_dict[float(caked_DA.sample_number)]],
                'theta': [90 - float(caked_DA.rotation_from_normal[3:-3])]
            })
            caked_DA_rows.append(caked_DA)

        recip_DA = xr.concat(recip_DA_rows, 'energy')
        caked_DA = xr.concat(caked_DA_rows, 'energy')

        # Save sample zarr, load later to concatenate full zarr
        sampleZarrsPath = propPath.joinpath('processed_data/zarrs/saxs_core_films_trexs_sample_zarrs')

        recip_samp_zarr_name = 'recip_'+recip_DA.sample_name.values[0]+'_'+str(int(recip_DA.theta.values[0]))+'deg.zarr'
        recip_DS = recip_DA.to_dataset(name='flatfield_corr')
        recip_DS.to_zarr(sampleZarrsPath.joinpath(recip_samp_zarr_name), mode='w')

        caked_samp_zarr_name = 'caked_'+caked_DA.sample_name.values[0]+'_'+str(int(caked_DA.theta.values[0]))+'deg.zarr'
        caked_DS = caked_DA.to_dataset(name='flatfield_corr')
        caked_DS.to_zarr(sampleZarrsPath.joinpath(caked_samp_zarr_name), mode='w')

In [None]:
recip_DA

### Solutions

In [None]:
all_saxs = set(saxsPath1.glob('*.tif'))
test_saxs = set(saxsPath1.glob('test*'))
sample_names = sorted(set([f.name[3:f.name.find('_sdd1.8')] for f in all_saxs.difference(test_saxs)]))
sample_names

In [None]:
saxs_63 = []

print('Before:')
for sample_name in sample_names[:-1]:
    files = sorted(saxsPath1.glob(f'CM_{sample_name}_sdd*'))
    files_number = len(files)
    print(files_number, '\n')
    if files_number == 63:
        saxs_63.append(files)
    else:
        file_energies = []
        kept_files = []

        for file in files:
            file_energy = file.name[:file.name.find('eV')].split('_')[-1]
            if file_energy in file_energies:
                pass
            else:
                file_energies.append(file_energy)
                kept_files.append(file)

        saxs_63.append(kept_files)  

print('After (in saxs_63):')
for i in range(len(saxs_63)):
    print(len(saxs_63[i]))

In [None]:
# Update sample names

sample_names = []
for folder in saxs_63:
    sample_name = folder[0].name[3:folder[0].name.find('_sdd1.8')]
    sample_names.append(sample_name)

display(sample_names)

In [None]:
# Make file sets & define unique sample names
for i, sample_name in enumerate(tqdm(sample_names, desc='Samples')):
    # Select files for a given sample and rotation
    # all_paths = set(saxsPath.glob(f'*_{sample_name}_*'))
    all_paths = saxs_63[i]
    
    # Select first scans
    filename_list = [f.name for f in sorted(all_paths)]

    # # Make sublists to stitch two waxs positions together
    # group_size = 2
    # filename_sublists = [filename_list[i:i + group_size] for i in range(0, len(filename_list), group_size)]

    # Run SMI loading code (this produces some fabio and divide by zero errors)
    names_list, recip_list, recip_extents, caked_list, caked_extents = SMI_numpy_loading(
        saxsPath1, filename_list)

    # # Define naming scheme:
    # saxs_naming_scheme = ['project', 'sample_type', 'sample_number', 'rotation_from_normal', 'set_sdd', 'energy', 'waxs_det_position',
    #                       'bpm', 'id', 'misc', 'detector']
    # md_naming_scheme = saxs_naming_scheme.copy()


    # Construct xarrays with full values along detector dimensions and the energy dimension
    # They contain sample name and theta value as well, as single values to be concatenated in later steps
    recip_DA_rows = []
    caked_DA_rows = []
    zipped_lists = zip(names_list, recip_list, recip_extents, caked_list, caked_extents)
    for filename, recip_arr, recip_extent, caked_arr, caked_extent in zipped_lists:
        # print(filename)
        # print(recip_arr.shape)
        # print(recip_extent)
        # print(caked_arr.shape)
        # print(caked_extent)

        attr_dict = {}
        attr_dict['filename'] = filename
        sample_name = filename[3:filename.find('_sdd')]
        energy = float(filename[:filename.find('eV')].split('_')[-1])
        # md_list = filename.split('_')
        # for i, md_item in enumerate(md_naming_scheme):
        #     attr_dict[md_item] = md_list[i]

        recip_DA = xr.DataArray(data = recip_arr, 
                                dims = ['pix_y', 'pix_x'],
                                attrs = attr_dict)
        recip_DA = recip_DA.assign_coords({
            'pix_x': recip_DA.pix_x.data,
            'pix_y': recip_DA.pix_y.data,
            'q_x': ('pix_x', np.linspace(recip_extent[0], recip_extent[1], len(recip_DA.pix_x.data))),
            'q_y': ('pix_y', np.linspace(recip_extent[3], recip_extent[2], len(recip_DA.pix_y.data)))
        })
        recip_DA = recip_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        recip_DA_rows.append(recip_DA)

        caked_DA = xr.DataArray(data = caked_arr, 
                                dims = ['index_y', 'index_x'],
                                attrs = attr_dict)
        caked_DA = caked_DA.assign_coords({
            'index_x': caked_DA.index_x.data,
            'index_y': caked_DA.index_y.data,
            'q_r': ('index_x', np.linspace(caked_extent[0], caked_extent[1], len(caked_DA.index_x.data))),
            'chi': ('index_y', np.linspace(caked_extent[3], caked_extent[2], len(caked_DA.index_y.data)))
        }) 
        caked_DA = caked_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        caked_DA_rows.append(caked_DA)

    recip_DA = xr.concat(recip_DA_rows, 'energy')
    caked_DA = xr.concat(caked_DA_rows, 'energy')

    # Save sample zarr, load later to concatenate full zarr
    sampleZarrsPath = propPath.joinpath('processed_data/saxs_solution_trexs_sample_zarrs')

    recip_samp_zarr_name = 'recip_'+sample_name+'.zarr'
    recip_DS = recip_DA.to_dataset(name='flatfield_corr')
    recip_DS.to_zarr(sampleZarrsPath.joinpath(recip_samp_zarr_name), mode='w')

    caked_samp_zarr_name = 'caked_'+sample_name+'.zarr'
    caked_DS = caked_DA.to_dataset(name='flatfield_corr')
    caked_DS.to_zarr(sampleZarrsPath.joinpath(caked_samp_zarr_name), mode='w')

## Quick DataArray contents checking

In [None]:
energy = 2477

sliced_DA = recip_DA.squeeze().sel(energy=energy, method='nearest')

cmin = sliced_DA.compute().quantile(0.15)
cmax = sliced_DA.compute().quantile(0.995)
ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=plt.cm.turbo, x='q_x', y='q_y')
# ax.axes.set(title=f'{sample_name}: Energy = {energy}')
plt.show()
plt.close('all')

In [None]:
swap_caked_DA = caked_DA.swap_dims({'index_x':'q_r', 'index_y':'chi'})

In [None]:
energy = 2477

sliced_DA = swap_caked_DA.squeeze().sel(energy=energy, method='nearest')

cmin = sliced_DA.compute().quantile(0.15)
cmax = sliced_DA.compute().quantile(0.995)
ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=plt.cm.turbo, xscale='log')
# ax.axes.set(title=f'{sample_name}: Energy = {energy}')
plt.show()
plt.close('all')

## Misc cells:

In [None]:
# trmsn_35_tot = sorted(reducedPath.glob('*tot*Trmsn_35*.txt'))
# trmsn_35_ver = sorted(reducedPath.glob('*ver*Trmsn_35*.txt'))
# trmsn_35_hor = sorted(reducedPath.glob('*hor*Trmsn_35*.txt'))

# len([f.name for f in trmsn_35_tot])

In [None]:
# for file in trmsn_35_tot:
#     pr = np.loadtxt(file)
#     plt.plot(pr[:, 1]-0.9*np.mean(pr[1100:1250, 1]))
#     plt.show()
#     plt.close()