This file contains optimization code to fit a series of multi-exponential decays to patient data our NMR scans of the patient's skeletal muscles. It is designed so that you should be able to run every block in succession after the initial function imports so that functions don't need to be recompiled.

The core ideas in the optimization process used are that newton/quasi-newton optimisation methods work really well for exponential decay fitting due to their ability to cope with wide trough-like loss landscapes. e.g.

    f(x, y) = x^2 + y^2/1000

However, there is an issue here because newton methods are local rather than global, so we're using a sampling method to sample starting points. To do this, we start with one decay, where the loss landscape typically only has a single minimum and then use the results of the mono-exponential decay fitting to produce a number of candidate starting points for two decays. We then repeat this process for three decays using the information from the bi-exponential best fit.

The specific optimization algorithm used is BFGS - a quasi-newton method. In theory it is not as effective as the pure newton method in certain circumstances, however, it is more robust to starting conditions so we're using it for the moment.

The essence of this local sampling method is that by sampling enough points to perform local optimziations from, we should always be able to find a global minimum. I will also note here that due to the physical nature of our measurements we know that all amplitudes and decay constants must be positive by definition.

In [None]:
## Import all the necessary libraries

import pandas as pd

from jax import random, jit
from jax import numpy as jnp
from jax.numpy.fft import fft, ifft
import glob
from matplotlib import pyplot as plt
from jaxopt import BFGS
jnp_float = jnp.float32
from time import time
from tqdm import tqdm

from contextlib import contextmanager
import os
import sys

from scipy.stats import shapiro

In [None]:
## The use_adjustment parameter was part of a test that we were running to 
## try and account for issues in our data processing

# use_adjustment indicates whether the parameter-less loss adjustment should be used
use_adjustment = False

In [None]:
## function definitions

# function to allow print suppression because the BFGS solver in jaxopt returns warnings as print statements
@contextmanager
def suppress_print():
    # Save the current stdout
    original_stdout = sys.stdout
    # Redirect stdout to null
    sys.stdout = open(os.devnull, 'w')
    try:
        yield
    finally:
        # Restore the original stdout
        sys.stdout.close()
        sys.stdout = original_stdout

# exponential decay function
@jit
def exponential(t, A1, tau1):
    return A1 * jnp.exp(-t / tau1)

def multi_exponential(t, *params):
    x, offset = params[:-1], params[-1]
    n = len(x) // 2
    return jnp.sum(jnp.array([exponential(t, x[2*i], x[2*i+1]) for i in range(n)], dtype=jnp_float), axis=0) - offset

# multi-exponential decay function containing the offset if specified
@jit
def fit_multi_exponential(t, *params):
    
    x = params
    
    n = len(x) // 2
    
    decay = jnp.sum(jnp.array([exponential(t, x[2*i], x[2*i+1]) for i in range(n)], dtype=jnp_float), axis=0)
    
    last_mean_offset = jnp.mean(decay[-1000:-10]) if use_adjustment else 0
    
    return decay - last_mean_offset

# printout the parameters for a decay
def print_params(params):
    x = params
    for i in range(0, len(x), 2):
        print(f"decay {i//2 + 1}: {x[i]:.4f} {x[i+1]:.2f}")

# durbin-watson autocorrelation test
def durbin_watson(residuals): 
    """
    This is a modified version of the Durbin-Watson statistic that is 
    used to test for autocorrelation in the residuals.
    
    The output has been modified to go from 0 to 1 instead of 0 to 4 to match 
    the range of theshapiro-wilk p-value.
    
    0 indicates no correlation and 1 indicates perfect correlation.
    """
    diff = jnp.diff(residuals) 
    dw_stat = jnp.sum(diff ** 2) / jnp.sum(residuals ** 2) 
    return jnp.abs(dw_stat-2)/2.0

# shapiro-wilk test for normality
def shapiro_wilk(residuals):
    """
    This is the Shapiro-Wilk test for normality.
    
    The output is the test statistic and the p-value.
    The statistic has a range of 0 to 1, but will mostly be between 0.9 and 1.
    1 indicates normality, but the p-value is the most important output.
    The p-value represents the probability that the data is normally distributed.
    A common threshold for normality is 0.05, above which the data is considered normal.
    """
    return shapiro(residuals)

