# 2024C1: Processing SST1 RSoXS Data

In [None]:
# # # Only needs to be run once per session, restart kernel after running

# # %pip install pyhyperscattering==0.2.1  # to use pip published package
# !pip install -e /nsls2/users/alevin/repos/PyHyperScattering  # to use pip to install via directory
# # !pip install --pre --upgrade tiled[all] databroker  # bottleneck # needed to fix tiled/databroker error in SST1RSoXSDB

## Imports

In [None]:
# Imports
import PyHyperScattering as phs
import pathlib
import sys
import ast
import json
import datetime
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm.auto import tqdm
import dask.array as da
from tiled.client import from_profile, from_uri

sys.path.append('/nsls2/users/alevin/local_lib')
from andrew_rsoxs_fxns import *

print(f'Using PyHyperScattering Version: {phs.__version__}')

# Set colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad(cmap.get_under())

## Define paths & short functions

In [None]:
# Define directory paths
userPath = pathlib.Path('/nsls2/users/alevin')
# propPath = pathlib.Path('/nsls2/data/sst/proposals/2022-2/pass-309180')
# propPath = pathlib.Path('/nsls2/data/sst/proposals/2023-2/pass-311130')
# propPath = pathlib.Path('/nsls2/data/sst/proposals/2023-3/pass-313412')
propPath = pathlib.Path('/nsls2/data/sst/proposals/2024-1/pass-313412')

outPath = propPath.joinpath('processed_data')
# jsonPath = outPath.joinpath('local_config')
maskPath = outPath.joinpath('masks')
zarrsPath = outPath.joinpath('rsoxs_zarrs')

In [None]:
# # Some user defined functions for loading metadata
# def load_monitors(loader, run, dims=['energy', 'polarization']):
#     md = loader.loadMd(run)
#     monitors = loader.loadMonitors(run)
#     dims_to_join = []
#     dim_names_to_join = []
#     for dim in dims:
#         dims_to_join.append(md[dim].compute())
#         dim_names_to_join.append(dim)  
#     index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
#     monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
#     # monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
#     return monitors_remeshed

# def load_diode(loader, run):
#     monitors = loader.loadMonitors(run)
#     energies = monitors['energy_readback']
    
#     monitors = monitors.swap_dims({'time':'energy_readback'}).rename({'energy_readback':'energy'})  #.drop_vars('time_bins')

#     # monitors = monitors.rename_vars({'time_bins':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
#     # monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
#     # monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    
#     polarization = float(round(run['baseline']['data']['en_polarization'][0].compute()))
#     monitors = monitors.expand_dims({'polarization': [polarization]})
#     # monitors.attrs['diode_scan_id'] = run.start['scan_id']
#     monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
#     return monitors

In [None]:
# Some user defined functions for loading metadata
def load_monitors_dask(loader, run, dims=['energy', 'polarization']):
    md = loader.loadMd(run)
    monitors = loader.loadMonitors(run)
    dims_to_join = []
    dim_names_to_join = []
    for dim in dims:
        dims_to_join.append(md[dim].compute())
        dim_names_to_join.append(dim)  
    index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
    # monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    return monitors_remeshed

def load_diode_dask(loader, lab_pol, run):
    monitors = loader.loadMonitors(run)
    energies = monitors['energy_readback']
    # monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
        
    monitors = monitors.expand_dims({'pol': [lab_pol]})
    monitors = monitors.assign_coords({'uid': ('pol', [run.start['uid']])})
    monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
    return monitors 


def load_monitors_np(loader, run, dims=['energy', 'polarization']):
    md = loader.loadMd(run)
    monitors = loader.loadMonitors(run)
    dims_to_join = []
    dim_names_to_join = []
    for dim in dims:
        dims_to_join.append(md[dim])
        dim_names_to_join.append(dim)  
    index = pd.MultiIndex.from_arrays(dims_to_join, names=dim_names_to_join)
    # monitors_remeshed = monitors.rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    monitors_remeshed = monitors.rename_vars({'time_bins':'time'}).rename({'time':'system'}).reset_index('system').assign_coords(system=index).unstack('system')
    return monitors_remeshed

def load_diode_np(loader, lab_pol, run):
    monitors = loader.loadMonitors(run)
    energies = monitors['energy_readback']
    # monitors = monitors.rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
    monitors = monitors.rename_vars({'time_bins':'time'}).rename({'time':'energy'}).reset_index('energy').assign_coords(energy=energies.data)
        
    monitors = monitors.expand_dims({'pol': [lab_pol]})
    monitors = monitors.assign_coords({'uid': ('pol', [run.start['uid']])})
    monitors = monitors.dropna(dim='energy').groupby('energy').mean()
    
    return monitors 

## Load from local file

In [None]:
local_loader = phs.load.SST1RSoXSLoader(corr_mode='none')

In [None]:
# scan_id = '65802'
# # scan_id = '34427'
# filepath = samplePath.joinpath(scan_id)
# filepath

In [None]:
[f.name for f in filepath.iterdir()]

In [None]:
# local_loader = SST1RSoXSLoader(corr_mode='None')
da = local_loader.loadFileSeries(filepath, dims=['energy', 'polarization'])
da

In [None]:
da = da.unstack('system')
# da = da.where(da>1e-3)
da

In [None]:
# cmin = float(da.quantile(0.1))
# cmax = float(da.quantile(0.9))

# da.sel(polarization=0, energy=285, method='nearest').plot.imshow(norm=LogNorm(1e1, 1e4), cmap=cmap, interpolation='nearest')

energies = [270, 280, 282, 283, 284, 285, 286, 290]

fg = da.sel(polarization=90, method='nearest').sel(energy=energies, method='nearest').plot.imshow(figsize=(18, 6),
                col='energy', col_wrap=4, norm=LogNorm(1, 1e4), cmap=cmap, interpolation='nearest')
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

## Load raw data from databroker & save zarrs

In [None]:
# # Define catalog(s):
# c = from_profile("rsoxs", structure_clients='dask')
# # c = from_uri('https://tiled.nsls2.bnl.gov/', structure_clients='numpy')['rsoxs']['raw']
# # print(c)

In [None]:
# # Define loader(s):
# db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask
# # db_loader = phs.load.SST1RSoXSDB(corr_mode='none', use_chunked_loading=True, dark_pedestal=40)  # initialize rsoxs databroker loader w/ Dask

In [None]:
## Search for and summarize runs:
# Define catalog(s):
c = from_profile("rsoxs", structure_clients='dask')
db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=0)  # initialize rsoxs databroker loader w/ Dask
# db_loader = phs.load.SST1RSoXSDB(corr_mode='None', catalog=c, catalog_kwargs={}, dark_pedestal=80)  # used for 1180 eV images


runs_sum_df = db_loader.summarize_run(institution='CUBLDER', cycle='2024-1', sample_id='', project='TRMSN', plan='rsoxs')

# # runs_sum_df = runs_sum_df.set_index('scan_id')  # optional, set index to scan id
# print(runs_sum_df['plan'].unique())
# display(runs_sum_df)

