# 2024C2: Processing SST1 RSoXS Data, loaded from tiled DB
Use the default NSLSII JupyterHub python environment!

## Imports

In [None]:
# Imports
import PyHyperScattering as phs
import pathlib
import sys
import ast
import json
import datetime
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm.auto import tqdm
import dask.array as da
from tiled.client import from_profile, from_uri

sys.path.append('/nsls2/users/alevin/local_lib')
from andrew_rsoxs_fxns import *

print(f'Using PyHyperScattering Version: {phs.__version__}')

# Set colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad(cmap.get_under())

## Define paths & short functions

In [None]:
# Define directory paths
userPath = pathlib.Path('/nsls2/users/alevin')
propPath = pathlib.Path('/nsls2/data/sst/proposals/2024-2/pass-313412')

outPath = propPath.joinpath('processed_data')
maskPath = outPath.joinpath('masks')
zarrsPath = outPath.joinpath('rsoxs_zarrs')

In [None]:
# # Some user defined functions for loading metadata
# def load_monitors(loader, run, dims=['energy', 'polarization']):
#     md = loader.loadMd(run)
#     monitors = loader.loadMonitors(run)
#     dims_to_join = []
#     dim_names_to_join = []
#     for dim in dims:
#         dims_to_join.append(md[dim].compute())
#         dim_names_to_join.append(dim)  
#     index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
#     monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
#     # monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
#     return monitors_remeshed

# def load_diode(loader, run):
#     monitors = loader.loadMonitors(run)
#     energies = monitors['energy_readback']
    
#     monitors = monitors.swap_dims({'time':'energy_readback'}).rename({'energy_readback':'energy'})  #.drop_vars('time_bins')

#     # monitors = monitors.rename_vars({'time_bins':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
#     # monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
#     # monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    
#     polarization = float(round(run['baseline']['data']['en_polarization'][0].compute()))
#     monitors = monitors.expand_dims({'polarization': [polarization]})
#     # monitors.attrs['diode_scan_id'] = run.start['scan_id']
#     monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
#     return monitors

In [None]:
# Some user defined functions for loading metadata
def load_monitors_dask(loader, run, dims=['energy', 'polarization']):
    md = loader.loadMd(run)
    monitors = loader.loadMonitors(run)
    dims_to_join = []
    dim_names_to_join = []
    for dim in dims:
        dims_to_join.append(md[dim].compute())
        dim_names_to_join.append(dim)  
    index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
    # monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    # monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    return monitors

def load_diode_dask(loader, lab_pol, run):
    monitors = loader.loadMonitors(run)
    energies = monitors['energy_readback']
    monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    # monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
        
    monitors = monitors.expand_dims({'pol': [lab_pol]})
    monitors = monitors.assign_coords({'uid': ('pol', [run.start['uid']])})
    monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
    return monitors 


def load_monitors_np(loader, run, dims=['energy', 'polarization']):
    md = loader.loadMd(run)
    monitors = loader.loadMonitors(run)
    dims_to_join = []
    dim_names_to_join = []
    for dim in dims:
        dims_to_join.append(md[dim])
        dim_names_to_join.append(dim)  
    index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
    # monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    # monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    return monitors

def load_diode_np(loader, lab_pol, run):
    monitors = loader.loadMonitors(run)
    energies = monitors['energy_readback']
    monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    # monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
        
    monitors = monitors.expand_dims({'pol': [lab_pol]})
    monitors = monitors.assign_coords({'uid': ('pol', [run.start['uid']])})
    monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
    return monitors 

## Load from local file

In [None]:
local_loader = phs.load.SST1RSoXSLoader(corr_mode='none')

In [None]:
# scan_id = '65802'
# # scan_id = '34427'
# filepath = samplePath.joinpath(scan_id)
# filepath

In [None]:
[f.name for f in filepath.iterdir()]

In [None]:
# local_loader = SST1RSoXSLoader(corr_mode='None')
da = local_loader.loadFileSeries(filepath, dims=['energy', 'polarization'])
da

In [None]:
da = da.unstack('system')
# da = da.where(da>1e-3)
da

In [None]:
# cmin = float(da.quantile(0.1))
# cmax = float(da.quantile(0.9))

