In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.signal as spsig
import seaborn as sns

from pasna_analysis import Experiment, utils

experiment_name = '20240619_23'
# Use the list below to exclude embryos from the experiment.
# Add embryo ids to the list (based on csv file name).
to_exclude = []

exp_path = Path.cwd().parent.joinpath('data', experiment_name)
exp = Experiment(exp_path, first_peak_threshold=0, to_exclude=to_exclude, dff_strategy='local_minima')

In [None]:
"""Plots activity signals for a group of embryos."""

# change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, 2, figsize=(20, 2 * (end - start)))

for (left, right), emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    left.plot(time, trace.active, color="green")
    right.plot(time, trace.struct, color="firebrick")

    # set title/labels
    left.set_title(f"{emb.name} - GCamP")
    left.set_ylabel("ΔF/F")
    left.set_xlabel("time (mins)")
    right.set_title(f"{emb.name} - tdTomato")
    right.set_ylabel("ΔF/F")
    right.set_xlabel("time (mins)")

    # set y axis bounds (comment out if you want it to autoscale)
    left.set_ylim([200, 600])
    right.set_ylim([400, 900])

    # set tick marks
    x_points = utils.time_scale_list(max(time))
    left.set_xticks(x_points)
    right.set_xticks(x_points)

    # mark identified peaks
    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        left.axvline(peak_in_mins, color="k", alpha=0.3)

plt.tight_layout()

In [None]:
"""Plots dff for a group of embryos, mark peaks."""

# change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(14, 2 * (end - start)))
axes = axes.flatten()
for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.dff, linewidth=3, color="k")

    # set title/labels
    fontsize = 15
    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    ax.set_title(f"{exp.name}{emb.name} - dff", fontsize=fontsize)
    ax.set_ylabel("ΔF/F", fontsize=fontsize)
    ax.set_xlabel("time (mins)", fontsize=fontsize)

    # x axis - trim to before onset & at hatching
    start = round(time[trace.peak_bounds_indices[0][0]])
    end_time = trace.time[trace.trim_idx] / 60
    ax.set_xlim(start - 20, end_time)  # start plotting 20 mins before hatching

    # x axis - adjust tick labels to make the start time 0
    hours = np.arange(start, end_time, 60)
    ax.set_xticks(hours)
    labels = np.arange(0, end_time - start, 60)
    ax.set_xticklabels(labels)
    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    # y axis - trim & tick labels
    ax.set_ylim(-0.1, 1)
    ax.set_yticks([0, 0.8])

    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="k", alpha=0.3)

    hundred_seconds = np.linspace(start, end_time, 100)
    for line in hundred_seconds:
        ax.axvline(line)

plt.tight_layout(pad=3)

In [None]:
num_episodes = []
for emb in exp.embryos.values():
    num_episodes.append(len(emb.trace.peak_times))

print(np.mean(num_episodes))

In [None]:
"""Plots dff for a group of embryos, mark peaks & shade bursts."""

# Change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(15, 2.5 * (end - start)))

for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.dff, linewidth=3, color="black")
    # set title/labels
    fontsize = 15
    ax.tick_params(axis="both", which="major", labelsize=fontsize)
    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    ax.set_ylabel("ΔF/F")
    ax.set_xlabel("time (mins)")

    # x axis - trim to before onset & at hatching
    start = round(time[trace.peak_bounds_indices[0][0]])
    end_time = trace.time[trace.trim_idx] / 60
    ax.set_xlim(start - 20, end_time)  # start plotting 20 mins before hatching

    # x axis - adjust tick labels to make the start time 0
    hours = np.arange(start, end_time, 60)
    ax.set_xticks(hours)
    labels = np.arange(0, end_time - start, 60)
    ax.set_xticklabels(labels)
    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    # y axis - trim & tick labels
    ax.set_ylim(-0.2, 2)
    ax.set_yticks([0, 0.8])

    for burst in trace.peak_bounds_indices:
        start = int(time[burst[0]])
        end = int(time[burst[1]])
        ax.axvline(start, color="blue", alpha=0.5, linewidth=4)
        ax.axvline(end, color="purple", alpha=0.5, linewidth=4)
        ax.axvspan(start, end, color="orchid", alpha=0.1)

    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="steelblue", alpha=0.4, linewidth=3)

