In [None]:
from pathlib import Path

from sed import SedProcessor
import sed
import numpy as np

%matplotlib inline
# %matplotlib ipympl
import matplotlib.pyplot as plt

In [None]:
%matplotlib widget

In [None]:
local_path = Path(sed.__file__).parent.parent/'tutorial/'
config_file = local_path/'hextof_config.yaml'
assert config_file.exists()
config={"core": {"paths": {
    "data_raw_dir": "/asap3/flash/gpfs/pg2/2023/data/11019101/raw/hdf/offline/fl1user3", 
    "data_parquet_dir": "/home/agustsss/temp/sed_parquet/"
}}}

# chessy run for spatial calibration

In [None]:
sp_chessy = SedProcessor(runs=[44762], config=config, user_config=config_file, system_config={}, collect_metadata=False)
sp_chessy.add_jitter()

In [None]:
axes = ['dldPosY', 'dldPosX']
bins = [240, 240]
ranges = [[420,900], [420,900]]

In [None]:
res_chessy = sp_chessy.compute(bins=bins, axes=axes, ranges=ranges)

In [None]:
plt.figure()
res_chessy.plot(robust=True)

In [None]:
p1 = 562,707
p2 = 756,754
plt.plot([p1[0], p2[0]], [p1[1], p2[1]], 'r-')
distance = np.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
px_to_um = 200/distance
px_to_um

In [None]:
px_to_um * 300

In [None]:
sp_chessy.dataframe['posx'] = sp_chessy.dataframe['dldPosX'] * px_to_um
sp_chessy.dataframe['posy'] = sp_chessy.dataframe['dldPosY'] * px_to_um


# Loading Data from run44798
* Optical spot profile - FoV=450 µm
* transmission = 1.0

In [None]:
sp98 = SedProcessor(runs=[44798], config=config, user_config=config_file, system_config={}, collect_metadata=False)
sp98.add_jitter()

In [None]:
axes = ['dldPosY', 'dldPosX']
bins = [240, 240]
ranges = [[420,900], [420,900]]
res98 = sp98.compute(bins=bins, axes=axes, ranges=ranges)

In [None]:
plt.figure()
res.plot(robust=True)


# using LMFIT

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import lmfit


In [None]:
def gaussian2D(
        x: float,
        y: float, 
        amplitude:float, 
        xo:float, 
        yo: float, 
        sigma_x:float, 
        sigma_y:float, 
        theta:float, 
        offset:float
    ) -> np.ndarray:# -> Any:
    """ 2D gaussian function
    
    Args:
        x (np.ndarray): x values
        y (np.ndarray): y values
        amplitude (float): amplitude
        xo (float): center x
        yo (float): center y
        sigma_x (float): sigma x
        sigma_y (float): sigma y
        theta (float): rotation angle
        offset (float): offset
    
    Returns:
        np.ndarray: 2D gaussian
    """
#     x, y = M
    xo = float(xo)
    yo = float(yo)
    if sigma_x <= 0 or sigma_y <= 0:
        return np.zeros_like(x)
    a = (np.cos(theta)**2)/(2*sigma_x**2) + (np.sin(theta)**2)/(2*sigma_y**2)
    b = -(np.sin(2*theta))/(4*sigma_x**2) + (np.sin(2*theta))/(4*sigma_y**2)
    c = (np.sin(theta)**2)/(2*sigma_x**2) + (np.cos(theta)**2)/(2*sigma_y**2)
    return offset + amplitude*np.exp( - (a*((x-xo)**2) + 2*b*(x-xo)*(y-yo)
                        + c*((y-yo)**2)))

def sigma_to_fwhm(sigma):
    return 2*np.sqrt(2*np.log(2))*sigma

from sed.utilities.optical import effective_gaussian_area, fluence 


In [None]:
from typing import Dict, Optional, Sequence
import xarray as xr
from lmfit import Model

