## Comparing different groups of experiments - all



This notebook compares different groups of experiments, using the metrics described in each cell.

Each group takes a dictionary of `'experiment name': list[Experiment]`.
First, create each group by changing the experiment names in the lists on the next cell (lines 10 and 11), and then add each Group object to the `groups` list (last line of next cell).

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind 
from scipy import signal as spsig
import seaborn as sns

from pasna_analysis import Experiment, ExperimentConfig, Group, utils

wt_config  = {
    '20240611_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[4,9,12,13,15,17,18,19,21,23]), 
    '20240919_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,5,6,7,11,14,16,18])
}

vglut_config = {
    '20240828-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[3,6,9]), # EMB 3,6,9
    '20240829-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6,8]), # EMB 6,8
    '20240830-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6]) # EMB 6
}

vgat_config = {
    '20241008_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4]),
    '20241009_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[7,18,19]),
    '20241010_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[8]),
    '20241011_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4,11,5]),
}

wt_experiments = {}
for exp, config in wt_config.items():
    exp_path = Path.cwd().parent.joinpath('data', '25C', exp)
    wt_experiments[exp] = Experiment(exp_path, first_peak_threshold=config.first_peak_threshold, to_exclude=config.to_exclude, dff_strategy='local_minima')

vglut_experiments = {}
for exp, config in vglut_config.items():
    exp_path = Path.cwd().parent.joinpath('data', 'vglut', exp)
    vglut_experiments[exp] = Experiment(exp_path, first_peak_threshold=config.first_peak_threshold, to_exclude=config.to_exclude, dff_strategy='local_minima')

vgat_experiments = {}
for exp, config in vgat_config.items():
    exp_path = Path.cwd().parent.joinpath('data', 'vgat', exp)
    vgat_experiments[exp] = Experiment(exp_path, first_peak_threshold=config.first_peak_threshold, to_exclude=config.to_exclude, dff_strategy='local_minima')

wt = Group('WT', wt_experiments)
vglut = Group('VGluT-', vglut_experiments)
vgat = Group('VGAT-', vgat_experiments)

groups = [wt, vglut, vgat]

In [None]:
"""Sample size"""

for group in groups:
    num_group = 0
    for exp in group.experiments.values():
        num_exp = 0
        for emb in exp.embryos:
            num_group = num_group + 1
            num_exp = num_exp + 1
        print(exp.name, num_exp)
    print(group.name, num_group)

In [None]:
"""Developmental times at first peak."""

data = {"dev_fp": [], "group": []}

for group in groups:
    for exp in group.experiments.values():
        num = 0
        for emb in exp.embryos.values():
            num = num + 1
            trace = emb.trace
            time_first_peak = trace.peak_times[0]
            dev_time_first_peak = emb.get_DT_from_time(time_first_peak)

            data["dev_fp"].append(dev_time_first_peak)
            data["group"].append(group.name)
df = pd.DataFrame(data)
group_avgs = {"group": [], "fp_avg": []}

for group in groups:
    g = df[df["group"] == group.name]
    group_avgs["group"].append(group.name)
    group_avgs["fp_avg"].append(g["dev_fp"].mean())

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.swarmplot(data=data, x="group", y="dev_fp", hue="group", size=7)
ax.set_ylabel("Dev time")
ax.set_xlabel("")
ax.set_title("Dev time at first peak")

wt = df[df["group"] == "WT"]
vglut = df[df["group"] == "VGluT-"]
vgat = df[df["group"] == "VGAT-"]

print("wt vs vglut", ttest_ind(wt["dev_fp"], vglut["dev_fp"]))
print("wt vs vgat", ttest_ind(wt["dev_fp"], vgat["dev_fp"]))
plt.savefig("wt_vglut_vgat_onsettime.svg")

In [None]:
"""Developmental times when hatching."""

data = {"dev_hatching": [], "group": []}

for group in groups:
    for exp in group.experiments.values():
        num = 0
        for emb in exp.embryos.values():
            num = num + 1
            trace = emb.trace
            time_hatching = trace.time[trace.trim_idx]
            dev_time_first_peak = emb.get_DT_from_time(time_hatching)

            data["dev_hatching"].append(dev_time_first_peak)
            data["group"].append(group.name)