In [None]:
with pd.option_context('display.max_rows', None):
    display(runs_sum_df)

In [None]:
## Slice output dataframe for samples of interest
plan_of_interest = 'rsoxs_carbon'

df = runs_sum_df
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80046))]  # 80033 - 80046 are bad scans for rsoxs_1180 (beam not transmitting through)
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80092)) & (df['num_Images']==112)]  # normal incidence rsoxs_carbon
runs_of_interest = df[((df['scan_id']==80332) | (df['scan_id']==80333))]  # repeat normal rsoxs_carbn
# runs_of_interest = df[(df['plan']==plan_of_interest) & ((df['scan_id']<80033) | (df['scan_id']>80092)) & (df['num_Images']==168)]  # tilted incidence rsoxs_carbon

# runs_of_interest = df[(df['plan']==plan_of_interest) & (df['num_Images']==114)]
# runs_of_interest = df[(df['plan']==plan_of_interest) & (df['num_Images']==80)]

with pd.option_context('display.max_rows', None):
    display(runs_of_interest)

In [None]:
raw_int_DA_rows = []
samp_au_DA_rows = []
monitors_rows = []

for scan_id in tqdm(runs_of_interest['scan_id'][:]):
    run = c[scan_id]
    raw_int_DA = db_loader.loadRun(run, dims=['energy', 'polarization'])

#     # New addition needed for 2023C3 unstacking system into energy & polarization
#     # Convert 'system' MultiIndex to DataFrame
#     index = pd.DataFrame(raw_int_DA['system'].values.tolist(), columns=['energy', 'polarization'])

#     # Add the energy and polarization as new coordinates
#     raw_int_DA = raw_int_DA.assign_coords(energy=('system', index['energy']))
#     raw_int_DA = raw_int_DA.assign_coords(polarization=('system', index['polarization']))

    # Unstack data
    raw_int_DA = raw_int_DA.unstack('system')    

    # Back to 2022C2 code
    sample_id = raw_int_DA.start['sample_id']
    sample_name = raw_int_DA.sample_name

    raw_int_DA = raw_int_DA.expand_dims({'scan_id': [raw_int_DA.sampleid]})
    raw_int_DA = raw_int_DA.assign_coords(sample_id=('scan_id', [sample_id]),
                          sample_name=('scan_id', [sample_name]))
    raw_int_DA_rows.append(raw_int_DA)

    monitors = load_monitors_dask(db_loader, run, dims=['energy', 'polarization'])
    
    monitors = monitors.expand_dims({'scan_id': [raw_int_DA.sampleid]})
    monitors = monitors.assign_coords(sample_id=('scan_id', [sample_id]),
                                sample_name=('scan_id', [sample_name]))
    
    monitors_rows.append(monitors)
    
    samp_au_DA = monitors['RSoXS Au Mesh Current']
    samp_au_DA = samp_au_DA.compute().interpolate_na(dim='energy')
    samp_au_DA_rows.append(samp_au_DA)
    
    # samp_au_DA = monitors['RSoXS Au Mesh Current']
    # samp_au_DA = samp_au_DA.expand_dims({'scan_id': [raw_int_DA.sampleid]})
    # samp_au_DA = samp_au_DA.assign_coords(sample_id=('scan_id', [sample_id]),
    #                             sample_name=('scan_id', [sample_name]))
    # samp_au_DA = samp_au_DA.compute().interpolate_na(dim='energy')
    # samp_au_DA_rows.append(samp_au_DA)

DS = xr.concat(raw_int_DA_rows, 'scan_id').to_dataset(name='raw_intensity')
DS['sample_au_mesh'] = xr.concat(samp_au_DA_rows, 'scan_id')

DS.attrs['name'] = plan_of_interest
DS = DS.swap_dims({'scan_id':'sample_name'})

In [None]:
DS = DS.sortby('sample_name')
DS

In [None]:
bcxy_2024C1 = {'waxs_bcx': 456.25, 'waxs_bcy': 506.19}  # confident for 2024C1, by refining around Y6BO p5CN-CF diffraction peaks

DS['raw_intensity'].attrs['beamcenter_x'] = bcxy_2024C1['waxs_bcx']
DS['raw_intensity'].attrs['beamcenter_y'] = bcxy_2024C1['waxs_bcy']

DS['raw_intensity'] = apply_q_labels(DS['raw_intensity'])

# DS = DS.chunk({'sample_name':1, 'energy':56, 'polarization':2, 'pix_x':1026, 'pix_y':1024})
DS

In [None]:
# Load carbon diode dataset via tiled databroker:
carbon_diode_scan_pols = [      20.0,       55.0,       90.0,      52.38,        0.0,      45.56]
carbon_diode_uids =      ['cbb1dae5', '00accfe3', 'af3255b3', '25042aca', '7e026642', '153238eb'] 

diode_monitors_list = []
for lab_pol, scan_uid in zip(carbon_diode_scan_pols, carbon_diode_uids):
    run = c[scan_uid]
    diode_monitors = load_diode_np(db_loader, lab_pol, run)
    diode_monitors_list.append(diode_monitors)

energies = DS.energy.values  # carbon

# interp_diode_monitors_list = [diode_DS.interp({'energy':energies}) for diode_DS in diode_monitors_list] 
interp_diode_monitors_list = []
for diode_DS in tqdm(diode_monitors_list):
    diode_DS = diode_DS.interp({'energy':energies})
    interp_diode_monitors_list.append(diode_DS)
    
carbon_diode_DS = xr.concat(interp_diode_monitors_list, dim='pol')

In [None]:
carbon_diode_DS

In [None]:
carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})

In [None]:
# For tilted incidence / 3 polarizations
DS['calib_au_mesh'] = carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})['RSoXS Au Mesh Current']
DS['calib_diode'] = carbon_diode_DS.sel(pol=[0, 45.56,90]).rename({'pol':'polarization'}).assign_coords({'polarization': ('polarization', [0, 45, 90])})['WAXS Beamstop']
DS

In [None]:
# # For 2 polarizations
# DS['calib_au_mesh'] = carbon_diode_DS.sel(pol=[0,90]).rename({'pol':'polarization'})['RSoXS Au Mesh Current']
# DS['calib_diode'] = carbon_diode_DS.sel(pol=[0,90]).rename({'pol':'polarization'})['WAXS Beamstop']
# DS

In [None]:
# # for rsoxs_1180

# for sample_name in DS.sample_name.values:
#     DA = DS['raw_intensity'].sel(sample_name=sample_name, polarization=90).squeeze()
#     DA = DA.where(DA>0)
#     # cmin = DA.compute().min()
#     cmin = DA.compute().quantile(0.0001)
#     cmax = DA.compute().quantile(0.995)
#     ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap, x='qx', y='qy')
#     # ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
#     ax.axes.set(title=f'{str(DA.sample_name.values)}, Energy = {int(DA.energy.values)} eV')
#     plt.show()
#     plt.close('all')

