In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.io as sio
from scipy.signal import correlate
from scipy.interpolate import interp1d
from scipy.ndimage import uniform_filter1d
import h5py
from scipy.stats import pearsonr

In [None]:
def extract_data(folder):
    """Extract data from .mat files in the specified folder"""
    data_dict = {}

    for filename in os.listdir(folder):
        if filename.endswith('.mat'):
            file_path = os.path.join(folder, filename)
            file_key = os.path.splitext(filename)[0]

            try:
                # Try loading with scipy.io (for most MATLAB files)
                mat_data = sio.loadmat(file_path)
                clean_data = {k: v for k, v in mat_data.items() if not k.startswith('__')}

            except NotImplementedError:
                # If it's a v7.3 (HDF5-based) file, use h5py
                with h5py.File(file_path, 'r') as f:
                    clean_data = {}
                    def recursively_load(name, obj):
                        if isinstance(obj, h5py.Dataset):
                            clean_data[name] = np.array(obj)
                    f.visititems(recursively_load)

            # Extract 'dataArray' key from file structure
            if 'dataArray' in clean_data:
                arr = clean_data['dataArray']
                if arr.ndim == 2 and arr.shape[1] >= 2:
                    HN_data = arr[:, 0]
                    WL_data = arr[:, 1]

                    data_dict[f'{file_key}_HN_data'] = HN_data
                    data_dict[f'{file_key}_WL_data'] = WL_data

            # Metadata stored here in case it is needed later
            meta = {}
            for key in ['runTime', 'xCenCorrection', 'xLinear', 'xLinearStage']:
                if key in clean_data:
                    meta[key] = clean_data[key]
            data_dict[f'{file_key}_meta'] = meta

    return data_dict

In [None]:
def nfft(a, axis=0):
    """Numpy equivalent of MATLAB's ifftshift(fft(fftshift(a)))"""
    return np.fft.ifftshift(np.fft.fft(np.fft.fftshift(a, axes=axis), axis=axis), axes=axis)


def nifft(a, axis=0):
    """Numpy equivalent of MATLAB's fftshift(ifft(ifftshift(a)))"""
    return np.fft.fftshift(np.fft.ifft(np.fft.ifftshift(a, axes=axis), axis=axis), axes=axis)

In [None]:
def smooth(data, window_size):
    """Simple moving average smoothing"""
    return uniform_filter1d(data, size=window_size, mode='nearest')


def calibrateDelayAxis(heneData, xL):
    """
    Calibrate the delay axis using HeNe data
    This is a simplified version - you may need to adjust based on your specific calibration needs
    """
    # For now, return xL unchanged. You can implement full calibration if needed
    # The original MATLAB function likely does interferometric calibration
    return xL


def removeLinearTerm(w, phase):
    """Remove linear term from phase using polynomial fit"""
    idx = ~np.isnan(phase)
    if np.sum(idx) < 2:
        return phase
    
    nanPhase = phase[idx]
    nanW = w[idx]
    
    # Fit linear polynomial
    p = np.polyfit(nanW, nanPhase, 1)
    phase = phase - p[0] * w - p[1]
    
    return phase


def calculateZeroDelayCorrection(w, dw, xL, wl_data):
    """Calculate zero delay correction from white light data"""
    wlSpec = nfft(wl_data - np.median(wl_data))
    wlSpec[w < 0] = 0
    
    # Calculate center of mass in frequency domain
    wCen = np.sum(w * np.abs(wlSpec)**2) / np.sum(np.abs(wlSpec)**2)
    wCenPx = int(np.round(wCen / dw))
    wlSpec = np.roll(wlSpec, -wCenPx)
    
    wlTemp = smooth(np.abs(nifft(smooth(wlSpec, 5)))**2, 5)
    
    xCen = np.sum(xL * wlTemp) / np.sum(wlTemp)
    
    return xCen, wlSpec