plt.tight_layout(pad=3)

In [None]:
"""Plots activity signals for a group of embryos."""

# Change start and end to select different groups
start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(28, 2.5 * (end - start)))

for ax, emb in zip(axes, embryos[start:end]):
    # plot data
    time = emb.activity[:, 0] / 60
    trace = emb.trace
    ax.plot(time, trace.active, color="black", linewidth=2)

    # set title & labels
    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    ax.set_ylabel("ΔF/F")
    ax.set_xlabel("time (mins)")

    # set tick marks
    x_points = utils.time_scale_list(max(time))
    ax.set_xticks(x_points)

    # mark identified bursts
    for peak in trace.peak_times:
        peak_in_mins = peak / 60
        ax.axvline(peak_in_mins, color="orchid", alpha=0.5, linewidth=3)

    ax.axvline(time[trace.trim_idx])

plt.tight_layout()

In [None]:
"""Generate single dff trace"""

# get data
embryos = list(exp.embryos.values())
i = 16  # choose emb
emb = embryos[i]
time = emb.activity[:, 0] / 60
trace = emb.trace
dff = trace.dff
# plot trace
fig, ax = plt.subplots()
ax.plot(time, dff, color="#de8f07ff", linewidth=4)

# set title/labels
fontsize = 12
ax.tick_params(axis="both", which="major", labelsize=fontsize)

ax.set_title(f"{exp.name} - {emb.name}", fontsize=fontsize)
ax.set_ylabel("ΔF/F", fontsize=fontsize)
ax.set_xlabel("time (mins)", fontsize=fontsize)

# x axis - trim to before onset & at hatching
start = round(time[trace.peak_bounds_indices[0][0]])
end_time = trace.time[trace.trim_idx] / 60
ax.set_xlim(start - 35, start + 300)

# x axis - adjust tick labels to make the start time 0
ax.set_xticks([start, start + 60, start + 120, start + 180, start + 240, start + 300])
ax.set_xticklabels([0, 60, 120, 180, 240, 300])

# y axis - trim & tick labels
ax.set_ylim(-0.1, 2)
ax.set_yticks([0, 0.8, 1.6])

ax.axhline(0, color="red", linewidth=2)

ax.axvline(start + 230)
ax.axvline(start + 230 + 45)

plt.tight_layout()

# plt.savefig('test.svg')

In [None]:
"""Developmental times at first peak."""

data = {"dev_fp": [], "emb": []}

for emb in exp.embryos.values():
    time_first_peak = emb.trace.peak_times[0]
    dev_time_first_peak = emb.get_DT_from_time(time_first_peak)
    data["dev_fp"].append(dev_time_first_peak)
    data["emb"].append(emb.get_id())

df = pd.DataFrame(data)

fig, ax = plt.subplots()
sns.set_theme(style="darkgrid")
ax = sns.swarmplot(data=data, x="emb", y="dev_fp")
ax.set_title("Developmental time at first peak")
ax.set_ylabel("Developmental time")
ax.set_xlabel("Experiment Group")
plt.tight_layout()
plt.show()

In [None]:
"""Average time of burst"""

sum = 0
num = 0
for emb in exp.embryos.values():
    start = round(time[emb.trace.peak_bounds_indices[0][0]])
    try:
        seven = emb.trace.peak_times[7] / 60
    except IndexError:
        continue
    print(emb.name, seven)
    sum = sum + seven
    num = num + 1

avg = sum / num

print(avg)

In [None]:
"""Peak amplitudes for each episode."""

num_of_peaks = 13
data = {"peak_amp": [], "emb": [], "peak_idx": []}