In [None]:
# Select Dataset
edge = 'carbon'
# bcx = DS['raw_intensity'].beamcenter_x
# bcy = DS['raw_intensity'].beamcenter_y


# Select Plotting Parameters
pol = 90
energy = 285
# energy=400
# pix_size = 500
# pix_x_slice = slice(bcx-(pix_size/2), bcx+(pix_size/2))
# pix_y_slice = slice(bcy-(pix_size/2), bcy+(pix_size/2))

# Select DataArray
# sample_name = 'PM6-Y6_3000_dSiN'
for pol in [0, 45, 90]:
    # for DS in tqdm(DS_sample_rows, desc=f'Pol = {pol}°'):
    for sample_name in tqdm(DS.sample_name.values, desc=f'Pol = {pol}°'):
        intensity_type = 'raw'
        DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity'].squeeze()

        # Plot
        sliced_DA = DA.sel(polarization=pol).sel(energy=energy,method='nearest').swap_dims({'pix_x':'qx', 'pix_y':'qy'})
        cmin = float(sliced_DA.sel(qx=slice(0.009,0.08), qy=slice(0.009,0.08)).where(sliced_DA>0).compute().quantile(0.05))
        cmax = float(sliced_DA.sel(qx=slice(0.009,0.08), qy=slice(0.009,0.08)).where(sliced_DA>0).compute().quantile(0.995))
        
        ax = sliced_DA.plot.imshow(figsize=(5.5,4.5), cmap=cmap, norm=LogNorm(cmin,cmax))
        ax.figure.suptitle(f'Photon Energy = {np.round(energy, 1)} eV', fontsize=14, y=0.96)
        ax.figure.set_tight_layout(True)
        ax.axes.set(aspect='equal', title=f'{sample_name}, Polarization = {pol}°', xlabel='q$_x$ [$Å^{-1}$]', ylabel='q$_y$ [$Å^{-1}$]')
        ax.colorbar.set_label('Double-Norm-Corrected Intensity [arb. units]', rotation=270, labelpad=12)
        # ax.figure.savefig(plotsPath.joinpath('detector_movies_carbon_v2', f'{sample_name}_{edge}_{intensity_type}_pol{pol}deg.jpeg'), dpi=120)
        plt.show()
        plt.close('all')

In [None]:
for sample_name in DS.sample_name.values:
    DS['sample_au_mesh'].sel(sample_name=sample_name).plot(hue='polarization')
    # (DS['calib_au_mesh']/(DS['sample_au_mesh'].sel(sample_name=sample_name))).plot(hue='polarization')
    plt.show()
    
# # DS['calib_au_mesh'].plot(hue='polarization')
# # plt.show

In [None]:
averaged_beamstop_rows = []

for ds in tqdm(monitors_rows):
    # Compute the derivative of each polarization curve with respect to energy
    deriv_p1 = ds['WAXS Beamstop'].sel(polarization=0).compute().differentiate('energy')
    deriv_p2 = ds['WAXS Beamstop'].sel(polarization=90).compute().differentiate('energy')

    # Compute the absolute value of the derivatives
    abs_deriv_p1 = abs(deriv_p1)
    abs_deriv_p2 = abs(deriv_p2)

    # Create a condition array where True indicates that p1 has a smaller absolute derivative than p2
    condition = (abs_deriv_p1 < abs_deriv_p2) & ~np.isnan(abs_deriv_p1)

    # Use where to create the new DataArray, selecting values from p1 or p2 based on the condition
    averaged_beamstop = xr.where(condition, ds['WAXS Beamstop'].sel(polarization=0), ds['WAXS Beamstop'].sel(polarization=90))
    averaged_beamstop = averaged_beamstop.rename('averaged_beamstop')
    
    averaged_beamstop_rows.append(averaged_beamstop)

    # Now 'averaged_beamstop' is the new data variable with values from the curve that has the least instantaneous change at each energy point


In [None]:
for averaged_beamstop in averaged_beamstop_rows:
    averaged_beamstop.plot()
    plt.show()
    plt.close('all')

In [None]:
window_size = 3  # This is the window size for the smoothing - you'll need to adjust it for your data

smoothed_beamstop_rows = []

for averaged_beamstop in tqdm(averaged_beamstop_rows):
    # Apply a rolling mean on the energy dimension
    smoothed_beamstop = averaged_beamstop.rolling(energy=window_size, center=True).mean()

    # Note that 'mean()' will introduce NaNs at the start and the end of the DataArray 
    # where the window does not have enough data points.
    # To deal with NaNs, you might want to use 'min_periods=1' which will calculate the mean
    # even with a single value, but this could affect the smoothing at the edges of your data.
    smoothed_beamstop = averaged_beamstop.rolling(energy=window_size, center=True, min_periods=1).mean()
    smoothed_beamstop_rows.append(smoothed_beamstop)


In [None]:
for smoothed_beamstop in DS['smoothed_beamstop']:
    smoothed_beamstop.plot()
    plt.show()
    plt.close('all')

In [None]:
DS['corr_intensity'] = ((DS['raw_intensity'] / DS['sample_au_mesh'])
                        * (DS['calib_au_mesh'] / DS['calib_diode']))

DS

In [None]:
# checks for non-serializable data types in the attributes of the raw_intensity and makes serializable
for k, v in DS['raw_intensity'].attrs.items():
    if isinstance(v, da.core.Array):
        DS['raw_intensity'].attrs[k] = v.compute()
        print(f'{k:<20}  |  {type(v)}')
    elif isinstance(v, dict) or isinstance(v, datetime.datetime):
        DS['raw_intensity'].attrs[k] = str(v) 
        print(f'{k:<20}  |  {type(v)}')

In [None]:
# # NetCDFs

# # cartesian_sample_DS = cartesian_DS_sample_rows[0]
# # for cartesian_sample_DS in tqdm(cartesian_DS_sample_rows):
# sample_names = DS.sample_name.values

# for sample_name in tqdm(sample_names):
#     cartesian_sample_DS = DS.sel(sample_name=[sample_name])
#     cartesian_sample_DS.to_netcdf(zarrsPath.joinpath('cartesian_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')

In [None]:
# netcdf_paths = str(zarrsPath.joinpath('cartesian_rsoxs_carbon_ncs')) + '/*.nc'
# netcdf_paths

In [None]:
# DS = xr.open_mfdataset(netcdf_paths)
# DS

In [None]:
DS

In [None]:
zarrsPath.exists()

In [None]:
zarrsPath

In [None]:
# encoding = {var: {'chunks': DS[var].shape} for var in DS.variables}

In [None]:
encoding

In [None]:
# DS.to_zarr(zarrsPath.joinpath('cartesian_rsoxs_carbon_v1.zarr'), mode='w')  # too big for carbon?

In [None]:
zarrsPath

In [None]:
plan_of_interest = f'rsoxs_carbon_tilted'
sample_names = DS.sample_name.values