def fit_spot_size(
        data: xr.DataArray,
        params: Optional[Dict[str, float]] = None,
        photoemission_order: int = 1,
        pulse_energy: Optional[float] = None,
        plot: bool = True,
        figsize=(6,6),
        normalize: bool = False,
    ) -> lmfit.model.ModelResult:
    """ Fit a 2D Gaussian to the spot size data.
    
    Args:
        data (xr.DataArray): DataArray containing the spot size data.
        params (dict): Initial guess for the Gaussian parameters.
        photoemission_order (int): The order of photoemission. 1 for linear, 2 for quadratic.
        plot (bool): Whether to plot the result.
        normalize (bool): Whether to normalize the data. If true, the data will be divided by its 
            maximum value. 
    
    Returns:
        The fitted model result.
    """
    if data.ndim != 2:
        raise ValueError('The data should be a 2D array.')
    dims = data.dims

    ux = data.coords[dims[0]].attrs.get('units','px')
    uy  = data.coords[dims[1]].attrs.get('units','px')
    assert ux == uy, 'The units of the x and y axes should be the same.'
    x = data.coords[dims[0]].values
    y = data.coords[dims[1]].values
    X,Y = np.meshgrid(x,y)
    z = data.values
    if normalize:
        z = z/np.amax(z)
    error = np.sqrt(z+1)

    model = Model(gaussian2D, independent_vars=['x', 'y'], nan_policy='omit')
    params = model.make_params()

    params['amplitude'].set(value=z.max(), min=0)
    params['xo'].set(value=x.mean(), min=x.min(), max=x.max())
    xamp = np.abs(x.max()-x.min())
    yamp = np.abs(y.max()-y.min())
    params['sigma_x'].set(value=xamp/2, min=xamp/100)
    params['yo'].set(value=y.mean(), min=y.min(), max=y.max())
    params['sigma_y'].set(value=yamp/2, min=yamp/100)
    params['theta'].set(value=0, min=0, max=np.pi)
    params['offset'].set(value=0)

    result = model.fit(z, x=X, y=Y, params=params, weights=1/error)

    if plot:
        fig = plt.figure(figsize=figsize, layout='tight')
        # [left, bottom, width, height]
        img_ax = fig.add_axes([0.1, 0.3, 0.5, 0.5], xticklabels=[], yticklabels=[])
        xproj_ax = fig.add_axes([0.1, 0.1, 0.5, 0.2], yticklabels=[])
        yproj_ax = fig.add_axes([0.6, 0.3, 0.2, 0.5], xticklabels=[])

        yproj_ax.yaxis.set_label_position("right")
        yproj_ax.yaxis.set_ticks_position("right")
        # yproj_ax.xaxis.set_label_position("top")
        # yproj_ax.xaxis.set_ticks_position("top") 
        # xproj_ax.set_xticklabels(np.arange(0, max(x) * px_to_um, 5))
        # yproj_ax.set_yticklabels(np.arange(0, max(y) * px_to_um, 5))

        xproj_ax.set_xlabel(dims[0])
        yproj_ax.set_ylabel(dims[1])
        for ax in [img_ax, yproj_ax, xproj_ax]:
            ax.tick_params(
                axis="both",
                direction="in",
                bottom=True,
                top=True,
                left=True,
                right=True,
                which="both",
            )
        # data.plot(ax=img_ax, cmap="terrain", robust=True)
        extent = [x.max(), x.min(), y.min(), y.max()]
        img_ax.imshow(z, cmap="terrain", origin='lower', extent=extent)  # gray_r
        # data.plot.imshow(ax=img_ax, cmap="terrain", robust=True)
        # img_ax.contour(result.best_fit, cmap="inferno", fill=False, origin='lower', extent=extent, )  # ,colors='nipy_spectral')
        xproj_ax.plot(x, z.max(axis=0))
        xproj_ax.plot(x, result.best_fit.max(0))
        xproj_ax.set_xlim(min(x), max(x))
        yproj_ax.plot(-z.max(axis=1), y)
        yproj_ax.plot(-result.best_fit.max(1), y)
        yproj_ax.set_ylim(min(y), max(y))

        FWHM_x = sigma_to_fwhm(result.params['sigma_x'].value)
        FWHM_y = sigma_to_fwhm(result.params['sigma_y'].value)

        x0, y0 = result.params['xo'].value, result.params['yo'].value
        theta = result.params['theta'].value
        x1 = x0 + 0.5 * FWHM_x * np.cos(theta)
        y1 = y0 + 0.5 * FWHM_x * np.sin(theta)
        x2 = x0 - 0.5 * FWHM_x * np.cos(theta)
        y2 = y0 - 0.5 * FWHM_x * np.sin(theta)
        img_ax.plot([x1, x2], [y1, y2], 'k-')

        x1 = x0 + 0.5 * FWHM_y * np.cos(theta + np.pi/2)
        y1 = y0 + 0.5 * FWHM_y * np.sin(theta + np.pi/2)
        x2 = x0 - 0.5 * FWHM_y * np.cos(theta + np.pi/2)
        y2 = y0 - 0.5 * FWHM_y * np.sin(theta + np.pi/2)
        img_ax.plot([x1, x2], [y1, y2], 'k-')

        eff_area = effective_gaussian_area(
            (FWHM_x, FWHM_y), photoemission_order=photoemission_order, sigma_is_fwhm=True
        )
        report_label = "Parameters:\n"
        report_label += f"Gaussian FWHM\nx={FWHM_x:.2f} {ux} | y={FWHM_y:.2f} {uy}\n"
        report_label += f"Eff. Area: {eff_area:,.2f} {ux}$^2$\n"
        report_label += f'center: {x0:.2f} {ux}, {y0:.2f} {uy}\n'

        if pulse_energy is not None:
            # report_label += f"OD mean: {pulse_energy:.2f}\n"

            # pe = pulse_energy(optical_diode_value, od_2_uj=od_2_uj, T=optical_transmission)
            report_label += f"Pulse energy: {1000*pulse_energy:.3f} nJ\n"
            fl = fluence(pulse_energy, eff_area)
            report_label += f"Fluence: {fl:.3f} mJ/cm²\n"

        t = img_ax.text(
            x.min() * .98,
            y.min() * .98,
            report_label,
            color="Black",
            va="top",
        )  # fontsize='small',)
    return result