# produces a solver function to optimize the parameters of a decay
def get_solve(maxiter, atol, time_data, noisy_signal, magnitudes):
    """
    The core idea is to jit compile anything that is going to run many times.
    
    It works out faster to optimize for 2000 steps with jit optimization than 
    to set a tolerance and let the solver run until it converges, especially 
    when we're performing many optimizations.
    """
    
    # defines functions to dedimensionalize and normalize the parameters to 
    # help with ill-conditioned inputs
    forward = jit(lambda x: x / magnitudes)
    backward = jit(lambda x: x * magnitudes)
    
    # loss function takes values in dedimensionalized form and returns the loss
    @jit
    def loss_fn(params):
        params = backward(params)
        return jnp.sum((fit_multi_exponential(time_data, *params) - noisy_signal) ** 2)
        
    # Define the solver
    solver = BFGS(
        fun=loss_fn,
        maxiter=2000,
        tol=1e-20,
        max_stepsize=0.5,
        min_stepsize=1e-20,
    )
    
    @jit
    def init_state(params):
        return solver.init_state(params)
    
    @jit
    def update(params, state):
        return solver.update(params, state)
    
    @jit 
    def is_close(x, y):
        return jnp.isclose(x, y, atol=atol)
    
    # For loops take forever to compile, it's much quicker and easier to jit compile a 
    # single update step and loop over that
    def solve(params):
        params = forward(params)
        state = init_state(params)
        
        last = -1
        same_last_count = 0
        with suppress_print():
            for _ in range(maxiter):
                params, state = update(params, state)
                if is_close(state.value, last):
                    same_last_count += 1
                else:
                    same_last_count = 0
                if same_last_count > 10:
                    break
                last = state.value
        
        return backward(params), state
    
    return solve

# produces a bootstrap solver function to produce a distribution of parameter 
# sets based on the noisy signal. 
# This is used to produce confidence intervals for the parameters. 
def get_bootstrap_solve(atol, time_data, noisy_signal, params):
    """
    The bootstrapping process takes the parameters identified with fitting and uses them 
    as the starting point for an optimziation on a randomly sampled subset of the data. 
    
    The core assumption here is that the identified parameters for the full signal are 
    local to the global minimum of the data subset. This should be valid in most cases, but 
    it is not guaranteed, especially when the noise has long tails. 
    
    Bootstrapping is used here to produce a distribution of parameter sets that doesn't 
    rely on the assumption of normality. This is particularly important when assessing 
    different fitting methods where the resulting distributions
    
    Frankly, this was a pain to implement in jax, and it's not very pretty, but it does the job.
    """
    
    # defines functions to dedimensionalize and normalize the parameters
    # for a smoother optimization
    backward = jit(lambda x: x * params)
    # forward is not needed here as params ≡ magnitudes so instead we have 
    params_i = jnp.ones_like(params)
    
    n = len(noisy_signal)
    
    @jit
    def is_close(x, y):
        return jnp.isclose(x, y, atol=atol)
    
    @jit
    def initialize_bootstrap(key):
        # sample bootstrap data indices from the noisy signal
        bootstrap_indices = random.choice(key, n, (n,), replace=True)
        noisy_signal_bootstrap = noisy_signal[bootstrap_indices]
        time_data_bootstrap = time_data[bootstrap_indices]
        
        def loss_fn(params):
            # transform the params to the correct scale
            params = backward(params)
            return jnp.sum((fit_multi_exponential(time_data_bootstrap, *params) - noisy_signal_bootstrap) ** 2)
        
        solver = BFGS(
            fun=loss_fn,
            maxiter=2000,
            tol=1e-20,
            max_stepsize=0.1,
            min_stepsize=1e-20,
        )
        
        state = solver.init_state(params_i)
        
        return state, noisy_signal_bootstrap, time_data_bootstrap
    
    # Solve 10 times in a row to exchange compile time for faster runtime
    @jit
    def solve_10(params, state, noisy_signal_bootstrap, time_data_bootstrap):
        
        def loss_fn(params):
            # transform the params to the correct scale
            params = backward(params)
            return jnp.sum((fit_multi_exponential(time_data_bootstrap, *params) - noisy_signal_bootstrap) ** 2)
        
        solver = BFGS(
            fun=loss_fn,
            maxiter=2000,
            tol=1e-20,
            max_stepsize=0.1,
            min_stepsize=1e-20,
        )
        
        for _ in range(10):
            params, state = solver.update(params, state)
        
        return params, state

    # unlike the normal solver, the bootstrap solver does use an exit condition 
    def solve(key):
        state, noisy_signal_bootstrap, time_data_bootstrap = initialize_bootstrap(key)
        params = params_i
        done = False
        last = -1
        before_last = -1
        with suppress_print():
            while not done:
                params, state = solve_10(params, state, noisy_signal_bootstrap, time_data_bootstrap)
                done = is_close(state.value, last) and is_close(before_last, last)
                before_last = last
                last = state.value
                # quit if any of the params are nan or negative
                if jnp.any(jnp.isnan(params)) or jnp.any(params < 0):
                    break
        return backward(params), state
    
    return solve