DS.sel(sample_name=[sample_names[0]]).to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}_v1.zarr'), mode='w')
for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    DS.sel(sample_name=[sample_name]).to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}_v1.zarr'), mode='a', append_dim='sample_name')

# DS.to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}.zarr'), mode='w')

In [None]:
# with ProgressBar():
#     DS.to_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}.zarr'))

## Load data from saved zarrs

In [None]:
plan_of_interest = 'rsoxs_carbon_tilted'
loaded_DS = xr.open_zarr(zarrsPath.joinpath(f'cartesian_{plan_of_interest}_v1.zarr'))  #.compute()
loaded_DS = loaded_DS.where(loaded_DS>0)

# Compute any dask coordiantes
for coord_name, coord_data in loaded_DS.coords.items():
    if isinstance(coord_data.data, da.Array):
        loaded_DS.coords[coord_name] = coord_data.compute()
        
# Substract bare SiN raw and double-norm-corrected intensities
# sin_sub_DS = DS.copy()
# sin_sub_DS['raw_intensity'] = DS['raw_intensity'] - DS['raw_intensity'].sel(sample_name='BareSiN')
# sin_sub_DS['raw_intensity'].attrs = DS['raw_intensity'].attrs
# sin_sub_DS['corr_intensity'] = DS['corr_intensity'] - DS['corr_intensity'].sel(sample_name='BareSiN')
# sin_sub_DS.attrs['name'] = 'rsoxs_carbon_SiN_subtracted'

# DS['sin_sub_raw_intensity'] = DS['raw_intensity'] - DS['raw_intensity'].sel(sample_name='BareSiN_1mm')
# DS['sin_sub_corr_intensity'] = DS['corr_intensity'] - DS['corr_intensity'].sel(sample_name='BareSiN_1mm')

display(loaded_DS)

# start_dict = ast.literal_eval(DA.start)

In [None]:
# # Used for rsoxs 1180 eV

# for sample_name in loaded_DS.sample_name.values:
#     DA = loaded_DS['raw_intensity'].sel(sample_name=sample_name, polarization=90).squeeze()
#     cmin = float(DA.compute().quantile(0.0001))
#     cmax = float(DA.compute().quantile(0.995))
#     ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap, x='qx', y='qy')
#     ax.axes.set(title=f'{str(DA.sample_name.values)}, Energy = {int(DA.energy.values)} eV')
#     plt.show()
#     plt.close('all')

## Draw/check data & beamcenters & data

In [None]:
# Set colormap
cmap = plt.cm.turbo.copy()
cmap.set_bad(cmap.get_under())

# # Choose a sample dataarray:
# bare_sin_DA = DS.sel(sample_name='BareSiN_1mm')
# print(DS.sample_name.values)
# sample_name = 'Y6_CB_2500'

### 1. Check raw images at a selected energy for all loaded scan configurations:

In [None]:
# sample_name = 'PM6-Y6_p5CN-2CF-3CB'
sample_name = 'PM6-Y6BO_p5CN-CB_rot'
# sample_DA = DS['raw_intensity'].sel(sample_name=sample_name)
sample_DA = loaded_DS['raw_intensity'].sel(sample_name=sample_name)

# energies = [270, 280, 282, 283, 284, 285, 286, 290]
energies = np.round(np.linspace(280, 290, 8), 1)  # carbon
# energies = np.round(np.linspace(380, 440, 8), 1)  # nitrogen

pol = 0
# pol = 90

fg = sample_DA.sel(polarization=pol, method='nearest').sel(energy=energies, method='nearest').sel(
            pix_x=slice(160, 780), pix_y=slice(240, 800)).plot.imshow(figsize=(18, 6),
                col='energy', col_wrap=4, norm=LogNorm(3e1, 1e4), cmap=cmap, x='qx', y='qy')
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
fg.fig.suptitle(f'{str(sample_DA.sample_name.data)},  Polarization = {pol}°', y=1.02)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

In [None]:
# sample_name = 'PM6-Y6_p5CN-2CF-3CB'
sample_name = 'PM6-Y6BO_p5CN-CB_rot'
corr_sample_DA = loaded_DS['corr_intensity'].sel(sample_name=sample_name)

# energies = [270, 280, 282, 283, 284, 285, 286, 290]
energies = np.round(np.linspace(280, 290, 8), 1)  # carbon
# energies = np.round(np.linspace(380, 440, 8), 1)  # nitrogen

pol = 45

fg = corr_sample_DA.sel(polarization=pol, method='nearest').sel(energy=energies, method='nearest').sel(
            pix_x=slice(160, 780), pix_y=slice(240, 800)).plot.imshow(figsize=(18, 6), x='qx', y='qy',
                col='energy', col_wrap=4, norm=LogNorm(5e8, 1e11), cmap=cmap)
fg.cbar.set_label('Intensity [arb. units]', rotation=270, labelpad=15)
fg.fig.suptitle(f'{str(sample_name)},  Polarization = {pol}°', y=1.02)
for axes in fg.axs.flatten():
    axes.set(aspect='equal')

plt.show()

### 2. Draw masks

In [None]:
# Define example image for mask & initialize phs DrawMask object:
sample_name = 'Y6BO_p5CN-CF'
sample_DA = DS['raw_intensity'].sel(sample_name=sample_name)

In [None]:
## WAXS mask:
waxs_mask_img = sample_DA.sel(polarization=0, energy=270, method='nearest').compute()
# draw = phs.IntegrationUtils.DrawMask(waxs_mask_img, clim=(30, 1e3))
# draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)
# draw.ui()

In [None]:
# ## Save and load saxs drawn mask
# draw.save(maskPath.joinpath('2023C3_full_length_masks.json'))

### 3. Check beamcenters

In [None]:
# Define example image for mask & initialize phs DrawMask object:
sample_name = 'PM6-Y6BO_p5CN-CB_rot'
sample_DA = loaded_DS['raw_intensity'].sel(sample_name=sample_name)
# sample_DA = sample_DA.where(sample_DA>0)

In [None]:
# sample_DA.attrs['beamcenter_x'] = 450
# sample_DA.attrs['beamcenter_y'] = 510

energy = 250
# energy = 400
# energy = 532

waxs_mask_img = sample_DA.sel(polarization=0, energy=energy, method='nearest').compute()
draw = phs.IntegrationUtils.DrawMask(waxs_mask_img)

# Load masks:
draw.load(maskPath.joinpath('WAXS_detector.json'))
waxs_mask = draw.mask

# Check masks:
cmin = float(waxs_mask_img.quantile(0.0001))
cmax = float(waxs_mask_img.quantile(0.995))

ax = waxs_mask_img.plot.imshow(norm=plt.Normalize(cmin,cmax), cmap=cmap)
ax.axes.imshow(waxs_mask, alpha=0.5, origin='lower')
# ax.axes.imshow(WAXSinteg.mask, alpha=0.5, origin='lower')
plt.show()

