In [1]:
from __future__ import annotations

from pathlib import Path

import mne
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import mne_connectivity
from cpl_pipeline.extraction.oscillations import LfpSignal, get_data
from cpl_pipeline.logs import logger

# Setup functions, classes and styles
logger.setLevel("INFO")
sns.set_style("darkgrid")
plt.rcParams.update({
    'font.size': 20,
    'font.weight': 'bold',
    'axes.labelweight': 'bold',
    'axes.titleweight': 'bold',
    'text.color': 'white',
    'axes.labelcolor': 'white',
    'xtick.color': 'white',
    'ytick.color': 'white',
    'axes.facecolor': '#282c36',  # dark inner background
    'axes.edgecolor': 'white',
    'figure.facecolor': 'black',  # black outer background
    'grid.color': 'white',
    'grid.alpha': 0.7,
    'grid.linestyle': '--'
})

data_path = Path().home() / "data" / "extracted" / "dk1"
file = list(data_path.glob("*dk1*.h5"))[0]
df_s, df_t, events, event_times = get_data(file)
unique_events = np.unique(events)
key = {
    "0": "dug incorrectly",
    "1": "dug correctly",
    "x": "crossed over (no context)",
    "b": "crossed over into the black room",
    "w": "crossed over into the white room",
}
fs = 2000
lfp = LfpSignal(df_s, df_t, event_arr=events, ev_times_arr=event_times, fs=2000, filename=file,
                exclude=["LFP1_AON", "LFP2_AON"])
lfp.bandpass = (0.1, 100)
save_outer = Path().home() / 'data' / 'figures'
save_outer.mkdir(exist_ok=True, parents=True)

INFO: Filtering data with bandpass (0.1, 100)


# Data Exploration
1) Show the first 4 channels

In [None]:
# Plot each channel ------------------------
num_channels = len(lfp.spikes.columns)
first_4_seconds = int(2 * fs)
ticks = np.arange(0, first_4_seconds + fs, fs)
labs = ["0", "1", "2"]

fig, axes = plt.subplots(num_channels, 1, figsize=(50, 40), sharex=True)
plt.xlabel("Time (ms)")
plt.ylabel("Voltage (uV)")
plt.xticks(ticks, labs)

for i, channel in enumerate(lfp.spikes.columns):
    ax = axes[i]
    ax.plot(lfp.spikes[channel][:first_4_seconds])
    t = lfp.spikes[channel][:first_4_seconds]
    ax.set_ylabel("Voltage (uV)")
    ax.set_title(channel)
    ax.grid(True)

fig.suptitle(f"First 2s of {lfp.filename.stem}", fontsize=50)
plt.savefig(save_outer / 'channels.png', facecolor='black', bbox_inches='tight',)

# Show morlet wavelet transform for a single channel, each event type

In [None]:
pre_t = 1
post_t = 0
x = 0
for letter, digit, spikes_window_df in lfp.get_windows(pre_t, post_t):
    
    if spikes_window_df.empty:
        print(f"Skipping empty DataFrame for {letter}, {digit}")
        continue
        
    channel = "LFP3_AON"
    x += 1
    if x < 4:
        continue
    spikes = np.array(spikes_window_df[channel])
    freqs = np.arange(1, 100, 6)
    start_idx = spikes_window_df.index[0]
    end_idx = spikes_window_df.index[-1]

    time_range_in_seconds = (end_idx - start_idx) / fs

    # Generate ticks and labels based on the time range
    ticks = np.linspace(start_idx, end_idx, num=5)  # 5 ticks as an example
    tick_labels = np.linspace(0, time_range_in_seconds, num=5)
    start_idx = spikes_window_df.index[0]
    end_idx = spikes_window_df.index[-1]
    
    info = mne.create_info(ch_names=list(map(str, np.arange(1, 2, 1))), sfreq=fs, ch_types=['eeg'])
    epoch = np.empty((1, 1, len(np.arange(start_idx, end_idx + 1, 1))))
    epoch[0, 0, :] = spikes
    epoch = mne.EpochsArray(epoch, info, verbose=False)
    n_cycles = freqs / 2
    tf_pow = mne.time_frequency.tfr_morlet(epoch, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False,
                                           verbose=False)
    baseline_max = tf_pow.times[-1]
    tf_pow.apply_baseline(mode='zscore', baseline=(0, baseline_max))
    tf_pow.data = np.squeeze(tf_pow.data)
    
    fig, axs = plt.subplots(1, 2, figsize=(20, 10), sharex=True)
    plt.suptitle(f"{channel}: {digit, letter}")

    # Time-frequency plot
    ax = axs[0]
    ax.imshow(tf_pow.data, extent=[start_idx, end_idx, tf_pow.freqs[0], tf_pow.freqs[-1]],
              aspect='auto', origin='lower', cmap='jet')
    ax.set_xticks(ticks)
    ax.set_xticklabels(np.round(tick_labels, 2))
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')

    # Voltage trace plot
    ax = axs[1]
    ax.plot(np.arange(start_idx, end_idx + 1), spikes)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Voltage (uV)')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
