In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import scipy.signal as spsig

from pasna_analysis import DataLoader, Experiment, ExperimentConfig, Group, utils

wt_folder = '25C'
wt_config  = {
    '20240611_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[4,9,12,13,15,17,18,19,21,23]), 
    '20240919_25C': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,5,6,7,11,14,16,18])
}

vglut_folder = 'vglut'
vglut_config = {
    '20240828-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[3,6,9]), # EMB 3,6,9
    '20240829-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6,8]), # EMB 6,8
    '20240830-vglutdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[6]) # EMB 6
}

vgat_folder = 'vgat'
vgat_config = {
    '20241008_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4]),
    '20241009_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[3,7,18,19]),
    '20241010_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[8]),
    '20241011_vgatdf': ExperimentConfig(first_peak_threshold=30, to_exclude=[2,4,11,5,1,8]),
}

wt_experiments = {}
for exp, config in wt_config.items():
    exp_path = Path.cwd().parent.joinpath('data', wt_folder, exp)
    wt_experiments[exp] = Experiment(DataLoader(exp_path), config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

vglut_experiments = {}
for exp, config in vglut_config.items():
    exp_path = Path.cwd().parent.joinpath('data', vglut_folder, exp)
    vglut_experiments[exp] = Experiment(DataLoader(exp_path), config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

vgat_experiments = {}
for exp, config in vgat_config.items():
    exp_path = Path.cwd().parent.joinpath('data', vgat_folder, exp)
    vgat_experiments[exp] = Experiment(DataLoader(exp_path), config.first_peak_threshold, config.to_exclude, dff_strategy='local_minima')

wt = Group('WT', wt_experiments)
vglut = Group('VGluT-', vglut_experiments)
vgat = Group('VGAT-', vgat_experiments)

groups = [wt, vglut, vgat]

In [None]:
'''Plots dff for a group of embryos, mark peaks & shade bursts.'''
group = groups[0]
for exp in group.experiments.values():
    start = 0
    end = len(exp.embryos)
    fig, axes = plt.subplots(end-start, figsize=(35, 6*(end-start)))
    axes = axes.flatten()
    for ax, emb in zip(axes, exp.embryos[start:end]):
        # plot data
        time = emb.activity[:, 0] / 60
        trace = exp.traces[emb.name]
        ax.plot(time, trace.dff, linewidth=3, color='k')

        # set title/labels
        fontsize = 15
        ax.tick_params(axis='both', which='major', labelsize=fontsize)

        ax.set_title(f'{exp.name}{emb.name} - dff', fontsize=fontsize)
        ax.set_ylabel('Î”F/F', fontsize=fontsize)
        ax.set_xlabel('time (mins)', fontsize=fontsize)
        
        # x axis - trim to before onset & at hatching
        start = round(time[trace.peak_bounds_indices[0][0]])
        end_time = trace.time[trace.trim_idx]/60
        ax.set_xlim(start - 20, end_time) # start plotting 20 mins before hatching

        # x axis - adjust tick labels to make the start time 0
        hours = np.arange(start, end_time, 60)
        ax.set_xticks(hours)
        labels = np.arange(0, end_time-start, 60)
        ax.set_xticklabels(labels)
        ax.tick_params(axis='both', which='major', labelsize=fontsize)

        # y axis - trim & tick labels
        ax.set_ylim(-0.2, 2)
        ax.set_yticks([0, 0.8])
        
        for burst in trace.peak_bounds_indices:
            start = int(time[burst[0]])
            end = int(time[burst[1]])
            ax.axvline(start, color='blue', alpha=0.5, linewidth=4)
            ax.axvline(end, color='purple', alpha=0.5, linewidth=4)
            ax.axvspan(start, end, color='orchid', alpha=0.1)

        for peak in trace.peak_times:
            peak_in_mins = peak / 60
            ax.axvline(peak_in_mins, color='steelblue', alpha=0.4, linewidth=3)

plt.tight_layout(pad=3)