In [None]:
# Initalize PFEnergySeriesIntegrator object & check beamcenter & masks
# WAXS
# WAXSinteg = phs.integrate.PFGeneralIntegrator(geomethod='template_xr', template_xr=sample_DA.sel(polarization=0))
WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr=sample_DA.sel(polarization=0))
WAXSinteg.mask = waxs_mask
WAXSinteg.ni_beamcenter_x = waxs_mask_img.beamcenter_x
WAXSinteg.ni_beamcenter_y = waxs_mask_img.beamcenter_y
print('WAXS Beamcenter: \n'
      f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
      f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')

phs.IntegrationUtils.Check.checkAll(WAXSinteg, 
                                    waxs_mask_img, 
                                    img_scaling='linear', 
                                    img_min=waxs_mask_img.quantile(0.0001), 
                                    img_max=waxs_mask_img.quantile(0.995), 
                                    d_inner=215,
                                    d_outer=310,
                                    alpha=0.4)

In [None]:
# ## Tweaking if needed:

# # ## WAXS Tweaking & Plot Check
# # waxs_new_bcx = 396.3
# # waxs_new_bcy = 553
# # WAXSinteg.ni_beamcenter_x = waxs_new_bcx
# # WAXSinteg.ni_beamcenter_y = waxs_new_bcy
# # raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# # raw_waxs.attrs['beamcenter_x'] = waxs_new_bcx
# # raw_waxs.attrs['poni1'] = WAXSinteg.poni1
# # raw_waxs.attrs['poni2'] = WAXSinteg.poni2

# # print('WAXS Beamcenter Tweaking: \n'
# #       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
# #       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# # phs.IntegrationUtils.Check.checkAll(WAXSinteg, waxs_mask_img, img_max=5e3, alpha=0.6, guide1=40)
# # plt.xlim(WAXSinteg.ni_beamcenter_x-200, WAXSinteg.ni_beamcenter_x+200)
# # plt.ylim(WAXSinteg.ni_beamcenter_y-200, WAXSinteg.ni_beamcenter_y+200)
# # plt.gcf().set(dpi=120)
# # plt.show()


# # Using Pete D.'s (very slightly modified) beamcentering script:
# phs.BeamCentering.CenteringAccessor.refine_geometry

# ## WAXS
# # res_waxs = raw_waxs.sel(pol=0).util.refine_geometry(energy=275, q_min=0.02, q_max=0.06, chi_min=-10, chi_max=70)
# res_waxs = sample_DA.sel(polarization=0).util.refine_geometry(energy=1180, q_min=0.28, q_max=0.35)
# sample_DA.attrs['poni1'] = res_waxs.x[0]
# sample_DA.attrs['poni2'] = res_waxs.x[1]
# WAXSinteg = phs.integrate.PFEnergySeriesIntegrator(geomethod='template_xr', template_xr = sample_DA.sel(polarization=0))
# WAXSinteg.mask = waxs_mask

# ## WAXS Plot check
# print('WAXS Beamcenter Post-optimization: \n'
#       f'poni1: {WAXSinteg.poni1}, poni2: {WAXSinteg.poni2} \n'
#       f'ni_beamcenter_y: {WAXSinteg.ni_beamcenter_y}, ni_beamcenter_x: {WAXSinteg.ni_beamcenter_x}')
# phs.IntegrationUtils.Check.checkAll(WAXSinteg, 
#                                     waxs_mask_img, 
#                                     img_scaling='linear', 
#                                     img_min=waxs_mask_img.quantile(0.01), 
#                                     img_max=waxs_mask_img.quantile(0.99), 
#                                     d_inner=215,
#                                     d_outer=310,
#                                     alpha=0.4)
# plt.xlim(WAXSinteg.ni_beamcenter_x-250, WAXSinteg.ni_beamcenter_x+250)
# plt.ylim(WAXSinteg.ni_beamcenter_y-250, WAXSinteg.ni_beamcenter_y+250)
# plt.gcf().set(dpi=120)
# plt.show()
# plt.close('all')

In [None]:
# ### Write beamcenters to saved .json file if content with them:

# beamcenters_dict = {
#     f'WAXS_2023C2': {'bcx':sample_DA.beamcenter_x, 'bcy':sample_DA.beamcenter_y}
# }

# # Check if the file exists, if not, create an empty JSON file
# jsonFile = jsonPath.joinpath('beamcenters_dict.json')
# if not jsonFile.exists():
#     with jsonFile.open('w') as f:
#         json.dump({}, f)

# # Now, read the existing or empty JSON file
# with jsonFile.open('r') as f:
#     dic = json.load(f)

# dic.update(beamcenters_dict)

# # Write the updated dictionary back to the JSON file
# with jsonFile.open('w') as f:
#     json.dump(dic, f)

## Convert to chi-q space & save zarrs

In [None]:
loaded_DS

In [None]:
# Integrate whole cartesian dataset!
polar_DS_sample_rows = []
for sample_name in tqdm(loaded_DS.sample_name.data[:]):
    polar_DS = xr.Dataset()
    polar_DA_polarization_rows = []
    for pol in [0, 45, 90]:
        cart_DA = loaded_DS['corr_intensity'].sel(polarization=pol, sample_name=sample_name)  #.compute()
        polar_DA = WAXSinteg.integrateImageStack_dask(cart_DA, chunksize=1)
        # polar_DA = WAXSinteg.integrateImageStack(cart_DA)
        # polar_DA = WAXSinteg.integrateSingleImage(cart_DA)
        # polar_DA = WAXSinteg.integrateImageStack_legacy(cart_DA)

        polar_DA = polar_DA.expand_dims({'polarization': [pol]})
        polar_DA_polarization_rows.append(polar_DA)

    polar_DS['corr_intensity'] = xr.concat(polar_DA_polarization_rows, dim='polarization')

    polar_DS = polar_DS.expand_dims({'sample_name':[sample_name]})
    polar_DS_sample_rows.append(polar_DS)
    polar_DS.attrs['name'] = loaded_DS.name
    
    # polar_DS.to_netcdf(zarrsPath.joinpath('polar_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')
    
polar_DS = xr.concat(polar_DS_sample_rows, dim='sample_name')

In [None]:
polar_DS

In [None]:
def make_para_perp_DAs(DS, sample_name, intensity_type, pol, qlims, chi_width):
    # select dataarray to plot
    DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity']
    sliced_DA = DA.sel(polarization=pol, q=slice(qlims[0],qlims[1]))

    # calculate ISI dataarrays
    if pol==0:
        para_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        perp_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))
    elif pol==90:
        perp_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        para_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))   
        
    return para_DA, perp_DA

In [None]:
# Select polar dataset, to qr linecuts: para, perp, and full
edge = 'carbon'
DS = polar_DS.copy()

chi_width = 90
q_slice = slice(0.009,0.08)
e_slice = slice(282, 292)

# make selection
# sample_name = 'PM6-Y6_3000_dSiN'
intensity_type = 'corr'
pol = 0