for emb in exp.embryos.values():
    for i, amp in zip(range(num_of_peaks), emb.trace.peak_amplitudes):
        data["peak_amp"].append(amp)
        data["emb"].append(emb.name)
        data["peak_idx"].append(i)

amps = pd.DataFrame(data)

fig, ax = plt.subplots()
sns.set_theme(style="darkgrid")
ax = sns.scatterplot(data=data, x="peak_idx", y="peak_amp", hue="emb")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax.set_title(f"Peak amplitudes")
ax.set_xlabel("Peak number")
ax.set_ylabel("\u0394F/F")

In [None]:
"""Plots AUC, grouped by bins."""

data = {"auc": [], "bin": [], "emb": []}

n_bins = 5
first_bin = 2
bin_width = 0.2
for i, emb in enumerate(exp.embryos.values()):
    trace = emb.trace

    dev_time_at_peaks = emb.get_DT_from_time(trace.peak_times)

    bins = [first_bin + j * bin_width for j in range(n_bins)]
    bin_idxs = utils.split_in_bins(dev_time_at_peaks, bins)
    if not isinstance(bin_idxs, np.ndarray):
        continue
    data["auc"].extend(trace.peak_aucs)
    data["bin"].extend(bin_idxs)
    data["emb"].extend([str(i)] * len(trace.peak_aucs))

print(len(data["auc"]), len(data["bin"]), len(data["emb"]))

fig, ax = plt.subplots()
# add a last bin point to generate the labels
bins.append(first_bin + bin_width * n_bins)
x_labels = [f"{s}~{e}" for (s, e) in zip(bins[:-1], bins[1:])]
ax = sns.pointplot(data=data, x="bin", y="auc", linestyle="None")
ax.set_xticks(ticks=list(range(n_bins)), labels=x_labels)
ax.set_title(f"Binned AUC - Exp {exp.name}")
ax.set_ylabel("AUC [activity*t]")

In [None]:
"""Plot peak times for all embryos. This should help to isolate the
early peaks."""

UPPER_LIMIT = 15000
times = []
for emb in exp.embryos.values():
    trace = emb.trace
    times.append([t / 60 for t in trace.peak_times if t < 15000])

fig, ax = plt.subplots()
for i, time in enumerate(times):
    ax.plot(time, [i] * len(time), marker=".", linestyle="dashed", linewidth=0.5)

ax.vlines(30, 0, len(exp.embryos), color="k", linewidth=0.3)
ax.set_title(f"Peak times for exp {exp.name}")
x_points = utils.time_scale_list(UPPER_LIMIT / 60)
ax.set_xticks(x_points)
ax.set_xlabel("time (mins)")
ax.set_ylabel("emb")
ax.set_yticks([])

In [None]:
"""Compare first interspike interval and average interspike interval"""

avg_ISIs = []
first_ISIs = []
for emb in exp.embryos.values():
    peak_times = emb.trace.peak_times
    if len(peak_times) <= 1:
        continue
    avg_ISIs.append(np.average(np.diff(peak_times[1:])))
    first_ISIs.append(peak_times[1] - peak_times[0])

max_val = np.max((np.max(avg_ISIs), np.max(first_ISIs)))
diag = list(range(0, int(max_val), 50))

fig, ax = plt.subplots()
ax.plot(avg_ISIs, first_ISIs, "k.")
ax.plot(diag, diag, label="y=x")
ax.plot(diag, list(range(0, int(max_val) * 2, 50 * 2)), label="y=2*x")
ax.set_xlabel("average ISI")
ax.set_ylabel("first ISI")
ax.set_xlim(0, max_val)
ax.set_ylim(0, max_val)
ax.set_title(f"Differences between first and avg intervals - exp {exp.name}")
ax.legend()

In [None]:
"""Plot PSD spectrum for an embryo"""

start = 0
embryos = list(exp.embryos.values())
end = len(embryos)

