# Code to visualize and correct TA data 

Code to import and visualize the TA data. You also have the possibility to chrip correct,remove the coherence using the solvent scan as a reference, and smooth the data

Arthur Vard - varthur@mit.edu

# Prelude

In [None]:
import numpy as np
import tkinter as tk
from tkinter import Tk, filedialog
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.io
import xarray as xr
from tkinter.filedialog import askopenfilename, askopenfilenames
from matplotlib.ticker import MaxNLocator
import os
%matplotlib widget

# allow multiple outputs in one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

def heatmap_interactive(_x, _y, _data, _title, _cmap='jet', _symlog=False):
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.5], height_ratios=[0.5, 1], hspace=0.2, wspace=0.2)
    ax_main = plt.subplot(gs[1, 0])
    main_plot = ax_main.pcolormesh(_x, _y, _data, cmap=_cmap)
    ax_main.set(xlabel='Delay / ps', ylabel='Wavelength / nm')
    # set mixed log-lin scale with threshold value linthresh
    if _symlog:
        ax_main.set_xscale('symlog', linthresh=0.01)
    # set axis range to min and max values
    ax_main.set_xlim(_x[0],_x[-1])
    ax_main.set_ylim(_y[0],_y[-1]) 

    ax_kin = plt.subplot(gs[0, 0])
    line_kin, = ax_kin.plot(_x,np.zeros(_x.shape))
    kin_zero_line, = ax_kin.plot([_x[0],_x[-1]],[0,0], color="0.6")
    ax_kin.set_xlim(_x[0],_x[-1])
    if _symlog:
        ax_kin.set_xscale('symlog', linthresh=0.01)

    ax_spec = plt.subplot(gs[1, 1])
    line_spec, = ax_spec.plot(np.zeros(_y.shape),_y)
    spec_zero_line, = ax_spec.plot([0,0],[_y[0],_y[-1]], color="0.6")
    ax_spec.set_ylim(_y[0],_y[-1])        
    
    # This lower bounds list is necessary because the blocks in the 2D-plot cover a certain range
    def create_lower_bounds(_value_list):
        result = np.empty_like(_value_list)
        #first lower bound is equal to the lowest value in the nm-list
        result[0] = _value_list[0]
        #example: lower bound for 100 ps is 97.5 ps if the value prior is 95 ps, and 75 ps if the value prior is 50 ps.
        for i in range(1,len(_value_list)):
            result[i] = (_value_list[i]+_value_list[i-1])/2
        return result    
    
    nm_lower_bounds = create_lower_bounds(_y)
    time_lower_bounds = create_lower_bounds(_x)
    
    def nm_to_index(_nm):
        return np.where(_nm > nm_lower_bounds)[0][-1]
    
    def time_to_index(_time):
        return np.where(_time > time_lower_bounds)[0][-1]
    
    def mouse_move(event):
        x = event.xdata
        y = event.ydata
        if x is not None and y is not None:
            if x>=_x[0] and x<=_x[-1] and y>=_y[0] and y<=_y[-1]:
                # update spectra slice and rescale
                new_spec = _data[:,time_to_index(x)]
                line_spec.set_xdata(new_spec)
                spec_bounds = ax_spec.get_ylim()
                spec_range = new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()-new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()
                ax_spec.set_xlim(new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].min()-0.1*spec_range,new_spec[(_y>=spec_bounds[0]) & (_y<=spec_bounds[1])].max()+0.1*spec_range)            

                # update kinetic slice and rescale
                new_kin = _data[nm_to_index(y),:]
                line_kin.set_ydata(new_kin)
                kin_bounds = ax_kin.get_xlim()  
                kin_range = new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()-new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()                
                ax_kin.set_ylim(new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].min()-0.1*kin_range,new_kin[(_x>=kin_bounds[0]) & (_x<=kin_bounds[1])].max()+0.1*kin_range)
                
                # redraw figure
                fig.canvas.draw_idle()
             
    fig.canvas.mpl_connect('motion_notify_event', mouse_move) 
    
    # find max absolute value of 2D data in the specified zoom mode of the plot
    def get_maxvalue(_xlim, _ylim, _xvals, _yvals, _data_array):
        y_filter = (_yvals>=_ylim[0]) & (_yvals<=_ylim[1])
        x_filter = (_xvals>=_xlim[0]) & (_xvals<=_xlim[1])
        
        if not np.all(y_filter == False) and not np.all(x_filter == False):
            return np.amax(np.abs(_data_array[y_filter][:,x_filter]))
        else:
            return 0
    
    def on_xlims_change(event_ax):
        ax_kin.set_xlim(event_ax.get_xlim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)

    def on_ylims_change(event_ax):
        ax_spec.set_ylim(event_ax.get_ylim())
        
        new_max = get_maxvalue(event_ax.get_xlim(),event_ax.get_ylim(),_x,_y,_data)
        if new_max > 0:
            main_plot.set_clim(vmin=-new_max, vmax=new_max)        

    ax_main.callbacks.connect('xlim_changed', on_xlims_change)
    ax_main.callbacks.connect('ylim_changed', on_ylims_change)
   
    plt.show(block=False)