# for sample_name in tqdm(filtered_selected_samples):
for sample_name in tqdm(DS.sample_name.values[:]):
    # for pol in [0]:
    #     para_DA, perp_DA = make_para_perp_DAs(rsoxs_datasets, sample_name, edge, intensity_type, pol, qlims, chi_width) 
   
    pol_paras = []
    pol_perps = []
    for pol in [0, 90]:
        para_DA, perp_DA = make_para_perp_DAs(DS, sample_name, intensity_type, pol, (q_slice.start,q_slice.stop), chi_width) 
        pol_paras.append(para_DA)
        pol_perps.append(perp_DA)


    pol_paras[0] = pol_paras[0].assign_coords({'chi': np.linspace(0, chi_width, len(pol_paras[0].chi.values))})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_paras[1].chi.values))})
    pol_paras[1] = pol_paras[1].assign_coords({'chi': np.linspace(0, chi_width, len(pol_paras[1].chi.values))})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_paras[1].chi.values))})
    pol_perps[0] = pol_perps[0].assign_coords({'chi': np.linspace(0, chi_width, len(pol_perps[0].chi.values))})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_perps[1].chi.values))})
    pol_perps[1] = pol_perps[1].assign_coords({'chi': np.linspace(0, chi_width, len(pol_perps[1].chi.values))})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_perps[1].chi.values))})

    pol_paras[0] = pol_paras[0].interp({'chi': pol_paras[1].chi.values})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_paras[1].chi.values))})
    pol_perps[0] = pol_perps[0].interp({'chi': pol_perps[1].chi.values})  # .assign_coords({'chi': np.linspace(0, 90,len(pol_perps[1].chi.values))})

    para_DA = (pol_paras[0] + pol_paras[1])/2
    perp_DA = (pol_perps[0] + pol_perps[1])/2

    para_DA = para_DA.assign_coords({'polarization':'avg'})
    perp_DA = perp_DA.assign_coords({'polarization':'avg'})

    para_DA = para_DA.interpolate_na(dim='q')
    perp_DA = perp_DA.interpolate_na(dim='q')

    pol = str(para_DA.polarization.values) 
    
#     # Fl correction?
#     scale_factor = float(para_DA.sel(energy=slice(285,305), q=slice(2e-2,6e-2)).mean('chi').mean('energy').integrate('q'))
#     manual_scale_factor = manual_scale_factors_v4[sample_name]
#     trmsn90_corr = (manual_scale_factor * scale_factor * trmsn90_bkgs_DA.sel(sample_name=sample_name))
    
#     para_DA = para_DA - trmsn90_corr
#     perp_DA = perp_DA - trmsn90_corr
                
    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(11,5))
    cmin = 3e8
    cmax = 1e11

    para_slice = para_DA.mean('chi').sel(q=q_slice, energy=e_slice)  # .plot(ax=axs[0], cmap=cmap, norm=LogNorm(cmin, cmax), add_colorbar=False)
    perp_slice = perp_DA.mean('chi').sel(q=q_slice, energy=e_slice)  # .plot(ax=axs[1], cmap=cmap, norm=LogNorm(cmin, cmax), add_colorbar=False)

    # cmin = para_slice.compute().quantile(0.01)
    # cmax = para_slice.compute().quantile(0.995)

    para_slice.plot(ax=axs[0], cmap=cmap, norm=LogNorm(cmin, cmax), add_colorbar=False)
    perp_slice.plot(ax=axs[1], cmap=cmap, norm=LogNorm(cmin, cmax), add_colorbar=False)

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=LogNorm(cmin, cmax)) # Create a ScalarMappable object with the colormap and normalization & add the colorbar to the figure
    cax = axs[1].inset_axes([1.03, 0, 0.05, 1])
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(label='Double-Norm-Corrected Intensity [arb. units]', labelpad=12, rotation=270)

    fig.suptitle(f'Linecut Maps: {sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', fontsize=14)
    # fig.suptitle(f'Linecut Maps: {sample_name} Bare SiN Subtracted, Polarization = {pol}°, Chi Width = {chi_width}°', fontsize=14)
    # fig.suptitle(f'Linecut Maps: {sample_name} Bare SiN & Fluorescence Subtracted, Polarization = {pol}°, Chi Width = {chi_width}°', fontsize=14)

    fig.set(tight_layout=True)

    axs[0].set(xscale='log', title='Parallel to $E_p$', ylabel='Photon energy [eV]', xlabel='q [$Å^{-1}$]')
    axs[1].set(xscale='log', title='Perpendicular to $E_p$ ', ylabel=None, xlabel='q [$Å^{-1}$]')

    # fig.savefig(plotsPath.joinpath('linecut_maps_carbon_v2', f'{sample_name}_{edge}_{energy_min}-{energy_max}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)

    plt.show()
    plt.close('all')

In [None]:
for sample_name in loaded_DS.sample_name.values:
# for sample_name in ['PM6_5CN-CB', 'PM6_5CN-CB_rot']:
    DA = loaded_DS['raw_intensity'].sel(sample_name=sample_name, polarization=90).squeeze()
    cmin = float(DA.compute().quantile(0.0001))
    cmax = float(DA.compute().quantile(0.995))
    ax = DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap, x='qx', y='qy')
    ax.axes.set(title=f'{str(DA.sample_name.values)}, Energy = {int(DA.energy.values)} eV')
    plt.show()
    plt.close('all')

In [None]:
energy = 284.4
# energy = 400
# energy = 530
for sample_name in polar_DS.sample_name.values:
# for sample_name in ['PM6_5CN-CB', 'PM6-Y6_p5CN-2CF-3CB']:
    sliced_DA = polar_DS['corr_intensity'].sel(sample_name=sample_name, polarization=0).sel(energy=energy, method='nearest')
    cmin = sliced_DA.compute().quantile(0.0001)
    cmax = sliced_DA.compute().quantile(0.995)
    ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
    ax.axes.set(title=f'{sample_name}, {int(sliced_DA.energy.values)} eV')
    plt.show()
    plt.close('all')

In [None]:
energy = 284.4
# energy = 400
# energy = 530
# for sample_name in polar_DS.sample_name.values:
for sample_name in ['PM6_5CN-CB', 'PM6-Y6_p5CN-2CF-3CB']:
    sliced_DA = polar_DS['raw_intensity'].sel(sample_name=sample_name, polarization=0).sel(energy=energy, method='nearest')
    cmin = sliced_DA.compute().quantile(0.0001)
    cmax = sliced_DA.compute().quantile(0.995)
    ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
    ax.axes.set(title=f'{sample_name}, {int(sliced_DA.energy.values)} eV')
    plt.show()
    plt.close('all')

In [None]:
loaded_polar_DS = xr.open_zarr(zarrsPath.joinpath('polar_rsoxs_carbon_tilted_v1.zarr'))

In [None]:
loaded_polar_DS

