# Prelude 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tkinter.filedialog import askopenfilename, askopenfilenames
import xarray as xr
%matplotlib widget
# allow multiple outputs in one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

# Define various function that are useful for data visualization 
# Might be nice to create a librairy to import in the future 

def heatmap_interactive(_x, _y, _data, _title, _cmap='jet', _symlog=False):
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.5], height_ratios=[0.5, 1], hspace=0.2, wspace=0.2)
    ax_main = plt.subplot(gs[1, 0])
    main_plot = ax_main.pcolormesh(_x, _y, _data, cmap=_cmap)
    ax_main.set(xlabel='Delay / ps', ylabel='Wavelength / nm')
    # set mixed log-lin scale with threshold value linthresh
    if _symlog:
        ax_main.set_xscale('symlog', linthresh=0.01)
    # set axis range to min and max values
    ax_main.set_xlim(_x[0],_x[-1])
    ax_main.set_ylim(_y[0],_y[-1]) 

    ax_kin = plt.subplot(gs[0, 0])
    line_kin, = ax_kin.plot(_x,np.zeros(_x.shape))
    kin_zero_line, = ax_kin.plot([_x[0],_x[-1]],[0,0], color="0.6")
    ax_kin.set_xlim(_x[0],_x[-1])
    if _symlog:
        ax_kin.set_xscale('symlog', linthresh=0.01)

    ax_spec = plt.subplot(gs[1, 1])
    line_spec, = ax_spec.plot(np.zeros(_y.shape),_y)
    spec_zero_line, = ax_spec.plot([0,0],[_y[0],_y[-1]], color="0.6")
    ax_spec.set_ylim(_y[0],_y[-1])        
    
    # This lower bounds list is necessary because the blocks in the 2D-plot cover a certain range
    def create_lower_bounds(_value_list):
        result = np.empty_like(_value_list)
        #first lower bound is equal to the lowest value in the nm-list
        result[0] = _value_list[0]
        #example: lower bound for 100 ps is 97.5 ps if the value prior is 95 ps, and 75 ps if the value prior is 50 ps.
        for i in range(1,len(_value_list)):
            result[i] = (_value_list[i]+_value_list[i-1])/2
        return result    
    
    nm_lower_bounds = create_lower_bounds(_y)
    time_lower_bounds = create_lower_bounds(_x)
    
    def nm_to_index(_nm):
        return np.where(_nm > nm_lower_bounds)[0][-1]
    
    def time_to_index(_time):
        return np.where(_time > time_lower_bounds)[0][-1]
    
    def mouse_move(event):
        x = event.xdata
        y = event.ydata
        if x is not None and y is not None:
            if x>=_x[0] and x<=_x[-1] and y>=_y[0] and y<=_y[-1]:
                # update spectra slice and rescale
                new_spec = _data[:,time_to_index(x)]
                line_spec.set_xdata(new_spec)
                spec_bounds = ax_spec.get_ylim()
                spec_range = new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()-new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()
                ax_spec.set_xlim(new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()-0.1*spec_range,new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()+0.1*spec_range)            

                # update kinetic slice and rescale
                new_kin = _data[nm_to_index(y),:]
                line_kin.set_ydata(new_kin)
                kin_bounds = ax_kin.get_xlim()  
                kin_range = new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()-new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()                
                ax_kin.set_ylim(new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()-0.1*kin_range,new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()+0.1*kin_range)
                
                # redraw figure
                fig.canvas.draw_idle()
             
    fig.canvas.mpl_connect('motion_notify_event', mouse_move) 
    
    # find max absolute value of 2D data in the specified zoom mode of the plot
    def get_maxvalue(_xlim, _ylim, _xvals, _yvals, _data_array):
        y_filter = (_yvals>=_ylim[0]) & (_yvals<=_ylim[1])
        x_filter = (_xvals>=_xlim[0]) & (_xvals<=_xlim[1])
        
        if not np.all(y_filter == False) and not np.all(x_filter == False):
            return np.amax(np.abs(_data_array[y_filter][:,x_filter]))
        else:
            return 0
    
    def on_xlims_change(event_ax):
        ax_kin.set_xlim(event_ax.get_xlim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)

    def on_ylims_change(event_ax):
        ax_spec.set_ylim(event_ax.get_ylim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)        

    ax_main.callbacks.connect('xlim_changed', on_xlims_change)
    ax_main.callbacks.connect('ylim_changed', on_ylims_change)
   
    plt.show(block=False)
 