# da.sel(polarization=0, energy=285, method='nearest').plot.imshow(norm=LogNorm(1e1, 1e4), cmap=cmap, interpolation='nearest')

energies = [270, 280, 282, 283, 284, 285, 286, 290]

fg = da.sel(polarization=90, method='nearest').sel(energy=energies, method='nearest').plot.imshow(figsize=(18, 6),
                col='energy', col_wrap=4, norm=LogNorm(1, 1e4), cmap=cmap, interpolation='nearest')
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

## Load raw data from databroker & save zarrs

In [None]:
# # Define catalog(s):
# c = from_profile("rsoxs", structure_clients='dask')
# # c = from_uri('https://tiled.nsls2.bnl.gov/', structure_clients='numpy')['rsoxs']['raw']
# # print(c)

In [None]:
# # Define loader(s):
# db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask
# # db_loader = phs.load.SST1RSoXSDB(corr_mode='none', use_chunked_loading=True, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask

In [None]:
## Search for and summarize runs:
# Define catalog(s):
c = from_profile("rsoxs", structure_clients='dask')
db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=0)  # initialize rsoxs databroker loader w/ Dask
# db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=80)  # used for 1180 eV images


runs_sum_df = db_loader.summarize_run(institution='CUBLDER', cycle='2024-2', sample_id='', project='TRMSN', plan='rsoxs')

# # runs_sum_df = runs_sum_df.set_index('scan_id')  # optional, set index to scan id
# print(runs_sum_df['plan'].unique())
# display(runs_sum_df)

In [None]:
with pd.option_context('display.max_rows', None):
    display(runs_sum_df)

In [None]:
db_loader.searchCatalog??

In [None]:
## Slice output dataframe for samples of interest
# plan_of_interest = 'rsoxs_carbon'
plan_of_interest = 'rsoxs_1180'
# plan_of_interest = 'rsoxs_[350, 305, 292, 287, 282, 270, 250]'

df = runs_sum_df
runs_of_interest = df[(df['plan']==plan_of_interest)]
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80046))]  # 80033 - 80046 are bad scans for rsoxs_1180 (beam not transmitting through)
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80092)) & (df['num_Images']==112)]  # normal incidence rsoxs_carbon
# runs_of_interest = df[((df['scan_id']==80332) | (df['scan_id']==80333))]  # repeat normal rsoxs_carbn
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80092)) & (df['num_Images']==168)]  # tilted incidence rsoxs_carbon

# runs_of_interest = df[(df['plan']==plan_of_interest) & (df['num_Images']==228)]
# runs_of_interest = runs_of_interest.drop(index=31)
# runs_of_interest = df[(df['plan']==plan_of_interest) & (df['num_Images']==80)]

with pd.option_context('display.max_rows', None):
    display(runs_of_interest)

In [None]:
raw_int_DA_rows[0].squeeze()

In [None]:
[DA.squeeze().swap_dims({'polarization':'datetime'}) for DA in raw_int_DA_rows][0]

In [None]:
full_DA = xr.concat([DA.squeeze().swap_dims({'polarization':'datetime'}) for DA in raw_int_DA_rows], dim='datetime')
full_DA

In [None]:
full_DA

In [None]:
def sel_coords(DA, dim, coord, val):
    return DA.swap_dims({dim:coord}).sel(polarization=val).swap_dims({coord:dim})
    
sel_DA = sel_coords(full_DA, 'datetime', 'polarization', 0)

In [None]:
sel_DAs = sel_coords(full_DA, 'datetime', 'polarization', 0)
for datetime in sel_DAs.datetime.values:
    sel_DA = sel_DAs.sel(datetime=datetime)
    sliced_DA = sel_DA.squeeze().compute() #.sel(polarization=0).compute()
    cmin, cmax = sliced_DA.quantile([0.1,1-1e-5])
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=LogNorm(30,cmax), interpolation='antialiased')
    ax.axes.set(aspect='equal', title=f'{sliced_DA.datetime.values} | {sliced_DA.polarization.values.round()}° lab pol')
    ax.figure.set(dpi=120)
    plt.show()
    plt.close('all')

In [None]:
db_loader.loadRun??

