# 2024C3 SMI WAXS TReXS plotting notebook

## Setup

### Imports (ignore any warnings)

In [None]:
import PyHyperScattering as phs
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import xarray as xr
from tqdm.auto import tqdm 
import subprocess
import io
import gc
print(phs.__version__)

### Define paths & objects/functions

In [None]:
# Define paths
propPath = pathlib.Path('/nsls2/data/smi/proposals/2024-3/pass-316856')
outPath = propPath.joinpath('processed_data')
refsPath = outPath.joinpath('references')
sampleZarrsPath1 = outPath.joinpath('zarrs/waxs_last_morning')
sampleZarrsPath2 = outPath.joinpath('zarrs/waxs_polysulfide_solutions_zarrs_v1')
sampleZarrsPath3 = outPath.joinpath('zarrs/waxs_Li2S8_static_solution_zarrs_v1')
# print(sampleZarrsPath.exists())

# Set a colormap for later
cmap = plt.cm.turbo.copy()
cmap.set_bad(cmap.get_under())

### Rclone copy statement(s) for saving out data

`rclone --dry-run copy -P /nsls2/data/smi/proposals/2024-3/pass-314903/processed_data/trexs_plots remote:research/data_analysis/rsoxs_suite/trexs_plots/2024C3 --exclude '*checkpoint*'`

### Load whole dataset from zarr(s)

In [None]:
sn = {
    'Li2S_TEGDME_reredo': 'Li$_2$S TEGDME flow solution',
    'TEGDME_neat': 'TEGDME flow solvent',
    'Li2S8_static': 'Li$_2$S$_8$ TEGDME static solution',
    'Li2S_powder_2': 'Li$_2$S powder'
}

In [None]:
sample_zarrs = []
for zarrPath in [sampleZarrsPath1, sampleZarrsPath2, sampleZarrsPath3]:
    zarrs = sorted(zarrPath.glob('*'))
    sample_zarrs += zarrs
    
[f.name for f in sample_zarrs]

In [None]:
# Check zarr sample names
unique_sample_names = sorted(set(['_'.join(f.name.split('_')[1:-1]) for f in sample_zarrs]))
unique_sample_names

In [None]:
# Load each caked & recip zarr

recip_DS_rows = []
caked_DS_rows = []
for sample_zarr in tqdm(sample_zarrs):
    sample_name = '_'.join(sample_zarr.name.split('_')[1:-1])    
    
    if 'recip_' in sample_zarr.name:
        recip_DS = xr.open_zarr(sample_zarr)
        recip_DS_rows.append(recip_DS)
    elif 'caked_' in sample_zarr.name:
        caked_DS = xr.open_zarr(sample_zarr)
        caked_DS_rows.append(caked_DS)
    
recip_DS = xr.concat(recip_DS_rows, 'sample_name')
caked_DS = xr.concat(caked_DS_rows, 'sample_name')

# Rechunk appropriately
recip_DS = recip_DS.chunk({'sample_name':1, 'theta':1, 'pix_y': 1000, 'pix_x': 921, 'energy':63,})
caked_DS = caked_DS.chunk({'sample_name':1, 'theta':1, 'index_y':500,'index_x':500,'energy':63})

# Reassign caked dataset to use q and chi instead of indices
caked_DS = caked_DS.swap_dims({'index_y':'chi'})

q_r_coords = caked_DS.q_r.mean('energy')  # some weirdness around q for caked data, shouldn't change with energy but there are float changes
caked_DS = caked_DS.assign_coords({'q':('index_x', q_r_coords.data)}).swap_dims({'index_x':'q'}).drop_vars('q_r')

# Show loaded datasets:
display(recip_DS)
display(caked_DS)

In [None]:
# # Check zarr sample names
# unique_sample_names = sorted(set(['_'.join(f.name.split('_')[1:-1]) for f in sampleZarrsPath1.glob('*')]))
# unique_sample_names

In [None]:
# # Load each caked & recip zarr

# recip_DS_rows = []
# caked_DS_rows = []
# for sample_name in tqdm(unique_sample_names):
#     sample_zarrs = sorted(sampleZarrsPath1.glob(f'*{sample_name}*'))
#     # display(sorted([f.name for f in sample_zarrs]))
    
#     samp_recip_DS_rows = []
#     samp_caked_DS_rows = []
#     for sample_zarr in sample_zarrs:
#         if 'recip_' in sample_zarr.name:
#             recip_DS = xr.open_zarr(sample_zarr)
#             samp_recip_DS_rows.append(recip_DS)
#         elif 'caked_' in sample_zarr.name:
#             caked_DS = xr.open_zarr(sample_zarr)
#             samp_caked_DS_rows.append(caked_DS)
            
#     recip_DS = xr.concat(samp_recip_DS_rows, 'theta')
#     recip_DS_rows.append(recip_DS)
    
#     caked_DS = xr.concat(samp_caked_DS_rows, 'theta')
#     caked_DS_rows.append(caked_DS)
    
# recip_DS = xr.concat(recip_DS_rows, 'sample_name')
# caked_DS = xr.concat(caked_DS_rows, 'sample_name')

