In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import scipy.signal as spsig

from pasna_analysis import Experiment, ExperimentConfig, Group

wt_folder = '25C'
wt_config  = {
    '20240611_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[4,9,12,13,15,17,18,19,21,23]), 
    '20240919_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,5,6,7,11,14,16,18])
}

vglut_folder = 'vglut'
vglut_config = {
    '20240828-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[3,6,9]), # EMB 3,6,9
    '20240829-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6,8]), # EMB 6,8
    '20240830-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6]) # EMB 6
}

vgat_folder = 'vgat'
vgat_config = {
    '20241008_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4]),
    '20241009_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[3,7,18,19]),
    '20241010_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[8]),
    '20241011_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4,11,5,1,8]),
}

wt_experiments = {}
for exp, config in wt_config.items():
    exp_path = Path.cwd().parent.joinpath('data', wt_folder, exp)
    wt_experiments[exp] = Experiment(exp_path, config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

vglut_experiments = {}
for exp, config in vglut_config.items():
    exp_path = Path.cwd().parent.joinpath('data', vglut_folder, exp)
    vglut_experiments[exp] = Experiment(exp_path, config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

vgat_experiments = {}
for exp, config in vgat_config.items():
    exp_path = Path.cwd().parent.joinpath('data', vgat_folder, exp)
    vgat_experiments[exp] = Experiment(exp_path, config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

wt = Group('WT', wt_experiments)
vglut = Group('VGluT-', vglut_experiments)
vgat = Group('VGAT-', vgat_experiments)

groups = [wt, vglut, vgat]

In [None]:
"""Sample size"""

for group in groups:
    num_group = 0
    for exp in group.experiments.values():
        num_exp = 0
        for emb in exp.embryos:
            num_group = num_group + 1
            num_exp = num_exp + 1
        print(exp.name, num_exp)
    print(group.name, num_group)

In [None]:
"""Generate all spectrograms plus average spectrogram."""

group = groups[0]

freqs = []
times = []
Zxxs = []

for exp in group.experiments.values():
    fig, axes = plt.subplots(len(exp.embryos), 2, figsize=(35, 3 * (len(exp.embryos))))
    fontsize = 10
    for (left, right), emb in zip(axes, exp.embryos.values()):
        # get trace
        time = emb.activity[:, 0] / 60
        trace = emb.trace
        dff = trace.dff
        hatching = trace.trim_idx
        dff[hatching:] = 0  # zero data after hatching

        # trim dff
        duration = 3600
        try:
            onset = trace.peak_bounds_indices[0][0]
        except IndexError:
            continue

        start_index = onset - 150
        end_index = hatching + 50

        if start_index < 0:
            continue

        if len(time) < duration:
            continue

        # trim dff and pad it to desired duration with zeros
        dff = dff[start_index:end_index]
        if duration > len(dff):
            pad_size = duration - len(dff)
            dff = np.pad(dff, (0, pad_size))
        else:
            print(
                f"desired duration of {duration} is shorter than dff with {len(dff)} frames. warning. truncating dff"
            )
            dff = dff[0:duration]
        time = time[0:duration]
        # plot trace
        left.plot(time, dff, color="black", linewidth=2)

        left.tick_params(axis="both", which="major", labelsize=fontsize)
        left.set_title(f"{exp.name} - {emb.name}", fontsize=fontsize)
        left.set_ylabel("Î”F/F", fontsize=fontsize)
        left.set_xlabel("time (mins)", fontsize=fontsize)

        left.set_ylim(-0.1, 1)
        left.set_yticks([0, 0.8])
        minute_ticks = np.arange(0, time[-1], 30, int)
        left.set_xticks(minute_ticks, minute_ticks)

        # calculate stft
        fs = 1 / 6
        fft_size = 600
        noverlap = 3 * (fft_size / 4)
        f, t, Zxx = spsig.stft(
            dff, fs, nperseg=fft_size, noverlap=noverlap, nfft=fft_size
        )
        freqs.append(f)
        times.append(t)
        Zxxs.append(Zxx)

        # spectrogram
        spec = right.pcolormesh(
            t,
            f,
            abs(Zxx),
            vmin=0,
            vmax=0.03,
            cmap="plasma",
            shading="nearest",
            snap=True,
        )
        right.set_ylabel("Hz", fontsize=fontsize)
        right.set_xlabel("time (mins)", fontsize=fontsize)

        # convert x axis units from seconds to minutes
        second_ticks = [x * 60 for x in minute_ticks]
        right.set_xticks(second_ticks, minute_ticks)
        right.tick_params(axis="both", which="major", labelsize=fontsize)

        right.set_ylim(0, 0.025)
        plt.colorbar(spec)

    plt.tight_layout()

freqs = np.array(freqs)
times = np.array(times)
Zxxs = np.array(Zxxs)

abs_Zxx = np.abs(Zxxs)
avg_Zxx = np.mean(abs_Zxx, axis=0)

# spectrogram
fontsize = 20
plt.figure(figsize=(35, 5))
spec = plt.pcolormesh(
    times[0],
    freqs[0],
    abs(avg_Zxx),
    vmin=0,
    vmax=0.03,
    cmap="plasma",
    shading="nearest",
    snap=True,
)
plt.ylabel("Hz", fontsize=fontsize)
plt.xlabel("time (mins)", fontsize=fontsize)

# convert x axis units from seconds to minutes
second_ticks = [x * 60 for x in minute_ticks]
plt.xticks(second_ticks, minute_ticks)
plt.tick_params(axis="both", which="major", labelsize=fontsize)

plt.ylim(0, 0.025)
colorbar = plt.colorbar(spec)
colorbar.ax.tick_params(labelsize=fontsize)
plt.show()

In [None]:
"""Generate all average spectrograms."""

for group in groups:
    freqs = []
    times = []
    Zxxs = []
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            # get trace
            time = emb.activity[:, 0] / 60
            trace = emb.trace
            dff = trace.dff
            hatching = trace.trim_idx
            dff[hatching : len(dff)] = 0  # zero data after hatching

            # trim dff
            duration = 3600
            try:
                onset = trace.peak_bounds_indices[0][0]
            except:
                continue
            start_index = onset - 150
            end_index = hatching + 50

            if start_index < 0:
                continue

            if len(time) < duration:
                continue

            # trim dff and pad it to desired duration with zeros
            dff = dff[start_index:end_index]
            if duration > len(dff):
                pad_size = duration - len(dff)
                dff = np.pad(dff, (0, pad_size))
            else:
                print(
                    f"{group.name} {exp.name} {emb.name} desired duration of {duration} is shorter than dff with {len(dff)} frames. warning. truncating dff"
                )
                dff = dff[0:duration]
            time = time[0:duration]

            # calculate stft
            fs = 1 / 6
            fft_size = 600
            noverlap = 3 * (fft_size / 4)
            f, t, Zxx = spsig.stft(
                dff, fs, nperseg=fft_size, noverlap=noverlap, nfft=fft_size
            )
            freqs.append(f)
            times.append(t)
            Zxxs.append(Zxx)

    freqs = np.array(freqs)
    times = np.array(times)
    Zxxs = np.array(Zxxs)

    abs_Zxx = np.abs(Zxxs)
    avg_Zxx = np.mean(abs_Zxx, axis=0)

    # spectrogram
    fontsize = 20
    plt.figure(figsize=(35, 5))
    spec = plt.pcolormesh(
        times[0],
        freqs[0],
        abs(avg_Zxx),
        vmin=0,
        vmax=0.035,
        cmap="plasma",
        shading="nearest",
        snap=True,
    )
    plt.title(f"{group.name}", fontsize=fontsize)
    plt.ylabel("Hz", fontsize=fontsize)
    plt.xlabel("time (mins)", fontsize=fontsize)

    # convert x axis units from seconds to minutes
    minute_ticks = np.arange(0, time[-1], 30, int)
    second_ticks = [x * 60 for x in minute_ticks]
    plt.xticks(second_ticks, minute_ticks)
    plt.tick_params(axis="both", which="major", labelsize=fontsize)

    plt.ylim(0, 0.025)
    colorbar = plt.colorbar(spec)
    colorbar.ax.tick_params(labelsize=fontsize)
    plt.show()

In [None]:
"""Generate all average spectrograms."""

group = groups[1]

for exp in group.experiments.values():
    freqs = []
    times = []
    Zxxs = []
    for emb in exp.embryos.values():
        # get trace
        time = emb.activity[:, 0] / 60
        trace = emb.trace
        dff = trace.dff
        hatching = trace.trim_idx
        dff[hatching : len(dff)] = 0  # zero data after hatching

        # trim dff
        duration = 3600
        try:
            onset = trace.peak_bounds_indices[0][0]
        except:
            continue
        start_index = onset - 150
        end_index = hatching + 50

        if start_index < 0:
            continue

        if len(time) < duration:
            continue

        # trim dff and pad it to desired duration with zeros
        dff = dff[start_index:end_index]
        if duration > len(dff):
            pad_size = duration - len(dff)
            dff = np.pad(dff, (0, pad_size))
        else:
            print(
                f"{group.name} {exp.name} {emb.name} desired duration of {duration} is shorter than dff with {len(dff)} frames. warning. truncating dff"
            )
            dff = dff[0:duration]
        time = time[0:duration]

        # calculate stft
        fs = 1 / 6
        fft_size = 600
        noverlap = 3 * (fft_size / 4)
        f, t, Zxx = spsig.stft(
            dff, fs, nperseg=fft_size, noverlap=noverlap, nfft=fft_size
        )
        freqs.append(f)
        times.append(t)
        Zxxs.append(Zxx)

    freqs = np.array(freqs)
    times = np.array(times)
    Zxxs = np.array(Zxxs)

    abs_Zxx = np.abs(Zxxs)
    avg_Zxx = np.mean(abs_Zxx, axis=0)

    # spectrogram
    fontsize = 20
    plt.figure(figsize=(35, 5))
    spec = plt.pcolormesh(
        times[0],
        freqs[0],
        abs(avg_Zxx),
        vmin=0,
        vmax=0.035,
        cmap="plasma",
        shading="nearest",
        snap=True,
    )
    plt.title(f"{exp.name}", fontsize=fontsize)
    plt.ylabel("Hz", fontsize=fontsize)
    plt.xlabel("time (mins)", fontsize=fontsize)

    # convert x axis units from seconds to minutes
    minute_ticks = np.arange(0, time[-1], 30, int)
    second_ticks = [x * 60 for x in minute_ticks]
    plt.xticks(second_ticks, minute_ticks)
    plt.tick_params(axis="both", which="major", labelsize=fontsize)

    plt.ylim(0, 0.025)
    colorbar = plt.colorbar(spec)
    colorbar.ax.tick_params(labelsize=fontsize)
    plt.show()