In [None]:
data = run["primary"]["data"].read()[md["detector"] + "_image"]
timestamps = data.time.values.round()
datetimes = np.array(list(map(lambda x: str(datetime.datetime.fromtimestamp(x)), timestamps)))
datetimes

In [None]:
print(timestamps)
print(datetimes)

In [None]:
dt_value = datetime.datetime.fromtimestamp(1.72120908e+09)
dt_value

In [None]:
dt_value.timestamp()

In [None]:
str(dt_value)

In [186]:
db_loader.loadRun??

[0;31mSignature:[0m
[0mdb_loader[0m[0;34m.[0m[0mloadRun[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrun[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdims[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcoords[0m[0;34m=[0m[0;34m{[0m[0;34m}[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_dataset[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0museMonitorShutterThinning[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mloadRun[0m[0;34m([0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mrun[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mdims[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mcoords[0m[0;34m=[0m[0;34m{[0m[0;34m}[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mreturn_dataset[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0

In [None]:
raw_int_DA_rows = []
# samp_au_DA_rows = []
# monitors_rows = []

for scan_id in tqdm(runs_of_interest['scan_id'][:]):
    run = c[scan_id]
    
    # Get the timestamps & times to add as coordinates
    data = run["primary"]["data"].read()[md["detector"] + "_image"]
    timestamps = data.time.values.round()
    datetimes = np.array(list(map(lambda x: str(datetime.datetime.fromtimestamp(x)), timestamps)))
    
    raw_int_DA = db_loader.loadRun(run, dims=['energy', 'polarization'])

#     # New addition needed for 2023C3 unstacking system into energy & polarization
#     # Convert 'system' MultiIndex to DataFrame
#     index = pd.DataFrame(raw_int_DA['system'].values.tolist(), columns=['energy', 'polarization'])

#     # Add the energy and polarization as new coordinates
#     raw_int_DA = raw_int_DA.assign_coords(energy=('system', index['energy']))
#     raw_int_DA = raw_int_DA.assign_coords(polarization=('system', index['polarization']))

    # Unstack data
    raw_int_DA = raw_int_DA.unstack('system')    

    # Back to 2022C2 code  
    sample_id = raw_int_DA.start['sample_id']
    sample_name = raw_int_DA.sample_name

    raw_int_DA = raw_int_DA.expand_dims({'scan_id': [raw_int_DA.sampleid]})
    raw_int_DA = raw_int_DA.assign_coords(sample_id=('scan_id', [sample_id]),
                                          sample_name=('scan_id', [sample_name]),
                                          timestamp=('polarization', timestamps),
                                          datetime=('polarization', datetimes))
    raw_int_DA_rows.append(raw_int_DA)

#    # return to proper normalizations later on...
#     monitors = load_monitors_dask(db_loader, run, dims=['energy', 'polarization'])
    
#     monitors = monitors.expand_dims({'scan_id': [raw_int_DA.sampleid]})
#     monitors = monitors.assign_coords(sample_id=('scan_id', [sample_id]),
#                                 sample_name=('scan_id', [sample_name]))
    
#     monitors_rows.append(monitors)
    
#     samp_au_DA = monitors['RSoXS Au Mesh Current']
#     samp_au_DA = samp_au_DA.compute().interpolate_na(dim='energy')
#     samp_au_DA_rows.append(samp_au_DA)
    
#     # samp_au_DA = monitors['RSoXS Au Mesh Current']
#     # samp_au_DA = samp_au_DA.expand_dims({'scan_id': [raw_int_DA.sampleid]})
#     # samp_au_DA = samp_au_DA.assign_coords(sample_id=('scan_id', [sample_id]),
#     #                             sample_name=('scan_id', [sample_name]))
#     # samp_au_DA = samp_au_DA.compute().interpolate_na(dim='energy')
#     # samp_au_DA_rows.append(samp_au_DA)

# DS = xr.concat(raw_int_DA_rows, 'scan_id').to_dataset(name='raw_intensity')
# # DS['sample_au_mesh'] = xr.concat(samp_au_DA_rows, 'scan_id')

# DS.attrs['name'] = plan_of_interest
# DS = DS.swap_dims({'scan_id':'sample_name'})

In [None]:
# sliced_DA = raw_int_DA.squeeze().sel(polarization=0).compute()
# cmin, cmax = sliced_DA.quantile([0.1,1-1e-5])
# ax = sliced_DA.plot.imshow(cmap=cmap, norm=LogNorm(50,cmax), interpolation='antialiased')
# ax.axes.set(aspect='equal')
# ax.figure.set(dpi=120)
# plt.show()

In [None]:
full_DA = xr.concat([DA.squeeze().swap_dims({'polarization':'datetime'}) for DA in raw_int_DA_rows], dim='datetime')
full_DA

In [None]:
import datetime

In [None]:
full_DA.attrs

In [None]:
# checks for non-serializable data types in the attributes of the raw_intensity and makes serializable
for k, v in full_DA.attrs.items():
    if isinstance(v, da.core.Array):
        full_DA.attrs[k] = v.compute()
        print(f'{k:<20}  |  {type(v)}')
    elif isinstance(v, dict) or isinstance(v, datetime.datetime):
        full_DA.attrs[k] = str(v) 
        print(f'{k:<20}  |  {type(v)}')

In [None]:
def sel_coords(DA, dim, coord, val):
    return DA.swap_dims({dim:coord}).sel(polarization=val).swap_dims({coord:dim})
    
sel_DA = sel_coords(full_DA, 'datetime', 'polarization', 0)

In [None]:
sel_DAs = sel_coords(full_DA, 'datetime', 'polarization', 0)
for dt_val in sel_DAs.datetime.values:
    sel_DA = sel_DAs.sel(datetime=dt_val)
    sliced_DA = sel_DA.squeeze().compute() #.sel(polarization=0).compute()
    cmin, cmax = sliced_DA.quantile([0.1,1-1e-5])
    ax = sliced_DA.plot.imshow(cmap=cmap, norm=LogNorm(30,cmax), interpolation='antialiased')
    ax.axes.set(aspect='equal', title=f'{sliced_DA.datetime.values} | {sliced_DA.polarization.values.round()}° lab pol')
    ax.figure.set(dpi=120)
    plt.show()
    plt.close('all')

In [None]:
zarrsPath

In [None]:
plan_of_interest

In [None]:
full_DA.to_dataset(name='raw_intensity').to_zarr(zarrsPath.joinpath(f'cartesian_raw_timestamped_{plan_of_interest}_v1.zarr'), mode='w')

In [None]:
DS = DS.sortby('sample_name')
DS

In [None]:
bcxy_2024C1 = {'waxs_bcx': 456.25, 'waxs_bcy': 506.19}  # confident for 2024C1, by refining around Y6BO p5CN-CF diffraction peaks

DS['raw_intensity'].attrs['beamcenter_x'] = bcxy_2024C1['waxs_bcx']
DS['raw_intensity'].attrs['beamcenter_y'] = bcxy_2024C1['waxs_bcy']

DS['raw_intensity'] = apply_q_labels(DS['raw_intensity'])

# DS = DS.chunk({'sample_name':1, 'energy':56, 'polarization':2, 'pix_x':1026, 'pix_y':1024})
DS

In [None]:
# # Load carbon diode dataset via tiled databroker:
# carbon_diode_scan_pols = [      20.0,       55.0,       90.0,      52.38,        0.0,      45.56]
# carbon_diode_uids =      ['cbb1dae5', '00accfe3', 'af3255b3', '25042aca', '7e026642', '153238eb'] 

# diode_monitors_list = []
# for lab_pol, scan_uid in zip(carbon_diode_scan_pols, carbon_diode_uids):
#     run = c[scan_uid]
#     diode_monitors = load_diode_np(db_loader, lab_pol, run)
#     diode_monitors_list.append(diode_monitors)

# energies = DS.energy.values  # carbon

# # interp_diode_monitors_list = [diode_DS.interp({'energy':energies}) for diode_DS in diode_monitors_list] 
# interp_diode_monitors_list = []
# for diode_DS in tqdm(diode_monitors_list):
#     diode_DS = diode_DS.interp({'energy':energies})
#     interp_diode_monitors_list.append(diode_DS)
    
# carbon_diode_DS = xr.concat(interp_diode_monitors_list, dim='pol')

In [None]:
# carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})

In [None]:
# For tilted incidence / 3 polarizations
DS['calib_au_mesh'] = carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})['RSoXS Au Mesh Current']
DS['calib_diode'] = carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})['WAXS Beamstop']
DS

In [None]:
# For 2 polarizations
DS['calib_au_mesh'] = carbon_diode_DS.sel(pol=[0,90]).rename({'pol':'polarization'})['RSoXS Au Mesh Current']
DS['calib_diode'] = carbon_diode_DS.sel(pol=[0,90]).rename({'pol':'polarization'})['WAXS Beamstop']
DS

In [None]:
# # for rsoxs_1180

# for sample_name in DS.sample_name.values:
#     DA = DS['raw_intensity'].sel(sample_name=sample_name, polarization=90).squeeze()
#     DA = DA.where(DA>0)
#     # cmin = DA.compute().min()
#     cmin = DA.compute().quantile(0.0001)
#     cmax = DA.compute().quantile(0.995)
#     ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap, x='qx', y='qy')
#     # ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
#     ax.axes.set(title=f'{str(DA.sample_name.values)}, Energy = {int(DA.energy.values)} eV')
#     plt.show()
#     plt.close('all')

In [None]:
# Select Dataset
edge = 'carbon'
# bcx = DS['raw_intensity'].beamcenter_x
# bcy = DS['raw_intensity'].beamcenter_y


# Select Plotting Parameters
pol = 90
energy = 285
# energy=400
# pix_size = 500
# pix_x_slice = slice(bcx-(pix_size/2), bcx+(pix_size/2))
# pix_y_slice = slice(bcy-(pix_size/2), bcy+(pix_size/2))

# Select DataArray
# sample_name = 'PM6-Y6_3000_dSiN'
for pol in [0, 90]:
    # for DS in tqdm(DS_sample_rows, desc=f'Pol = {pol}°'):
    for sample_name in tqdm(DS.sample_name.values, desc=f'Pol = {pol}°'):
        intensity_type = 'raw'
        DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity'].squeeze()

        # Plot
        sliced_DA = DA.sel(polarization=pol).sel(energy=energy,method='nearest').swap_dims({'pix_x':'qx', 'pix_y':'qy'})
        cmin = float(sliced_DA.sel(qx=slice(0.009,0.08), qy=slice(0.009,0.08)).where(sliced_DA>0).compute().quantile(0.01))
        cmax = float(sliced_DA.sel(qx=slice(0.009,0.08), qy=slice(0.009,0.08)).where(sliced_DA>0).compute().quantile(0.995))
        
        ax = sliced_DA.plot.imshow(figsize=(5.5,4.5), cmap=cmap, norm=LogNorm(cmin,cmax))
        ax.figure.suptitle(f'Photon Energy = {np.round(energy, 1)} eV', fontsize=14, y=0.96)
        ax.figure.set_tight_layout(True)
        ax.axes.set(aspect='equal', title=f'{sample_name}, Polarization = {pol}°', xlabel='q$_x$ [$Å^{-1}$]', ylabel='q$_y$ [$Å^{-1}$]')
        ax.colorbar.set_label('Raw Intensity [arb. units]', rotation=270, labelpad=12)
        # ax.figure.savefig(plotsPath.joinpath('detector_movies_carbon_v2', f'{sample_name}_{edge}_{intensity_type}_pol{pol}deg.jpeg'), dpi=120)
        plt.show()
        plt.close('all')

In [None]:
for sample_name in DS.sample_name.values:
    DS['sample_au_mesh'].sel(sample_name=sample_name).plot(hue='polarization')
    # (DS['calib_au_mesh']/(DS['sample_au_mesh'].sel(sample_name=sample_name))).plot(hue='polarization')
    plt.show()
    
# # DS['calib_au_mesh'].plot(hue='polarization')
# # plt.show

In [None]:
averaged_beamstop_rows = []

for ds in tqdm(monitors_rows):
    # Compute the derivative of each polarization curve with respect to energy
    deriv_p1 = ds['WAXS Beamstop'].sel(polarization=0).compute().differentiate('energy')
    deriv_p2 = ds['WAXS Beamstop'].sel(polarization=90).compute().differentiate('energy')

    # Compute the absolute value of the derivatives
    abs_deriv_p1 = abs(deriv_p1)
    abs_deriv_p2 = abs(deriv_p2)

    # Create a condition array where True indicates that p1 has a smaller absolute derivative than p2
    condition = (abs_deriv_p1 < abs_deriv_p2) & ~np.isnan(abs_deriv_p1)

    # Use where to create the new DataArray, selecting values from p1 or p2 based on the condition
    averaged_beamstop = xr.where(condition, ds['WAXS Beamstop'].sel(polarization=0), ds['WAXS Beamstop'].sel(polarization=90))
    averaged_beamstop = averaged_beamstop.rename('averaged_beamstop')
    
    averaged_beamstop_rows.append(averaged_beamstop)

    # Now 'averaged_beamstop' is the new data variable with values from the curve that has the least instantaneous change at each energy point


In [None]:
for averaged_beamstop in averaged_beamstop_rows:
    averaged_beamstop.plot()
    plt.show()
    plt.close('all')

In [None]:
window_size = 3  # This is the window size for the smoothing - you'll need to adjust it for your data

smoothed_beamstop_rows = []

for averaged_beamstop in tqdm(averaged_beamstop_rows):
    # Apply a rolling mean on the energy dimension
    smoothed_beamstop = averaged_beamstop.rolling(energy=window_size, center=True).mean()

    # Note that 'mean()' will introduce NaNs at the start and the end of the DataArray 
    # where the window does not have enough data points.
    # To deal with NaNs, you might want to use 'min_periods=1' which will calculate the mean
    # even with a single value, but this could affect the smoothing at the edges of your data.
    smoothed_beamstop = averaged_beamstop.rolling(energy=window_size, center=True, min_periods=1).mean()
    smoothed_beamstop_rows.append(smoothed_beamstop)


In [None]:
for smoothed_beamstop in DS['smoothed_beamstop']:
    smoothed_beamstop.plot()
    plt.show()
    plt.close('all')

In [None]:
DS['corr_intensity'] = ((DS['raw_intensity'] / DS['sample_au_mesh'])
                        * (DS['calib_au_mesh'] / DS['calib_diode']))

DS

In [None]:
# checks for non-serializable data types in the attributes of the raw_intensity and makes serializable
for k, v in DS['raw_intensity'].attrs.items():
    if isinstance(v, da.core.Array):
        DS['raw_intensity'].attrs[k] = v.compute()
        print(f'{k:<20}  |  {type(v)}')
    elif isinstance(v, dict) or isinstance(v, datetime.datetime):
        DS['raw_intensity'].attrs[k] = str(v) 
        print(f'{k:<20}  |  {type(v)}')

In [None]:
# # NetCDFs

# # cartesian_sample_DS = cartesian_DS_sample_rows[0]
# # for cartesian_sample_DS in tqdm(cartesian_DS_sample_rows):
# sample_names = DS.sample_name.values

# for sample_name in tqdm(sample_names):
#     cartesian_sample_DS = DS.sel(sample_name=[sample_name])
#     cartesian_sample_DS.to_netcdf(zarrsPath.joinpath('cartesian_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')

In [None]:
# netcdf_paths = str(zarrsPath.joinpath('cartesian_rsoxs_carbon_ncs')) + '/*.nc'
# netcdf_paths

In [None]:
# DS = xr.open_mfdataset(netcdf_paths)
# DS

In [None]:
DS

In [None]:
zarrsPath.exists()

In [None]:
zarrsPath

In [None]:
# encoding = {var: {'chunks': DS[var].shape} for var in DS.variables}

In [None]:
encoding

In [None]:
# DS.to_zarr(zarrsPath.joinpath('cartesian_rsoxs_carbon_v1.zarr'), mode='w')  # too big for carbon?

In [None]:
zarrsPath

In [None]:
plan_of_interest = f'rsoxs_carbon'
sample_names = DS.sample_name.values

DS.sel(sample_name=[sample_names[0]]).to_zarr(zarrsPath.joinpath(f'cartesian_raw_{plan_of_interest}_v2.zarr'), mode='w')
for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    DS.sel(sample_name=[sample_name]).to_zarr(zarrsPath.joinpath(f'cartesian_raw_{plan_of_interest}_v2.zarr'), mode='a', append_dim='sample_name')

# DS.to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}.zarr'), mode='w')

In [None]:
# with ProgressBar():
#     DS.to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}.zarr'))