In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.signal as spsig

from pasna_analysis import Experiment, utils

experiment_name = '20240619_23'
# Use the list below to exclude embryos from the experiment.
# Add embryo ids to the list (based on csv file name).
to_exclude = []

exp_path = Path.cwd().parent.joinpath('data', experiment_name)
exp = Experiment(exp_path, first_peak_threshold=0, to_exclude=to_exclude,
                 dff_strategy='local_minima')

In [None]:
%matplotlib ipympl
import json

import ipywidgets as widgets
from IPython.display import display

from pasna_analysis import debounce
from pasna_analysis.interactive_find_peaks import get_initial_values, save_detection_params, local_peak_at, save_add_peak, save_remove_peak


pd_params_path = Path(exp_path / 'peak_detection_params.json')

def rerun(btn):
    order_zero_min = order_zero_slider.value
    order_one_min = order_one_slider.value
    mpd = mpd_slider.value
    prominence = prominence_slider.value
    # calculate detect peaks for all embryos again
    for emb in exp.embryos.values():
        emb.trace.detect_peaks(mpd=mpd, order0_min=order_zero_min, order1_min=order_one_min, 
                           prominence=prominence)

    save_detection_params(pd_params_path, mpd=mpd, order0_min=order_zero_min, order1_min=order_one_min, 
                          prominence=prominence)

@debounce(0.1)
def update(changes):
    order_zero_min = order_zero_slider.value
    order_one_min = order_one_slider.value
    mpd = mpd_slider.value
    prominence_value = prominence_slider.value
    try:
        new_times, new_idxes = trace.detect_peaks(mpd=mpd, order0_min=order_zero_min, 
                                              order1_min=order_one_min, 
                                              prominence=prominence_value)
    except ValueError:
        print('Could not find any peaks with the selected parameters')
        return
    print(new_times/60, new_idxes)
    new_peaks = emb.trace.dff[new_idxes]
    peak_plot.set_data(new_times/60, new_peaks)
    fig.canvas.draw_idle()


embryos = list(exp.embryos.values())
i = 1
emb = embryos[i]

fig, ax = plt.subplots(figsize=(10, 4))
time = emb.activity[:, 0] / 60
trace = emb.trace
peak_times = trace.peak_times / 60
dff_plot = ax.plot(time, trace.dff)[0]

before_hatching = trace.dff[:trace.trim_idx]
max_y = np.max(before_hatching)+0.01
min_y = np.min(before_hatching)-0.01
ax.set_ylim(bottom=min_y, top=max_y)
peak_plot = ax.plot(peak_times, trace.peak_amplitudes, 'r.')[0]
print(peak_times)
print(trace.peak_amplitudes)

ax.set_title(emb.name)
fig.canvas.header_visible = False

def on_click(e):
    if e.key != 'shift':
        return

    with open(pd_params_path, 'r') as f:
        config = json.load(f)
    if 'embryos' not in config.keys():
        config['embryos'] = {}

    x = int(e.xdata*10)
    wlen = 10

    if e.button == 1:
        window = slice(x - wlen, x + wlen)
        peak = local_peak_at(x, trace.dff[window], wlen)
        new_arr = np.append(trace.peak_idxes, peak)
        new_arr.sort()
        trace.peak_idxes = new_arr
        save_add_peak(emb.name, config, peak, wlen)

    elif e.button == 3:
        target = (trace.peak_idxes >= x - wlen) & (trace.peak_idxes <= x + wlen)
        removed = trace.peak_idxes[target].tolist()
        new_arr = trace.peak_idxes[~target]
        trace.peak_idxes = new_arr
        save_remove_peak(emb.name, config, removed, x, wlen)
    else:
        print('Right click to remove a peak, left click to add a peak.')

    with open(pd_params_path, 'w') as f:
        json.dump(config, f, indent=4)

    new_times = time[trace.peak_idxes]
    new_peaks = trace.dff[trace.peak_idxes]
    peak_plot.set_data(new_times, new_peaks)
    fig.canvas.draw_idle()

fig.canvas.mpl_connect('button_press_event', on_click)