# Find the index of the closest value you input. Very useful to find the index of the nearest wavelength
# you are looking for when cutting the spectral region for instance

def find_closest(A, target):
    #A must be sorted
    idx = A.searchsorted(target)
    idx = np.clip(idx, 1, len(A)-1)
    left = A[idx-1]
    right = A[idx]
    idx -= target - left < right - target
    return idx

In [ ]:
from pyglotaran_extras.io.utils import result_dataset_mapping

def heatmap_interactive_fit(_result, dataset_name, _cmap='jet', _symlog=False):
    #Extract values from the result
    result_map = result_dataset_mapping(result)
    _x = result_map[dataset_name].time.values
    _y = result_map[dataset_name].spectral.values
    _data = np.transpose(result_map[dataset_name].data.values)
    _fit = result_map[dataset_name].fitted_data.values
    _irf = result_map[dataset_name].center_dispersion_1.values
    
    # Main heatmap plot
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.5], height_ratios=[0.5, 1], hspace=0.2, wspace=0.2)
    ax_main = plt.subplot(gs[1, 0])
    main_plot = ax_main.pcolormesh(_x, _y, _data, cmap=_cmap)
    irf_plot, = ax_main.plot(_irf, _y, 'k', label = 'IRF')
    ax_main.set(xlabel='Delay / ps', ylabel='Wavelength / nm')
    # set mixed log-lin scale with threshold value linthresh
    if _symlog:
        ax_main.set_xscale('symlog', linthresh=0.01)
    # set axis range to min and max values
    ax_main.set_xlim(_x[0],_x[-1])
    ax_main.set_ylim(_y[0],_y[-1]) 
    ax_main.legend()
    # Kinetic plot
    ax_kin = plt.subplot(gs[0, 0])
    line_kin, = ax_kin.plot(_x,np.zeros(_x.shape),label = 'data')
    fit_kin, =  ax_kin.plot(_x,np.zeros(_x.shape),'--r', label = 'Fit')
    kin_zero_line, = ax_kin.plot([_x[0],_x[-1]],[0,0], color="0.6")
    ax_kin.set_xlim(_x[0],_x[-1])
    ax_kin.legend()
    if _symlog:
        ax_kin.set_xscale('symlog', linthresh=0.01)
        
    # Spectral plot
    ax_spec = plt.subplot(gs[1, 1])
    line_spec, = ax_spec.plot(np.zeros(_y.shape),_y)
    spec_zero_line, = ax_spec.plot([0,0],[_y[0],_y[-1]], color="0.6")
    ax_spec.set_ylim(_y[0],_y[-1])        

   
    # This lower bounds list is necessary because the blocks in the 2D-plot cover a certain range
    def create_lower_bounds(_value_list):
        result = np.empty_like(_value_list)
        #first lower bound is equal to the lowest value in the nm-list
        result[0] = _value_list[0]
        #example: lower bound for 100 ps is 97.5 ps if the value prior is 95 ps, and 75 ps if the value prior is 50 ps.
        for i in range(1,len(_value_list)):
            result[i] = (_value_list[i]+_value_list[i-1])/2
        return result    
    
    nm_lower_bounds = create_lower_bounds(_y)
    time_lower_bounds = create_lower_bounds(_x)
    
    def nm_to_index(_nm):
        return np.where(_nm > nm_lower_bounds)[0][-1]
    
    def time_to_index(_time):
        return np.where(_time > time_lower_bounds)[0][-1]
    
    def mouse_move(event):
        x = event.xdata
        y = event.ydata
        if x is not None and y is not None:
            if x>=_x[0] and x<=_x[-1] and y>=_y[0] and y<=_y[-1]:
                # update spectra slice and rescale
                new_spec = _data[:,time_to_index(x)]
                line_spec.set_xdata(new_spec)
                spec_bounds = ax_spec.get_ylim()
                spec_range = new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()-new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()
                ax_spec.set_xlim(new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()-0.1*spec_range,new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()+0.1*spec_range)            

                # update kinetic slice and rescale
                new_kin = _data[nm_to_index(y),:]
                new_fit = _fit[:,nm_to_index(y)]
                line_kin.set_ydata(new_kin)
                fit_kin.set_ydata(new_fit)
                kin_bounds = ax_kin.get_xlim()  
                kin_range = new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()-new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()                
                ax_kin.set_ylim(new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()-0.1*kin_range,new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()+0.1*kin_range)
                
                # redraw figure
                fig.canvas.draw_idle()
             
    fig.canvas.mpl_connect('motion_notify_event', mouse_move) 
    
    # find max absolute value of 2D data in the specified zoom mode of the plot
    def get_maxvalue(_xlim, _ylim, _xvals, _yvals, _data_array):
        y_filter = (_yvals>=_ylim[0]) & (_yvals<=_ylim[1])
        x_filter = (_xvals>=_xlim[0]) & (_xvals<=_xlim[1])
        
        if not np.all(y_filter == False) and not np.all(x_filter == False):
            return np.amax(np.abs(_data_array[y_filter][:,x_filter]))
        else:
            return 0
    
    def on_xlims_change(event_ax):
        ax_kin.set_xlim(event_ax.get_xlim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)

    def on_ylims_change(event_ax):
        ax_spec.set_ylim(event_ax.get_ylim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)        

    ax_main.callbacks.connect('xlim_changed', on_xlims_change)
    ax_main.callbacks.connect('ylim_changed', on_ylims_change)
    

    plt.show(block=False)

