# 2024C3 SMI SAXS TReXS processing notebook 
This notebook is for processing tender resonant X-ray scattering data collected on the WAXS detector from SMI. 
Potentiall still need better flatfield/mask for bad SAXS detector pixels

**Copy this notebook (along with all other notebooks in this folder) to your own user folder, don't change this one.**

## Imports

In [None]:
import pathlib
import os
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from PIL import Image
from tqdm.auto import tqdm 

import fabio
from smi_analysis import SMI_beamline

## Define paths & functions

In [None]:
propPath = pathlib.Path('/nsls2/data/smi/proposals/2024-3/pass-316856')

rawPaths = propPath.joinpath('raw_05')
saxsPath = rawPaths.joinpath('1M')
waxsPath = rawPaths.joinpath('900KW')

### SMI loading function & setup

In [None]:
#waxs
path = '/nsls2/data/smi/legacy/results/data/%s/%s/1M/'%('2024_2', '314483_Freychet_01')

sam, sam1 = [], []
for file in sorted(os.listdir(path)):
     if 'wa0' in file and 'AHPP25-0p25_' in file:
        idx = file.find('1.8m_')
        if file[:idx] not in sam:
            sam = sam + [file[:idx]]

all_dat = [[]] * len(sam)
all_da = [[]] * len(sam)
    
for j, sa in enumerate(sam): 
    for file in sorted(os.listdir(path)):
        if sa in file and 'tif' in file and 'wa0' in file:
            all_dat[j] = all_dat[j] + [file]
            
for i, all_d in enumerate(all_dat[0]):
    img=fabio.open(os.path.join(path, all_d)).data
    if i==0:
        img_sum=np.zeros(np.shape(img))
    img_sum += img

idx_mask = np.where(img_sum>100)

In [None]:
def saxs_SMI_numpy_loading(path, filename_list):
    """
    Function adapted from Guillames SMI notebooks to process raw waxs tiffs:
    
    Returns filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents
    """
    
    #Saxs
    geometry = 'Transmission'
    energy = 2.450
    wav = 1E-10 * (12.398/energy)
    bs_kind = 'pindiode'
    alphai = np.deg2rad(0)

    #SAXS
    detector_waxs = 'Pilatus1m'
    # sdd_waxs = 1800
    sdd_waxs = 3000

    center_waxs = [477, 564]
    bs_pos_waxs = [[467, 546]]   
    
    filename_wa0_list = []
    recip_list = []
    recip_extents = []
    caked_list = []
    caked_extents = []
    for dat in tqdm(filename_list, desc='Processing tiffs'):

            idx = [dat][0].find('eV')
            energy = 0.001*float([dat][0][idx-7:idx])
            wav = 1E-10 * (12.398/energy)

            #This part is to stitch the data
            SMI_waxs = SMI_beamline.SMI_geometry(geometry = geometry,
                                                 detector = detector_waxs,
                                                 sdd = sdd_waxs,
                                                 wav = wav,
                                                 alphai = 0,
                                                 center = center_waxs,
                                                 bs_pos = bs_pos_waxs,
                                                 det_angles = [0],
                                                 bs_kind = bs_kind)


            SMI_waxs.open_data(path, [dat], optional_mask='tender')
               
            for da in [dat][1:]:
                img=fabio.open(os.path.join(path, da)).data
                SMI_waxs.imgs[0]+=img
            SMI_waxs.masks[0][idx_mask]=True

            SMI_waxs.masks[0][835:, 488:616]=True
            SMI_waxs.masks[0][350:414, 857:920]=True
            SMI_waxs.masks[0][370:410, 550:620]=True
            SMI_waxs.masks[0][570:, 452:460]=True

            SMI_waxs.masks[0][600:630, 800:860]=True
            SMI_waxs.masks[0][700:, 463:470]=True
            SMI_waxs.masks[0][600:630, 600:630]=True

            for i, (img, mask) in enumerate(zip(SMI_waxs.imgs, SMI_waxs.masks)):
                SMI_waxs.imgs[i] *= 10
                
            SMI_waxs.stitching_data(interp_factor=3)
            SMI_waxs.caking()
            
            filename_wa0_list.append(dat)
            recip_list.append(SMI_waxs.img_st)
            recip_extents.append([SMI_waxs.qp[0], SMI_waxs.qp[-1], SMI_waxs.qz[0], SMI_waxs.qz[-1]])
            
            caked_list.append(SMI_waxs.cake)
            caked_extents.append([SMI_waxs.q_cake[0], SMI_waxs.q_cake[-1], SMI_waxs.chi_cake[0], SMI_waxs.chi_cake[-1]])
            
    return filename_wa0_list, recip_list, recip_extents, caked_list, caked_extents