fig, axes = plt.subplots(end - start, figsize=(35, 3 * (end - start)))
axes = axes.flatten()
for ax, emb in zip(axes, embryos[start:end]):
    dff = trace.dff

    fs = 1 / 6

    finite = np.isfinite(dff)
    finite_dff = dff[finite]

    ax.set_title(f"{exp.name} {emb.name} - GCamP")
    freqs, amps = spsig.welch(
        finite_dff, fs=1 / 6, window="hamming", nperseg=1024, scaling="spectrum"
    )
    ax.plot(freqs, amps)
    ax.set(xlim=(0, 0.015))

In [None]:
"""Count mini peaks"""

embryos = list(exp.embryos.values())
i = 1
emb = embryos[i]

dff = emb.trace.dff

fig, ax = plt.subplots()

time = emb.activity[:, 0] / 60
minipeak_indices, properties = spsig.find_peaks(dff, prominence=0.05)
minipeak_times = [time[x] for x in minipeak_indices]
for peak in minipeak_times:
    ax.axvline(peak, color="purple", alpha=0.05, linewidth=10)
x_points = utils.time_scale_list(max(time))

plt.plot(time, dff)

In [None]:
"""Local peaks for each peak."""

data = {"num_local_peaks": [], "peak": [], "emb": []}

fig, ax = plt.subplots()

for emb in exp.embryos.values():
    local_peaks = emb.trace.compute_local_peaks(height=0.03, prominence=0.02)
    for i, lp in enumerate(local_peaks):
        data["num_local_peaks"].append(lp)
        data["peak"].append(str(i))
        data["emb"].append(emb.name)

ax = sns.pointplot(data=data, x="peak", y="num_local_peaks", linestyle="None")
ax.set_title("Local peaks for each peak")
ax.set_ylabel("Number of local peaks")
ax.set_xlabel("Peak number")

In [None]:
"""Local peak rate."""

i = 5
split_idx = 1250

fig, ax = plt.subplots()

embryos = list(exp.embryos.values())
emb = embryos[i]
time = emb.activity[:, 0] / 60
x_points = utils.time_scale_list(max(time))
trace = emb.trace
ax.set_title(f"{emb.name} - dff")
ax.set_ylabel("ΔF/F")
ax.set_xlabel("time (mins)")
ax.set_xticks(x_points)
ax.plot(time, trace.dff)

left, right = trace.compute_all_local_peaks(split_idx, height=0.02, prominence=0.02)
first_peak_idx = trace.peak_idxes[0]
first_peak_time = trace.time[first_peak_idx]
split_time = trace.time[split_idx]
end_time = trace.time[trace.trim_idx]
print(left / (split_time - first_peak_time), right / (end_time - split_time))

ax.axvline(first_peak_time / 60, color="g", alpha=0.8)
ax.axvline(split_time / 60, color="k", alpha=0.8)
ax.axvline(end_time / 60, color="r", alpha=0.8)

plt.tight_layout()

In [None]:
"""Get a spectrogram."""

n = 20
embryos = list(exp.embryos.values())
emb = embryos[n]

# get trace
time = emb.activity[:, 0] / 60
trace = emb.trace
dff = trace.dff
hatching = trace.trim_idx
dff[hatching : len(dff)] = 0  # zero data after hatching

# trim dff
duration = 3600
onset = trace.peak_bounds_indices[0][0]
start_index = onset - 150
end_index = hatching + 50

# trim dff and padd it to desired duration with zeros
dff = dff[start_index:end_index]
if duration > len(dff):
    pad_size = duration - len(dff)
    dff = np.pad(dff, (0, pad_size))
else:
    print(
        f"desired duration of {duration} is shorter than dff with {len(dff)} frames. warning. truncating dff"
    )
    dff = dff[0:duration]
time = time[0:duration]

# plot trace
fontsize = 20
plt.figure(figsize=(36, 6))
plt.plot(time, dff, color="black", linewidth=4)
plt.axhline(0, color="red", linewidth=2)