order_zero_value, order_one_value, mpd_value, prominence_value = get_initial_values(pd_params_path).values()

mpd_slider = widgets.IntSlider(value=mpd_value, min=40, max=600, step=1)
mpd_box = widgets.HBox([widgets.Label('Minimum peak distance'), mpd_slider])

order_zero_slider = widgets.FloatSlider(value=order_one_value, min=0, max=0.5, step=0.01)
order_zero_box = widgets.HBox([widgets.Label("Minimum height order zero"), order_zero_slider])

order_one_slider = widgets.FloatSlider(value=order_one_value, min=0, max=0.1, step=0.001,
                                       readout_format='.3f')
order_one_box = widgets.HBox([widgets.Label("Minimum height order one"), order_one_slider])

prominence_slider = widgets.FloatSlider(value=prominence_value, min=0, max=2, step=0.02)
prominence_box = widgets.HBox([widgets.Label("Prominence"), prominence_slider])

update_btn = widgets.Button(description='Update peak detection parameters', icon='repeat', 
                            layout=widgets.Layout(width='300px'), button_style='primary')
update_btn.on_click(rerun) 

order_zero_slider.observe(update)
order_one_slider.observe(update)
mpd_slider.observe(update)
prominence_slider.observe(update)

display(update_btn, mpd_box, order_zero_box, order_one_box, prominence_box)

In [None]:
"""Plots activity signals for a group of embryos."""

# change start and end to select different groups
start = 0
print(exp.embryos)
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, 2, figsize=(20, 2 * (end - start)))

for (left, right), emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    left.plot(time, trace.active, color="green")
    right.plot(time, trace.struct, color="firebrick")

    # set title/labels
    left.set_title(f"{emb.name} - GCamP")
    left.set_ylabel("ΔF/F")
    left.set_xlabel("time (mins)")
    right.set_title(f"{emb.name} - tdTomato")
    right.set_ylabel("ΔF/F")
    right.set_xlabel("time (mins)")

    # set y axis bounds (comment out if you want it to autoscale)
    left.set_ylim([200, 600])
    right.set_ylim([400, 900])

    # set tick marks
    x_points = utils.time_scale_list(max(time))
    left.set_xticks(x_points)
    right.set_xticks(x_points)

    # mark identified peaks
    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        left.axvline(peak_in_mins, color="k", alpha=0.3)

plt.tight_layout()

In [None]:
"""Plots dff for a group of embryos."""

# change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(14, 2 * (end - start)))
axes = axes.flatten()
for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.dff, linewidth=3, color="k")

    # set title/labels
    fontsize = 15
    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    ax.set_title(f"{exp.name}{emb.name} - dff", fontsize=fontsize)
    ax.set_ylabel("ΔF/F", fontsize=fontsize)
    ax.set_xlabel("time (mins)", fontsize=fontsize)

    # x axis - trim to before onset & at hatching
    start = round(time[trace.peak_bounds_indices[0][0]])
    end_time = trace.time[trace.trim_idx] / 60
    ax.set_xlim(start - 20, end_time)  # start plotting 20 mins before hatching

    # x axis - adjust tick labels to make the start time 0
    hours = np.arange(start, end_time, 60)
    ax.set_xticks(hours)
    labels = np.arange(0, end_time - start, 60)
    ax.set_xticklabels(labels)
    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    # y axis - trim & tick labels
    ax.set_ylim(-0.1, 1)
    ax.set_yticks([0, 0.8])

    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="k", alpha=0.3)

plt.tight_layout(pad=3)

In [None]:
"""Plots activity signals for a group of embryos. Color code blocks of time"""

# Change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(28, 2.5 * (end - start)))