## Load data & save zarrs

In [None]:
all_saxs

In [None]:
all_saxs = set(saxsPath.glob('*.tif'))
test_saxs = set(saxsPath.glob('test*'))
# sample_names = sorted(set([f.name[3:f.name.find('_sdd1.8')] for f in all_saxs.difference(test_saxs)]))
sample_names = sorted(set([f.name[3:f.name.find('_sdd3.0')] for f in all_saxs.difference(test_saxs)]))
sample_names

In [None]:
# sample_names = ['TEGDME_neat', 'Li2S_TEGDME_reredo']
sample_names = ['Li2S8_static']

In [None]:
filename_list

In [None]:
# Make file sets & define unique sample names
for i, sample_name in enumerate(tqdm(sample_names, desc='Samples')):
    # Select files for a given sample and rotation
    all_paths = set(saxsPath.glob(f'*_{sample_name}_*wa20*'))
    
    # Select first scans
    filename_list = [f.name for f in sorted(all_paths)]

    # Run SMI loading code (this produces some fabio and divide by zero errors)
    names_list, recip_list, recip_extents, caked_list, caked_extents = saxs_SMI_numpy_loading(saxsPath, filename_list)



    # Construct xarrays with full values along detector dimensions and the energy dimension
    # They contain sample name and theta value as well, as single values to be concatenated in later steps
    recip_DA_rows = []
    caked_DA_rows = []
    zipped_lists = zip(names_list, recip_list, recip_extents, caked_list, caked_extents)
    for filename, recip_arr, recip_extent, caked_arr, caked_extent in zipped_lists:
        attr_dict = {}
        attr_dict['filename'] = filename
        
        sample_name = filename[3:filename.find('_sdd')]
        energy = float(filename[:filename.find('eV')].split('_')[-1])

        recip_DA = xr.DataArray(data = recip_arr, 
                                dims = ['pix_y', 'pix_x'],
                                attrs = attr_dict)
        recip_DA = recip_DA.assign_coords({
            'pix_x': recip_DA.pix_x.data,
            'pix_y': recip_DA.pix_y.data,
            'q_x': ('pix_x', np.linspace(recip_extent[0], recip_extent[1], len(recip_DA.pix_x.data))),
            'q_y': ('pix_y', np.linspace(recip_extent[3], recip_extent[2], len(recip_DA.pix_y.data)))
        })
        recip_DA = recip_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        recip_DA_rows.append(recip_DA)

        caked_DA = xr.DataArray(data = caked_arr, 
                                dims = ['index_y', 'index_x'],
                                attrs = attr_dict)
        caked_DA = caked_DA.assign_coords({
            'index_x': caked_DA.index_x.data,
            'index_y': caked_DA.index_y.data,
            'q_r': ('index_x', np.linspace(caked_extent[0], caked_extent[1], len(caked_DA.index_x.data))),
            'chi': ('index_y', np.linspace(caked_extent[3], caked_extent[2], len(caked_DA.index_y.data)))
        }) 
        caked_DA = caked_DA.expand_dims({
            'energy': [energy],
            'sample_name': [sample_name]
        })
        caked_DA_rows.append(caked_DA)

    recip_DA = xr.concat(recip_DA_rows, 'energy')
    caked_DA = xr.concat(caked_DA_rows, 'energy')

    # Save sample zarr, load later to concatenate full zarr
    # sampleZarrsPath = propPath.joinpath('processed_data/zarrs/saxs_polysulfide_solutions_zarrs_v1')
    sampleZarrsPath = propPath.joinpath('processed_data/zarrs/saxs_Li2S8_static_solution_zarrs_v1')

    recip_samp_zarr_name = 'recip_'+sample_name+'.zarr'
    recip_DS = recip_DA.to_dataset(name='flatfield_corr')
    recip_DS.to_zarr(sampleZarrsPath.joinpath(recip_samp_zarr_name), mode='w')

    caked_samp_zarr_name = 'caked_'+sample_name+'.zarr'
    caked_DS = caked_DA.to_dataset(name='flatfield_corr')
    caked_DS.to_zarr(sampleZarrsPath.joinpath(caked_samp_zarr_name), mode='w')

#### Now switch to to plotting notebook!