plt.tick_params(axis="both", which="major", labelsize=fontsize)
plt.title(f"{exp.name} - {emb.name}", fontsize=fontsize)
plt.ylabel("ΔF/F", fontsize=fontsize)
plt.xlabel("time (mins)", fontsize=fontsize)

plt.ylim(-0.1, 1)
plt.yticks([0, 0.8])
minute_ticks = np.arange(0, time[-1], 30, int)
plt.xticks(minute_ticks, minute_ticks)

# calculate stft
fs = 1 / 6
fft_size = 600
noverlap = 3 * (fft_size / 4)
f, t, Zxx = spsig.stft(dff, fs, nperseg=fft_size, noverlap=noverlap, nfft=fft_size)

# spectrogram
fontsize = 20
plt.figure(figsize=(36, 6))
spec = plt.pcolormesh(
    t, f, abs(Zxx), vmin=0, vmax=0.03, cmap="plasma", shading="nearest", snap=True
)
plt.title(f"{exp.name} - {emb.name}", fontsize=fontsize)
plt.ylabel("Hz", fontsize=fontsize)
plt.xlabel("time (mins)", fontsize=fontsize)

# convert x axis units from seconds to minutes
second_ticks = [x * 60 for x in minute_ticks]
plt.xticks(second_ticks, minute_ticks)
plt.tick_params(axis="both", which="major", labelsize=fontsize)

plt.ylim(0, 0.025)
colorbar = plt.colorbar(spec)
colorbar.ax.tick_params(labelsize=fontsize)
plt.show()

In [None]:
n = 8
embryos = list(exp.embryos.values())
emb = embryos[n]

# step 1 - plot data
time = emb.activity[:, 0] / 60
dff = emb.trace.dff
fig, ax = plt.subplots(figsize=(35, 3))
ax.plot(time, dff, linewidth=2, color="black")
ax.axhline(0, color="red")
ax.set_title(f"{exp.name} {emb.name}")
# add 12 min lines after the start of activity
start = round(time[trace.peak_bounds_indices[0][0]])
five_mins = np.arange(start, max(time), 20)
for line in five_mins:
    ax.axvline(line, linewidth=3, color="green", alpha=0.3)

# step 3 - get time and calcium activity
t = emb.activity[:, 0] / 60
calciumOsc = dff  # (here, I'm using dff instead of activity)

# step 6 - compute FFT of the signal
N = 2048
G = np.fft.fft(calciumOsc, N)

# step 7 - obtain and normalize the power spectral density
P = G * np.conj(G) / N
P = P[1 : int(N / 2 + 1)]
P[2 : int(N / 2)] = 2 * P[2 : int(N / 2)]

# step 8 - map PSD (P) onto frequency array
dt = 6
fs = 1 / dt
vector = np.arange(0, N / 2)
f = fs * vector / N
fig, ax = plt.subplots(figsize=(35, 3))
ax.plot(f, P)
ax.set(xlim=(0, 0.01))
hundreths = np.arange(0, 0.01, 0.001)
ax.set_xticks(hundreths)

In [None]:
# Compute the FFT and PSD using a segment of the signal, and identify the strongest frequency
n = 0
embryos = list(exp.embryos.values())
emb = embryos[n]

fontsize = 20

# Get signal
trace = emb.trace
time = emb.activity[:, 0] / 60
dff = trace.dff
onset = trace.peak_bounds_indices[0][0]
start_index = onset - 200  # 20 mins before onset
end_index = trace.trim_idx
dff = trace.dff[start_index:end_index]
time = time[0 : len(dff)]
fft_size = 2048

# Compute the FFT
fft_result = np.fft.fft(dff, fft_size)

# Compute the IFFT
ifft_result = np.fft.ifft(fft_result, fft_size)