# # Rechunk appropriately
# recip_DS = recip_DS.chunk({'sample_name':1, 'theta':1, 'pix_y': 1000, 'pix_x': 921, 'energy':63,})
# caked_DS = caked_DS.chunk({'sample_name':1, 'theta':1, 'index_y':500,'index_x':500,'energy':63})

# # Reassign caked dataset to use q and chi instead of indices
# caked_DS = caked_DS.swap_dims({'index_y':'chi'})

# q_r_coords = caked_DS.q_r.mean('energy')  # some weirdness around q for caked data, shouldn't change with energy but there are float changes
# caked_DS = caked_DS.assign_coords({'q':('index_x', q_r_coords.data)}).swap_dims({'index_x':'q'}).drop_vars('q_r')

# # Show loaded datasets:
# display(recip_DS)
# display(caked_DS)

### Load pindiode data for intenisty normalization (flux changes with energy)

In [None]:
diode_data = np.loadtxt(refsPath.joinpath('energy_bpm3s_bpm2s_pds_diode_refs.txt'))
print(diode_data.shape)

# %matplotlib inline
# plt.close('all')
# plt.plot(diode_data[:,0], diode_data[:,1]/diode_data[:,1].max(), label='bpm3s')
# plt.plot(diode_data[:,0], diode_data[:,2]/diode_data[:,2].max(), label='bpm2s')
# plt.plot(diode_data[:,0], diode_data[:,3]/diode_data[:,3].max(), label='pds')
# plt.legend()
# plt.show()

In [None]:
# Put into an Xarray dataset for quick interpolation to match loaded data energies
diode_DS = xr.Dataset()
for i, diode_name in enumerate(['bpm3s', 'bpm2s', 'pds']):
    diode_DS[diode_name] = xr.DataArray(data=diode_data[:,i+1],
                                        coords={'energy':diode_data[:,0]*1000},
                                        dims='energy')
    
# Interpolate energies to match loaded data energies
diode_DS = diode_DS.interp({'energy':recip_DS.energy.values})

diode_DS

### Apply -3 eV (SMI) energy shift to diode and recip/caked datasets:

In [None]:
# Apply -3 eV (SMI) energy shift to diode and recip/caked datasets:

diode_DS['energy'] = diode_DS.energy-3  # Shift 3 lower energy (SMI offset)
recip_DS['energy'] = recip_DS.energy-3  # Shift 3 lower energy (SMI offset)
caked_DS['energy'] = caked_DS.energy-3  # Shift 3 lower energy (SMI offset)

## Cartesian plots

### 2D detector plots 

In [None]:
# Define, then check plotter function
def plotter(DA, energy, cmin, cmax):
    ax = DA.plot.imshow(figsize=(5.5,4.5), x='q_x', y='q_y', cmap=cmap, norm=plt.Normalize(cmin,cmax))
    # ax = DA.plot.imshow(origin='upper', figsize=(5.5,4.5), cmap=cmap, norm=plt.Normalize(cmin,cmax))
    # ax.figure.suptitle(f'Photon Energy = {np.round(energy, 1)} eV', fontsize=14, y=0.96)
    ax.figure.set_tight_layout(True)
    ax.axes.set(aspect='equal', title=f'{sn[sample_name]}: Energy = {np.round(energy, 1)-3} eV', xlabel='Q$_x$ [Å$^{-1}$]', ylabel='Q$_y$ [Å$^{-1}$]')
    # ax.axes.set(aspect='equal', title=f'{sn[sample_name]}: Energy = {np.round(energy, 1)-3} eV')
    ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=12)
    
    return ax

In [None]:
%matplotlib inline
plt.close('all')

# Select Dataset
DS = recip_DS

# Select Plotting Parameters
energy = 2490

# Select DataArray
for sample_name in tqdm(unique_sample_names):
    for theta in [90]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']
        # Plot
        sliced_DA = DA.sel(energy=energy,method='nearest')
        cmin = float(sliced_DA.compute().quantile(0.18))
        cmax = float(sliced_DA.compute().quantile(0.996))
        im = plotter(sliced_DA, energy, cmin, cmax)
        
        # savePath = outPath.joinpath('polysulfides_plots/flow_cell/waxs/recip_plots_v1')
        savePath = outPath.joinpath('polysulfides_plots/static_cell/waxs/recip_plots_v1')
        savePath.mkdir(exist_ok=True)
        im.figure.savefig(savePath.joinpath(f'{sample_name}_energy{energy-3}eV.png'), dpi=120)
    
        plt.show()
        # plt.close('all')

#### Set NEXAFS ROIs based on interactive detector plot(s)

In [None]:
# # (pix_x_min, pix_x_max, pix_y_min, pix_y_max)
# nf_rois_dict = {
#     'ROI_1': (491, 555, 674, 710),
#     'ROI_2': (233, 296, 495, 530),
#     'ROI_3': (800, 860, 674, 710),
#     'ROI_4': (387, 472, 150, 175),
# }

#### Save detector energy mp4 movies