In [None]:
energy = 285
# energy = 400
# energy = 530
for sample_name in loaded_polar_DS.sample_name.values:
    for pol in [0,45,90]:
        sliced_DA = loaded_polar_DS['corr_intensity'].sel(sample_name=sample_name, polarization=pol).sel(energy=energy, method='nearest')
        cmin = sliced_DA.compute().quantile(0.0001)
        cmax = sliced_DA.compute().quantile(0.995)
        ax = sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
        ax.axes.set(title=f'{sample_name}, pol={pol}°, {int(sliced_DA.energy.values)} eV')
        plt.show()
        plt.close('all')

In [None]:
def make_para_perp_DAs(DS, sample_name, intensity_type, pol, qlims, chi_width):
    # select dataarray to plot
    DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity']
    sliced_DA = DA.sel(polarization=pol, q=slice(qlims[0],qlims[1]))

    # calculate ISI dataarrays
    if pol==0:
        para_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        perp_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))
    elif pol==90:
        perp_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        para_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))   
        
    return para_DA, perp_DA


# # make selection
# sample_name = 'BareSiN'
# edge = 'carbon'
# intensity_type = 'corr'
# pol = 0
# qlims = (0.01, 0.08)
# chi_width = 30

# para_DA, perp_DA = make_para_perp_DAs(polar_DS, sample_name, intensity_type, pol, qlims, chi_width)  

# # slice ISI data
# para_ISI = para_DA.interpolate_na(dim='q').mean('chi').sum('q')
# perp_ISI = perp_DA.interpolate_na(dim='q').mean('chi').sum('q')

# # plot
# fig, ax = plt.subplots()
# para_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='para', yscale='log')
# perp_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='perp', yscale='log')
# fig.suptitle('Integrated Scattering Intensity (ISI)', fontsize=14)
# ax.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', xlabel='Photon Energy [eV]', ylabel='Double-Norm-Corrected Intensity [arb. units]')
# ax.legend()
# plt.show()

In [None]:
# # polar_sample_DS = polar_DS_sample_rows[0]
# for polar_sample_DS in tqdm(polar_DS_sample_rows):
#     # display(polar_sample_DS)
#     sample_name = polar_sample_DS.sample_name.values[0]
#     print(sample_name)
#     polar_sample_DS.to_netcdf(zarrsPath.joinpath('polar_rsoxs_carbon_ncs', f'{sample_name}.nc'), format='netCDF4', engine='h5netcdf')

In [None]:
polar_DS

In [None]:
polar_DS = polar_DS.chunk({'sample_name':1, 'energy':56, 'polarization':1})

# polar_DS['samp_diode'] = DS['samp_diode']
# polar_DS['smoothed_diode'] = DS['smoothed_diode']


polar_DS

In [None]:
zarrsPath

In [None]:
# # rsoxs_1180
# polar_DS.to_zarr(zarrsPath.joinpath('polar_rsoxs_1180.zarr'), mode='w')

# # rsoxs_carbon
# polar_DS.to_zarr(zarrsPath.joinpath('polar_rsoxs_carbon_v2.zarr'), mode='w')

# rsoxs_carbon_tilted
polar_DS.to_zarr(zarrsPath.joinpath('polar_rsoxs_carbon_tilted_v1.zarr'), mode='w')

In [None]:
sample_names = polar_DS.sample_name.values

polar_DS.sel(sample_name=[sample_names[0]]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}_v2.zarr'), mode='w')

for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    polar_DS.sel(sample_name=[sample_name]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}_v2.zarr'), mode='a', append_dim='sample_name')

In [None]:
netcdf_paths = str(zarrsPath.joinpath('polar_rsoxs_carbon_ncs')) + '/*.nc'
netcdf_paths

In [None]:
polar_DS = xr.open_mfdataset(netcdf_paths)
polar_DS

In [None]:
polar_DS

In [None]:
polar_DS.sample_name.values

In [None]:
# make selection
sample_name = 'Y6BO_p5CN-CF'

para_DA, perp_DA = make_para_perp_DAs(polar_DS, sample_name, intensity_type, pol, qlims, chi_width)   

# Select AR data
ar_DA = (para_DA.mean('chi') - perp_DA.mean('chi')) / (para_DA.mean('chi') + perp_DA.mean('chi'))

# Plot
ax = ar_DA.sel(energy=slice(282,292)).plot(figsize=(8,5), norm=plt.Normalize(-0.6,0.6))
ax.figure.suptitle('Anisotropy Ratio (AR) Map', fontsize=14, x=0.43)
ax.axes.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', ylabel='Photon Energy [eV]', xlabel='q [$Å^{-1}$]')
ax.colorbar.set_label('AR [arb. units]', rotation=270, labelpad=12)
plt.show()

In [None]:
# make selection
sample_name = 'Y6BO_p5CN-CF'
edge = 'carbon'
intensity_type = 'corr'
pol = 0
qlims = (0.01, 0.08)
chi_width = 90

para_DA, perp_DA = make_para_perp_DAs(polar_DS, sample_name, intensity_type, pol, qlims, chi_width)   

# Select AR data
ar_DA = (para_DA.mean('chi') - perp_DA.mean('chi')) / (para_DA.mean('chi') + perp_DA.mean('chi'))

# Plot
ax = ar_DA.sel(energy=slice(282,292)).plot(figsize=(8,5), norm=plt.Normalize(-0.6,0.6))
ax.figure.suptitle('Anisotropy Ratio (AR) Map', fontsize=14, x=0.43)
ax.axes.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', ylabel='Photon Energy [eV]', xlabel='q [$Å^{-1}$]')
ax.colorbar.set_label('AR [arb. units]', rotation=270, labelpad=12)
plt.show()

In [None]:
# display(polar_DS.sample_name.values)
# sample_name = 'PM6-Y6_3000_dSiN'

In [None]:
energy = 285
# energy = 400
# energy = 530
for sample_name in polar_DS.sample_name.values:
    sliced_DA = polar_DS['raw_intensity'].sel(sample_name=sample_name, polarization=90).sel(energy=energy, method='nearest')
    cmin = sliced_DA.compute().quantile(0.0001)
    cmax = sliced_DA.compute().quantile(0.995)
    sliced_DA.plot.imshow(norm=plt.Normalize(cmin, cmax), cmap=cmap)
    plt.show()
    plt.close('all')

In [None]:
# 1. Get energy values
energy_values = polar_DS.energy.values
energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]
energy_slices

In [None]:
polar_DS

In [None]:
stacked_polar_DS = polar_DS.stack(system=('sample_name', 'polarization')).reset_index('system')
stacked_polar_DS

In [None]:
# Save the first part of the dataset to initialize the Zarr store
first_system = stacked_polar_DS.isel(system=slice(0, 1))
first_system.to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='w')

# Iterate over the rest of the systems and append to the Zarr store
for i in tqdm(range(1, len(stacked_polar_DS.system)), desc='Samples...'):
    subset = stacked_polar_DS.isel(system=slice(i, i+1))
    subset.to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='system')

In [None]:
sample_names = polar_DS.sample_name.values

polar_DS.sel(sample_name=[sample_names[0]], polarization=[0]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='w')

