In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dewan_h5_git.dewan_h5 import DewanH5

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import signal
from scipy.stats import zscore


In [None]:
h5_file = []

with DewanH5('/opt/dev/repo/sniffing/dewan_h5_git/test/data/mouse200_sess1_D2025_2_24T16_51_53.h5') as h5:
    h5_file = h5

In [None]:
filter = signal.cheby2(2, 40, [0.01, 100], 'bandpass', output='sos', fs=1000)

In [None]:
def filter_sniff_traces(sniff_traces, filter, baseline=False, z_score=False, shift=False):
    filtered_traces = {}
    baseline_filter = signal.butter(2, 1, 'highpass', output='sos', fs=1000)
    for name, trace in sniff_traces.items():
        index = trace.index
        filtered_trace = signal.sosfiltfilt(filter, trace)
        if baseline:
            filtered_trace = signal.sosfiltfilt(baseline_filter, filtered_trace)
        if z_score:
            filtered_trace = zscore(filtered_trace)
        if shift:
            filtered_trace = filtered_trace - filtered_trace[0]

        filtered_trace = pd.Series(filtered_trace, index=index)
        filtered_traces[name] = filtered_trace
    return filtered_traces

In [None]:
filtered_traces = filter_sniff_traces(h5_file.sniff, filter, baseline=True, z_score=True)

In [None]:
trial_number = 56
pre_fv_time = -1000

raw_data = h5_file.sniff[trial_number].loc[pre_fv_time:]
filtered_trimmed_traces = filtered_traces[trial_number].loc[pre_fv_time:]

def plot_multi_traces(traces):
    colors = plt.rcParams['axes.prop_cycle']()
    figs, axs = plt.subplots(len(traces), figsize=(10, 10), sharex=True)
    for i, trace in enumerate(traces):
        color = next(colors)['color']
        _ = axs[i].plot(trace, color=color)
        _ = axs[i].vlines(x=0, ymin=min(trace), ymax=max(trace), color='k')
        _ = axs[i].hlines(y=np.mean(trace), xmin=trace.index[0], xmax=trace.index[-1], color='m')

plot_multi_traces([raw_data, filtered_trimmed_traces])

In [None]:
crossings = np.where(np.diff(np.signbit(filtered_trimmed_traces)))[0]
crossings = filtered_trimmed_traces.index[crossings]

inhale_peak, props = signal.find_peaks(filtered_trimmed_traces, distance=50, height=0.1)
exhalation_peak, props_2 = signal.find_peaks(filtered_trimmed_traces*-1, distance=50, height=0.1)

inhale_x = filtered_trimmed_traces.index[inhale_peak]
inhale_y = filtered_trimmed_traces.loc[inhale_x]
inhales = pd.Series(inhale_x, index=inhale_y)

exhale_x = filtered_trimmed_traces.index[exhalation_peak]
exhale_y = filtered_trimmed_traces.loc[exhale_x]
exhales = pd.Series(exhale_x, index=exhale_y)

crossing_pairs = zip(crossings[:-1], crossings[1:])

true_inhales_x = []
true_inhales_y = []
true_exhales_x = []
true_exhales_y = []

inhales_prev_cross = []
exhales_prev_cross = []
for first_cross, second_cross in crossing_pairs:

    inhale_peaks_mask = inhales.between(first_cross, second_cross)
    exhale_peaks_mask = exhales.between(first_cross, second_cross)

    inhale_peaks = inhales.loc[inhale_peaks_mask]
    exhale_peaks = exhales.loc[exhale_peaks_mask]

    possible_inhale = None
    possible_exhale = None

    if inhale_peaks.shape[0] == 1:
        possible_inhale = inhale_peaks
    elif inhale_peaks.shape[0] > 1:
        max_val = inhale_peaks.index.max()
        max_inhale_peak = inhale_peaks.loc[max_val]
        possible_inhale = pd.Series(max_inhale_peak, index=[max_val])

    if exhale_peaks.shape[0] == 1:
        possible_exhale = exhale_peaks
    elif exhale_peaks.shape[0] > 1:
        max_val = exhale_peaks.index.min()
        max_exhale_peak = exhale_peaks.loc[max_val]
        possible_exhale = pd.Series(max_exhale_peak, index=[max_val])

    inhale = False
    if possible_inhale is None and possible_exhale is None:
        continue
    elif possible_inhale is None:
        inhale = False
    elif possible_exhale is None:
        inhale = True
    elif abs(possible_exhale.index) > possible_inhale.index:
        inhale = False
    else:
        inhale = True

    if inhale:
        true_inhales_x.extend(possible_inhale.values)
        true_inhales_y.extend(possible_inhale.index.values)
        inhales_prev_cross.extend([first_cross])
    else:
        true_exhales_x.extend(possible_exhale.values)
        true_exhales_y.extend(possible_exhale.index.values)
        exhales_prev_cross.extend([first_cross])