In [None]:
def da_to_mp4(DA, dim, output_path, plotter, frame_rate=9, quality=17, cmin_quantile=0.1, cmax_quantile=0.99, clim_style='fixed'):
    """
    Generate mp4 video of images along a specified dimension (e.g. energy, time). 
    Requires subprocess import. 
    
    Inputs:
    DA (xr.DataArray): DataArray to generate mp4 from
    dim (str): dimension to generate frames along
    output_path (str or pathlib.Path): path to generated mp4 (includes mp4 filename)
    plotter (function): wrap custom matplotlib plotting code for each frame into a function to be called for each frame
    frame_rate (int, default=15): frame rate of mp4 generated
    quality (int, default=17): 'crf' quality value; lower is better, 17 is often considered visually lossless
    cmin_quantile (float, default=0.1): cmin quantile
    cmax_quantile (float, default=0.99): cmax quantile
    clim_style (str, default='fixed'): 'fixed' or 'by_frame', decide whether color limits should change with each frame or remain fixed based on whole dataset
    
    Outputs:
    mp4 movie file where specified in output path
    """
    if clim_style=='fixed':
        cmin = float(DA.compute().quantile(cmin_quantile))
        cmax = float(DA.compute().quantile(cmax_quantile))

    # FFmpeg command. This is set up to accept data from the pipe and use it as input, with PNG format.
    # It will then output an H.264 encoded MP4 video.
    cmd = [
        'ffmpeg',
        '-y',  # Overwrite output file if it exists
        '-f', 'image2pipe',
        '-vcodec', 'png',
        '-r', str(frame_rate),  # Frame rate
        '-i', '-',  # The input comes from a pipe
        '-vcodec', 'libx264',
        '-pix_fmt', 'yuv420p',
        '-crf', str(quality),  # Set the quality (lower is better, 17 is often considered visually lossless)
        str(output_path)
    ]

    # Start the subprocess
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    # Loop through the energy dimension and send frames to FFmpeg
    for value in tqdm(DA[dim].values, desc=f'Building MP4'):
        # Make & customize plot
        sliced_DA = DA.sel({dim:value}, method='nearest')
        if clim_style=='by_frame':
            cmin = float(DA.compute().quantile(cmin_quantile))
            cmax = float(DA.compute().quantile(cmax_quantile))
        ax = plotter(sliced_DA, value, cmin, cmax)

        buf = io.BytesIO()
        ax.figure.savefig(buf, format='png')
        buf.seek(0)

        # Write the PNG buffer data to the process
        proc.stdin.write(buf.getvalue())
        plt.close('all')

    # Finish the subprocess
    out, err = proc.communicate()
    if proc.returncode != 0:
        print(f"Error: {err}")    
    

In [None]:
# Select Dataset
DS = recip_DS

# Select DataArray
for sample_name in tqdm(unique_sample_names[:]):
    for theta in [90]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']   
        
        savePath = outPath.joinpath('trexs_plots/waxs_core_films_trexs_plots')
        output_path = savePath.joinpath(f'{sample_name}_recip.mp4')
        
        da_to_mp4(DA, 'energy', output_path, plotter)

### NEXAFS

In [None]:
# (pix_x_min, pix_x_max, pix_y_min, pix_y_max)
nf_rois_dict = {
    'ROI_1': (491, 555, 674, 710),
    'ROI_2': (233, 296, 495, 530),
    'ROI_3': (800, 860, 674, 710),
    # 'ROI_4': (387, 472, 150, 175),
}

In [None]:
%matplotlib inline
plt.close('all')

# for nexafs ROI plotting!
DS = recip_DS.copy()

# make selection(s)
# e_slice = slice(None, None)
e_slice = slice(2460, 2490)

for sample_name in tqdm(unique_sample_names[:]):
    for theta in [90]:
        fig, ax = plt.subplots(figsize=(6.5, 3.5), tight_layout=True)
        for roi, extents_tuple in nf_rois_dict.items():
            pix_x_min, pix_x_max, pix_y_min, pix_y_max = extents_tuple
            
            DA = DS['flatfield_corr'].sel(sample_name=sample_name, theta=theta)

            # Integrate ROI box:
            areas_DA = DA.sel(pix_x=slice(pix_x_min, pix_x_max),pix_y=slice(pix_y_min, pix_y_max)).integrate('pix_x').integrate('pix_y')
            
            # Divide areas DA by diode data
            areas_DA = areas_DA / diode_DS['pds']
            
            # Quick plot corrections: Subtract pre-edge mean, divide post-edge mean
            areas_DA = areas_DA - areas_DA.sel(energy=slice(2448, 2462)).mean('energy')  # subtract pre-edge
            # areas_DA = areas_DA / areas_DA.sel(energy=slice(2505, 2535)).mean('energy')  # divide post-edge
            areas_DA = areas_DA / areas_DA.sel(energy=slice(2495, 2505)).mean('energy')  # divide post-edge

            # Plot        
            areas_DA.sel(energy=e_slice).plot.line(ax=ax, label=f'{roi}')

        ax.set_title(f'FL NEXAFS: {sample_name}')
        ax.set(ylabel=f'Integrated Intensity [arb. units]', xlabel='Energy [eV]')

        ax.xaxis.set_minor_locator(MultipleLocator(1))
        ax.xaxis.grid(True, which='both')
        ax.legend(title='NEXAFS ROIs')

        # savePath = outPath.joinpath('trexs_plots/waxs_core_films_trexs_plots/nexafs_rois_vtesting')
        # savePath.mkdir(exist_ok=True)
        # fig.savefig(savePath.joinpath(
        #     f'{sample_name}_theta-{theta}deg_chiWidth-{chi_width}deg_q-{q_slice.start}-{q_slice.stop}_energy{e_slice.start}-{e_slice.stop}.png'), dpi=120)

        plt.show()
        # plt.close('all')