for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.active, color="black")

    # set title/labels
    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    ax.set_ylabel("ΔF/F")
    ax.set_xlabel("time (mins)")

    # set tick marks
    x_points = utils.time_scale_list(max(time))
    ax.set_xticks(x_points)

    # mark identified bursts
    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="orchid", alpha=0.5, linewidth=3)

    # add 5 min lines after the start of activity
    start = round(time[trace.peak_bounds_indices[0][0]])
    five_mins = np.arange(start, max(time), 5)
    for line in five_mins:
        ax.axvline(line, linewidth=3, color="green", alpha=0.3)

    # add hour blocks
    hours = np.arange(start, max(time), 60)
    colors = [
        "yellow",
        "orange",
        "plum",
        "royalblue",
        "lightpink",
        "lime",
        "peachpuff",
        "aqua",
        "tomato",
    ]
    color_hour = zip(hours, colors[0 : len(hours)])
    for hour, c in color_hour:
        ax.axvline(
            hour + 30, color=c, linewidth=230, alpha=0.2
        )  # adjust line width to control color block size

plt.tight_layout()

In [None]:
"""Plots activity signals for a group of embryos."""

# Change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(28, 2.5 * (end - start)))

for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.active, color="black", linewidth=2)

    # set title & labels
    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    ax.set_ylabel("ΔF/F")
    ax.set_xlabel("time (mins)")

    # set tick marks
    x_points = utils.time_scale_list(max(time))
    ax.set_xticks(x_points)

    # trim x and y axis
    ax.set_xlim([0, 500])
    ax.set_ylim([200, 600])

    # mark identified bursts
    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="orchid", alpha=0.5, linewidth=3)

plt.tight_layout()

In [None]:
"""Generate single dff trace"""

# get data
embryos = list(exp.embryos.values())
i = 16  # choose emb
emb = embryos[i]
time = emb.activity[:, 0] / 60
trace = emb.trace

# plot trace
fig, ax = plt.subplots()
ax.plot(time, trace.dff, color="black")


ax.set_title(f"{exp.name} - {emb.name}")
ax.set_ylabel("ΔF/F")
ax.set_xlabel("time (mins)")

# x axis - trim to before onset & at hatching
start = round(time[trace.peak_bounds_indices[0][0]])
end_time = trace.time[trace.trim_idx] / 60
ax.set_xlim(start - 20, end_time)  # start plotting 20 mins before hatching

In [None]:
"""Developmental times at first peak."""

data = {"dev_fp": [], "emb": []}

for emb in exp.embryos.values():
    time_first_peak = emb.trace.peak_times[0]
    dev_time_first_peak = emb.get_DT_from_time(time_first_peak)
    data["dev_fp"].append(dev_time_first_peak)
    data["emb"].append(emb.get_id())

df = pd.DataFrame(data)

fig, ax = plt.subplots()
sns.set_theme(style="darkgrid")
ax = sns.swarmplot(data=data, x="emb", y="dev_fp")
ax.set_title("Developmental time at first peak")
ax.set_ylabel("Developmental time")
ax.set_xlabel("Experiment Group")
plt.tight_layout()
plt.show()

In [None]:
"""Peak amplitudes for each episode."""

num_of_peaks = 13
data = {"peak_amp": [], "emb": [], "peak_idx": []}

for emb in exp.embryos.values():
    for i, amp in zip(range(num_of_peaks), emb.trace.peak_amplitudes):
        data["peak_amp"].append(amp)
        data["emb"].append(emb.name)
        data["peak_idx"].append(i)

amps = pd.DataFrame(data)

fig, ax = plt.subplots()
sns.set_theme(style="darkgrid")
ax = sns.scatterplot(data=data, x="peak_idx", y="peak_amp", hue="emb")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax.set_title(f"Peak amplitudes")
ax.set_xlabel("Peak number")
ax.set_ylabel("\u0394F/F")

In [None]:
"""Plots AUC, grouped by bins."""

data = {"auc": [], "bin": [], "emb": []}

n_bins = 5
first_bin = 2
bin_width = 0.2
for i, emb in enumerate(exp.embryos.values()):
    trace = emb.trace

    dev_time_at_peaks = emb.get_DT_from_time(trace.peak_times)

    bins = [first_bin + j * bin_width for j in range(n_bins)]
    bin_idxs = utils.split_in_bins(dev_time_at_peaks, bins)
    if not isinstance(bin_idxs, np.ndarray):
        continue
    data["auc"].extend(trace.peak_aucs)
    data["bin"].extend(bin_idxs)
    data["emb"].extend([str(i)] * len(trace.peak_aucs))