# Import data

In [None]:
#Pump Wavelenght (in nm)
Pump = 400 

#Wavelength Calibaration
pixels = np.arange(0, 2048)
calibration = [0.2380, 301.8]  # A and B linear calibration parameters
lambda_values = calibration[0] * pixels + calibration[1]

#Get time vector data file
time_file = askopenfilename(filetypes=[("Text files", "*.txt")], title="Select Time vector data")
time = np.loadtxt(time_file)
# time = time / 1000  # convert from fs to ps

# Get TA scan files
ta_scan_files = askopenfilenames(filetypes=[("Text files", "*.txt")], title="Select TA scan files")

if len(ta_scan_files) > 1:
    Full_Data = np.zeros((2048, len(time), len(ta_scan_files))) #Array that will contain all the 
    for n, file in enumerate(ta_scan_files):
        data = np.loadtxt(file)
        Full_Data[:, :, n] = data  # save data array to 3D array (lambda, time, scan)
    
    for ii in range(len(ta_scan_files)): #Flip the sign if pump scatter is not negative.
        for jj in range(len(time)):
            if Full_Data[Pump, jj, ii] > 0:
                Full_Data[:, jj, ii] *= -1

    scan = np.mean(Full_Data, axis=2)
else:
    scan = np.loadtxt(ta_scan_files[0])
# Now 'scan' contains the processed data

#Create an xarray to manipulate the data (much easier)
dataset = xr.Dataset(
    {
        "data": (["time","spectral",], np.transpose(scan))
    },
    coords={
        "time": time,
        "spectral": lambda_values
    }
)

# Print the dataset
print(dataset)

# Heat map visualization of the data

In [None]:
heatmap_interactive(time, lambda_values, scan,'Averaged scan plot',_symlog=False)

# Look at a specific TA trace or spectrum

## TA trace

In [None]:
plot_data = dataset.data.sel(spectral=[597], method="nearest").sel(time=slice(None, 10))
plot_data.plot.line(x="time", aspect=2, size=5);

## TA spectrum

In [None]:
plot_data = dataset.data.sel(time=[10,25], method="nearest").sel(spectral=slice(409, 755))
ax = plot_data.plot.line(x="spectral", aspect=2, size=5)
plt.grid(True) 
plt.show()

# Import solvent data

In [None]:
#Get time vector data file
time_file = askopenfilename(filetypes=[("Text files", "*.txt")], title="Select Time vector data")
time = np.loadtxt(time_file)
# time = time / 1000  # convert from fs to ps

# Get TA scan files
ta_scan_files = askopenfilenames(filetypes=[("Text files", "*.txt")], title="Select solvent scan file(s)")