# sort the parameters by the decay time
def sort_params(params):
    return params[params[:, 1].argsort()]

In [None]:
## controls used to reference specific files, if you want to run this code, I've attatched
## some patient data in a separate file that you can trial it on but you might have to modify the import block below.
# pre or post, with regard to patient data before and after dialysis
post = 0 # 0 for pre, 1 for post
prepost = "Post" if post else "Pre"


# index of the patient number for file referencing
index = 26

# copy_dict is populated throughout the program to store information to be copied to excel
copy_dict = {}

In [None]:

#####################################
#### IMPORT SIGNAL DATA FROM CSV ####
#####################################

def get_filename_from_index(index, prepost):
    # Define the file paths
    files = [
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD05MGHI-Post.csv", # 0
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD08MGHI-Post.csv", # 1
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD06FKCC-Post.csv", # 2
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD07FKCC-Post.csv", # 3
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD08FKCC-Post.csv", # 4
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD09FKCC-Post.csv", # 5
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD03FKCC2-Post.csv", # 6
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD04FKCC2-Post.csv", # 7
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD07FKCC2-Post.csv", # 8
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD08FKCC2-Post.csv", # 9
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD02FKCS-Post.csv", # 10
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD03FKCS-Post.csv", # 11
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD04FKCS-Post.csv", # 12
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD05FKCS-Post.csv", # 13
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\Post\B_avgScans_HD06FKCS-Post.csv", # 14
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC01MGHRA-.csv", # 15
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC02MGHRA-.csv", # 16
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC03MGHRA-.csv", # 17
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC04MGHRA-.csv", # 18
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC05MGHRA-.csv", # 19
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC06MGHRA-.csv", # 20
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC07MGHRA-.csv", # 21
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC08MGHRA-.csv", # 22
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC09MGHRA-.csv", # 23
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC10MGHRA-.csv", # 24
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC11MGHRA-.csv", # 25
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC12MGHRA-.csv", # 26
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC13MGHRA-.csv", # 27
        r"c:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC14MGHRA-.csv", # 28
        
    ]
    
    # Check the index is valid
    assert type(index) == int, "index must be an integer"
    assert index < len(files), "index must be less than the total number of files: " + str(len(files))
    assert index >= 0, "index must be greater than or equal to 0"
    assert prepost.lower() in ["pre", "post"], "prepost must be either 'pre' or 'post'"
    
    # Get the filename and replace the prepost if necessary
    filename = files[index]
    if prepost.lower() == "pre":
        filename = filename.replace("Post", "Pre")
    
    return filename


# Define the file path pattern
# file_path_pattern = r"C:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC*MGHRA-.csv"
# file_path_pattern = r"C:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC0[1-8]MGHRA-.csv"
# file_path_pattern = r"C:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\B_avgScans_HC01MGHRA-.csv"
# file_path_pattern = r"C:\Users\omnic\OneDrive\Desktop\HC MGRH preprocesssed\pre_post\B_avgScans_HD07FKCC-Post.csv"

file_path_pattern = get_filename_from_index(index, prepost)
print(file_path_pattern)

# Use glob to find all matching file paths
file_paths = glob.glob(file_path_pattern)

# List to store dataframes
dataframes = []

# Load and process each file
for file_path in file_paths:
    # Read the file, skipping the first 40 rows
    df = pd.read_csv(file_path, skiprows=39)
    dataframes.append(df)
    print(f"Loaded file: {file_path}")

# Ensure all DataFrames have the same structure
if not all(df.columns.equals(dataframes[0].columns) for df in dataframes):
    raise ValueError("Files do not have matching columns.")