print(len(data["auc"]), len(data["bin"]), len(data["emb"]))

fig, ax = plt.subplots()
# add a last bin point to generate the labels
bins.append(first_bin + bin_width * n_bins)
x_labels = [f"{s}~{e}" for (s, e) in zip(bins[:-1], bins[1:])]
ax = sns.pointplot(data=data, x="bin", y="auc", linestyle="None")
ax.set_xticks(ticks=list(range(n_bins)), labels=x_labels)
ax.set_title(f"Binned AUC - Exp {exp.name}")
ax.set_ylabel("AUC [activity*t]")

In [None]:
"""Plot peak times for all embryos. This should help to isolate the 
early peaks."""

UPPER_LIMIT = 15000
times = []
for emb in exp.embryos.values():
    trace = emb.trace
    times.append([t / 60 for t in trace.peak_times if t < 15000])

fig, ax = plt.subplots()
for i, time in enumerate(times):
    ax.plot(time, [i] * len(time), marker=".", linestyle="dashed", linewidth=0.5)

ax.vlines(30, 0, len(exp.embryos), color="k", linewidth=0.3)
ax.set_title(f"Peak times for exp {exp.name}")
x_points = utils.time_scale_list(UPPER_LIMIT / 60)
ax.set_xticks(x_points)
ax.set_xlabel("time (mins)")
ax.set_ylabel("emb")
ax.set_yticks([])

In [None]:
"""Compare first interspike interval and average interspike interval"""

avg_ISIs = []
first_ISIs = []
for emb in exp.embryos.values():
    peak_times = emb.trace.peak_times
    if len(peak_times) <= 1:
        continue
    avg_ISIs.append(np.average(np.diff(peak_times[1:])))
    first_ISIs.append(peak_times[1] - peak_times[0])

max_val = np.max((np.max(avg_ISIs), np.max(first_ISIs)))
diag = list(range(0, int(max_val), 50))

fig, ax = plt.subplots()
ax.plot(avg_ISIs, first_ISIs, "k.")
ax.plot(diag, diag, label="y=x")
ax.plot(diag, list(range(0, int(max_val) * 2, 50 * 2)), label="y=2*x")
ax.set_xlabel("average ISI")
ax.set_ylabel("first ISI")
ax.set_xlim(0, max_val)
ax.set_ylim(0, max_val)
ax.set_title(f"Differences between first and avg intervals - exp {exp.name}")
ax.legend()

In [None]:
"""Plot PSD spectrum for an embryo"""

start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(35, 3 * (end - start)))
axes = axes.flatten()
for ax, emb in zip(axes, embryos[start:end]):
    dff = trace.dff

    fs = 1 / 6

    finite = np.isfinite(dff)
    finite_dff = dff[finite]

    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    freqs, amps = spsig.welch(
        finite_dff, fs=1 / 6, window="hamming", nperseg=1024, scaling="spectrum"
    )
    ax.plot(freqs, amps)
    ax.set(xlim=(0, 0.015))

In [None]:
"""Count mini peaks"""

embryos = list(exp.embryos.values())
i = 1
emb = embryos[i]

dff = emb.trace.dff

fig, ax = plt.subplots()

time = emb.activity[:, 0] / 60
minipeak_indices, properties = spsig.find_peaks(dff, prominence=0.05)
minipeak_times = [time[x] for x in minipeak_indices]
for peak in minipeak_times:
    ax.axvline(peak, color="purple", alpha=0.05, linewidth=10)
x_points = utils.time_scale_list(max(time))

plt.plot(time, dff)

In [None]:
"""Quantifying Frequency: spectrogram"""

embryos = list(exp.embryos.values())
i = 1
emb = embryos[i]
time = emb.activity[:, 0] / 60
x_points = utils.time_scale_list(max(time))
trace = emb.trace
freqs, times, spectrogram = spsig.spectrogram(trace.dff, 1 / 6, nfft=16384)

fig, ax = plt.subplots()
ax.pcolormesh(times, freqs, spectrogram, shading="gouraud")
plt.ylabel("Frequency [Hz]")
plt.xlabel("Time")
plt.ylim((0, 0.01))
plt.tight_layout()