In [None]:
unique_sample_names

In [None]:
outPath

In [None]:
%matplotlib inline
plt.close('all')

# Average ROIs and make nicer NEXAFS plot to save out
DS = recip_DS.copy()

# make selection(s)
# e_slice = slice(None, None)
# e_slice = slice(2445, 2535)
e_slice = slice(2455, 2495)
# e_slice = slice(2465, 2480)

fig, ax = plt.subplots(figsize=(7, 4), tight_layout=True, dpi=150)
for sample_name in tqdm(unique_sample_names[:-1]):
    # fig, ax = plt.subplots(figsize=(5, 3), tight_layout=True, dpi=150)
    for theta in [90]:
        areas_DS = xr.Dataset()
        for roi, extents_tuple in nf_rois_dict.items():
            pix_x_min, pix_x_max, pix_y_min, pix_y_max = extents_tuple

            DA = DS['flatfield_corr'].sel(sample_name=sample_name, theta=theta)

            # Integrate ROI box:
            areas_DA = DA.sel(pix_x=slice(pix_x_min, pix_x_max),pix_y=slice(pix_y_min, pix_y_max)).integrate('pix_x').integrate('pix_y')

            # Divide areas DA by diode data
            areas_DA = areas_DA / diode_DS['pds']

            # Quick plot corrections: Subtract pre-edge mean, divide post-edge mean
            areas_DA = areas_DA - areas_DA.sel(energy=slice(2448, 2462)).mean('energy')  # subtract pre-edge
            # areas_DA = areas_DA / areas_DA.sel(energy=slice(2505, 2535)).mean('energy')  # divide post-edge
            areas_DA = areas_DA / areas_DA.sel(energy=slice(2495, 2505)).mean('energy')  # divide post-edge

            areas_DS[roi] = areas_DA

        avg_areas_DA = areas_DS.to_array(dim='ROI').mean('ROI')

        # Plot    
        avg_areas_DA.sel(energy=e_slice).plot.line(ax=ax, label=sn[sample_name])

# ax.set_title(f'Fluorescence NEXAFS: {sn[sample_name]}')
ax.set_title(f'Fluorescence NEXAFS')
ax.set(ylabel=f'Intensity [arb. units]', xlabel='Energy [eV]')
ax.legend()

ax.xaxis.set_minor_locator(MultipleLocator(2))
ax.xaxis.set_major_locator(MultipleLocator(4))
# ax.xaxis.set_major_locator(MultipleLocator(2))
ax.xaxis.grid(True, which='both')

# savePath = outPath.joinpath('polysulfides_plots/Li2S_powder/fl_nexafs')
# savePath = outPath.joinpath('polysulfides_plots/flow_cell/waxs/fl_nexafs')
# savePath = outPath.joinpath('polysulfides_plots/static_cell/waxs/fl_nexafs')
savePath = outPath.joinpath('polysulfides_plots/together/waxs/fl_nexafs')
savePath.mkdir(exist_ok=True)
fig.savefig(savePath.joinpath(
    # f'{sample_name}_theta-{theta}-deg_4-rois-avg-nexafs_energy-{e_slice.start}-{e_slice.stop}.png'), dpi=150)
    f'combined-samples_theta-{theta}-deg_4-rois-avg-nexafs_energy-{e_slice.start}-{e_slice.stop}.png'), dpi=150)

plt.show()
# plt.close('all')

In [None]:
outPath

## Polar plots

### 2D polar plots

In [None]:
# Select Dataset
DS = caked_DS.copy()


# Select Plotting Parameters
energy = 2470
chi_slice = slice(-125,25)
# chi_slice = slice(-150,50)
# chi_slice = slice(None,None)

qr_slice = slice(None,0.7)
# qr_slice = slice(None,None)

# Select DataArray
# sample_name = 'PM6-Y6_3000_dSiN'
for sample_name in tqdm(unique_sample_names):
    for theta in [90]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']


        # Plot
        sliced_DA = DA.sel(energy=energy,method='nearest').sel(chi=chi_slice, q_red=qr_slice)
        cmin = float(sliced_DA.compute().quantile(0.1))
        cmax = float(sliced_DA.compute().quantile(0.98))

        ax = sliced_DA.plot.imshow(figsize=(5.5,4.5), x='q_r', y='chi', cmap=cmap, norm=plt.Normalize(cmin,cmax))
        ax.figure.suptitle(f'Photon Energy = {np.round(energy, 1)} eV', fontsize=14, y=0.96)
        ax.figure.set_tight_layout(True)
        ax.axes.set(title=f'{sample_name}, $\\theta$ = {theta}°', xlabel='q$_r$ [$Å^{-1}$]', ylabel='$\\chi$ [°]')
        ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=12)
        # ax.figure.savefig(outPath.joinpath('waxs_detector_movies_v1', f'{sample_name}_{theta}degth.png'), dpi=120)
        plt.show()
        plt.close('all')

#### Save polar plot movies

In [None]:
# Select Dataset
DS = caked_DS.copy()

# plotting parameters
chi_slice = slice(-150,50)
# chi_slice = slice(None,None)