if len(ta_scan_files) > 1:
    Full_Data = np.zeros((2048, len(time), len(ta_scan_files))) #Array that will contain all the 
    for n, file in enumerate(ta_scan_files):
        data = np.loadtxt(file)
        Full_Data[:, :, n] = data  # save data array to 3D array (lambda, time, scan)
    
    for ii in range(len(ta_scan_files)): #Flip the sign if pump scatter is not negative.
        for jj in range(len(time)):
            if Full_Data[Pump, jj, ii] > 0:
                Full_Data[:, jj, ii] *= -1

    scan_solvent = np.mean(Full_Data, axis=2)
else:
    scan_solvent = np.loadtxt(ta_scan_files[0])
# Now 'scan' contains the processed data

#Create an xarray to manipulate the data (much easier)
dataset_solvent = xr.Dataset(
    {
        "data": (["time","spectral",], np.transpose(scan_solvent))
    },
    coords={
        "time": time,
        "spectral": lambda_values
    }
)

# Print the dataset
print(dataset_solvent)

In [None]:
heatmap_interactive(dataset_solvent.time, dataset_solvent.spectral, dataset_solvent['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)

Visualization of the Solvent and data kinetic trace to see effect of substraction 

In [None]:
# Select the data from both datasets
wv = 597 #in nm 

plot_data_solvent = dataset_solvent.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 10))
plot_data = dataset.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 10))

# Create a plot
plt.figure(figsize=(10, 5))

# Plot the solvent data
plt.plot(plot_data_solvent.time, plot_data_solvent, label='Solvent Data')

# Plot the other dataset
plt.plot(plot_data.time, plot_data, label='Cu(dmp) Data')

# Add labels and title
plt.xlabel('Time in ps')
plt.ylabel('OD')
plt.title(f'Solvent Data vs. Cu(dmp)2 at {wv} nm')
# Add a legend
plt.legend()
plt.grid("On")
# Show the plot
plt.show()

# Chirp correction from the solvent using a Fit 

Import necesaary librairies and define fits that can be useful 

In [None]:
from scipy.optimize import curve_fit

def poly2(x, a, b, c):
    return a * x**2 + b * x + c


def poly4(x, a, b, c, d, e):
    return a * x**4 + b * x**3 + c * x**2 + d * x + e

def lin(x,a,b):
    return a*x+b

In [None]:
# Initialization of the matices 

lower_spectral_bound = 470 #in nm
higher_spectral_bound = 706 #in nm
spectral_region = dataset_solvent.sel(spectral=slice(lower_spectral_bound, higher_spectral_bound)).spectral
time_coherence = np.zeros_like(spectral_region)

for i in range(0,np.size(time_coherence)): #Find where is the maximum coherence in time 
    time_coherence[i] = time[dataset_solvent.data.sel(spectral=spectral_region[i], method="nearest").argmax(dim='time')]

#Fit 
params_l, _ = curve_fit(poly4, spectral_region, time_coherence)
Fit_coherence = poly4(spectral_region,*params_l)

#Substraction
time_chirp_corrected = np.zeros((np.shape(spectral_region)[0],np.shape(time)[0]))
for i in range(0,np.size(time_coherence)):
    time_chirp_corrected[i,:] = time - np.array(Fit_coherence[i])


plt.figure()
plt.plot(spectral_region, time_coherence, 'o', label='Data')
plt.plot(spectral_region, Fit_coherence, '-', label='Fit')
plt.xlabel('Wavelength (in nm)')
plt.ylabel('Position of the coherence max (in ps)')
plt.legend()
plt.grid("On")
plt.show()

## Chirp Correct the solvent data 

In [None]:
from scipy.interpolate import interp1d

# Make a deep copy of the original dataset
chirp_corrected_solvent = dataset_solvent.copy(deep=True)

for i in range(0,np.size(time_coherence)):
    spectral_val = spectral_region[i]
    time_chirp = time_chirp_corrected[i, :]
    data_slice = chirp_corrected_solvent.data.sel(spectral=spectral_val)
    
    # Perform interpolation
    interpolated_function = interp1d(time_chirp,data_slice,bounds_error=False, fill_value="extrapolate")
    interpolated_data = interpolated_function(time)
    
    # Assign the interpolated data back to the dataset
    chirp_corrected_solvent.data.loc[:, spectral_val] = interpolated_data

