In [60]:
import h5py
import numpy as np
import polars as pl
import plotnine as p9
from scipy.signal import convolve
from tqdm.notebook import tqdm
from scipy.fft import fft, fftfreq, ifft
from scipy import signal
from pathlib import Path
from ipywidgets import widgets as widgets
from IPython.display import display

In [61]:
# load the dataset one at a time and calculate the metrics per window_size
anaesthesia_datasets = list(Path("/datasets/octave/").glob("anaesthetic*.hdf5"))
print(len(anaesthesia_datasets), "datasets found")

4 datasets found


In [84]:
fs = 4800
notch_freqs = np.arange(50, 1000, 50)  # Frequencies to notch out (50 Hz, 100 Hz, 150 Hz)
filter_order = 4
low_freqs = 30
high_freqs = 150
downsample_factor = 1
window_size = 1  # in seconds

In [100]:
# as an example, we will use the first dataset
df_path = anaesthesia_datasets[1]
print(f"Using dataset {df_path}")

Using dataset /datasets/octave/anaesthetic_s03_2024.08.21_15.55.08.hdf5


In [101]:
with h5py.File(df_path, 'r') as f:
    raw_data = f['RawData/Samples'][()].T


In [102]:
data = raw_data.copy()
for notch_freq in notch_freqs:
    b, a = signal.iirnotch(notch_freq, 30, fs)
    data = signal.filtfilt(b, a, data)
    
# apply a bandpass filter
sos = signal.butter(filter_order, [low_freqs, high_freqs], fs=fs, btype='bandpass', analog=False, output="sos")
data = signal.sosfilt(sos, data)

data = signal.decimate(data, downsample_factor, axis=1)

fs_ds = fs / downsample_factor  # Adjusted sampling frequency
signal_length = data.shape[1]
window_size_samples = int(window_size * fs_ds)
window_starts = np.arange(0, signal_length - window_size_samples, window_size_samples)
n_windows = len(window_starts)

f"{n_windows=}, {data.shape}, {len(data)=}, {window_size_samples=},"

'n_windows=547, (2, 2628864), len(data)=2, window_size_samples=4800,'

In [103]:
def twitch_kernel(t, tau, alpha=1.0):
    return alpha * t * np.exp(-t / tau)

def normalised_twitch_kernel(t, tau):
    kernel = twitch_kernel(t, tau)
    kernel /= np.sqrt(np.sum(kernel**2) * (t[1] - t[0]))  # Normalization
    return kernel

In [104]:
def calculate_spike_freq(windowed_data, fs=4800, tikhonov_lambda=1e-4, tau=0.01):
    n_fft = 2**int(np.ceil(np.log2(len(windowed_data))))
    windowed_data = np.pad(windowed_data, (0, n_fft - len(windowed_data)))


    t_kernel = np.arange(0, 0.05, 1/fs)
    kernel = normalised_twitch_kernel(t_kernel, tau)
    # kernel /= np.sum(kernel) * (1/fs)
    omt_signal = windowed_data.flatten()
    kernel_padded = np.zeros_like(omt_signal)
    kernel_padded[:len(kernel)] = kernel
    Y = fft(omt_signal)
    H = fft(kernel_padded)
    H_conj = np.conj(H)
    S_hat_tikhonov = (H_conj / (H_conj * H + tikhonov_lambda)) * Y
    s_deconv_tikhonov = np.real(np.fft.ifft(S_hat_tikhonov))
    s_deconv_tikhonov -= s_deconv_tikhonov.mean()


    N_samples = len(omt_signal)
    freqs = fftfreq(N_samples, d=1/fs)[:N_samples//2]
    S_deconv_fft = np.abs(fft(s_deconv_tikhonov))[:N_samples//2]
    power_spectrum = S_deconv_fft**2

    # Limit the frequency range between 20 Hz and 150 Hz
    freq_min = 20
    freq_max = 150
    mask = (freqs >= freq_min) & (freqs <= freq_max)
    freqs_limited = freqs[mask]
    power_spectrum_limited = power_spectrum[mask]

    pl_df = pl.DataFrame({
        "frequency": freqs_limited,
        "power": power_spectrum_limited,
    })

    fig = (
        p9.ggplot(pl_df, p9.aes(x="frequency", y="power")) +
        p9.geom_line(color='orange') +
        # p9.facet_wrap("type", ncol=1, scales="free_y") +
        p9.labs(x="Frequency (Hz)", y=r"$\text{Log}_{10}$ Power") +
        # p9.scale_y_continuous(breaks=np.linspace(0, 4, 4)) +
        # p9.scale_y_log10(labels=lambda x: [fr"{int(np.log10(v))}" for v in x]) +
        p9.scale_x_continuous(breaks=np.arange(0, freq_max, 20)) +
        p9.coord_cartesian(xlim=(20, 150), ylim=(1e-1, 1e4)) +
        p9.theme_matplotlib() +
        p9.theme(figure_size=(3.5, 2.0), dpi=100, text=p9.element_text(size=8))
    )

    return fig, pl_df

In [None]:

# Define the function that updates the plot
def update_plot(index, tau=0.01):
    window_start = window_starts[index]
    window_data = data[:, window_start:window_start + window_size_samples]
    fig, pl_df = calculate_spike_freq(window_data, fs=fs_ds, tau=tau, tikhonov_lambda=1e-4)
    fig.show()

# Slider
window_start_widget = widgets.IntSlider(
    value=0,
    min=0,
    max=len(window_starts) - 1,
    step=1,
    description='Window:',
    continuous_update=False
)

tau_widget = widgets.FloatSlider(
    value=0.01,
    min=0.001,
    max=0.1,
    step=0.001,
    description='Tau:',
    continuous_update=False
)

# Play button (updates every 1000 ms = 1 second)
play = widgets.Play(
    value=0,
    min=0,
    max=len(window_starts) - 1,
    step=1,
    interval=100,  # milliseconds
    description="Press play",
    disabled=False
)

# Link play button to slider
widgets.jslink((play, 'value'), (window_start_widget, 'value'))

# Link function to slider
out = widgets.interactive_output(update_plot, {'index': window_start_widget, 'tau': tau_widget})

# Display everything together
controls = widgets.HBox([play, window_start_widget, tau_widget])
display(widgets.VBox([controls, out]))


VBox(children=(HBox(children=(Play(value=0, description='Press play', max=546), IntSlider(value=0, continuous_…