qr_slice = slice(None,0.7)
# qr_slice = slice(None,None)

# Select DataArray
for sample_name in tqdm(unique_sample_names):
# for sample_name in tqdm(['PM6_1CN-CB']):
    for theta in [90, 55, 35]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']
        cmin = float(DA.compute().quantile(0.15))
        cmax = float(DA.compute().quantile(0.995))

        output_path = outPath.joinpath('trexs_plots/caked_waxs_detector_movies_v1', f'{sample_name}_{theta}degth.mp4')

        # FFmpeg command. This is set up to accept data from the pipe and use it as input, with PNG format.
        # It will then output an H.264 encoded MP4 video.
        cmd = [
            'ffmpeg',
            '-y',  # Overwrite output file if it exists
            '-f', 'image2pipe',
            '-vcodec', 'png',
            '-r', '15',  # Frame rate
            '-i', '-',  # The input comes from a pipe
            '-vcodec', 'libx264',
            '-pix_fmt', 'yuv420p',
            '-crf', '17',  # Set the quality (lower is better, 17 is often considered visually lossless)
            str(output_path)
        ]

        # Start the subprocess
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        # Loop through the energy dimension and send frames to FFmpeg
        for i, energy in enumerate(tqdm(DA.energy.values, desc=f'Making the {sample_name} {theta}° movie')):
            # Make & customize plot
            sliced_DA = DA.sel(energy=energy,method='nearest').swap_dims(
                {'index_y':'chi','index_x':'q_r'}).sel(chi=chi_slice, q_r=qr_slice)
            
            ax = sliced_DA.plot.imshow(figsize=(5.5,4.5), cmap=cmap, norm=plt.Normalize(cmin,cmax))
            ax.figure.suptitle(f'Photon Energy = {np.round(energy, 1)} eV', fontsize=14, y=0.96)
            ax.figure.set_tight_layout(True)   
            ax.axes.set(title=f'{sample_name}, $\\theta$ = {theta}°', xlabel='q$_r$ [$Å^{-1}$]', ylabel='$\\chi$ [°]')
            ax.colorbar.set_label('Intensity [arb. units]', rotation=270, labelpad=12)

            # Save figure if first frame:
            if i==0:
                ax.figure.savefig(outPath.joinpath('trexs_plots/caked_waxs_detector_movies_v1', f'{sample_name}_{theta}degth.png'), dpi=120)

            buf = io.BytesIO()
            ax.figure.savefig(buf, format='png')
            buf.seek(0)

            # Write the PNG buffer data to the process
            proc.stdin.write(buf.getvalue())
            plt.close('all')

        # Finish the subprocess
        out, err = proc.communicate()
        if proc.returncode != 0:
            print(f"Error: {err}")


### I vs Q linecuts

In [None]:
chi_width = 90
q_slice = slice(0.11,0.65)
energy = 2445
# bad_qr_slices = [
#     slice(0.34, 0.37), 
#     slice(0.49, 0.55)
# ]

bad_qr_slices = []

def qr_linecut_plotter(DA, energy):
    """
    Input DA: sliced DA just to plot
    """
    # Remove bad q ranges:
    for slice_to_nan in bad_qr_slices:
        DA.loc[{'q_red': slice_to_nan}] = np.nan
        # DA = DA.interpolate_na(dim='q_red')
    
    # Make para & perp DAs:
    para_DA, perp_DA = make_para_perp_DAs(DA)
    
    avg_DA_mean = (para_DA.mean('chi') + perp_DA.mean('chi')) / 2
    para_DA_mean = para_DA.mean('chi')
    perp_DA_mean = perp_DA.mean('chi')
    
    # Plot
    regions = ['para', 'perp', 'full']
    # colors = plt.cm.Dark2(np.linspace(0, 1, 8))
    colors = plt.cm.viridis(np.linspace(0, 0.85, 3))

    fig, ax = plt.subplots(figsize=(6,3.5), tight_layout=True, dpi=120)
    
    # for j, energy in enumerate(energies):
    p2, = (para_DA_mean.sel(q_red=q_slice).sel(energy=energy, method='nearest')
     .plot.line(ax=ax, color=colors[0], yscale='linear', xscale='linear', label='Para'))
    p3, = (perp_DA_mean.sel(q_red=q_slice).sel(energy=energy, method='nearest')
     .plot.line(ax=ax, color=colors[2], yscale='linear', xscale='linear', label='Perp'))
    p1, = (avg_DA_mean.sel(q_red=q_slice).sel(energy=energy, method='nearest')
     .plot.line(ax=ax, color=colors[1], yscale='linear', xscale='linear', label='Avg'))

    ax.set_title(f'I vs Q ({chi_width}° wedges): {sample_name}, Energy = {energy:.2f} eV')
    ax.set(ylabel='Intensity [arb. units]', xlabel='Q [$Å^{-1}$]', yscale='log')
    
    lines= [p2,p1,p3]
    ax.legend(loc='lower left', title='$\chi$ regions', handles=lines, labels=[l.get_label() for l in lines])

    return fig, ax

In [None]:
%matplotlib inline
plt.close('all')

# Select necessary slices

DS = caked_DS.copy()
chi_width = 90
q_slice = slice(0.15,0.65)
energy = 2445
theta = 90