# Plot of the signal & ifft together
plt.figure(figsize=(32, 3))
plt.plot(dff, color="grey")
plt.plot(ifft_result, color="blue", linestyle=":")
plt.title(f"dff [grey]\n ifft(fft(dff)) [blue]\nN={fft_size}", fontsize=fontsize)
plt.xlabel("Samples", fontsize=fontsize)
plt.ylabel("ΔF/F", fontsize=fontsize)
plt.show()

# Compute the PSD
PSD = fft_result * np.conj(fft_result) / fft_size
PSD = PSD[0 : int(fft_size / 2)]
PSD[1 : int(fft_size / 2) - 1] = 2 * PSD[1 : int(fft_size / 2) - 1]
# map PSD onto frequency array
dt = 6
fs = 1 / dt
vector = np.arange(0, fft_size / 2)
Hz = fs * vector / fft_size  # in Hz
mHz = 1000 * Hz  # in mHz

# Plot PSD
plt.figure(figsize=(32, 3))
plt.plot(mHz[1:], PSD[1:])
plt.xlim(0, 4)
plt.title(f"PSD\nN={fft_size}", fontsize=fontsize)
plt.xlabel("mHz", fontsize=fontsize)
plt.ylabel("PSD", fontsize=fontsize)
plt.tick_params(axis="both", which="major", labelsize=fontsize)

# Find prominent frequencies
global_max_PSD = np.max(PSD[1:])
global_max_HZ = mHz[np.argmax(PSD[1:]) + 2]
local_maxima_indices = spsig.find_peaks(PSD, prominence=(global_max_PSD * 0.20))[
    0
]  # local max are >20% prominence of global max
local_maxima_Hz = [Hz[max] for max in local_maxima_indices]
local_maxima_mHz = [(PSD[max], mHz[max]) for max in local_maxima_indices]

for amp, freq in local_maxima_mHz:
    plt.vlines(freq, 0, amp, colors="red")
    period = round((pow((freq / 1000), -1)) / 60, 1)  # min
    plt.text(freq, 0.001, f"{period}", fontsize=20)
plt.show()

# Plot dff with strongest freq
max_amp = 0
corr_freq = 0
for amp, freq in local_maxima_mHz:
    if amp > max_amp:
        max_amp = amp
        corr_freq = freq
corr_period = round((pow((corr_freq / 1000), -1)) / 60, 1)

plt.figure(figsize=(32, 3))
plt.plot(time, dff, color="grey")
plt.xlabel("Time (mins)", fontsize=fontsize)
plt.ylabel("ΔF/F", fontsize=fontsize)
x_points = utils.time_scale_list(max(time))
plt.xticks(x_points)
plt.title(f"dff [grey]\nstrongest freq [red]\nN={fft_size}", fontsize=fontsize)

for line in np.arange(0, max(time), corr_period):
    plt.vlines(line, 0, 0.5, color="red")
plt.show()

In [None]:
# Compute the PSD over one segment of the original signal
n = 0
embryos = list(exp.embryos.values())
emb = embryos[n]

fft_size = 2048
dt = 6
fs = 1 / dt

fontsize = 15

# Get dff
trace = emb.trace
time = emb.activity[:, 0] / 60
dff = trace.dff
# trim dff to include relevant data
start_index = trace.peak_bounds_indices[0][0] - 20  # 20 mins before onset
end_index = trace.trim_idx  # after hatching
dff = trace.dff[start_index:end_index]
time = time[0 : len(dff)]
# center dff around time axsis (0)
degree = 2
trend = np.polyfit(time, dff, degree)
dff_centered = dff - np.polyval(trend, time)
# apply a hanning window to reduce spectral leakage
hanning = 0.5 - 0.5 * np.cos(2 * np.pi * time / time[-1])
dff_windowed = np.multiply(dff_centered, hanning)

# Plot dff windowed and centered
plt.figure(figsize=(8, 4))
plt.plot(time, dff_windowed, color="midnightblue")
plt.plot(
    time, hanning * 0.25, color="orange"
)  # adjust window amplitude to improve plotting
plt.xlabel("Time (mins)", fontsize=fontsize)
plt.ylabel("ΔF/F", fontsize=fontsize)
x_points = utils.time_scale_list(max(time))
plt.xticks(x_points)
plt.tick_params(axis="both", which="major", labelsize=fontsize)
plt.show()