# print out the columns
for df in dataframes:
    print(dataframes[0].columns)
    # apparently there's a value in df.values that has a non-numeric value
    # let's find it
    for i, row in enumerate(df.values):
        for j, value in enumerate(row):
            try:
                float(value)
            except:
                print(f"Non-numeric value at row {i} column {j}: {value}")
    # data = jnp.array(df.values)

print(f"signal length: {len(dataframes[0])}")

# Stack all DataFrames to calculate the mean
stacked_data = jnp.stack([jnp.array(df.values) for df in dataframes])

# the shape of this is (12, 2725, 2)
# this is (signal-number, time_point, [time,amplitude])

# normalize the data in each row via the average of their first three values
normalization_factor = jnp.mean(stacked_data[:, :3, 1], axis=1)[:, None]
stacked_data = stacked_data.at[:, :, 1].set(stacked_data[:, :, 1]/normalization_factor)

# Calculate the average across the files
average_data = jnp.mean(stacked_data, axis=0)
time_data = average_data[:, 0]*1000
noisy_signal = average_data[:, 1]

# Plot the average data
plt.plot(time_data, noisy_signal)
plt.axhline(0, color="black", linewidth=0.5)
plt.xlabel("Time (ms)")
plt.ylabel("Amplitude")
plt.title("Average Data")
plt.show()

# add the name of the file to copy data, between the last underscore and the dot
copy_dict["Name"] = file_path_pattern.split("_")[-1].split(".")[0]

In [None]:

#################
#### SOLVERS ####
#################

## for running solvers look to the 1 decay example

## 1 DECAY

# set magnitudes for de-dimensionalization
magnitudes_1 = jnp.array([
    1.0, 80.0,
])

# get the solver function
solve_1 = get_solve(
    maxiter=2000,
    atol=1e-20,
    time_data=time_data,
    noisy_signal=noisy_signal,
    magnitudes=magnitudes_1,
)

# set the decay constant starting values
decay_constants_1 = jnp.array([
    1.0, 50,
])

# run the solver
params_1, state_1 = solve_1(decay_constants_1)

# All this done will produce a jit compiled solve_1 function that can be called 
# with different starting positions and runs outrageously fast.

## 2 DECAYS

magnitudes_2 = jnp.array([
    0.5, 40.0,
    0.5, 160.0,
])

solve_2 = get_solve(
    maxiter=2000,
    atol=1e-20,
    time_data=time_data,
    noisy_signal=noisy_signal,
    magnitudes=magnitudes_2,
)

param_list = []
state_list = []
scaling_factors = jnp.linspace(1, 2, 10)
amplitude_ratios = jnp.linspace(2, 3, 10)

# Sampling the parameter space to find the best fit
for scale in tqdm(scaling_factors):
    for ratio in amplitude_ratios:
        # scale represents the scaling factor for the decay constants
        # ratio represents the amplitude ratio between the two decays
        # the two together allow us to sample all of the likely parts of 
        # the parameter space given the results of the first decay
        
        ratio_left = (ratio-1)/ratio
        decay_constants_2 = jnp.array([
            params_1[0]*ratio_left, params_1[1]/scale,
            params_1[0]/ratio, params_1[1]*scale,
        ])
    
        params_2, state_2 = solve_2(decay_constants_2)
        param_list.append(params_2)
        state_list.append(state_2)


        decay_constants_2 = jnp.array([
            params_1[0]/ratio, params_1[1]/scale,
            params_1[0]*ratio_left, params_1[1]*scale,
        ])
    
        params_2, state_2 = solve_2(decay_constants_2)
        param_list.append(params_2)
        state_list.append(state_2)

# Select the best fit, excluding any with negative values, but produce a warning
positive_found = False
while not positive_found:
    if len(state_list) == 0:
        break
    best_index = jnp.argmin(jnp.array([state.value for state in state_list]))
    params_2 = param_list[best_index]
    state_2 = state_list[best_index]
    if jnp.all(params_2 > 0) and jnp.sum(params_2[:-1][::2]) < 1:
        positive_found = True
    else:
        param_list.pop(best_index)
        state_list.pop(best_index)
        print("parameters rejected")

if positive_found:
    print("No negative values found in parameters")
    params_2 = sort_params(params_2)
    print_params(params_2)
else:
    print("All parameter sets had negative values")

## 3 DECAYS

magnitudes_3 = jnp.array([
    0.5, 20.0,
    0.5, 50.0,
    0.5, 200.0,
])

