## Things to add

### For `check_particle_fit`
1. Change the unit to min instead of second
2. Show all the traces instead of just the accepted ones
3. In addition to the above, add a toggle parameter to look at only the accepted or rejected ones

### For `check_bin_fit`
1. Include approval status in the dataframe (Nick's suggestion: use a dictionary)

### For both functions: add the linear fit

## Prepare Dataset and Specify Parameters (please only edit cells in this section)

In [1]:
# Parameters to specify

# Specify here at what frame NC14 starts
nc14_start_frame = 263

# Any trace with frame number smaller than min_frames will be filtered out
min_frames = 40

# Time resolution (unit: second per frame)
time_res_sec = 4.25939
time_res_min = time_res_sec/60

# Number of bins you want to split the full embryo into
num_bins = 42

In [2]:
# Dataset Directory

dataset_folder = '/mnt/Data1/Nick/transcription_pipeline/'

RBSPWM_datasets = [
    "test_data/2024-02-26/Halo-RBSPWM_embryo01",
    "test_data/2024-02-26/Halo-RBSPWM_embryo02",
    "test_data/2024-05-07/Halo552-RBSPWM_embryo01",
    "test_data/2024-05-07/Halo552-RBSPWM_embryo02",
    "test_data/2024-05-09/Halo552-RBSPWM_embryo01",
]

RBSVar2_datasets = [
    "test_data/2024-07-23/Halo673_RBSVar2_embryo01",
    "test_data/2024-07-25/Halo673_RBSVar2_embryo01", 
    "test_data/2024-10-10/Halo673_RBSVar2_embryo01",
    "test_data/2024-10-10/Halo673_RBSVar2_embryo02",
]

MCP_mSG_datasets = [
    "test_data/2024-08-13/MCP-mSG,ParB-mScar_normWindow",
    "test_data/2024-10-31/MCP-mSG_ParB-mScar_RBSPWM_embryo01",
    "test_data/2024-10-31/MCP-mSG_ParB-mScar_RBSPWM_embryo02",
    "test_data/2024-11-05/MCP-mSG_ParB-mScar_RBSPWM_embryo01",
    ]
test_dataset_name = dataset_folder + RBSPWM_datasets[4]
print('Dataset Path: ' + test_dataset_name)

Dataset Path: /mnt/Data1/Nick/transcription_pipeline/test_data/2024-05-09/Halo552-RBSPWM_embryo01


In [3]:
# Import pipeline
from transcription_pipeline import nuclear_pipeline
from transcription_pipeline import preprocessing_pipeline

from transcription_pipeline import spot_pipeline
from transcription_pipeline import fullEmbryo_pipeline

# Importing libraries
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from transcription_pipeline.spot_analysis import compile_data
from transcription_pipeline.utils import plottable

from scipy.signal import medfilt
from skimage.restoration import denoise_tv_chambolle

import numpy as np
from scipy.optimize import least_squares
from scipy.stats import chi2
import pandas as pd
from IPython.display import display
import emcee
import os
from warnings import warn
import tkinter as tk
from tkinter import simpledialog
from tkinter import messagebox

root = tk.Tk()
root.withdraw();

`JAVA_HOME` environment variable set to /mnt/Data1/Nick/miniforge3/envs/transcription_pipeline


In [4]:
# Specify how you would want the plots to be shown: Use TkAgg if you use PyCharm, or widget if you use a browser

mpl.use('TkAgg')
# %matplotlib widget

## Import Dataset

### Import MS2 Dataset

Detect whether the dataset has already been converted into `zarr` files, i.e. whether there's "previously" processed data. If so, load the previous results.

In [5]:
ms2_import_previous = os.path.isdir(test_dataset_name + '/collated_dataset')
ms2_import_previous

True

In [6]:
dataset = preprocessing_pipeline.DataImport(
    name_folder=test_dataset_name,
    trim_series=True,
    working_storage_mode='zarr',
    import_previous=ms2_import_previous, 
)

In [15]:
dataset = preprocessing_pipeline.DataImport(
    name_folder=test_dataset_name,
    trim_series=True,
    working_storage_mode='zarr',
    import_previous=False, 
)

The series in ['Series002', 'Series003', 'Series004', 'Series005'] have inconsistent LaserID, check your imaging settings and metadata.


  prominence = peak_prominences(offsets, [proposed_peak])[0]


In [16]:
dataset.export_frame_metadata[0]

{'frame': array([[   0,    1,    2, ...,   18,   19,   20],
        [  21,   22,   23, ...,   39,   40,   41],
        [  42,   43,   44, ...,   60,   61,   62],
        ...,
        [5208, 5209, 5210, ..., 5226, 5227, 5228],
        [5229, 5230, 5231, ..., 5247, 5248, 5249],
        [5250, 5251, 5252, ..., 5268, 5269, 5270]]),
 'series': 0,
 'mpp': 0.22194604105571847,
 'mppZ': 0.4196171000000001,
 'x': 0,
 'y': 0,
 'x_um': 0.0386673141867,
 'y_um': -0.06760669878379,
 'axes': ['t', 'c', 'z', 'y', 'x'],
 'coords': {},
 'c': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 'colors': (0.0, 1.0, 0.0),
 'z': array([[ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2, ..., 18, 19, 20],
        ...,
        [ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2

In [17]:
dataset.export_frame_metadata[0]

{'frame': array([[   0,    1,    2, ...,   18,   19,   20],
        [  21,   22,   23, ...,   39,   40,   41],
        [  42,   43,   44, ...,   60,   61,   62],
        ...,
        [5208, 5209, 5210, ..., 5226, 5227, 5228],
        [5229, 5230, 5231, ..., 5247, 5248, 5249],
        [5250, 5251, 5252, ..., 5268, 5269, 5270]]),
 'series': 0,
 'mpp': 0.22194604105571847,
 'mppZ': 0.4196171000000001,
 'x': 0,
 'y': 0,
 'x_um': 0.0386673141867,
 'y_um': -0.06760669878379,
 'axes': ['t', 'c', 'z', 'y', 'x'],
 'coords': {},
 'c': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 'colors': (0.0, 1.0, 0.0),
 'z': array([[ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2, ..., 18, 19, 20],
        ...,
        [ 0,  1,  2, ..., 18, 19, 20],
        [ 0,  1,  2

### Import FullEmbryo Dataset

In [None]:
FullEmbryo_dataset = preprocessing_pipeline.FullEmbryoImport(
    name_folder=test_dataset_name,
    #import_previous=True
)
# Loading FullEmbryo dataset is not working currently, but reported to Yovan where it only reads in the last channel
# FullEmbryo_dataset.save()

## Starting a DASK Client for parallel processing

In [27]:
from dask.distributed import LocalCluster, Client

try:
    cluster = LocalCluster(
        host="localhost",
        scheduler_port=37763,
        threads_per_worker=1,
        n_workers=14,
        memory_limit="6GB",
    )
    
    client = Client(cluster)
except:
    print("Cluster already running")
    client = Client('localhost:37763')

print(client)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 37167 instead


Cluster already running
<Client: 'tcp://127.0.0.1:37763' processes=14 threads=14, memory=78.23 GiB>


In [20]:
client.shutdown()

In [None]:
print(client.dashboard_link)

## Nuclear Tracking

Detect whether the nuclear tracking has been done "previously." If so, load the previous results.

In [None]:
nuclear_tracking_previous = os.path.isdir(test_dataset_name + '/nuclear_analysis_results')
nuclear_tracking_previous

In [None]:
if nuclear_tracking_previous:
    # Load nuclear tracking results
    print('Load from previous nuclear tracking results')
    
    nuclear_tracking = nuclear_pipeline.Nuclear()
    nuclear_tracking.read_results(name_folder=test_dataset_name)
    
else:
    # Do nuclear tracking and save the results
    print('Do nuclear tracking for the dataset')
    
    nuclear_tracking = nuclear_pipeline.Nuclear(
        data=dataset.channels_full_dataset[0],
        global_metadata=dataset.export_global_metadata[0],
        frame_metadata=dataset.export_frame_metadata[0],
        series_splits=dataset.series_splits,
        series_shifts=dataset.series_shifts,
        search_range_um=1.5,
        stitch=False,
        stitch_max_distance=4,
        stitch_max_frame_distance=2,
        client=client,
        keep_futures=False,
    )
    
    nuclear_tracking.track_nuclei(
            working_memory_mode="zarr",
            working_memory_folder=test_dataset_name,
            trackpy_log_path="".join([test_dataset_name, "trackpy_log"]),
        )
        # Saves tracked nuclear mask as a zarr, and pickles dataframes with segmentation and
        # tracking information.
    nuclear_tracking.save_results(
            name_folder=test_dataset_name, save_array_as=None
        )

## Spot Tracking

Detect whether the spot tracking has been done "previously." If so, load the previous results.

In [7]:
spot_tracking_previous = os.path.isdir(test_dataset_name + '/spot_analysis_results')
spot_tracking_previous

True

In [8]:
%%time

if spot_tracking_previous:
    # Load spot tracking results
    print('Load from spot  tracking results')
    
    spot_tracking = spot_pipeline.Spot()
    spot_tracking.read_results(name_folder=test_dataset_name)
    
else:
    # Do spot tracking and save the results
    print('Do spot tracking for the dataset')
    
    spot_tracking = spot_pipeline.Spot(
        data=dataset.channels_full_dataset[1],
        global_metadata=dataset.export_global_metadata[1],
        frame_metadata=dataset.export_frame_metadata[1],
        labels=None,#nuclear_tracking.reordered_labels,
        expand_distance=3,
        search_range_um=4.2,
        retrack_search_range_um=4.5,
        threshold_factor=1.3,
        memory=3,
        retrack_after_filter=False,
        stitch=True,
        min_track_length=0,
        series_splits=dataset.series_splits,
        series_shifts=dataset.series_shifts,
        keep_bandpass=False,
        keep_futures=False,
        keep_spot_labels=False,
        evaluate=True,
        retrack_by_intensity=True,
        client=client,
    )
    
    spot_tracking.extract_spot_traces(
        working_memory_folder=test_dataset_name, 
        stitch=True,
        retrack_after_filter=True,
        trackpy_log_path = test_dataset_name+'/trackpy_log'
    )
    
    # Saves tracked spot mask as a zarr, and pickles dataframes with spot fitting and
    # quantification information.
    spot_tracking.save_results(name_folder=test_dataset_name, save_array_as=None)

Load from spot  tracking results
CPU times: user 1.35 s, sys: 412 ms, total: 1.76 s
Wall time: 3.28 s


In [22]:
    spot_tracking = spot_pipeline.Spot(
        data=dataset.channels_full_dataset[1],
        global_metadata=dataset.export_global_metadata[1],
        frame_metadata=dataset.export_frame_metadata[1],
        labels=None,#nuclear_tracking.reordered_labels,
        expand_distance=3,
        search_range_um=4.2,
        retrack_search_range_um=4.5,
        threshold_factor=1.3,
        memory=0,
        retrack_after_filter=False,
        stitch=False,
        min_track_length=0,
        series_splits=dataset.series_splits,
        series_shifts=dataset.series_shifts,
        keep_bandpass=False,
        keep_futures=False,
        keep_spot_labels=False,
        evaluate=True,
        retrack_by_intensity=True,
        client=client,
    )
    
    spot_tracking.extract_spot_traces(
        working_memory_folder=test_dataset_name, 
        stitch=True,
        retrack_after_filter=True,
        trackpy_log_path = test_dataset_name+'/trackpy_log'
    )
    
    # Saves tracked spot mask as a zarr, and pickles dataframes with spot fitting and
    # quantification information.
    spot_tracking.save_results(name_folder=test_dataset_name, save_array_as=None)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


Tracking and filtering:


Preliminary spot filtering: 100%|██████████| 2/2 [00:00<00:00, 129.51it/s]
Tracking: 100%|██████████| 247/247 [00:03<00:00, 73.09it/s] 
Post-tracking spot filtering: 100%|██████████| 2/2 [00:00<00:00, 59.78it/s]


Re-tracking after filtering:


Preliminary spot filtering: 100%|██████████| 2/2 [00:00<00:00, 627.65it/s]
Compiling variations in intensity: 100%|██████████| 9893/9893 [00:02<00:00, 4757.86it/s]
Tracking: 100%|██████████| 247/247 [00:01<00:00, 171.62it/s]
Post-tracking spot filtering: 100%|██████████| 2/2 [00:00<00:00, 280.39it/s]


Stitching pass 1 of 1


Finding track y start position: 100%|██████████| 566/566 [00:00<00:00, 4132.50it/s]
Finding track y end position: 100%|██████████| 566/566 [00:00<00:00, 4156.16it/s]
Finding track x start position: 100%|██████████| 566/566 [00:00<00:00, 4342.23it/s]
Finding track x end position: 100%|██████████| 566/566 [00:00<00:00, 3923.24it/s]
Finding track first and last frames: 100%|██████████| 566/566 [00:00<00:00, 6604.49it/s]
Stitching track nearest neighbors: 100%|██████████| 566/566 [00:02<00:00, 263.94it/s]


Stitching tracks: 100%|██████████| 28/28 [00:00<00:00, 81.66it/s]


Removing duplicate spots from stitching: 100%|██████████| 12338/12338 [00:02<00:00, 5523.10it/s]


### Make Compiled Dataframe

In [9]:
# Load spot tracking dataframe
spot_df = spot_tracking.spot_dataframe

# Remove spots that were not detected
detected_spots = spot_df[spot_df["particle"] != 0]

# Compile traces
compiled_dataframe = compile_data.compile_traces(
    detected_spots,
    compile_columns_spot=[
        "frame",
        "t_s",
        "intensity_from_neighborhood",
        "intensity_std_error_from_neighborhood",
        "x",
        "y"
    ],
    nuclear_tracking_dataframe=None,
)

compiled_dataframe.head()

Unnamed: 0,particle,frame,t_s,intensity_from_neighborhood,intensity_std_error_from_neighborhood,x,y
0,2,"[608, 609, 610, 613, 614, 615, 616, 617, 618, ...","[2981.271999359131, 2985.51900100708, 2989.767...","[67.20298245614036, 62.61358125, 92.1778863636...","[54.07608080882359, 53.08599141703332, 48.6846...","[706.3077781092006, 705.0800225329217, 704.078...","[132.22746686990476, 133.33576781054558, 131.9..."
1,3,"[579, 580, 581, 582, 583, 584, 585, 586, 587, ...","[2857.3950004577637, 2861.64400100708, 2865.89...","[368.71108000000004, 162.58609316770185, 207.0...","[48.961304741944936, 48.764177034702456, 49.24...","[814.23756413008, 814.0490753702337, 812.66835...","[165.32374456695854, 165.4022049366289, 164.15..."
2,4,"[640, 641, 642, 643, 644, 645, 647, 648, 649, ...","[3117.5319995880127, 3121.579999923706, 3126.1...","[137.8392151898734, 246.30097297297297, 162.72...","[50.04719298970737, 48.033152885448544, 53.299...","[610.5256101215496, 610.5078280423238, 610.170...","[200.09244408786026, 199.6949788757095, 200.19..."
3,5,"[613, 614, 615, 616, 617, 618, 619, 620, 621, ...","[3002.3190002441406, 3006.1590003967285, 3010....","[223.98077931034484, 149.44997701149424, 102.1...","[52.019684703514685, 46.46238563874822, 49.753...","[789.4493998034413, 790.1299534943532, 790.754...","[111.42483044437226, 110.54379569293525, 110.3..."
4,6,"[583, 585, 586, 588, 590, 591, 592, 593, 594, ...","[2874.9950008392334, 2883.492000579834, 2887.7...","[56.80923595505618, 58.452866666666665, 56.477...","[45.10765132358713, 45.86984351271528, 48.0794...","[830.0205022942665, 828.8195400854084, 827.835...","[152.07172254602511, 153.15739986339034, 152.5..."


## Full Embryo Analysis

In [None]:
plt.figure(figsize=(12,6))

plt.subplot(1, 2, 1)
plt.imshow(FullEmbryo_dataset.channels_full_dataset_surf[0][0, :, :], cmap='gray')
plt.title('Full Embryo Surf')

plt.subplot(1, 2, 2)
plt.imshow(FullEmbryo_dataset.channels_full_dataset_mid[0][0, :, :], cmap='gray')
plt.title('Full Embryo Mid')

plt.tight_layout()
plt.show()

In [None]:
fullEmbryo = fullEmbryo_pipeline.FullEmbryo(FullEmbryo_dataset, dataset, his_channel=0)

In [None]:
fullEmbryo.find_ap_axis(make_plots=True)

In [None]:
compiled_dataframe = fullEmbryo.xy_to_ap(compiled_dataframe)
compiled_dataframe.head()

### Plot Individual Traces as a Check

In [10]:
# Restrict to longer traces
traces_compiled_dataframe = compiled_dataframe[
    compiled_dataframe["frame"].apply(lambda x: x.size) > min_frames
]

In [11]:
# Copied from Nick's Analysis Notebook

# The part of the code for scrolling between plots is taken from https://stackoverflow.com/questions/18390461/scroll-backwards-and-forwards-through-matplotlib-plots

traces = plottable.generate_trace_plot_list(traces_compiled_dataframe)
#median_filtered_traces = [medfilt(trace[1], kernel_size=15) for trace in traces]

tv_denoised_traces = [
    denoise_tv_chambolle(trace[1], weight=1080, max_num_iter=500) for trace in traces
]
# potts_steps_traces = [
#     potts_l1.l1_potts_step_detection(trace[1], gamma=-5e3, weights=(1 / trace[2] ** 2))
#     for trace in traces
# ]


# plt.close()
# plt.plot(traces[curr_pos][0], traces[curr_pos][1], label="Original")otts_l1.l1_potts_step_detection(trace[1], gamma=-5e3, weights=(1 / trace[2] ** 2))
#     for trace in traces
# ]

curr_pos = 0


def key_event(e):
    global curr_pos

    if e.key == "right":
        curr_pos = curr_pos + 1
    elif e.key == "left":
        curr_pos = curr_pos - 1
    else:
        return
    curr_pos = curr_pos % len(traces)

    ax.cla()
    ax.errorbar(
        traces[curr_pos][0],
        traces[curr_pos][1],
        yerr=traces[curr_pos][2],
        fmt=".",
        elinewidth=1,
    )
    # ax.plot(traces[curr_pos][0], median_filtered_traces[curr_pos], color="k")
    ax.plot(traces[curr_pos][0], tv_denoised_traces[curr_pos], color="k", label="TV")
    # ax.step(
    #     traces[curr_pos][0],
    #     potts_steps_traces[curr_pos],
    #     where="mid",
    #     color="red",
    #     label="Potts L1",
    # )
    ax.set_xlabel("time (s)")
    ax.set_ylabel("Spot intensity (AU)")

    particle = traces[curr_pos][3]
    mean_x = (
        compiled_dataframe.loc[compiled_dataframe["particle"] == particle, "x"]
        .values[0]
        .mean()
    )
    initial_frame = (
    compiled_dataframe.loc[compiled_dataframe["particle"] == particle, "frame"]
    .values[0][0]
    )
    ax.set_title(f"Particle {particle}, x = {mean_x}, Initial frame {initial_frame}")
    ax.legend()
    fig.canvas.draw()


fig = plt.figure()
fig.canvas.mpl_connect("key_press_event", key_event)

ax = fig.add_subplot(111)
ax.errorbar(
    traces[curr_pos][0],
    traces[curr_pos][1],
    yerr=traces[curr_pos][2],
    fmt=".",
    elinewidth=1,
)
# ax.plot(traces[curr_pos][0], median_filtered_traces[curr_pos], color="k")
ax.plot(traces[curr_pos][0], tv_denoised_traces[curr_pos], color="k", label="TV")
# ax.step(
#     traces[curr_pos][0],
#     potts_steps_traces[curr_pos],
#     where="mid",
#     color="red",
#     label="Potts L1",
# )
ax.set_xlabel("time (s)")
ax.set_ylabel("Spot intensity (AU)")

particle = traces[curr_pos][3]
mean_x = (
    compiled_dataframe.loc[compiled_dataframe["particle"] == particle, "x"]
    .values[0]
    .mean()
)
initial_frame = (
    compiled_dataframe.loc[compiled_dataframe["particle"] == particle, "frame"]
    .values[0][0]
)
ax.set_title(f"Particle {particle}, x = {mean_x}, Initial frame {initial_frame}")
ax.legend()

plt.show()



## Scheme 1 (Fit & Average): Fitting to Individual MS2 Traces and then Take Their Average for Each Bin

### Define Fit Functions

In [16]:
# fit_all_traces: a function that generates the fit for each trace

import numpy as np
from scipy.optimize import least_squares
from scipy.stats import chi2
import pandas as pd
import emcee

# Version with normalization and regularization
def make_half_cycle(basal, t_on, t_dwell, rate, t_interp):
    half_cycle = np.zeros_like(t_interp)
    half_cycle[t_interp < t_on] = basal
    half_cycle[(t_interp >= t_on) & (t_interp < t_on + t_dwell)] = basal + rate * (t_interp[(t_interp >= t_on) & (t_interp < t_on + t_dwell)] - t_on)
    half_cycle[t_interp >= t_on + t_dwell] = basal + rate * t_dwell
    return half_cycle

def fit_func(params, MS2, timepoints, t_interp):
    return np.interp(timepoints, t_interp, make_half_cycle(*params, t_interp)) - MS2

def initial_guess(MS2, timepoints):
    # Initial guess for the parameters
    basal0 = MS2[0]
    t_on0 = timepoints[0]
    t_dwell0 = (2/3)*(timepoints[-1]-timepoints[0])
    rate0 = 1
    # print(np.max(mean_dy_dx))
    return [basal0, t_on0, t_dwell0, rate0]


def fit_half_cycle(MS2, timepoints, t_interp, std_errors, max_nfev=3000):
    # Initial guess
    x0 = initial_guess(MS2, timepoints)
    
    # Parameter bounds
    lb = [np.min(MS2), 0, 0, 0]  # Ensure t_dwell is non-negative
    ub = [np.max(MS2), np.max(timepoints), np.max(timepoints), 1e7]

    # Scaling factors to normalize parameters
    scale_factors = np.array([np.max(MS2), np.max(timepoints), np.max(timepoints), 100])

    # Scaled bounds
    lb_scaled = np.array(lb) / scale_factors
    ub_scaled = np.array(ub) / scale_factors
    x0_scaled = np.array(x0) / scale_factors

    # Scaled fit function
    def fit_func_scaled(params, MS2, timepoints, t_interp):
        params_unscaled = params * scale_factors
        return fit_func(params_unscaled, MS2, timepoints, t_interp)

    # Negative log-likelihood function
    def negative_log_likelihood(params, MS2, timepoints, t_interp, std_errors, reg=1e-3):
        residuals = fit_func_scaled(params, MS2, timepoints, t_interp) / std_errors
        regularization = reg * np.sum(params[:]**2)
        nll = 0.5 * np.sum(residuals**2) + regularization
        return nll

    # Initial parameter estimation using least_squares
    res = least_squares(negative_log_likelihood, 
                        x0_scaled, bounds=(lb_scaled, ub_scaled), 
                        args=(MS2, timepoints, t_interp, std_errors), max_nfev=max_nfev)
    

    # Define log-probability function for MCMC
    def log_prob(params, MS2, timepoints, t_interp, std_errors, scale_factors, lb_scaled, ub_scaled):
        if np.any(params < lb_scaled) or np.any(params > ub_scaled):
            return -np.inf
        nll = negative_log_likelihood(params, MS2, timepoints, t_interp, std_errors)
        return -nll  # Convert to log-probability

    # MCMC parameters
    nwalkers = 10
    ndim = len(x0_scaled)
    nsteps = 1000
    initial_pos = res.x + 1e-4 * np.random.randn(nwalkers, ndim)
    # Run MCMC
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=(MS2, timepoints,
                                                                    t_interp, std_errors,
                                                                    scale_factors, lb_scaled, ub_scaled))
    # Run MCMC until the acceptance fraction is at least 0.5
    sampler.run_mcmc(initial_pos, nsteps, 
                     progress=False, tune=True)

    # Flatten the chain and discard burn-in steps
    flat_samples = sampler.get_chain(discard=200, thin=15, flat=True)

    # Extract and rescale fit parameters
    basal, t_on, t_dwell, rate = np.median(flat_samples, axis=0) * scale_factors

    # Calculate confidence intervals
    CI = np.percentile(flat_samples, [5, 95], axis=0).T * scale_factors[:, np.newaxis]

    return basal, t_on, t_dwell, rate, CI

def first_derivative(x, y):
    """
    Compute the first discrete derivative of y with respect to x.
    Parameters:
    x (numpy.ndarray): Independent variable data points.
    y (numpy.ndarray): Dependent variable data points.
    Returns:
    numpy.ndarray: Discrete first derivative of y with respect to x.
    """
    dx = np.diff(x)
    dy = np.diff(y)
    dydx = dy / dx

    # Use central differences for the interior points and forward/backward differences for the endpoints
    dydx_central = np.zeros_like(y)
    dydx_central[1:-1] = (y[2:] - y[:-2]) / (x[2:] - x[:-2])
    dydx_central[0] = dydx[0]
    dydx_central[-1] = dydx[-1]

    return dydx_central

def mean_sign_intervals(function):
    """
    Compute the mean of function over intervals where the function has a constant sign.
    Parameters:
    derivative (numpy.ndarray): Array representing the function.
    Returns:
    numpy.ndarray: Array with mean values of the function over intervals with constant sign.
    """
    # Identify where the sign changes
    sign_changes = np.diff(np.sign(function))
    # Get indices where the sign changes
    change_indices = np.where(sign_changes != 0)[0] + 1

    # Initialize the list to hold mean values
    mean_values = []
    start_index = 0

    for end_index in change_indices:
        # Calculate the mean of the current interval
        interval_mean = np.mean(function[start_index:end_index])
        # Append the mean value to the list
        mean_values.extend([interval_mean] * (end_index - start_index))
        # Update the start index
        start_index = end_index

    # Handle the last interval
    interval_mean = np.mean(function[start_index:])
    mean_values.extend([interval_mean] * (len(function) - start_index))

    return np.array(mean_values), change_indices

# Function to generate fits for all traces
def fit_all_traces(traces, tv_denoised_traces):
    """
    Fit half-cycles to all traces in the dataset.
    Parameters:
    traces (list): List of traces to fit.
    tv_denoised_traces (list): List of TV denoised traces.
    Returns:
    list: List of tuples with fit parameters for each trace.
    """
    # Initialize the list to hold fit results
    fit_results = []

    # Create new dataframe to store fit results
    dataframe = pd.DataFrame(columns=['particle', 'fit_results'])
    
    for i in range(len(traces)):
        # Compute the first derivative of TV denoised with respect to time
        dy_dx = first_derivative(traces[i][0], tv_denoised_traces[i])

        # Compute the mean of the first derivative over intervals with constant sign
        mean_dy_dx, change_indices = mean_sign_intervals(dy_dx)

        # Keep datapoints from before first sign change
        try:
            timepoints = traces[i][0][:change_indices[0]]
            MS2 = traces[i][1][:change_indices[0]]
            MS2_std = traces[i][2][:change_indices[0]]

            # Interpolate the timepoints
            t_interp = np.linspace(min(timepoints), max(timepoints), 1000)
        except:
            print(f"Failed to find derivative sign change for trace {traces[i][3]}")
            fit_results.append([None, None, None, None, None])
            dataframe.loc[i] = [traces[i][3], [None, None, None, None, None]]
            continue


        # Compute the fit values
        try:
            basal, t_on, t_dwell, rate, CI = fit_half_cycle(MS2, timepoints, t_interp, MS2_std)
            fit_result = [timepoints, t_interp, MS2, make_half_cycle(basal, t_on, t_dwell, rate, t_interp),
                                [basal, t_on, t_dwell, rate, CI]]
            
            fit_results.append(fit_result)
            dataframe.loc[i] = [traces[i][3], fit_result]
        except:
            print(f"Failed to fit trace {traces[i][3]}")
            fit_results.append([timepoints,t_interp, MS2, None, None])
            dataframe.loc[i] = [traces[i][3], [timepoints,t_interp, MS2, None, None]]
            continue

    return fit_results, dataframe

### Perform Fitting on Ordered Spots

In [None]:
# Restrict to longer traces
traces_compiled_dataframe = compiled_dataframe[
    compiled_dataframe["frame"].apply(lambda x: x.size) > min_frames
]
# Restrict to traces starting at frame nc14_start_frame and above
traces_compiled_dataframe = traces_compiled_dataframe[
    traces_compiled_dataframe["frame"].apply(lambda x: x[0] >= nc14_start_frame)
]

# Order the traces based on the mean x position
traces_compiled_dataframe = traces_compiled_dataframe.sort_values(
    by="x", key=lambda x: x.apply(np.mean)
)

traces = plottable.generate_trace_plot_list(traces_compiled_dataframe)


# Generate TV denoised traces
tv_denoised_traces = [
    denoise_tv_chambolle(trace[1], weight=1080, max_num_iter=500) for trace in traces
]

# Generate fits for all traces
fit_results, dataframe = fit_all_traces(traces, tv_denoised_traces)

print(f"Number of traces: {len(traces)}")

# Show number of traces with valid fits
print(f"Number of traces with valid fits: {sum([result[4] is not None for result in fit_results])}")

# Show number of traces with invalid fits
print(f"Number of traces with invalid fits: {sum([result[4] is None for result in fit_results])}")

traces_compiled_dataframe_fits = pd.merge(traces_compiled_dataframe, dataframe, on='particle', how='inner')

# Add columns: the approval status, denoised trace, and fitted rate of the particle
length = traces_compiled_dataframe_fits.index.max()
status = [1 for _ in range(length+1)]
traces_compiled_dataframe_fits['tv_denoised_trace'] = tv_denoised_traces
traces_compiled_dataframe_fits['approval_status'] = status

In [None]:
traces_compiled_dataframe_fits.head()

### Checking the Fits

In [14]:
# check_particle_fit: a function that checks the fits

### Josh's Fit checking function, updated on 9.4.2024. Added the function to select a certain particle to look at
def check_particle_fit(binned_particles_fitted, show_denoised_plot=False):
    '''
    Check the fit of each particle. All particles are approved by default.
    '''

    fig, ax = plt.subplots()
    
    particle_index = 0
    particle_num = binned_particles_fitted.index.max()

    # move to the first unchecked particle--------------------------------------
    # first_flag = False
    # while not first_flag:
    #     particle_data = binned_particles_fitted[particle_index:particle_index+1]
    #     status = particle_data['approval_status'].values[0]
    #     if status == 0:
    #         first_flag = True
    #     else:
    #         if particle_index < particle_num:
    #             particle_index += 1
    #         elif particle_index == particle_num:
    #             warn('No particle has been left unchecked')
    #             break
    #---------------------------------------------------------------------------
    
    def update_plot(particle_index):
        ax.clear()
        try:
            particle_data = binned_particles_fitted[particle_index:particle_index+1] # select the particle
            
            x = particle_data['t_s'].values[0]
            y = particle_data['intensity_from_neighborhood'].values[0]
            y_denoised = particle_data['tv_denoised_trace'].values[0]
            y_err = particle_data['intensity_std_error_from_neighborhood'].values[0]
    
            # plot the particle trace with error bar along with the denoised trace
            ax.errorbar(x, y, yerr=y_err, fmt=".", elinewidth=1, label='Data')
            if show_denoised_plot:
                ax.plot(x, y_denoised, color='k', label='TV denoised')
    
            # plot the half cycle fit
            try:
                fit_result = particle_data['fit_results'].values[0]
                timepoints, t_interp, MS2, half_cycle_fit, [basal, t_on, t_dwell, rate, CI] = fit_result
    
                ax.plot(t_interp, half_cycle_fit, label=f'Fit (slope = {round(rate,2)})', linewidth=3)
    
                # ax.plot(t_interp, make_half_cycle(basal, t_on, t_dwell, rate, t_interp), label=f"Fit (slope = {round(rate, 2)})", linewidth=3, color='orange')
            except:
                pass
    
            particle = particle_data['particle'].values[0]
            # bin = particle_data['bin'].values[0]
            mean_x = (particle_data.loc[particle_data["particle"] == particle, "x"]
            .values[0]
            .mean()
            )
            initial_frame = (compiled_dataframe.loc[compiled_dataframe["particle"] == particle, "frame"].
            values[0][0]
            )
            status = particle_data['approval_status'].values[0]
    
            if status == 1:
                ax.set_facecolor((0.7, 1, 0.7)) # approve
            elif status == -1:
                ax.set_facecolor((1, 0.7, 0.7)) # reject color
            elif status == 0:
                ax.set_facecolor((1, 1, 1))
            elif status == 2:
                ax.set_facecolor((1, 1, 0.7))
            
            ax.set_title(f'Particle #{particle} ({particle_index+1}/{particle_num+1}), x = {np.round(mean_x, 2)}, Initial frame {initial_frame}')
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Spot intensity (AU)")
            ax.legend()

        except Exception as e:
            particle = particle_data['particle'].values[0]
            print(f"Error processing particle {particle}: {e}")

        fig.canvas.draw()

    def on_key(event):
        nonlocal particle_index
        if event.key == 'left':
            particle_index = max(0, particle_index - 1)
        elif event.key == 'right':
            particle_index = min(len(binned_particles_fitted) - 1, particle_index + 1)
        elif event.key == 'a':
            binned_particles_fitted.at[particle_index, 'approval_status'] = 1
        elif event.key == 'r':
            binned_particles_fitted.at[particle_index, 'approval_status'] = -1
        elif event.key == 'c':
            binned_particles_fitted.at[particle_index, 'approval_status'] = 0
        elif event.key == 'p':
            binned_particles_fitted.at[particle_index, 'approval_status'] = 2
        elif event.key == 'j':
            # Create a Tkinter root window and hide it
            root = tk.Tk()
            root.withdraw()

            # Ask for input
            input_index = simpledialog.askinteger("Input", "Enter particle index:", minvalue=0)#, maxvalue=particle_num)
            
            # Update the particle index if input is valid
            if input_index is not None:
                particle_index = binned_particles_fitted[binned_particles_fitted['particle'] == input_index].index.values[0]

            # Destroy the Tkinter root window
            root.destroy()
        update_plot(particle_index)
    update_plot(particle_index)
    fig.canvas.mpl_connect('key_press_event', on_key)
    plt.show()

Detect whether the trace checking has been done "previously". If so, load the previous results.

In [12]:
checked_traces_file_path = test_dataset_name + '/traces_compiled_dataframe_fits_checked.pkl'

checked_traces_previous = os.path.isfile(checked_traces_file_path)
checked_traces_previous

False

In [15]:
if checked_traces_previous:
    # Load the DataFrame from the .pkl file
    print('Load from previous trace checking results, which are shown below.')
    traces_compiled_dataframe_fits_checked = pd.read_pickle(checked_traces_file_path)
    traces_compiled_dataframe_fits_checked_temp = traces_compiled_dataframe_fits_checked.copy()
    check_particle_fit(traces_compiled_dataframe_fits_checked_temp)

else:
    print('Do trace checking for the dataset')
    check_particle_fit(traces_compiled_dataframe_fits)

Do trace checking for the dataset


NameError: name 'traces_compiled_dataframe_fits' is not defined

In [None]:
# Save the checked traces or update the checked traces if any changes are made

if checked_traces_previous:
    if all(traces_compiled_dataframe_fits_checked_temp.approval_status == traces_compiled_dataframe_fits_checked.approval_status):
        print('No changes made to the trace checking results')
    else:
        answer = messagebox.askyesno('Question', 'Changes to the checked traces detected. Save the changes?')
        if answer:
            traces_compiled_dataframe_fits_checked_temp.to_pickle(checked_traces_file_path, compression=None)
            print('Checked traces updated')
        else:
            print('No changes made to the trace checking results')

else:
    traces_compiled_dataframe_fits_checked = traces_compiled_dataframe_fits[traces_compiled_dataframe_fits['approval_status'] == 1].reset_index()
    traces_compiled_dataframe_fits_checked.to_pickle(checked_traces_file_path, compression=None)
    print('Checked traces saved')

### Sort the traces by which bin they are in

In [None]:
# Sorting traces by the bin they belong to, and calculate the average of fit slope for each bin

bin_width = 1/num_bins

# Create an array to store the bin indices for each trace
bin_indices = np.zeros(len(traces_compiled_dataframe_fits_checked))

# Loop through traces_compiled_dataframe_fits_checked and assign each trace to a bin based on mean_ap position
for i in range(len(traces_compiled_dataframe_fits_checked)):
    particle = traces_compiled_dataframe_fits_checked['particle'][i]
    bin_indices[i] = (
        traces_compiled_dataframe_fits_checked.loc[traces_compiled_dataframe_fits_checked['particle'] == particle, 'ap']
        .values[0].mean() // bin_width
    )

# Calculate the number of traces in each bin
bin_counts = np.zeros(num_bins)
for i in range(num_bins):
    bin_counts[i] = np.sum(bin_indices == i)

print(f'bin_counts = {bin_counts}')
print(f'np.sum(bin_counts) = {np.sum(bin_counts)}')

# Calculate the average fit rates for each bin and store particle IDs with rates in each bin
mean_fit_rates = np.zeros(num_bins)
bin_particles_rates = np.zeros(num_bins, dtype=object)
SE_fit_rates = np.zeros(num_bins)
for i in range(num_bins):
    if bin_counts[i] == 0:
        mean_fit_rates[i] = np.nan
        continue
    else:
        rates = (
            60*traces_compiled_dataframe_fits_checked.loc[bin_indices == i, 'fit_results'].apply(
                lambda x: x[4][3] if x[4] is not None else np.nan).values
            )
        particles = (
            traces_compiled_dataframe_fits_checked.loc[bin_indices == i, 'particle']
            .values
            )
        
        # Store the particle IDs with their rates in each bin for further analysis
        bin_particles_rates[i] = {
            'bin': i,
            'particles': particles,
            'rates': rates
        }
        
        mean_fit_rates[i] = (np.nanmean(rates))

        # Standard error of the mean
        SE_fit_rates[i] = np.nanstd(rates) / np.sqrt(len(rates))

In [None]:
# Prepare the data for plotting

not_nan_1 = ~np.isnan(mean_fit_rates)

bin_indices_1 = np.arange(num_bins)[not_nan_1]
ap_positions_1 = bin_indices_1 * 1/num_bins

bin_slopes_1 = mean_fit_rates[not_nan_1]
bin_slope_errs_1 = SE_fit_rates[not_nan_1]

max_bin_slope_1 = np.max(bin_slopes_1)
ylim_up = 1.5*max_bin_slope_1

In [None]:
# Plot the average slope of trace fits for each bin number

plt.figure()
plt.errorbar(ap_positions_1, bin_slopes_1, yerr=bin_slope_errs_1, capsize=2, fmt='o')
plt.xlabel('AP Position')
plt.ylabel('Average rate of trace fits (AU/min)')
plt.title('Average rate of trace fits vs. AP position')
plt.ylim(0,ylim_up)
plt.show()

## Scheme 2 (Average & Fit): Taking the Average of All MS2 Traces for Each Bin and then Fit to the Averaged Trace

In [None]:
# bin_average_NC14: a function that bins all the traces and takes the average of the traces for each bin

def bin_average_NC14(compiled_dataframe, NC14_start_frame, bin_num=num_bins, shift_to_same_start_frame=True):
    '''
    A function that bins all the traces and takes the average of the traces for each bin.

    ARGUMENTS
        bin_num: number of bins, default value is 42
        shift_to_same_start_frame: if true, shift all NC14 particles so that they start at the same frame.

    OUTPUT
        a list, where each element is a pandas dataframe for one bin containing the average intensity of all 
        the traces in that bin. The average is taken frame-wise.
    '''
    
    bin_width = 1/num_bins
    
    # Keep only NC14 particles
    compiled_dataframe_NC14 = compiled_dataframe[compiled_dataframe['frame'].apply(min) >= NC14_start_frame]

    # Create an array to store the bin indices for each trace
    bin_indices = np.zeros(len(compiled_dataframe_NC14))
    
    # Loop through compiled_dataframe_NC14 and assign each trace to a bin based on mean ap position
    for i in range(len(compiled_dataframe_NC14)):
        particle = compiled_dataframe_NC14['particle'][i]
        bin_indices[i] = (
            compiled_dataframe_NC14.loc[compiled_dataframe_NC14['particle'] == particle, 'ap']
            .values[0].mean() // bin_width
        )
        
    compiled_dataframe_NC14['bin'] = bin_indices.astype(int)

    # Shift all NC14 particles to the same start frame
    if shift_to_same_start_frame:
        for particle in range(len(compiled_dataframe_NC14)):
            try:
                frame_array = compiled_dataframe_NC14['frame'][particle][:]
                first_frame = np.min(frame_array)
                new_frame_array = frame_array - first_frame
                compiled_dataframe_NC14['frame'][particle] = new_frame_array
            except:
                continue

    # sort particles by bins
    binned_particles_NC14 = [None]*bin_num
    for bin in range(bin_num): # for each bin
        mask = compiled_dataframe_NC14["bin"] == bin
        binned_particles_NC14[bin] = compiled_dataframe_NC14[mask]

    
    # A function that sorts intensity data based on bin and frame, and for each bin, calculates the average intensity in each frame
    #----------------------------------------------------------------------------------------------------
    def bin_average_process(binned_particles):
    
        # sort intensity data based on bin and frame
        
        intensity_by_frame = [[] for _ in range(bin_num)]
        
        for bin in range(bin_num):
                
            # find the max and min frame number in a bin
            try:
                max_frame = max(binned_particles[bin]["frame"].apply(np.max))
                min_frame = min(binned_particles[bin]["frame"].apply(np.min))
                
                intensity_by_frame[bin] = pd.DataFrame({'frame': range(min_frame,max_frame), 
                                     'intensity': [[] for _ in range(0,max_frame-min_frame)], 
                                     'average_intensity': [None for _ in range(0,max_frame-min_frame)],
                                    'std_err_intensity': [None for _ in range(0,max_frame-min_frame)]})
                
                bin_particle_num = binned_particles[bin].shape[0]
                
                for frame in range(min_frame, max_frame): # for each frame along the movie
                    FrameIndex = frame - min_frame
                    
                    for particle in range(bin_particle_num): # for each particle in the bin
                        particle_frames = np.array(binned_particles[bin])[particle][1] # extract the frame list of a single particle
                        particle_intensity = np.array(binned_particles[bin])[particle][3] # extract the intensity list of a single particle
    
                        # avoid 0 dimension array that will create undesired results
                        if particle_frames.ndim == 0:
                            particle_frames = np.array([particle_frames])
    
                        if particle_intensity.ndim == 0:
                            particle_intensity = np.array([particle_intensity])
                        
                        for el in range(len(particle_frames)): # for each frame of this particle
                            if particle_frames[el] == frame:
                                # add the intensity value of this particle at this frame to the new data structure
                                intensity_by_frame[bin]['intensity'][FrameIndex].append(particle_intensity[el])
            except:
                continue
    
        
        # for each bin, calculate the average intensity in each frame along with the standard error
        
        for bin in range(bin_num):
            try:
                frame_num = intensity_by_frame[bin].shape[0]
                for frame in range(frame_num):
                    intensity_data = intensity_by_frame[bin]['intensity'][frame]
                    intensity_by_frame[bin]['average_intensity'][frame] = np.mean(intensity_data)
                    intensity_by_frame[bin]['std_err_intensity'][frame] = np.std(intensity_data)/np.sqrt(len(intensity_data))
            except:
                continue
    
        return intensity_by_frame

    #----------------------------------------------------------------------------------------------------

    return bin_average_process(binned_particles_NC14)

In [None]:
bin_average_intensity_list = bin_average_NC14(compiled_dataframe, NC14_start_frame=nc14_start_frame)

In [None]:
print(bin_average_intensity_list)

### Define fit functions
Before running the cell below, please run the first cell in Scheme 1. 

In [None]:
# fit_average_trace: a function that fits the average trace for a single bin

def fit_average_trace(timepoints, MS2, tv_denoised_traces, MS2_std, bin_index):
    # timepoints: the 'frame' column from intensity_by_frame multiplied by time_res_min
    # MS2: the 'average_intensity' column
    # tv_denoised_traces: the 'denoised_average_intensity' column
    # MS2_std: the 'std_erfit_resultsr_intensity' column
    """
    The single-trace version of fit_all_traces.
    """
    fit_results = []

    # Compute the first derivative of TV denoised with respect to time
    dy_dx = first_derivative(timepoints, tv_denoised_traces)

    # Compute the mean of the first derivative over intervals with constant sign
    mean_dy_dx, change_indices = mean_sign_intervals(dy_dx)

    # Keep datapoints from before first sign change
    try:
        timepoints = timepoints[:change_indices[0]]
        MS2 = MS2[:change_indices[0]]
        MS2_std = MS2_std[:change_indices[0]]

        # Interpolate the timepoints
        t_interp = np.linspace(min(timepoints), max(timepoints), 1000)
    except:
        print(f"Failed to find derivative sign change for average trace {bin_index+1}")


    # Compute the fit values
    try:
        basal, t_on, t_dwell, rate, CI = fit_half_cycle(MS2, timepoints, t_interp, MS2_std)

        fit_results = [timepoints,t_interp, MS2, make_half_cycle(basal, t_on, t_dwell, rate, t_interp),
                            [basal, t_on, t_dwell, rate, CI]]
    except Exception as e: 
        print(f"Failed to fit average trace {bin_index+1}: {e}")
        fit_results = [timepoints,t_interp, MS2, None, None]

    return fit_results

In [None]:
# fit_bin_averages: a function that fits to the average intensity plot for each bin and outputs the slopes of the fits

def fit_bin_averages(bin_average_intensity_list, time_res_min=time_res_min):

    bin_num = len(bin_average_intensity_list)

    fit_results = [None]*bin_num
    fit_slopes = [np.nan]*bin_num

    for bin in range(bin_num):
        try:
            bin_average_intensity = bin_average_intensity_list[bin]
    
            bin_average_intensity_denoised = denoise_bin_average_intensity(bin_average_intensity)

            # Prepare the lists/arrays needed for the fit
            first_frame = bin_average_intensity_denoised['frame'][0]
            x = (bin_average_intensity_denoised['frame'].values - first_frame) * time_res_min
            y = bin_average_intensity_denoised['average_intensity'].values
            y_err = bin_average_intensity_denoised['std_err_intensity'].values
            y_denoised = bin_average_intensity_denoised['denoised_average_intensity'].values

            # Generate the fit
            fit_result = fit_average_trace(x, y, y_denoised, y_err, bin)

            # Store the fit result and the fit slope
            fit_results[bin] = fit_result
            
            try:
                # Store the fit slope
                timepoints, t_interp, MS2, half_cycle_fit, [basal, t_on, t_dwell, rate, CI] = fit_result
                fit_slopes[bin] = rate
            except:
                pass

        except:
            continue

    return fit_results, np.array(fit_slopes)


# A function that denoise the dataframes in bin_average_intensity_list, used in the function fit_bin_average

def denoise_bin_average_intensity(bin_average_intensity):

    bin_average_intensity_denoised = bin_average_intensity.copy()
    
    try:
        bin_average_intensity_denoised = bin_average_intensity_denoised.dropna(subset=['average_intensity']) # remove all nan

        # denoise
        before_denoise = np.array(list(bin_average_intensity_denoised['average_intensity']))
        after_denoise = denoise_tv_chambolle(before_denoise, weight=1080, max_num_iter=500)
        
        bin_average_intensity_denoised['denoised_average_intensity'] = after_denoise
        
    except:
        pass
            
    return bin_average_intensity_denoised

In [None]:
fit_results, fit_slopes = fit_bin_averages(bin_average_intensity_list)

In [None]:
bin_average_intensity = bin_average_intensity_list[6]

bin_average_intensity_denoised = denoise_bin_average_intensity(bin_average_intensity)

# Prepare the lists/arrays needed for the fit
first_frame = bin_average_intensity_denoised['frame'][0]
x = (bin_average_intensity_denoised['frame'].values - first_frame) * time_res_min
y = bin_average_intensity_denoised['average_intensity'].values
y_err = bin_average_intensity_denoised['std_err_intensity'].values
y_denoised = bin_average_intensity_denoised['denoised_average_intensity'].values

plt.figure()
plt.scatter(x, y)

In [None]:
cut = 2
head = 0

mask = (x >= head) & (x <= cut)

new_x = x[mask]
new_y = y[mask]
new_yerr = y_err[mask]
plt.figure()
plt.scatter(x, y)
plt.scatter(new_x, new_y)

In [None]:
x_interp = np.linspace(min(new_x), max(new_x), 1000)

fit_half_cycle(new_y, new_x, x_interp, new_yerr)

In [None]:
# Generate the fit
fit_result = fit_average_trace(x, y, y_denoised, y_err, bin)

# Store the fit result and the fit slope
fit_results[bin] = fit_result

### Check the fit to each bin

In [None]:
# check_bin_fit: a function that checks the fit for each bin

def check_bin_fit(bin_average_intensity_list, fit_results, time_res_min=time_res_min, show_denoised_plot=False, show_fit=True):
    '''
    Check the fit of each particle. All particles are approved by default.

    Note that the bin index shown in the title and the total bin number start from 0, 
    i.e. if it's 2/10 it means it's actually the third bin out of 11 bins.
    '''

    fig, ax = plt.subplots()
    
    bin_index = 0
    bin_num =len(bin_average_intensity_list) 

    # move to the first unchecked particle--------------------------------------
    # first_flag = False
    # while not first_flag:
    #     particle_data = binned_particles_fitted[particle_index:particle_index+1]
    #     status = particle_data['approval_status'].values[0]
    #     if status == 0:
    #         first_flag = True
    #     else:
    #         if particle_index < particle_num:
    #             particle_index += 1
    #         elif particle_index == particle_num:
    #             warn('No particle has been left unchecked')
    #             break
    #---------------------------------------------------------------------------
    # import matplotlib
    # plt.close('all')
    
    def update_plot(bin_index):
        ax.clear()
        try:
            bin_average_intensity = bin_average_intensity_list[bin_index] # select the particle

            bin_average_intensity_denoised = denoise_bin_average_intensity(bin_average_intensity)
            
            x = bin_average_intensity_denoised['frame'].values * time_res_min
            y = bin_average_intensity_denoised['average_intensity'].values
            y_err = bin_average_intensity_denoised['std_err_intensity'].values
            y_denoised = bin_average_intensity_denoised['denoised_average_intensity'].values
    
            # plot the particle trace with error bar along with the denoised trace
            ax.errorbar(x, y, yerr=y_err, fmt=".", elinewidth=1, label='Data')

            if show_denoised_plot:
                ax.plot(x, y_denoised, color='k', label='TV denoised')
    
            # plot the half cycle fit
            try:
                fit_result = fit_results[bin_index]
                timepoints, t_interp, MS2, half_cycle_fit, [basal, t_on, t_dwell, rate, CI] = fit_result
    
                if show_fit:
                    ax.plot(t_interp, half_cycle_fit, label=f'Fit (slope = {round(rate,2)})', linewidth=3)
    
                # ax.plot(t_interp, make_half_cycle(basal, t_on, t_dwell, rate, t_interp), label=f"Fit (slope = {round(rate, 2)})", linewidth=3, color='orange')
            except:
                pass
    
            # status = particle_data['approval_status'].values[0]
    
            # if status == 1:
            #     ax.set_facecolor((0.7, 1, 0.7)) # approve
            # elif status == -1:
            #     ax.set_facecolor((1, 0.7, 0.7)) # reject color
            # elif status == 0:
            #     ax.set_facecolor((1, 1, 1))
            # elif status == 2:
            #     ax.set_facecolor((1, 1, 0.7))
                
            
            ax.set_title(f'Bin #{bin_index+1}/{bin_num-1}')
            ax.set_xlabel("Time (min)")
            ax.set_xlim(0, 14)
            ax.set_ylabel("Spot intensity (AU)")
            ax.legend()

        except Exception as e:
            print(f"Error processing bin {bin_index}: {e}")

        fig.canvas.draw()

    def on_key(event):
        nonlocal bin_index
        if event.key == 'left':
            bin_index = max(0, bin_index - 1)
        elif event.key == 'right':
            bin_index = min(len(bin_average_intensity_list) - 1, bin_index + 1)
        # elif event.key == 'a':
        #     bin_average_intensity_list.at[bin_index, 'approval_status'] = 1
        # elif event.key == 'r':
        #     bin_average_intensity_list.at[bin_index, 'approval_status'] = -1
        # elif event.key == 'c':
        #     bin_average_intensity_list.at[bin_index, 'approval_status'] = 0
        # elif event.key == 'p':
        #     bin_average_intensity_list.at[bin_index, 'approval_status'] = 2
        update_plot(bin_index)

    update_plot(bin_index)
    fig.canvas.mpl_connect('key_press_event', on_key)
    plt.show()

In [None]:
check_bin_fit(bin_average_intensity_list, fit_results, show_denoised_plot=False, show_fit=True)

### Plot slope vs. AP position

In [None]:
# Prepare the data to be plotted

not_nan_2 = ~np.isnan(fit_slopes)

bin_indices_2 = np.arange(num_bins)[not_nan_2]
bin_slopes_2 = fit_slopes[not_nan_2]

# extract the widths of confidence intervals for the slopes, stored in bin_slope_errs
bin_slope_errs_2 = np.zeros((num_bins, 2))
for bin in range(num_bins):
    try:
        bin_slope_err = fit_results[bin][-1][-1][-1]
        bin_slope_errs_2[bin] = bin_slope_err
    except:
        pass

bin_slope_errs_2 = np.transpose(np.abs(bin_slope_errs_2[not_nan_2] - bin_slopes_2[:, np.newaxis]))

In [None]:
# Plot the bins with data

# Adjust the range of the bins if necessary. Default is start to end
start = 0
end = len(bin_indices_2)

ap_positions_2 = bin_indices_2 * 1/num_bins

plt.figure()
plt.errorbar(ap_positions_2[start:end], bin_slopes_2[start:end], yerr=bin_slope_errs_2[:,start:end], capsize=2, fmt='o')
plt.xlabel('AP position')
plt.ylabel('Fit rate of average trace (AU/min)')
plt.title('Fit rate of average trace vs. AP position (with shifting)')
plt.ylim(0, ylim_up)
plt.show()

## Overlaying the results from the two schemes

In [None]:
plt.figure()
plt.xlabel('AP Position')
plt.ylabel('Rate (AU/min)')
plt.title('Rate vs. AP position')
plt.ylim(0, 120)

# Scheme 1: Fit & Average
plt.errorbar(ap_positions_1, bin_slopes_1, yerr=bin_slope_errs_1, capsize=2, fmt='o-', label='fit & average')

# Scheme 2: Average & Fit
plt.errorbar(ap_positions_2[start:end], bin_slopes_2[start:end], yerr=bin_slope_errs_2[:,start:end], capsize=2, fmt='o-', label='average & fit')

plt.legend()
plt.show()