# relative_times = [-2000, 1500]
# for event, evt in zip(lfp.events, lfp.event_times):
#     if event in ["0", "1", "00", "01", "10", "11"]:
#         continue
#     elif event == "b":
#         continue
#     time_segment = [int(evt * fs) + relative_times[0], int(evt * fs) + relative_times[1]]
# 
#     info = mne.create_info(ch_names=list(map(str, np.arange(1, 2, 1))), sfreq=fs, ch_types=['eeg'])
#     epoch = np.empty((1, 1, len(np.arange(time_segment[0], time_segment[1], 1))))
#     epoch[0, 0, :] = lfp.spikes[channel][
#                      time_segment[0]:time_segment[1]]  #Format data into [epochs, channels, samples] format
#     epoch = mne.EpochsArray(epoch, info, verbose=False)
#     freqs = np.arange(1, 100, 4)
#     n_cycles = freqs / 2
#     tf_pow = mne.time_frequency.tfr_morlet(epoch, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False,
#                                            verbose=False)
#     baseline_max = tf_pow.times[-1]  # Last time point in epochs data
#     tf_pow.apply_baseline(mode='zscore', baseline=(0, baseline_max))
#     tf_pow.data = np.squeeze(tf_pow.data)
#     fig, ax = plt.subplots(1, 2, figsize=(20, 10))
#     plt.suptitle(f"{channel} {event} {time_segment}")
#     ax[0].plot(lfp.spikes[channel][time_segment[0]:time_segment[1]])
#     ax[0].set_title("Raw Data")
#     ax[0].set_xlabel("Time (ms)")
#     ax[0].set_ylabel("Voltage (uV)")
#     ax[1].imshow(tf_pow.data, extent=[relative_times[0], relative_times[1], tf_pow.freqs[0], tf_pow.freqs[-1]],
#                  aspect='auto', origin='lower', cmap='jet')
#     plt.xlim([0, 1000])
#     plt.xlabel('Time (ms)')
#     plt.ylabel('Frequency (Hz)')

# Compare coherence in time domain

In [2]:

pre_t = 1
post_t = 0
x = 0
for letter, digit, spikes_window_df in lfp.get_windows(pre_t, post_t):
    
    if spikes_window_df.empty:
        print(f"Skipping empty DataFrame for {letter}, {digit}")
        continue
        
    channel = "LFP3_AON"
    channel2 = "LFP1_vHp"
    x += 1
    if x < 4:
        continue
    spikes = np.array(spikes_window_df[channel])
    spikes2 = np.array(spikes_window_df[channel2])

    freqs = np.arange(1, 100, 6)
    start_idx = spikes_window_df.index[0]
    end_idx = spikes_window_df.index[-1]

    time_range_in_seconds = (end_idx - start_idx) / fs

    # Generate ticks and labels based on the time range
    ticks = np.linspace(start_idx, end_idx, num=5)  # 5 ticks as an example
    tick_labels = np.linspace(0, time_range_in_seconds, num=5)
    times = np.arange(start_idx, end_idx + 1, 1) / fs
    
    info = mne.create_info(ch_names=list(map(str, np.arange(1, 3, 1))), sfreq=fs, ch_types=['eeg','eeg'])
    epochs = np.empty((1, 2, len(times)))
    epochs[0,0,:] = spikes
    epochs[0,1,:] = spikes2
    epochs = mne.EpochsArray(epochs, info, verbose=False)

    n_cycles = freqs/2
    con = mne_connectivity.spectral_connectivity_time(epochs, method='coh', sfreq=int(fs), mode='cwt_morlet', freqs=freqs, n_cycles=n_cycles, verbose=True)
    
    coh = con.get_data()
    coh = coh[0,0,1,:,:]
    plt.imshow(np.squeeze(coh), extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect='auto', origin='lower', cmap='jet')
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')


Fmin was not specified. Using fmin=min(freqs)
Fmax was not specified. Using fmax=max(freqs).
Connectivity computation...
   Processing epoch 1 / 1 ...
[Connectivity computation done]


 '1': 1>, so metadata was not modified.
  con = mne_connectivity.spectral_connectivity_time(epochs, method='coh', sfreq=int(fs), mode='cwt_morlet', freqs=freqs, n_cycles=n_cycles, verbose=True)


IndexError: too many indices for array: array is 3-dimensional, but 5 were indexed