#Set NaN values to 0
chirp_corrected_solvent.data.values = np.nan_to_num(chirp_corrected_solvent.data.values)

In [None]:
heatmap_interactive(chirp_corrected_solvent.time, chirp_corrected_solvent.spectral, chirp_corrected_solvent['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)


# Chirp correct the data using the solvent fit 

In [None]:
chirp_corrected_data = dataset.copy(deep=True)

for i in range(0,np.size(time_coherence)):
    spectral_val = spectral_region[i]
    time_chirp = time_chirp_corrected[i, :]
    data_slice = chirp_corrected_data.data.sel(spectral=spectral_val)
    
    # Perform interpolation
    interpolated_function = interp1d(time_chirp,data_slice,bounds_error=False, fill_value="extrapolate")
    interpolated_data = interpolated_function(time)
    
    # Assign the interpolated data back to the dataset
    chirp_corrected_data.data.loc[:, spectral_val] = interpolated_data


In [None]:
heatmap_interactive(chirp_corrected_data.time, chirp_corrected_data.spectral, chirp_corrected_data['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)


# Remove the coherence from the sample by substraction sample 

## Initial look at both chirp corrected traces 

In [None]:
# Select the data from both datasets
wv = 515 #in nm

plot_data_solvent = chirp_corrected_solvent.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 30))
plot_data = chirp_corrected_data.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 30))

# Create a plot
plt.figure(figsize=(10, 5))

# Plot the solvent data
plt.plot(plot_data_solvent.time, plot_data_solvent, label='Solvent Data')

# Plot the other dataset
plt.plot(plot_data.time, plot_data, label='Cu(dmp) Data')

# Add labels and title
plt.xlabel('Time in ps')
plt.ylabel('OD')
plt.title(f'Chriped corrected Solvent Data vs. Cu(dmp)2 at {wv} nm')
# Add a legend
plt.legend()
plt.grid("On")
# Show the plot
plt.show()

## Subtraction

In [None]:
processed_data = chirp_corrected_data - chirp_corrected_solvent

In [None]:
wv = 597 #in nm

plot_data = processed_data.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 80))

# Create a plot
plt.figure(figsize=(10, 5))

# Plot the other dataset
plt.plot(plot_data.time, plot_data, label='Cu(dmp) Data')

# Add labels and title
plt.xlabel('Time in ps')
plt.ylabel('OD')
plt.title(f'Chriped corrected & substracted Data at {wv} nm')
# Add a legend
plt.legend()
plt.grid("On")
# Show the plot
plt.show()

# Smoothing the data

In [None]:
# Define the moving average window size
window_size = 3
window = np.ones(window_size) / window_size

def smooth_data(data, window):
    return np.apply_along_axis(lambda m: np.convolve(m, window, mode='same'), axis=0, arr=data)


In [None]:
smoothed_data = smooth_data(processed_data.data.values, window)

smoothed_dataset = xr.Dataset(
    {
        "data": (["time", "spectral"], smoothed_data)
    },
    coords={
        "time": processed_data.time.values,
        "spectral": processed_data.spectral.values
    }
)


In [None]:
wv = 520 #in nm

plot_data = smoothed_dataset.data.sel(spectral=[wv], method="nearest").sel(time=slice(None, 80))

# Create a plot
plt.figure(figsize=(10, 5))

# Plot the other dataset
plt.plot(plot_data.time, plot_data, label='Cu(dmp) Data')

# Add labels and title
plt.xlabel('Time in ps')
plt.ylabel('OD')
plt.title(f'Fully processed Data at {wv} nm')
# Add a legend
plt.legend()
plt.grid("On")
# Show the plot
plt.show()

In [None]:
heatmap_interactive(smoothed_dataset.time, smoothed_dataset.spectral, smoothed_dataset['data'].transpose('spectral','time'),'Averaged scan plot',_symlog=False)

# Save xarray for pyglotaran 

In [None]:
# Save the dataset to a NetCDF file
smoothed_dataset.to_netcdf("Cu(dmp)2_fully_processed.nc")