df = pd.DataFrame(data)
group_avgs = {"group": [], "hatch_avg": []}

for group in groups:
    g = df[df["group"] == group.name]
    group_avgs["group"].append(group.name)
    group_avgs["hatch_avg"].append(g["dev_hatching"].mean())

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.swarmplot(data=data, x="group", y="dev_hatching", hue="group", size=7)
ax.set_ylabel("Dev time")
ax.set_xlabel("")
ax.set_title("Dev time when hatching")

wt = df[df["group"] == "WT"]
vglut = df[df["group"] == "VGluT-"]
vgat = df[df["group"] == "VGAT-"]

print("wt vs vglut", ttest_ind(wt["dev_hatching"], vglut["dev_hatching"]), sep="\n")
print("wt vs vgat", ttest_ind(wt["dev_hatching"], vgat["dev_hatching"]), sep="\n")
plt.savefig("wt_vglut_vgat_onsettime.svg")

In [None]:
"""SNA duration."""

data = {"group": [], "duration": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            data["group"].append(group.name)
            duration = (trace.time[trace.trim_idx] - trace.peak_times[0]) / 60
            data["duration"].append(duration)

df = pd.DataFrame(data)

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.swarmplot(data=data, x="group", y="duration", hue="group", size=7)
ax.set_title("SNA duration")
ax.set_ylabel("time (mins)")
ax.set_xlabel("Group")

In [None]:
"""Number of episodes during SNA."""

data = {"group": [], "num_eps": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            data["group"].append(group.name)
            data["num_eps"].append(len(trace.peak_idxes))

df = pd.DataFrame(data)

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.swarmplot(data=data, x="group", y="num_eps", hue="group")
ax.set_title("Number of episodes")
ax.set_ylabel("# eps")
ax.set_xlabel("Group")

In [None]:
"""CDF of peak developmental times"""

data = {"dev_time": [], "group": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            dev_times = [emb.get_DT_from_time(t) for t in trace.peak_times]
            data["dev_time"].extend(dev_times)
            data["group"].extend([group.name] * len(dev_times))

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.ecdfplot(data=data, x="dev_time", hue="group")
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
ax.set_xlim([1.9, 2.9])
# ax.set_title('CDF developmental times of peaks')
ax.set_ylabel("Proportion")
ax.set_xlabel("Developmental Time")
plt.savefig("wt_vglut_vgat_devtimepeak.svg")

In [None]:
"""Peak amplitudes for each episode."""

num_of_peaks = 15
data = {"peak_amp": [], "group": [], "peak_idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            t = emb.trace
            for i, amp in zip(range(num_of_peaks), t.peak_amplitudes):
                data["peak_amp"].append(amp)
                data["group"].append(group.name)
                data["peak_idx"].append(i)

amps = pd.DataFrame(data)

fig, ax = plt.subplots(figsize=(10, 5))

font_scale = 1.5
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(data=data, x="peak_idx", y="peak_amp", hue="group", linestyle="None")
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
# ax.set_title(f'Burst amplitudes')
ax.set_xlabel("Burst #")
ax.set_ylabel("\u0394F/F")
plt.savefig("wt_vglut_vgat_amps.svg")

for peak in range(num_of_peaks):
    wt = amps[(amps["group"] == "WT") & (amps["peak_idx"] == peak)]
    vglut = amps[(amps["group"] == "VGluT-") & (amps["peak_idx"] == peak)]
    _, p_value = ttest_ind(wt["peak_amp"], vglut["peak_amp"])
    print(f"{peak} wt vs vglut, pval = {p_value}")

for peak in range(num_of_peaks):
    wt = amps[(amps["group"] == "WT") & (amps["peak_idx"] == peak)]
    vgat = amps[(amps["group"] == "VGAT-") & (amps["peak_idx"] == peak)]
    _, p_value = ttest_ind(wt["peak_amp"], vgat["peak_amp"])
    print(f"{peak} wt vs vgat, pval = {p_value}")

In [None]:
"""Developmental time for each episode."""

data = {"group": [], "dev_time": [], "idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb_n, emb in enumerate(exp.embryos.values()):
            trace = emb.trace
            for i, t in zip(range(15), trace.peak_times):
                data["group"].append(group.name)
                data["dev_time"].append(emb.get_DT_from_time(t))
                data["idx"].append(i)

dev_times = pd.DataFrame(data)


f, ax = plt.subplots(figsize=(10, 6))
font_scale = 2
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(
    data=dev_times,
    x="idx",
    y="dev_time",
    hue="group",
    alpha=0.7,
    errorbar="ci",
    linestyle="None",
)
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
# ax.set_title('Dev time per burst')
ax.set_xlabel("Burst #")
ax.set_ylabel("Dev time")
plt.savefig("wt_vglut_vgat_devtimepeak.svg")

In [None]:
"""Intervals between each episode."""

data = {"group": [], "interval": [], "idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            for i, interval in zip(range(15), trace.peak_intervals):
                data["group"].append(group.name)
                data["interval"].append(interval / 60)
                data["idx"].append(i)

inter = pd.DataFrame(data)

f, ax = plt.subplots(figsize=(10, 6))
font_scale = 2
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(data=inter, x="idx", y="interval", hue="group", linestyle="None")
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
# ax.set_title('Intervals by burst')
ax.set_xlabel("Interval #")
ax.set_ylabel("Interval (min)")
plt.savefig("wt_vglut_vgat_interval.svg")

for peak in range(num_of_peaks):
    wt = inter[(inter["group"] == "WT") & (inter["idx"] == peak)]
    vglut = inter[(inter["group"] == "VGluT-") & (inter["idx"] == peak)]
    _, p_value = ttest_ind(wt["interval"], vglut["interval"])
    print(f"{peak} wt vs vglut, pval = {p_value}")

for peak in range(num_of_peaks):
    wt = inter[(inter["group"] == "WT") & (inter["idx"] == peak)]
    vgat = inter[(inter["group"] == "VGAT-") & (inter["idx"] == peak)]
    _, p_value = ttest_ind(wt["interval"], vgat["interval"])
    print(f"{peak} wt vs vgat, pval = {p_value}")

In [None]:
"""Duration of each peak."""

data = {"group": [], "duration": [], "idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            for i, duration in zip(range(15), trace.peak_durations):
                data["group"].append(group.name)
                data["duration"].append(duration / 60)
                data["idx"].append(i)

amps = pd.DataFrame(data)

f, ax = plt.subplots(figsize=(10, 6))
font_scale = 2
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(data=amps, x="idx", y="duration", hue="group", linestyle="None")
sns.move_legend(
    ax, "lower center", bbox_to_anchor=(0.5, 1.1), ncol=3, title=None, frameon=False
)
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
# ax.set_title('Durations by peak')
ax.set_xlabel("Peak #")
ax.set_ylabel("Duration (min)")
plt.savefig("wt_vglut_vgat_durations.svg")

In [None]:
"""Rise times"""

data = {"group": [], "rise_times": [], "idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            for i, rise in zip(range(15), trace.peak_rise_times):
                data["group"].append(group.name)
                data["rise_times"].append(rise)
                data["idx"].append(i)

amps = pd.DataFrame(data)

f, ax = plt.subplots(figsize=(10, 6))
font_scale = 2
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(data=amps, x="idx", y="rise_times", hue="group", linestyle="None")
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
# ax.set_title('Peak rise times')
ax.set_xlabel("Peak #")
ax.set_ylabel("Duration (min)")
plt.savefig("wt_vglut_vgat_risetime.svg")

In [None]:
"""Decay times"""

data = {"group": [], "decay_times": [], "idx": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            for i, decay in zip(range(15), trace.peak_decay_times):
                data["group"].append(group.name)
                data["decay_times"].append(decay)
                data["idx"].append(i)

amps = pd.DataFrame(data)

f, ax = plt.subplots(figsize=(10, 6))
font_scale = 2
sns.set_theme(style="whitegrid", palette="colorblind", font_scale=font_scale)
ax = sns.pointplot(
    data=amps, x="idx", y="decay_times", hue="group", errorbar="ci", linestyle="None"
)
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
# ax.set_title('Peak decay times')
ax.set_xlabel("Peak #")
ax.set_ylabel("Duration (min)")
plt.savefig("wt_vglut_vgat_decaytime.svg")

In [None]:
"""Plots AUC, grouped by bins."""

data = {"group": [], "auc": [], "bin": []}

n_bins = 5
first_bin = 2
bin_width = 0.2

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            dev_time_at_peaks = emb.get_DT_from_time(trace.peak_times)
            bins = [first_bin + j * bin_width for j in range(n_bins)]
            bin_idxs = utils.split_in_bins(dev_time_at_peaks, bins)
            data["group"].extend([str(group.name)] * len(trace.peak_aucs))
            data["auc"].extend(trace.peak_aucs)
            data["bin"].extend(bin_idxs)

# print(len(data['group']), len(data['auc']), len(data['bin']))

auc = pd.DataFrame(data)

# add a last bin point to generate the labels
bins.append(first_bin + bin_width * n_bins)

f, ax = plt.subplots(figsize=(10, 6))
x_labels = [f"{s}~{e}" for (s, e) in zip(bins[:-1], bins[1:])]
ax = sns.pointplot(data=data, x="bin", y="auc", hue="group", linestyle="None")
ax.set_xticks(ticks=list(range(n_bins)), labels=x_labels)
ax.set_title(f"Binned AUC - Exp {exp.name}")
ax.set_ylabel("AUC [activity*t]")

In [None]:
"""Local peaks for each peak."""

data = {"num_local_peaks": [], "idx": [], "group": []}

for group in groups:
    for exp in group.experiments.values():
        for emb in exp.embryos.values():
            trace = emb.trace
            local_peaks = trace.compute_local_peaks(height=0.03, prominence=0.02)
            for i, lp in zip(range(15), local_peaks):
                data["num_local_peaks"].append(lp)
                data["idx"].append(i)
                data["group"].append(group.name)

ax = sns.pointplot(
    data=data,
    x="idx",
    y="num_local_peaks",
    hue="group",
    errorbar="ci",
    linestyle="None",
)
sns.move_legend(
    ax,
    "lower center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    title=None,
    frameon=False,
)
ax.set_xticks([0, 2, 4, 6, 8, 10, 12, 14])
ax.set_yticks([2, 4, 6, 8, 10, 12])
ax.set_title("Local peaks for each peak")
ax.set_ylabel("Num local peaks")
ax.set_xlabel("Burst #")

In [None]:
"""Plot representative traces & psd plots for each group"""

reps = {"20240919_25C": 21, "20241011_vgatdf": 3, "20240829-vglutdf": 1}  # index

fig, ax = plt.subplots(len(reps), figsize=(35, 5 * (len(reps))))

index = 0

for group in groups:
    for exp in group.experiments.values():
        if exp.name in reps.keys():
            emb_index = "emb" + str(reps[exp.name])
            emb = exp.embryos[emb_index]

            time = emb.activity[:, 0] / 60
            trace = emb.trace

            # plot trace
            # fig, ax = plt.subplots(figsize=(35, 5))
            ax[index].plot(time, trace.dff, color="black", linewidth=4)
            print(f"{group.name} {exp.name} {emb.name} ")

            # set title/labels
            fontsize = 35
            ax[index].tick_params(axis="both", which="major", labelsize=fontsize)

            ax[index].set_title(f"{group.name}", fontsize=fontsize, x=-0.1, y=0.5)
            ax[index].set_ylabel("Î”F/F", fontsize=fontsize)
            ax[index].set_xlabel("time (mins)", fontsize=fontsize)

            # x axis - trim to before onset & at hatching
            start = round(time[trace.peak_bounds_indices[0][0]])
            end_time = trace.time[trace.trim_idx] / 60
            ax[index].set_xlim(
                start - 20, end_time
            )  # start plotting 20 mins before hatching

            # x axis - adjust tick labels to make the start time 0
            hours = np.arange(start, end_time, 60)
            ax[index].set_xticks(hours)
            labels = np.arange(0, end_time - start, 60)
            ax[index].set_xticklabels(labels)
            ax[index].tick_params(axis="both", which="major", labelsize=fontsize)

            # y axis - trim & tick labels
            ax[index].set_ylim(-0.1, 2)
            amps = np.arange(0, ax[index].get_ylim()[1] + 0.1, 1)
            ax[index].set_yticks(amps)

            # remove boarders
            ax[index].spines["top"].set_visible(False)
            ax[index].spines["right"].set_visible(False)

            index = index + 1

plt.tight_layout(pad=5)

# plt.savefig('test.svg')

In [None]:
"""Plot representative traces & psd plots for each group"""

reps = {
    "20240919_25C": [12, 3],
    "20240829-vglutdf": [0, 1, 3],
    "20241011_vgatdf": [0],
    "20241008_vgatdf": [0],
}  # index

total_embs = 0
for rep in reps:
    for emb in reps[rep]:
        total_embs = total_embs + 1

fig, ax = plt.subplots(2 * total_embs, figsize=(35, 5 * (2 * total_embs)))

index = 0

for rep in reps:
    for group in groups:
        if exp in group.experiments.keys():
            exp = group.experiments[rep]
            for emb_index in reps[rep]:
                emb = "emb" + str(exp.embryos[emb_index])
                print(f"{group.name} {exp.name} {emb.name}")

                time = emb.activity[:, 0] / 60
                trace = emb.trace
                start_index = (
                    trace.peak_bounds_indices[0][0] - 20
                )  # 20 mins before onset
                end_index = trace.trim_idx
                dff = trace.dff[start_index:end_index]
                print(len(dff))
                time = time[start_index:end_index]
                ax[index].plot(time, dff, color="black", linewidth=3)
                ax[index].set_ylim(-0.1, 2)
                index = index + 1

                # step 3 - get time and calcium activity
                t = emb.activity[:, 0] / 60
                calciumOsc = dff  # (here, I'm using dff instead of activity)

                # step 6 - compute FFT of the signal
                N = 512  # number of bins
                G = np.fft.fft(calciumOsc, N)

                # step 7 - obtain and normalize the power spectral density
                P = G * np.conj(G) / N
                P = P[1 : int(N / 2 + 1)]
                P[2 : int(N / 2)] = 2 * P[2 : int(N / 2)]

                # step 8 - map PSD (P) onto frequency array
                dt = 6
                fs = 1 / dt
                vector = np.arange(0, N / 2)
                f = fs * vector / N  # in Hz
                mf = 1000 * f  # in mHz

                ax[index].plot(mf, P, color="black", linewidth=3)
                print(f"{group.name} {exp.name} {emb.name} ")
                fontsize = 25
                ax[index].set_title(
                    f"{group.name}\n{exp.name}\n{emb.name}",
                    fontsize=fontsize,
                    x=-0.1,
                    y=0.5,
                )
                ax[index].tick_params(axis="both", which="major", labelsize=fontsize)
                ax[index].set_ylabel("PSD", fontsize=fontsize)
                ax[index].set_xlabel("mHz", fontsize=fontsize)

                ax[index].set(xlim=(0, 10))
                ticks = np.arange(0, ax[index].get_xlim()[1], 1)
                ax[index].set_xticks(ticks)

                # new step 9 - find local maxima
                global_max = np.max(P)
                local_maxima = spsig.find_peaks(P, prominence=(global_max * 0.20))[
                    0
                ]  # local max are >20% prominence of global max
                for max in local_maxima:
                    psd = P[max]
                    mfreq = mf[max]  # mHz
                    period = round((pow((mfreq / 1000), -1)) / 60, 1)  # min
                    ax[index].axvline(mfreq)
                    # print(f'PSD {psd} | Freq {mfreq} mHz | Periodicity 1/{period} min-1')
                    ax[index].text(mfreq, 0.001, f"{period}", fontsize=20)

                index = index + 1

plt.tight_layout(pad=5)

# plt.savefig('test.svg')