# Import the data file(s) 

## Option 1: Import the data from text file 

In [ ]:
#Pump Wavelenght (in nm)
Pump = 400 

#Wavelength Calibaration
pixels = np.arange(0, 2048)
calibration = (0.2380, 301.8)  # A and B linear calibration parameters
lambda_values = calibration[0] * pixels + calibration[1]

#Get time vector data file
time_file = askopenfilename(filetypes=[("Text files", "*.txt")], title="Select Time vector data")
time = np.loadtxt(time_file)


# Get TA scan files
ta_scan_files = askopenfilenames(filetypes=[("Text files", "*.txt")], title="Select TA scan files")

if len(ta_scan_files) > 1:
    Full_Data = np.zeros((2048, len(time), len(ta_scan_files))) #Array that will contain all the 
    for n, file in enumerate(ta_scan_files):
        data = np.loadtxt(file)
        Full_Data[:, :, n] = data  # save data array to 3D array (lambda, time, scan)
    
    for ii in range(len(ta_scan_files)): #Flip the sign if pump scatter is not negative.
        for jj in range(len(time)):
            if Full_Data[Pump, jj, ii] > 0:
                Full_Data[:, jj, ii] *= -1

    scan = np.mean(Full_Data, axis=2)
else:
    scan = np.loadtxt(ta_scan_files[0])
# Now 'scan' contains the processed data

dataset = xr.Dataset(
    {
        "data": (["time","spectral"], np.transpose(scan))
    },
    coords={
        "time": time,
        "spectral": lambda_values
    }
)

# Print the dataset
print(dataset)

## Option 2: Import the data from NetCDF file 

Import the data this way if you used the code "TA Analysis" to do an initial treatment of the data and you want to import the reasulting xarray.

The code "TA Analysis" can be found on Github at the following link:
https://github.com/SchlauCohenLab/Ensemble_TA_Analysis/tree/main

In [None]:
dataset_file = askopenfilename(filetypes=[("NetCDF files", "*.nc")], title="Select .nc data file")
dataset = xr.load_dataset(dataset_file)

print(dataset)

# Initial look at the data

In [None]:
from pyglotaran_extras import plot_data_overview


plot_data_overview(dataset, linlog=True, linthresh=1,figsize = (12,8));