solve_3 = get_solve(
    maxiter=2000,
    atol=1e-20,
    time_data=time_data,
    noisy_signal=noisy_signal,
    magnitudes=magnitudes_3,
)

param_list = []
state_list = []
scaling_factors = jnp.linspace(1, 1.5, 10)
amplitude_ratios = jnp.linspace(2, 3.5, 40)

for scale in tqdm(scaling_factors):
    for ratio in amplitude_ratios:
        # Very similar scaling to the 2 decay case
    
        ratio_left = (ratio-1)/ratio
        
        decay_constants_3 = jnp.array([
            params_2[0]*ratio_left,             # A1 = A1_prev * (ratio-1)/ratio
            params_2[1]/scale,                  # tau1 = tau1_prev/scale
                
            (params_2[0] + params_2[2])/ratio,  # A2 = (A1_prev + A2_prev)/ratio
            (params_2[1]+params_2[3])/2,        # tau2 = (tau1_prev + tau2_prev)/2
                
            params_2[2]*ratio_left,             # A3 = A2_prev * (ratio-1)/ratio
            params_2[3]*scale,                  # tau3 = tau2_prev*scale
        ])
        
        params_3, state_3 = solve_3(decay_constants_3)
        
        param_list.append(params_3)
        state_list.append(state_3)

# Second section to identify fits relating to early decays
amplitudes = jnp.linspace(0.001, 0.5, 10)
decay_constants = jnp.linspace(1, 20, 20)
for tau0 in decay_constants:
    for a0 in amplitudes:
        
        decay_constants_3 = jnp.array([
                a0,             # A1 = A1_prev * (ratio-1)/ratio
                tau0,                  # tau1 = tau1_prev/scale
                    
                params_2[0],  # A2 = (A1_prev + A2_prev)/ratio
                params_2[1],        # tau2 = (tau1_prev + tau2_prev)/2
                    
                params_2[2],             # A3 = A2_prev * (ratio-1)/ratio
                params_2[3],                  # tau3 = tau2_prev*scale
            ])

        params_3, state_3 = solve_3(decay_constants_3)
        
        param_list.append(params_3)
        state_list.append(state_3)

# Select the best fit, excluding any with negative values, but produce a warning
positive_found = False
while not positive_found:
    best_index = jnp.argmin(jnp.array([state.value for state in state_list]))
    params_3 = param_list[best_index]
    state_3 = state_list[best_index]
    if jnp.all(params_3 > 0):
        positive_found = True
    else:
        print("Negative values found in parameters:")
        print_params(params_3)
        param_list.pop(best_index)
        state_list.pop(best_index)

if positive_found:
    print("No negative values found in parameters")
    params_3 = sort_params(params_3)
    print_params(params_3)
else:
    print("All parameter sets had negative values")


In [None]:
print("Guide to understanding statistics:")
print("Autocorrelation: 0 indicates no correlation, 1 indicates a perfect correlation. This only compares each \n\tpoint to the next, so it is not a perfect measure of correlation.")
print("Shapiro-Wilk: p value > 0.05 typically indicates normality, however, this evaluates the distribution as \n\ta whole and does not utilise signal structure or any temporal information.")

# Print the best fits for each
for i in range(1, 4):
    params = locals()[f"params_{i}"]
    state = locals()[f"state_{i}"]
    
    # Print the final parameters and the loss
    print(f"\n{i} decay{'' if i==1 else 's'} with loss of {state.value:.4f}: ")
    print_params(params)
    fit_residual = noisy_signal - fit_multi_exponential(time_data, *params)
    
    # calcualte the fit SNR
    fit_SNR = jnp.sum(params[::2])/jnp.std(fit_residual)
    
    # Print the statistics
    dw = durbin_watson(fit_residual)
    print(f"Autocorrelation: {dw:.4f}")
    sw = shapiro_wilk(fit_residual)
    print(f"Shapiro-Wilk statistic: {sw[0]:.5f}, p value: {sw[1]:.5f}")
    if i>1:
        # Print the relative amplitudes
        print(f"Relative amplitudes:")
        amplitudes = params[::2]
        amp_sum = jnp.sum(amplitudes[params[1::2] > 20])
        amplitudes = amplitudes/amp_sum
        print("[" + ", ".join([f"{amp:.6f}" for amp in amplitudes]) + "]")
    
    # Add parameters to the copy dict    
    if i == 2:
        # Add SNR to the copy dict
        copy_dict["SNR_2"] = fit_SNR
        # Add Tau1, Tau2, RA2 to the copy dict
        copy_dict["Tau1_2"] = params[1]
        copy_dict["Tau2_2"] = params[3]
        copy_dict["RA2_2"] = amplitudes[1]
        # Add statistics to the copy dict
        copy_dict["swp_2"] = sw[1]
        copy_dict["dw_2"] = dw
    elif i == 3:
        # Add SNR to the copy dict
        copy_dict["SNR_3"] = fit_SNR
        # Add all taus and RAs to the copy dict
        copy_dict["Tau1_3"] = params[1]
        copy_dict["Tau2_3"] = params[3]
        copy_dict["Tau3_3"] = params[5]
        copy_dict["RA1_3"] = amplitudes[0]
        copy_dict["RA2_3"] = amplitudes[1]
        copy_dict["RA3_3"] = amplitudes[2]
        # Add statistics to the copy dict
        copy_dict["swp_3"] = sw[1]
        copy_dict["dw_3"] = dw 
            