true_inhales = pd.DataFrame({'magnitude': true_inhales_y, 'crossing': inhales_prev_cross}, index=true_inhales_x)
true_exhales = pd.DataFrame({'magnitude': true_exhales_y, 'crossing': exhales_prev_cross}, index=true_exhales_x)

In [None]:
true_inhales_post_fv = true_inhales.loc[0:]
first_true_inhale = true_inhales_post_fv.iloc[0]
first_crossing = first_true_inhale['crossing']

if first_crossing > 0:
    true_inhales.index = true_inhales.index - first_crossing
    true_exhales.index = true_exhales.index - first_crossing
    true_inhales.loc[:, 'crossing'] = true_inhales.loc[:, 'crossing'] - first_crossing
    true_exhales.loc[:, 'crossing'] = true_exhales.loc[:, 'crossing'] - first_crossing
    crossings = crossings - first_crossing
    filtered_trimmed_traces.index = filtered_trimmed_traces.index - first_crossing

In [None]:
inhale_pairs = zip(true_inhales.iloc[:-1].index, true_inhales.iloc[1:].index)
inhale_frequencies = []
inhale_times = []
for inhale1, inhale2 in inhale_pairs:
    delta_t = abs(inhale2 - inhale1)
    frequency = round(1 / (delta_t / 1000), 2)
    inhale_frequencies.append(frequency)
    inhale_times.append(inhale1 + (delta_t/2))

exhale_pairs = zip(true_exhales.iloc[:-1].index, true_exhales.iloc[1:].index)
exhale_frequencies = []
exhale_times = []
for exhale1, exhale2 in exhale_pairs:
    delta_t = abs(exhale2 - exhale1)
    frequency = round(1 / (delta_t / 1000), 2)
    exhale_frequencies.append(frequency)
    exhale_times.append(exhale1 + (delta_t/2))


In [None]:
fig, ax = plt.subplots(2, sharex=True, figsize=(10, 7))
ax[0].title.set_text(f'Trace w/ Peaks and Crossings | Trial: {trial_number}')
ax[1].set_ylabel('Frequency (Hz)')
ax[1].set_xlabel('Time (ms)')
ax[1].title.set_text('Instantaneous Frequency')
_ = ax[0].plot(filtered_trimmed_traces, color='c')
_ = ax[1].plot(inhale_times, inhale_frequencies, color='red', label='inhale')
_ = ax[1].plot(exhale_times, exhale_frequencies, color='orange', label='exhale')
_ = ax[0].vlines(x=0, ymin=min(filtered_trimmed_traces), ymax=max(filtered_trimmed_traces), color='k')
_ = ax[0].hlines(y=np.mean(filtered_trimmed_traces), xmin=filtered_trimmed_traces.index[0], xmax=filtered_trimmed_traces.index[-1], alpha=0.3, color='m')
_ = ax[0].scatter(crossings, np.zeros(len(crossings)), marker="o", color='g')
_ = ax[0].scatter(true_inhales.index, true_inhales['magnitude'].values, marker='x', color='r', label='inhale')
_ = ax[0].scatter(true_exhales.index, true_exhales['magnitude'].values, marker='x', color='orange', label='exhale')
_ = ax[0].legend()

_ = ax[1].legend(['inhale', 'exhale'])