for sample_name in unique_sample_names:
    DA = DS['flatfield_corr'].sel(sample_name=sample_name, theta=theta)
    fig, ax = qr_linecut_plotter(DA, energy)

    plt.show()
    plt.close('all')

In [None]:
def da_to_linecut_mp4(DA, dim, output_path, plotter, frame_rate=15, quality=17):
    """
    Generate mp4 video of images along a specified dimension (e.g. energy, time). 
    Requires subprocess import. 
    
    Inputs:
    DA (xr.DataArray): DataArray to generate mp4 from
    dim (str): dimension to generate frames along
    output_path (str or pathlib.Path): path to generated mp4 (includes mp4 filename)
    plotter (function): wrap custom matplotlib plotting code for each frame into a function to be called for each frame
    frame_rate (int, default=15): frame rate of mp4 generated
    quality (int, default=17): 'crf' quality value; lower is better, 17 is often considered visually lossless
    
    Outputs:
    mp4 movie file where specified in output path
    """
    # FFmpeg command. This is set up to accept data from the pipe and use it as input, with PNG format.
    # It will then output an H.264 encoded MP4 video.
    cmd = [
        'ffmpeg',
        '-y',  # Overwrite output file if it exists
        '-f', 'image2pipe',
        '-vcodec', 'png',
        '-r', str(frame_rate),  # Frame rate
        '-i', '-',  # The input comes from a pipe
        '-vcodec', 'libx264',
        '-pix_fmt', 'yuv420p',
        '-crf', str(quality),  # Set the quality (lower is better, 17 is often considered visually lossless)
        str(output_path)
    ]

    # Start the subprocess
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    # Loop through the energy dimension and send frames to FFmpeg
    for value in tqdm(DA[dim].values, desc=f'Building MP4'):
        # Make & customize plot        
        # Make & customize plot
        # sliced_DA = DA.sel({dim:value}, method='nearest')
        fig, ax = plotter(DA, value)

        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)

        # Write the PNG buffer data to the process
        proc.stdin.write(buf.getvalue())
        
        buf.close()
        plt.close('all')
        gc.collect()

    # Finish the subprocess
    out, err = proc.communicate()
    if proc.returncode != 0:
        print(f"Error: {err}")    

In [None]:
# plt.close('all')

# Select necessary slices
DS = caked_DS.copy()
chi_width = 90
q_slice = slice(0.15,0.65)
# energy = 2470

for sample_name in tqdm(unique_sample_names[:]):
    for theta in [35, 55]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']
        
        savePath = outPath.joinpath('trexs_plots/waxs_core_films_trexs_plots')
        output_path = savePath.joinpath('qr_linecut_movies_v1', f'{sample_name}_{theta}degth.mp4')
        
        da_to_linecut_mp4(DA, 'energy', output_path, qr_linecut_plotter)

### I vs Chi linecuts

In [None]:
# make selection, subtracted flat line to compare peak intensities overlayed
DS = caked_DS.copy()

# make selection
chi_slice = slice(-125, 25)
# bad_chi_slices = [slice(-110, -102), slice(-77, -68), slice(-59,-56), slice(-33,-29), slice(-25,-21), slice(-5,-2)]
bad_chi_slices = []

# # # y_fit around lamella peak:
# peak_name = 'Lamella'
# q_min = 0.22
# q_max = 0.33

# y_fit around mystery lump:
peak_name = '0p4 mystery'
q_min = 0.38
q_max = 0.43

# # y_fit around mystery lump:
# q_min = 0.57
# q_max = 0.63

# # y_fit around mystery lump:
# q_min = 0.18
# q_max = 0.63

# Select DataArray
samp_sub_DAs = []
for sample_name in tqdm(unique_sample_names[1:2]):
    # for theta in [90]:
    theta = 90
    DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']
    DA = DA.where(DA>1).where(DA<(DA.max()*0.03))
    for slice_to_nan in bad_chi_slices:
        DA.loc[{'chi': slice_to_nan}] = np.nan
    DA = DA.interpolate_na(dim='chi')

    DA = DA.sel(chi=chi_slice).groupby_bins('chi', 12).mean('chi')

    # # Subtract pre-peak flat line
    # DA = DA - DA.sel(q_red=slice(0.2,0.23)).mean('q_red')

    # Draw y_fits
    points_x = [q_min, q_max]
    points_y = [DA.sel(q_red=slice(points_x[0]-0.005, points_x[0]+0.005)).mean('q_red'), 
                DA.sel(q_red=slice(points_x[1]-0.005, points_x[1]+0.005)).mean('q_red')]
    m_DA = (points_y[1]-points_y[0])/(points_x[1]-points_x[0])
    b_DA = points_y[1] - (m_DA*points_x[1])

    y_fits = np.empty((len(DA.energy), 0, len(DA.q_red)), float)
    for chi_bin in tqdm(DA.chi_bins, desc='Fitting energies in each chi bins'):
        bin_y_fits = np.empty((0, len(DA.q_red)), float)
        for m, b in zip(m_DA.sel(chi_bins=chi_bin).compute().data, b_DA.sel(chi_bins=chi_bin).compute().data):
            y_fit = np.polyval([m, b], DA.q_red).reshape(1, len(DA.q_red))
            bin_y_fits = np.append(bin_y_fits, y_fit, axis=0)

        bin_y_fits = bin_y_fits.reshape(len(DA.energy), 1, len(DA.q_red))
        y_fits = np.append(y_fits, bin_y_fits, axis=1)

    y_fits_DA = DA.copy()
    y_fits_DA.data = y_fits

    sub_DA = DA - y_fits_DA
    samp_sub_DAs.append(sub_DA)
        