# Padd dff
if len(dff) < fft_size:
    dff = np.pad(dff, fft_size - len(dff), "constant", constant_values=0)

# Compute the fft
fft_result = np.fft.fft(dff_windowed, fft_size)

# Compute the PSD
PSD = fft_result * np.conj(fft_result) / fft_size
PSD = PSD[0 : int(fft_size / 2)]
PSD[1 : int(fft_size / 2) - 1] = 2 * PSD[1 : int(fft_size / 2) - 1]
# map PSD onto frequency array
vector = np.arange(0, fft_size / 2)
Hz = fs * vector / fft_size  # in Hz
mHz = 1000 * Hz  # in mHz

# Calculate prominent frequencies
global_max = np.max(PSD[2:])
local_maxima_indices = spsig.find_peaks(PSD, prominence=(global_max * 0.20))[
    0
]  # local max are >20% prominence of global max
local_maxima_Hz = [Hz[max] for max in local_maxima_indices]
local_maxima_mHz = [(PSD[max], mHz[max]) for max in local_maxima_indices]

# Plot PSD
plt.figure(figsize=(8, 2))
plt.plot(mHz, PSD, color="black")
plt.ylabel(f"PSD", fontsize=fontsize)
plt.xlabel("Frequency (mHz)", fontsize=fontsize)
plt.tick_params(axis="both", which="major", labelsize=fontsize)
plt.xlim(0, 10)
y_points = np.arange(0, plt.ylim()[1], 0.2)
plt.yticks(y_points)

colors = ["orchid", "deeppink", "salmon", "forestgreen"]
for local_max, color in zip(local_maxima_mHz, colors):
    amp, freq = local_max
    plt.vlines(freq, 0, amp, color=color)
    period = int(pow((freq / 1000), -1) / 60)  # min
    plt.text(
        freq,
        amp + 0.1,
        f"{round(freq, 1)} mHz,  1/{period} min-1",
        fontsize=10,
        color=color,
        rotation=90,
        rotation_mode="anchor",
        bbox=dict(facecolor="white", edgecolor="none", boxstyle="round", pad=0.5),
    )
plt.show()

In [None]:
# Compute the PSD over segments (chunks) of the original signal
n = 0
embryos = list(exp.embryos.values())
emb = embryos[n]

process_windows = 900  # 90 mins
fft_size = 4096

fontsize = 30
fontsize_small = 15

# Get dff
trace = emb.trace
time = emb.activity[:, 0] / 60
dff = trace.dff
# trim dff to include relevant data
start_index = trace.peak_bounds_indices[0][0] - 20  # 20 mins before onset
end_index = trace.trim_idx  # after hatching
dff = trace.dff[start_index:end_index]
time = time[0 : len(dff)]
# center dff around time axsis (0)
degree = 2
trend = np.polyfit(time, dff, degree)
dff_centered = dff - np.polyval(trend, time)
# apply a hanning window to reduce spectral leakage
hanning = 0.5 - 0.5 * np.cos(2 * np.pi * time / time[-1])
dff_windowed = np.multiply(dff_centered, hanning)

# Get process windows
indices = np.arange(0, len(dff_windowed), int(process_windows / 2))
index_pairs = [(x, y) for x, y in zip(indices, indices[2:])]