In [None]:
def WLI_processData(dataArray, xL, start_idx=None, end_idx=None):
    """
    Process WLI data to extract spectrum and phase
    
    Parameters:
    -----------
    dataArray : 2D array with columns [HeNe, WhiteLight]
    xL : Linear delay axis
    start_idx, end_idx : Optional indices to slice the data
    
    Returns:
    --------
    xCen : Zero delay correction
    spectrum : Processed spectrum (HeNe and WL)
    wLPhase : Unwrapped white light phase
    lambda_vals : Wavelength array
    """
    
    # Slice data if indices provided
    if start_idx is not None and end_idx is not None:
        dataArray = dataArray[start_idx:end_idx, :]
        xL = xL[start_idx:end_idx]
    
    N = len(dataArray)
    heneData = dataArray[:, 0]
    
    # Calibrate delay axis (simplified - keeping linear for now)
    xNL = calibrateDelayAxis(heneData, xL)
    
    # Interpolate data onto calibrated axis
    f_hene = interp1d(xNL, heneData, kind='cubic', fill_value=0, bounds_error=False)
    f_wl = interp1d(xNL, dataArray[:, 1], kind='cubic', fill_value=0, bounds_error=False)
    dataArray_interp = np.column_stack([f_hene(xL), f_wl(xL)])
    
    # Calculate frequency axis
    c_mm_ps = 0.3  # Speed of light in mm/ps
    t = 2 * xL / c_mm_ps
    dw = 2 * np.pi / (np.max(np.abs(t)) - np.min(np.abs(t)))
    wMax = dw * N
    w = np.arange(-wMax/2, wMax/2, dw)
    
    # Ensure w and data have same length
    if len(w) > N:
        w = w[:N]
    elif len(w) < N:
        # Pad or trim as needed
        w = np.linspace(-wMax/2, wMax/2, N)
    
    # Convert to wavelength (in microns)
    lambda_vals = (2 * np.pi * c_mm_ps / w) * 1000
    idx = (lambda_vals > 0.2) & (lambda_vals < 1.5)
    lambda_vals = lambda_vals[idx]
    
    # Calculate zero delay correction
    xCen = calculateZeroDelayCorrection(w, dw, xL, dataArray_interp[:, 1])[0]
    
    # Process field - subtract baseline
    field = dataArray_interp - dataArray_interp[N//4, :]  # Equivalent to dataArray(256,:) for N=1024
    field[field < -0.03] = 0
    field = nfft(field, axis=0)
    
    # Calculate spectrum
    spectrum = np.abs(field)**2
    
    # Extract white light phase
    wLPhase = np.unwrap(np.angle(field[idx, 1]))
    
    # Smooth and normalize spectrum
    spectrum = spectrum[idx, :]
    spectrum[:, 0] = smooth(spectrum[:, 0], 5)
    spectrum = (spectrum - np.min(spectrum, axis=0)) / (np.max(spectrum, axis=0) - np.min(spectrum, axis=0))
    
    # Mask low-intensity regions in phase
    phaseBlank = spectrum[:, 1] < 0.1
    wLPhase[phaseBlank] = np.nan
    
    # Remove linear term from phase
    wLPhase = removeLinearTerm(w[idx], wLPhase)
    
    # Smooth phase (only non-NaN values)
    valid_idx = ~np.isnan(wLPhase)
    if np.sum(valid_idx) > 5:
        wLPhase[valid_idx] = smooth(wLPhase[valid_idx], 5)
    
    return xCen, spectrum, wLPhase, lambda_vals, w[idx]

In [None]:
def fit_phase_polynomial(lambda_vals, phase, degree=2):
    """
    Fit phase with polynomial of specified degree
    
    Parameters:
    -----------
    lambda_vals : Wavelength array
    phase : Phase array (may contain NaNs)
    degree : Polynomial degree (default=2 for quadratic)
    
    Returns:
    --------
    coeffs : Polynomial coefficients
    phase_fit : Fitted phase values
    """
    # Remove NaN values for fitting
    valid_idx = ~np.isnan(phase)
    
    if np.sum(valid_idx) < degree + 1:
        return None, None
    
    # Fit polynomial
    coeffs = np.polyfit(lambda_vals[valid_idx], phase[valid_idx], degree)
    
    # Evaluate polynomial
    phase_fit = np.polyval(coeffs, lambda_vals)
    
    return coeffs, phase_fit

In [None]:
def plot_spectrum_and_phase(lambda_vals, spectrum, phase, phase_fit=None, title="WLI Data"):
    """
    Plot spectrum and phase on dual y-axis
    
    Parameters:
    -----------
    lambda_vals : Wavelength array in microns
    spectrum : Spectrum array (2 columns: HeNe, WL)
    phase : Phase array
    phase_fit : Optional fitted phase
    title : Plot title
    """
    fig, ax1 = plt.subplots(figsize=(12, 8))
    
    # Plot spectrum on left axis
    color1 = 'tab:blue'
    ax1.set_xlabel('λ [μm]', fontweight='bold', fontsize=14)
    ax1.set_ylabel('Spectrum [a.u.]', fontweight='bold', fontsize=14, color=color1)
    ax1.plot(lambda_vals, spectrum[:, 0], label='HeNe Spectrum', color='red', alpha=0.7)
    ax1.plot(lambda_vals, spectrum[:, 1], label='WL Spectrum', color='black', alpha=0.7)
    ax1.tick_params(axis='y', labelcolor=color1)
    ax1.tick_params(axis='both', which='both', direction='in', top=True, right=True)
    ax1.grid(True, alpha=0.3)
    
    # Plot phase on right axis
    ax2 = ax1.twinx()
    color2 = 'tab:green'
    ax2.set_ylabel('Phase [rad]', fontweight='bold', fontsize=14, color=color2)
    ax2.plot(lambda_vals, phase, 'o', label='WL Phase', color=color2, markersize=3, alpha=0.6)
    
    if phase_fit is not None:
        ax2.plot(lambda_vals, phase_fit, '-', label='Phase Fit', color='darkgreen', linewidth=2)
    
    ax2.tick_params(axis='y', labelcolor=color2)
    
    # Combine legends
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right', fontsize=10)
    
    plt.title(title, fontweight='bold', fontsize=16)
    plt.tight_layout()
    
    return fig, (ax1, ax2)

In [None]:
# Main processing script
if __name__ == "__main__":
    # Path to data
    path = '../Team32025/WLI_With_Glass_20251015/'
    data = extract_data(path)
    
    start_idx = 700
    end_idx = 1550
    
    # Storage for averaged results
    all_spectrums = []
    all_phases = []
    all_lambdas = []
    
    print(f"Processing data from index {start_idx} to {end_idx}")
    print(f"Found {len([k for k in data.keys() if k.endswith('_HN_data')])} datasets\n")
    
    # Process each dataset
    for key in sorted(data.keys()):
        if key.endswith('_HN_data'):
            base_key = key[:-8]
            wl_key = f'{base_key}_WL_data'
            meta_key = f'{base_key}_meta'
            
            if wl_key in data:
                print(f"Processing: {base_key}")
                
                # Get xLinear from metadata or create default
                if meta_key in data and 'xLinear' in data[meta_key]:
                    xL = data[meta_key]['xLinear'].flatten()
                else:
                    # Create default linear axis
                    N = len(data[key])
                    xL = np.linspace(0, 1, N)
                
                # Create data array
                dataArray = np.column_stack([data[key], data[wl_key]])
                
                # Process the data
                try:
                    xCen, spectrum, wLPhase, lambda_vals, w_vals = WLI_processData(
                        dataArray, xL, start_idx, end_idx
                    )
                    
                    # Fit phase with 2nd degree polynomial
                    coeffs, phase_fit = fit_phase_polynomial(lambda_vals, wLPhase, degree=2)
                    
                    if coeffs is not None:
                        print(f"  Zero delay correction: {xCen:.6f}")
                        print(f"  Polynomial coefficients: {coeffs}")
                        
                        # Store for averaging
                        all_spectrums.append(spectrum)
                        all_phases.append(wLPhase)
                        all_lambdas.append(lambda_vals)
                        
                        # Plot
                        fig, axes = plot_spectrum_and_phase(
                            lambda_vals, spectrum, wLPhase, phase_fit, 
                            title=f"{base_key} - Spectrum and Phase"
                        )
                        plt.show()
                    else:
                        print(f"  Warning: Could not fit polynomial (insufficient valid phase points)")
                    
                except Exception as e:
                    print(f"  Error processing {base_key}: {str(e)}")
                
                print()
    
    # Calculate and plot averaged results
    if len(all_spectrums) > 0:
        print("\nCalculating averaged spectrum and phase...")
        
        # Ensure all have same wavelength axis (interpolate if needed)
        lambda_common = all_lambdas[0]
        
        spectrums_interp = []
        phases_interp = []
        
        for spec, phase, lam in zip(all_spectrums, all_phases, all_lambdas):
            if not np.array_equal(lam, lambda_common):
                # Interpolate to common wavelength axis
                f_spec_hene = interp1d(lam, spec[:, 0], bounds_error=False, fill_value=np.nan)
                f_spec_wl = interp1d(lam, spec[:, 1], bounds_error=False, fill_value=np.nan)
                f_phase = interp1d(lam, phase, bounds_error=False, fill_value=np.nan)
                
                spec_interp = np.column_stack([f_spec_hene(lambda_common), f_spec_wl(lambda_common)])
                phase_interp = f_phase(lambda_common)
            else:
                spec_interp = spec
                phase_interp = phase
            
            spectrums_interp.append(spec_interp)
            phases_interp.append(phase_interp)
        
        # Calculate mean (ignoring NaNs)
        spectrum_avg = np.nanmean(spectrums_interp, axis=0)
        phase_avg = np.nanmean(phases_interp, axis=0)
        
        # Fit averaged phase
        coeffs_avg, phase_fit_avg = fit_phase_polynomial(lambda_common, phase_avg, degree=2)
        
        print(f"Averaged {len(all_spectrums)} datasets")
        print(f"Average phase polynomial coefficients: {coeffs_avg}")
        
        # Plot averaged results
        fig, axes = plot_spectrum_and_phase(
            lambda_common, spectrum_avg, phase_avg, phase_fit_avg,
            title="Averaged Spectrum and Phase"
        )
        plt.show()