sub_DA = xr.concat(samp_sub_DAs, 'sample_name')

In [None]:
# Define plotter function per energy slice
def chi_linecut_plotter(sub_DA, energy):
    """
    Input DA, already with subtraction around lamella peak.
    """
    # Plot
    sliced_DA = sub_DA.sel(q_red=slice(q_min,q_max), energy=energy)
    sliced_DA = sliced_DA.integrate('q_red')

    fig, ax = plt.subplots(figsize=(4.5,3.5), tight_layout=True)

    for i, sample_name in enumerate(sub_DA.sample_name.values):
        samp_sliced_DA = sliced_DA.sel(sample_name=sample_name)
        # samp_sliced_DA = samp_sliced_DA - float(samp_sliced_DA.sel(chi_bins=slice(-60,-40)).mean('chi_bins'))
        samp_sliced_DA.plot.line(ax=ax, color=colors[i], label=sample_name)

    ax.set_title(f'{peak_name} pole figure, $\\theta$ = {theta}°, Energy = {energy:.2f} eV')
    ax.set(ylabel='Chi-binned q-integrated intensity [arb. units]',
           xlabel='Chi value [°]')

    ax.xaxis.set_major_locator(MultipleLocator(90))
    ax.xaxis.set_minor_locator(MultipleLocator(30))
    ax.xaxis.grid(True, which='both')
    ax.legend(loc='upper right')

    return fig, ax

# Plot
colors = plt.cm.viridis(np.linspace(0,0.85,len(sub_DA.sample_name)))
energies = sub_DA.energy.sel(energy=[2445, 2470.2, 2472, 2474, 2476, 2477, 2478, 2484, 2550], method='nearest').data

for i, energy in enumerate(tqdm(energies)):
    fig, ax = chi_linecut_plotter(sub_DA, energy)
    plt.show()
    plt.close('all')

In [None]:
savePath = outPath.joinpath('trexs_plots/waxs_core_films_trexs_plots')
# output_path = savePath.joinpath('chi_linecut_movies_vtesting', f'PM6_CNCB_series_90degth.mp4')
# output_path = savePath.joinpath('chi_linecut_movies_vtesting', f'PM6_CNCF_series_90degth.mp4')
output_path = savePath.joinpath('chi_linecut_movies_vtesting', f'PM6_0p4_peak_CNCB_series_90degth.mp4')

da_to_linecut_mp4(sub_DA, 'energy', output_path, chi_linecut_plotter)

In [None]:
# make selection, subtracted flat line to compare peak intensities overlayed
DS = caked_DS

# make selection
q_slice = slice(0.25, 0.33)   # peak slice here
chi_slice = slice(-125, 25)
bad_chi_slices = [slice(-110, -102), slice(-77, -68), slice(-59,-56), slice(-33,-29), slice(-25,-21), slice(-5,-2)]

# Select DataArray
for sample_name in tqdm(unique_sample_names[2:5]):
    for theta in [90]:
        DA = DS.sel(sample_name=sample_name, theta=theta)['flatfield_corr']
        DA = DA.where(DA>1).where(DA<(DA.max()*0.03))
        for slice_to_nan in bad_chi_slices:
            DA.loc[{'chi': slice_to_nan}] = np.nan
        DA = DA.interpolate_na(dim='chi')
        
        DA = DA.sel(chi=chi_slice).groupby_bins('chi', 20).mean('chi')
        
        # Subtract pre-peak flat line
        DA = DA - DA.sel(q_red=slice(0.2,0.23)).mean('q_red')

        # Plot
        energies = DA.energy.sel(energy=[2445, 2470.2, 2472, 2474, 2476, 2477, 2478, 2484, 2550], method='nearest').data

        cmap = plt.cm.viridis.copy()
        colors = cmap(np.linspace(0, 1, len(energies)))

        fig, ax = plt.subplots(figsize=(4,3), tight_layout=True)
        
        for i, energy in enumerate(energies):
            # sliced_DA = (DA.sel(q_red=q_slice, energy=energy) - 
            #              DA.sel(q_red=slice(0.2,0.23), energy=energy).mean('q_red'))
            # sliced_DA_snap = sliced_DA.copy()
            sliced_DA = DA.sel(q_red=q_slice, energy=energy)
            sliced_DA = sliced_DA.integrate('q_red')
            sliced_DA = sliced_DA / float(sliced_DA.sel(chi_bins=slice(-60,-40)).mean('chi_bins'))
            
            sliced_DA.plot.line(ax=ax, color=colors[i], label=energy)
            
        ax.set_title(f'{sample_name}, $\\theta$={theta}°')
        ax.set(ylabel='Chi-binned integrated intensity [arb. units]')
        
        ax.xaxis.set_major_locator(MultipleLocator(90))
        ax.xaxis.set_minor_locator(MultipleLocator(45))
        ax.xaxis.grid(True, which='both')
        
        plt.show()
        plt.close('all')