In [None]:
heatmap_interactive(dataset.time, dataset.spectral, dataset['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)

View of a spectrum

In [ ]:
plot_data = dataset.data.sel(spectral=[518.1], method="nearest").sel(time=slice(None, 10))
plot_data.plot.line(x="time", aspect=2, size=5);

Single kinetic trace

In [ ]:
plot_data = dataset.data.sel(time=[10,25], method="nearest").sel(spectral=slice(409, 755))
ax = plot_data.plot.line(x="spectral", aspect=2, size=5);
plt.grid(True) 
plt.show()

View of two different normalized kinetic trace

In [ ]:
wv1 = 525 #in nm
wv2 = 580 #in nm

plot_data_1 = dataset.data.sel(spectral=wv1, method="nearest").sel(time=slice(-2, 1))
plot_data_2 = dataset.data.sel(spectral=wv2, method="nearest").sel(time=slice(-2, 1))

# Create a plot
plt.figure(figsize=(10, 5))

plt.plot(plot_data_1.time, plot_data_1/np.max(plot_data_1), label=f'{wv1} nm')
plt.plot(plot_data_2.time, plot_data_2/np.max(plot_data_2), label=f'{wv2} nm')
plt.legend()
plt.xlabel('Time in ps')
plt.grid("On")
plt.show()

# Cut WL and time (If needed)

In [ ]:
# Select the data at the specific time point
data_at_specific_time = dataset.sel(time=20)

# Plot the data against the spectral dimension
plt.figure(figsize=(10, 6))
data_at_specific_time["data"].plot()
plt.xlabel("Spectral (nm)")
plt.ylabel("Data Value")
plt.grid(True)
plt.show()

In [ ]:
time_min =  find_closest(time, 0.5)
time_max = find_closest(time, 10)
spectral_min = find_closest(lambda_values, 700)
spectral_max = find_closest(lambda_values, 700)

ds = xr.Dataset(
    {
        "data": (["time","spectral"], np.transpose(scan[spectral_min:spectral_max,time_min:time_max]))
    },
    coords={
        "time": time[time_min:time_max],
        "spectral": lambda_values[spectral_min:spectral_max]
    }
)

# Print the dataset
print(ds)

In [ ]:
plot_data_overview(ds, linlog=False, linthresh=1,figsize = (12,8))

In [ ]:
heatmap_interactive(ds.time,ds.spectral,ds['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)

# Project and optimization

In [ ]:
from glotaran.project import Project

project = Project.open("Cu(bcp)2")

In [ ]:
project.import_data(ds, dataset_name="bcp_int") #Import the data

In [ ]:
project.show_model_definition("dmp_model_V2")

In [ ]:
project.show_parameters_definition("dmp_parameters_V2")

In [ ]:
project.validate("bcp_model","bcp_parameters")

In [ ]:
result = project.optimize(
    model_name="bcp_model",
    parameters_name="bcp_parameters",
    maximum_number_function_evaluations=500,
)

In [ ]:
result.optimized_parameters

In [ ]:
heatmap_interactive_fit(result, dataset_name = "bcp_int", _symlog= False)

In [ ]:
result.data["bcp_int"].lifetime_decay

In [ ]:
from pyglotaran_extras.plotting.plot_traces import plot_data_and_fits
axes_shape = (1, 1)
fig, axes = plt.subplots(*axes_shape, figsize= [11,8])

plot_data_and_fits(result, wavelength = 550, axis = axes)

In [ ]:
result_dataset = result.data["bcp_int"]

residual_left = result_dataset.residual_left_singular_vectors.sel(left_singular_value_index=0)
residual_left.plot.line(x="time", aspect=2, size=5)

In [ ]:
from pyglotaran_extras import select_plot_wavelengths
from pyglotaran_extras import plot_fitted_traces
from pyglotaran_extras.plotting.style import PlotStyle

wavelengths = [495,518,540,560,585,600,630,650]
fig3tr, axes = plot_fitted_traces(
    result,
    wavelengths,
    axes_shape=(4, 2),
    linlog=True,
    linthresh=1,
    cycler=PlotStyle().data_cycler_solid_dashed,
    figsize=(10, 10)
)
for ax in axes.flatten():
    ax_title = ax.get_title()
    ax.set_title(rf"{ax_title.replace('spectral', 'Wavelength')}$\,$nm")
    ax.set_xlabel("Time (ps)")

In [ ]:
from pyglotaran_extras import plot_overview

fig, axes = plot_overview(result, linlog=True, linthresh=1, nr_of_residual_svd_vectors=1,figsize=(10, 14))

# Integrate data in one dimension (If needed)

In [ ]:
Data_integrated = ds['data'].integrate("spectral")
ds_int = Data_integrated.expand_dims(spectral=[550.0]) #Recreate a spectral label, needed for the optimization

In [ ]:
from pyglotaran_extras import plot_data_overview
plot_data_overview(ds_int, linlog=False, linthresh=1,figsize = [10,8]);

# Single Wavelength (If needed)

In [ ]:
SW =  dataset.data.sel(spectral=[550], method="nearest") 
SW

In [ ]:
plot_data_overview(SW_550nm)