In [None]:

###############################
#### PLOT FITS & RESIDUALS ####
###############################


def low_pass_filter(signal, cutoff_freq, sample_rate):
    # Compute the Fourier transform of the signal
    signal_freq = fft(signal)
    
    # Create the frequency domain
    n = signal_freq.shape[0]
    freqs = jnp.fft.fftfreq(n, d=1/sample_rate)
    
    # Create the filter
    filter = jnp.where(jnp.abs(freqs) < cutoff_freq, 1.0, 0.0)
    
    # Apply the filter
    filtered_signal_freq = signal_freq * filter
    
    # Compute the inverse Fourier transform
    filtered_signal = jnp.real(ifft(filtered_signal_freq))
    
    return filtered_signal

# Plot a comparions of the three
for i in range(2):
    plt.figure()
    plt.plot(time_data, noisy_signal, label="Data")
    plt.plot(time_data, fit_multi_exponential(time_data, *params_1), label="1 Decay")
    plt.plot(time_data, fit_multi_exponential(time_data, *params_2), label="2 Decays")
    plt.plot(time_data, fit_multi_exponential(time_data, *params_3), label="3 Decays")
    plt.xlabel("Time (ms)")
    plt.ylabel("Amplitude")
    plt.title("Fits")
    if i == 1:
        plt.xscale("log")
    plt.legend()
    plt.show()

# Plot the residuals
sample_rate = 1/(time_data[1] - time_data[0])
cutoff_freq = 0.05
for i in range(1, 4):
    
    params = locals()[f"params_{i}"]
    residuals = noisy_signal - fit_multi_exponential(time_data, *params)
    low_pass_residuals = low_pass_filter(residuals, cutoff_freq, sample_rate)
    
    plt.figure()
    plt.plot(time_data, residuals, label=f"{i} Decays")
    plt.plot(time_data, low_pass_residuals, label=f"{i} Decays Low Pass")
    plt.axhline(0, color="black", linewidth=0.5)
    plt.xlabel("Time (ms)")
    plt.ylabel("Residual amplitude")
    plt.legend(loc="upper right")
    plt.show()


In [None]:

########################################
#### BOOTSTRAP CONFIDENCE INTERVALS ####
########################################

# Define hyperparameters
# number of bootstraps to run
n_bootstrap = 1000
# keys for random number generation
rng, key = random.split(random.PRNGKey(2025), 2)