for start, stop in index_pairs:
    # Pad window to fft_size
    dff_chunk = dff_windowed[start:stop]
    dff_chunk = np.pad(
        dff_chunk, fft_size - process_windows, "constant", constant_values=0
    )
    plt.figure(figsize=(32, 3))
    plt.plot(dff_chunk)
    plt.show()
    # Compute the FFT
    fft_result = np.fft.fft(dff_chunk, fft_size)

    # Compute the PSD
    PSD = fft_result * np.conj(fft_result) / fft_size
    PSD = PSD[0 : int(fft_size / 2)]
    PSD[1 : int(fft_size / 2) - 1] = 2 * PSD[1 : int(fft_size / 2) - 1]
    # map PSD onto frequency array
    dt = 6
    fs = 1 / dt
    vector = np.arange(0, fft_size / 2)
    Hz = fs * vector / fft_size  # in Hz
    mHz = 1000 * Hz  # in mHz

    # Calculate prominent frequencies
    global_max = np.max(PSD[2:])
    local_maxima_indices = spsig.find_peaks(PSD, prominence=(global_max * 0.20))[
        0
    ]  # local max are >20% prominence of global max
    local_maxima_Hz = [Hz[max] for max in local_maxima_indices]
    local_maxima_mHz = [(PSD[max], mHz[max]) for max in local_maxima_indices]

    # Calculate the strongest cycle
    max_amp = 0
    corr_freq = 0
    for amp, freq in local_maxima_mHz:
        if amp > max_amp:
            max_amp = amp
            corr_freq = freq
    corr_period = int(pow((corr_freq / 1000), -1) / 60)

    # Plot dff with strongest cycle lines
    plt.figure(figsize=(32, 3))
    plt.axvspan(time[start], time[stop], color="lightyellow")
    plt.axhline(0, color="green")
    plt.plot(time, dff, color="midnightblue", linewidth=3)
    plt.title(f"dff {start}:{stop}, {corr_period} mins,", fontsize=fontsize)
    plt.xlabel("Time (mins)", fontsize=fontsize)
    plt.ylabel("ΔF/F", fontsize=fontsize)
    plt.tick_params(axis="both", which="major", labelsize=fontsize)
    x_points = utils.time_scale_list(max(time))
    plt.xticks(x_points)

    for line in np.arange(time[start], time[stop], corr_period):
        plt.vlines(line, -0.2, 0.5, color="red", linestyles=":")
    plt.show()

    # Plot PSD
    plt.figure(figsize=(32, 3))
    plt.plot(
        mHz[2:],
        PSD[2:],
        linewidth=1,
        color="black",
    )
    plt.title(f"PSD {start}:{stop}, {corr_period} mins")
    plt.ylabel(f"PSD", fontsize=fontsize_small)
    plt.xlabel("Frequency (mHz)", fontsize=fontsize_small)
    plt.tick_params(axis="both", which="major", labelsize=fontsize_small)
    plt.xlim(0, 10)
    y_points = np.linspace(0, plt.ylim()[1], 4)
    plt.yticks(np.round(y_points, 2))

    # Add local peak information
    colors = ["orchid", "deeppink", "salmon", "forestgreen"]
    for local_max, color in zip(local_maxima_mHz, colors):
        amp, freq = local_max
        plt.vlines(freq, 0, amp, color=color)
        period = int(pow((freq / 1000), -1) / 60)  # min
        plt.text(
            freq,
            amp + (amp / 4),
            f"{round(freq, 1)} mHz,  {period} mins",
            fontsize=fontsize_small * 2 / 3,
            color=color,
            rotation=90,
            rotation_mode="anchor",
            bbox=dict(facecolor="white", edgecolor="none", boxstyle="round", pad=0.5),
        )
    plt.show()

In [None]:
"""Quantifying Frequency: spectrogram"""

embryos = list(exp.embryos.values())
i = 1
emb = embryos[i]
time = emb.activity[:, 0] / 60
x_points = utils.time_scale_list(max(time))
trace = emb.trace
freqs, times, spectrogram = spsig.spectrogram(trace.dff, 1 / 6, nfft=16384)

fig, ax = plt.subplots()
ax.pcolormesh(times, freqs, spectrogram, shading="gouraud")
plt.ylabel("Frequency [Hz]")
plt.xlabel("Time")
plt.ylim((0, 0.01))
plt.tight_layout()