In [None]:
# make selection, subtracted flat line to compare peak intensities overlayed
DS = caked_DS.copy()

# make selection
q_slice = slice(0.23, 0.32)   # peak slice here
chi_width = 90
e_slice = slice(2470, 2485)

for sample_name in tqdm(unique_sample_names[:]):
    for theta in [90, 55, 35]:
        para_DA, perp_DA = make_para_perp_DAs(DS, sample_name, theta, chi_width) 
        
        para_DA = para_DA.where(para_DA>0.4).where(para_DA<(para_DA.max()*0.03))
        perp_DA = perp_DA.where(perp_DA>0.4).where(perp_DA<(perp_DA.max()*0.03))

        # Subtract pre-peak flat line
        para_DA = para_DA - para_DA.sel(q_red=slice(0.2,0.23)).mean('q_red')
        perp_DA = perp_DA - perp_DA.sel(q_red=slice(0.2,0.23)).mean('q_red')
        
        # Mean/integrate chi/q:
        para_areas_DA = para_DA.sel(q_red=q_slice).mean('chi').integrate('q_red')
        perp_areas_DA = perp_DA.sel(q_red=q_slice).mean('chi').integrate('q_red')
        
        # Plot        
        fig, ax = plt.subplots(figsize=(5,4), tight_layout=True)
        para_areas_DA.sel(energy=e_slice).plot.line(ax=ax, label='Para')
        perp_areas_DA.sel(energy=e_slice).plot.line(ax=ax, label='Perp')

        fig.suptitle(f'Lamella peak area vs photon energy: {sample_name}', x=0.53, y=0.95)

        ax.set_title(f'$\\theta$ = {theta}°, $\chi$ width = 90°')
        ax.set(ylabel=f'Integrated q ({q_slice.start}, {q_slice.stop}) [arb. units]', xlabel='Energy [eV]')
        
        ax.xaxis.set_minor_locator(MultipleLocator(1))
        ax.xaxis.grid(True, which='both')
        ax.legend()
        

        savePath = outPath.joinpath('trexs_plots/waxs_core_films_trexs_plots/peakarea-vs-energy_v2')
        savePath.mkdir(exist_ok=True)
        fig.savefig(savePath.joinpath(
            f'{sample_name}_theta-{theta}deg_chiWidth-{chi_width}deg_q-{q_slice.start}-{q_slice.stop}_energy{e_slice.start}-{e_slice.stop}.png'), dpi=120)
        
        # plt.show()
        plt.close('all')

In [None]:
# make selection
DS = caked_DS.copy()

# make selection
q_slice = slice(0.1, 0.5)
chi_width = 90
# energy_slice = slice(2470, 2485)

for sample_name in tqdm(unique_sample_names[:]):
    for theta in [90]:
        para_DA, perp_DA = make_para_perp_DAs(DS, sample_name, theta, chi_width) 
        
        para_DA = para_DA.where(para_DA>0.4).where(para_DA<(para_DA.max()*0.03))  #.interpolate_na(dim='chi')
        # para_DA.sel(energy=2477.2,method='nearest').sel(q_red=slice(0.05, 0.7)).plot.imshow()
        # plt.title('para')
        # plt.show()
        
        perp_DA = perp_DA.where(perp_DA>0.4).where(perp_DA<(perp_DA.max()*0.03))  #.interpolate_na(dim='chi')
        # perp_DA.sel(energy=2477.2,method='nearest').sel(q_red=slice(0.05, 0.7)).plot.imshow()
        # plt.title('perp')
        # plt.show()
        # plt.close('all')

        
        # Plot
        energies = para_DA.energy.sel(energy=[2445, 2470.2, 2472, 2474, 2476, 2477, 2478, 2484, 2550], method='nearest').data

        cmap = plt.cm.turbo.copy()
        colors = cmap(np.linspace(0, 1, len(energies)))

        fig, axs = plt.subplots(ncols=2,figsize=(8,4), tight_layout=True)

        for j, energy in enumerate(energies):
            (para_DA.sel(q_red=q_slice, energy=energy).mean('chi')
             .plot.line(ax=axs[0], color=colors[j], yscale='linear', xscale='linear', label=energy))
            (perp_DA.sel(q_red=q_slice, energy=energy).mean('chi')
             .plot.line(ax=axs[1], color=colors[j], yscale='linear', xscale='linear', label=energy))

        fig.suptitle(f'IvsQ, $\\theta$ = {theta}°, $\chi$ width = 90°: {sample_name}', x=0.47)

        axs[0].set(title=f'Parallel to E$_p$', ylabel='Intensity [arb. units]', xlabel='Q [$Å^{-1}$]')
        axs[1].set(title=f'Perpendicular to E$_p$', ylabel='Intensity [arb. units]', xlabel='Q [$Å^{-1}$]')
        axs[1].legend(title='Energy [eV]', loc=(1.05,0.1))

        # fig.savefig(outPath.joinpath('trexs_plots/I_cuts_v1', 
        #     f'{sample_name}_theta-{theta}deg_chiWidth-{chi_width}deg_q-{q_slice.start}-{q_slice.stop}.png'), dpi=120)
        
        plt.show()
        plt.close('all')