for dec_index in range(2, 4):
    params = locals()[f"params_{dec_index}"]
    print(f"\n{dec_index} decays:")
    
    # Run the bootstrap
    bootstrap_params = jnp.zeros((n_bootstrap, len(params)))

    # Define the bootstrap solver
    bootstrap_solve = get_bootstrap_solve(
        atol=1e-20,
        time_data=time_data,
        noisy_signal=noisy_signal,
        params=params,
    )

    start = time()
    for i in tqdm(range(n_bootstrap)):
        rng, key = random.split(rng)
        bootstrap_params = bootstrap_params.at[i].set(bootstrap_solve(key)[0])
    print(f"{i+1} bootstraps complete in {time()-start:.2f} seconds.")

    print(f"shape of bootstrap_params before exclusion: ", bootstrap_params.shape)
    # first exclude non-physical parameters
    for i in range(len(params)):
        bootstrap_params = bootstrap_params[bootstrap_params[:, i] > 0]
    for i in range(0, len(params), 2):
        bootstrap_params = bootstrap_params[bootstrap_params[:, i+1] < time_data[-1]]
    print(f"shape of bootstrap_params after exclusion: ", bootstrap_params.shape)
    
    # Perform normalization of amplitudes
    contains_low_tau = jnp.sum(params[1::2] < 20)
    for i in range(n_bootstrap):
        
        amplitudes = bootstrap_params[i, ::2]
        taus = bootstrap_params[i, 1::2]
        amp_sum = jnp.sum(amplitudes[taus > 20])
        bootstrap_params = bootstrap_params.at[i, ::2].set(bootstrap_params[i, ::2] / amp_sum)

    # plot the bootstrap parameters
    fig, ax = plt.subplots(params.shape[0]//2, 2)
    ax = ax.flatten()
    for j in range(params.shape[0]):
        ax[j].hist(bootstrap_params[:, j], bins=200)
        ax[j].set_title(f"Parameter {j}")
    plt.suptitle("Bootstrap Parameter Distributions")
    plt.tight_layout()
    plt.show()

    # now to find the 95% confidence intervals
    contains_low_tau = jnp.sum(params[1::2] < 20)
    for i in range(params.shape[0]):
        lower = jnp.percentile(bootstrap_params[:, i], 2.5)
        upper = jnp.percentile(bootstrap_params[:, i], 97.5)
        
        parameter_name = ("RA" if i%2 == 0 else "Tau") + str(i//2 + 1 - contains_low_tau)
        print(f"{parameter_name}: \t{lower:.6f} to {upper:.6f}")
        
        # add values to the copy dict
        if dec_index == 2 and i%2 == 0 and i//2 + 1 == 2:
            copy_dict["RA2_2_lower"] = lower
            copy_dict["RA2_2_upper"] = upper
        elif dec_index == 3 and i%2 == 0:
            if i//2 + 1 == 2:
                copy_dict["RA2_3_lower"] = lower
                copy_dict["RA2_3_upper"] = upper
            elif i//2 + 1 == 3:
                copy_dict["RA3_3_lower"] = lower
                copy_dict["RA3_3_upper"] = upper
        

In [None]:
# Now to format the copy_dict into a string that can be copied and pasted into google sheets

# The order of the string is important, so we will define the order here
string_order = [
    "Name",
    
    "SNR_2", "Tau1_2", "Tau2_2", "RA2_2", 
    "RA2_2_lower", "RA2_2_upper", 
    "swp_2", "dw_2",
    
    "SNR_3", "Tau1_3", "Tau2_3", "Tau3_3", "RA1_3", "RA2_3", "RA3_3",
    "RA2_3_lower", "RA2_3_upper", "RA3_3_lower", "RA3_3_upper",
    "swp_3", "dw_3",
]

# taus will be to 2 decimal places
# RA values and their lower and upper confidence intervals will be to 6 decimal places
# swp and dw will be to 5 decimal places
# snr will be to 2 decimal place

from jaxlib.xla_extension import ArrayImpl

copy_string = ""

for key in string_order:
    if key in copy_dict:
        if type(copy_dict[key]) == ArrayImpl:
            # print("We got one!!!!")
            # copy_dict[key] = copy_dict[key].astype(copy_dict[key].dtype)
            copy_dict[key] = copy_dict[key].tolist()

print(copy_dict)

# now to format the string
for key in string_order:
    if key not in copy_dict:
        copy_string += ", "
        print(f"WARNING: {key} not found in copy_dict")
        continue
    # get the value
    value = copy_dict[key]
    
    if type(value) == str:
        copy_string += value + ", "
    elif type(value) in [float, jnp.float32, jnp.float64]:
        
        # format the value based on the specification above
        if "RA" in key:
            copy_string += f"{value:.6f}, "
        elif "Tau" in key:
            copy_string += f"{value:.2f}, "
        elif "swp" in key or "dw" in key:
            copy_string += f"{value:.5f}, "
        elif "SNR" in key:
            copy_string += f"{value:.2f}, "
    else:
        raise ValueError("Invalid type for value: " + str(type(value)))

# remove the last comma and space becasue we're civilized
copy_string = copy_string[:-2]

# print the string and make sure not to copy the name
print(copy_string)
print(f"\nsignal length: {len(noisy_signal)}")