In [None]:
def circle_mask(x,y,xc,yc,r):
    return (x-xc)**2 + (y-yc)**2 < r**2
def apply_circle_mask(r, r_mask, xc=None, yc=None):
    if xc is None:
        xc = r.dldPosX.mean()
    if yc is None:
        yc = r.dldPosY.mean()
    return r.where(circle_mask(r.dldPosX, r.dldPosY, xc, yc, r_mask))


In [None]:
res['dldPosX'].attrs['units'] = 'µm'
res['dldPosY'].attrs['units'] = 'µm'
fr98 = fit_spot_size(res, pulse_energy=1, plot=True)

In [None]:
config={"core": {"paths": {
    "data_raw_dir": "/asap3/flash/gpfs/pg2/2023/data/11019101/raw/hdf/offline/fl1user3", 
    "data_parquet_dir": "/home/agustsss/temp/sed_parquet/"
}}}
sp98 = SedProcessor(runs=[44798], config=config, user_config=config_file, system_config={}, collect_metadata=False)
sp98.add_jitter()
sp99 = SedProcessor(runs=[44799], config=config, user_config=config_file, system_config={}, collect_metadata=False)
sp99.add_jitter()

In [None]:
axes = ['dldPosY', 'dldPosX']
bins = [240, 240]
ranges = [[420,900], [420,900]]
res98 = sp98.compute(bins=bins, axes=axes, ranges=ranges)
res99 = sp99.compute(bins=bins, axes=axes, ranges=ranges)


In [None]:
fig,ax = plt.subplots(1,2)
res98.plot(ax=ax[0], robust=True)
res99.plot(ax=ax[1], robust=True)

In [None]:
fr98 = fit_spot_size(res98, pulse_energy=1, plot=True, normalize=True)
fr99 = fit_spot_size(res99, pulse_energy=1, plot=True, normalize=False)

In [None]:
fr99