for sample_name in tqdm(sample_names[1:], desc='Samples...'):
    # polar_DS.sel(sample_name=[sample_name], polarization=[0]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='sample_name')
    polar_DS.sel(sample_name=[sample_name], polarization=[90]).to_zarr(zarrsPath.joinpath(f'polar_{polar_DS.name}.zarr'), mode='a', append_dim='polarization')

In [None]:
# # import zarr

# # sample_names = polar_DS.sample_name.values
# # energy_values = polar_DS.energy.values
# # energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]

# # root_store = zarr.open_group(zarrsPath.joinpath(f'polar_{polar_DS.name}_regions.zarr').as_posix(), mode='a')

# # for sample in sample_names:
# #     # Make sure there's a group for this sample in the Zarr store
# #     if sample not in root_store.group_keys():
# #         root_store.create_group(sample)
    
# #     subset_by_sample = polar_DS.sel(sample_name=sample)
    
# #     for idx, energy_slice in enumerate(energy_slices):
# #         final_subset = subset_by_sample.sel(energy=energy_slice)
        
# #         # Save to the specific group and energy slice within the Zarr store
# #         final_subset.to_zarr(root_store[sample], mode='a', append_dim='energy', consolidated=True)

# sample_names = polar_DS.sample_name.values
# energy_values = polar_DS.energy.values
# energy_slices = [energy_values[i:i+5] for i in range(0, len(energy_values), 5)]

# main_zarr_path = zarrsPath.joinpath(f'polar_{polar_DS.name}_regions.zarr')

# for sample in tqdm(sample_names, desc='Samples...'):
#     subset_by_sample = polar_DS.sel(sample_name=sample)
    
#     # Define path for the sample within the main Zarr store
#     sample_path = main_zarr_path.joinpath(sample)
    
#     for idx, energy_slice in enumerate(energy_slices):
#         final_subset = subset_by_sample.sel(energy=energy_slice)
        
#         # Save to the specific sample path and energy slice within the Zarr store
#         if idx==0:
#             final_subset.to_zarr(sample_path, mode='w', consolidated=True)
#         else:
#             final_subset.to_zarr(sample_path, mode='a', append_dim='energy', consolidated=True)



In [None]:
def make_para_perp_DAs(datasets, sample_name, edge, intensity_type, pol, qlims, chi_width):
    # select dataarray to plot
    DS = datasets[f'polar_{edge}']
    DA = DS.sel(sample_name=sample_name)[f'{intensity_type}_intensity']
    sliced_DA = DA.sel(polarization=pol, q=slice(qlims[0],qlims[1]))

    # calculate ISI dataarrays
    if pol==0:
        para_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        perp_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))
    elif pol==90:
        perp_DA = sliced_DA.rsoxs.slice_chi(180, chi_width=(chi_width/2))
        para_DA = sliced_DA.rsoxs.slice_chi(90, chi_width=(chi_width/2))   
        
    return para_DA, perp_DA

# load dictionary of rsoxs datasets
rsoxs_datasets = {}
key = 'polar_regions'
key_start = key.split('_')[0]
key_end = key.split('_')[1]        
zarrPath = list(zarrsPath.glob(f'{key_start}*{key_end}.zarr'))[0]
rsoxs_datasets[key] = xr.open_zarr(zarrPath)

# Compute any dask coordiantes
for coord_name, coord_data in rsoxs_datasets[key].coords.items():
    if isinstance(coord_data.data, da.Array):
        rsoxs_datasets[key].coords[coord_name] = coord_data.compute()
            
rsoxs_datasets[key]

In [None]:
# make selection
edge = 'regions'
intensity_type = 'corr'
qlims = (0.01, 0.08)
chi_width = 30

for sample_name in tqdm(rsoxs_datasets[f'polar_{edge}'].sample_name.data):
    for pol in [0, 90]:
        ### Select para & perp DataArrays
        para_DA, perp_DA = make_para_perp_DAs(rsoxs_datasets, sample_name, edge, intensity_type, pol, qlims, chi_width)  
        
#         ### ISI:
#         # Slice ISI data
#         para_ISI = para_DA.interpolate_na(dim='q').mean('chi').sum('q')
#         perp_ISI = perp_DA.interpolate_na(dim='q').mean('chi').sum('q')

#         # Plot
#         fig, ax = plt.subplots()
#         para_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='para', yscale='log')
#         perp_ISI.sel(energy=slice(280,290)).plot.line(ax=ax, label='perp', yscale='log')
#         fig.suptitle('Integrated Scattering Intensity (ISI)', fontsize=14)
#         ax.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', xlabel='Photon Energy [eV]', ylabel='Double-Norm-Corrected Intensity [arb. units]')
#         ax.legend()
#         fig.savefig(plotsPath.joinpath('isi', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
#         plt.close('all')
        
#         ### Linecut Maps:
#         fig, axs = plt.subplots(1, 2, figsize=(11,5))

#         para_DA.mean('chi').sel(energy=slice(282,290)).plot(ax=axs[0], cmap=cmap, norm=LogNorm(1e9, 1e11), add_colorbar=False)
#         perp_DA.mean('chi').sel(energy=slice(282,290)).plot(ax=axs[1], cmap=cmap, norm=LogNorm(1e9, 1e11), add_colorbar=False)

#         sm = plt.cm.ScalarMappable(cmap=cmap, norm=LogNorm(2e10, 1e12)) # Create a ScalarMappable object with the colormap and normalization & add the colorbar to the figure
#         cax = axs[1].inset_axes([1.03, 0, 0.05, 1])
#         cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
#         cbar.set_label(label='Intensity [arb. units]', labelpad=12)
#         fig.suptitle(f'Linecut Maps: {sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', fontsize=14)
#         fig.set(tight_layout=True)
#         axs[0].set(title='Parallel to $E_p$', ylabel='Photon energy [eV]', xlabel='q [$Å^{-1}$]')
#         axs[1].set(title='Perpendicular to $E_p$ ', ylabel=None, xlabel='q [$Å^{-1}$]')
#         fig.savefig(plotsPath.joinpath('linecut_maps', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
#         plt.close('all')

        ### AR Maps:
        # Select AR data
        ar_DA = (para_DA.mean('chi') - perp_DA.mean('chi')) / (para_DA.mean('chi') + perp_DA.mean('chi'))

        # Plot
        ax = ar_DA.sel(energy=slice(282,292)).plot(figsize=(8,5), norm=plt.Normalize(-0.6, 0.6))
        ax.figure.suptitle('Anisotropy Ratio (AR) Map', fontsize=14, x=0.43)
        ax.axes.set(title=f'{sample_name}, Polarization = {pol}°, Chi Width = {chi_width}°', ylabel='Photon Energy [eV]', xlabel='q [$Å^{-1}$]')
        ax.colorbar.set_label('AR [arb. units]', rotation=270, labelpad=12)
        # ax.figure.savefig(plotsPath.joinpath('ar_maps', f'{sample_name}_{edge}_{intensity_type}_chiWidth-{chi_width}deg_pol{pol}deg.png'), dpi=120)
        plt.show()
        plt.close('all')