## LFP Analysis

In [None]:
# Imports and setup
# Allow jupyter notebook to reload modules

#%load_ext autoreload
from __future__ import annotations

import itertools
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from fooof import FOOOF

from spk2extract.logs import logger
from spk2extract.extraction.oscillations import LfpSignal, get_data

# Setup functions, classes and styles
logger.setLevel("INFO")
sns.set_style("darkgrid")

data_path = Path().home() / "data" / "extracted" / "dk1"
file = list(data_path.glob("*dk1*.h5"))[0]
df_s, df_t, events, event_times = get_data(file)
unique_events = np.unique(events)
key = {
    "0": "dug incorrectly",
    "1": "dug correctly",
    "x": "crossed over (no context)",
    "b": "crossed over into the black room",
    "w": "crossed over into the white room",
}
fs = 2000
lfp = LfpSignal(df_s, df_t, event_arr=events, ev_times_arr=event_times, fs=2000, filename=file, exclude=["LFP1_AON", "LFP2_AON"])

In [None]:
from scipy.signal import coherence

# Get spikes for each event
spikes = lfp.spikes 
aon_cols = [col for col in spikes.columns if 'AON' in col]

df_aon = spikes[[col for col in spikes.columns if 'vHp' in col]]
df_vhp = spikes[[col for col in spikes.columns if 'AON' in col]]

coherence_sums = {}
coherence_counts = {}
windows = lfp.get_windows(3)

## Plot Average Coherence across all trials

In [None]:
# Larger, bold font for plots and axis labels
plt.rcParams.update({'font.size': 20, 'font.weight': 'bold', 'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})
num_keys = len(windows.keys())

# Title for entire figure
fig_title = f"Coherence between AON and vHPC"
fig, axes = plt.subplots(1, num_keys, figsize=(12 * num_keys, 12), sharex=True, sharey=True)

if num_keys == 1:
    axes = [axes]

axes = axes.flatten()
axis_idx = 0

y_min_global = float('inf')
y_max_global = float('-inf')

for window_key, window_list in windows.items():
    
    letter, digit = window_key.split('_')
    correctness = 'Correct' if digit == '1' else 'Incorrect'
    color = 'Black' if letter == 'b' else 'White'
    title = f"{correctness}, {color}"
    ax = axes[axis_idx]
    axis_idx += 1

    if not window_list:
        ax.text(50, 0.5, 'No trials', ha='center', va='center')
        ax.set_title(title)
        continue
    local_coherence_sums = {}
    local_coherence_counts = {}

    for window in window_list:

        start_idx = int(window[0] * fs)
        end_idx = int(window[1] * fs)

        spike_window_aon = df_aon.iloc[start_idx:end_idx]
        spike_window_vhp = df_vhp.iloc[start_idx:end_idx]
        for col_aon, col_vhp in itertools.product(df_aon.columns, df_vhp.columns):

            f, Cxy = coherence(spike_window_aon[col_aon], spike_window_vhp[col_vhp], fs=fs)
            valid_idxs = np.where(f <= 100)[0]
            f_filtered = f[valid_idxs]
            Cxy_filtered = Cxy[valid_idxs]

            pair_key = f"{col_aon}_{col_vhp}"
            if pair_key not in local_coherence_sums:
                local_coherence_sums[pair_key] = np.zeros_like(Cxy_filtered)
                local_coherence_counts[pair_key] = 0

            local_coherence_sums[pair_key] += Cxy_filtered
            local_coherence_counts[pair_key] += 1

    legend_labels = []
    for pair_key, Cxy_sum in local_coherence_sums.items():

        Cxy_avg = Cxy_sum / local_coherence_counts[pair_key]
        ax.plot(f_filtered, Cxy_avg, label=pair_key)

        # Update global y-axis limits
        y_min_global = min(y_min_global, np.min(Cxy_avg))
        y_max_global = max(y_max_global, np.max(Cxy_avg))
        legend_labels.append(pair_key)

    ax.set_title(title)
    ax.set_xlabel('Frequency (Hz)')
    ax.set_xlim(0, 100)
    ax.grid(True)

for ax in axes:
    ax.set_ylim(y_min_global, y_max_global)
    ax.set_ylabel('Coherence')
# Remove y-labels from all but the first subplot
axes[0].legend(legend_labels, loc='upper right')

# Set title for entire figure
fig.suptitle(fig_title, fontsize=30, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Make room for suptitle and legend
savepath = Path().home() / "data" / "figures" / "coherence"
savepath.mkdir(exist_ok=True, parents=True)
plt.savefig(savepath, dpi=300, bbox_inches='tight', pad_inches=0.1)


## Plot the Channels

In [None]:
# Plot each channel ------------------------
num_channels = len(lfp.spikes.columns)
fig, axes = plt.subplots(num_channels, 1, figsize=(50, 40), sharex=True)
plt.xlabel("Time (ms)")
plt.ylabel("Voltage (uV)")
first_4_seconds = int(1 * fs)
ticks = [0, int(1.0 * fs), int(2.0 * fs), int(3.0 * fs), int(4.0 * fs)]
labels = ['0', '1000', '2000', '3000', '4000']

for i, channel in enumerate(lfp.spikes.columns):
    ax = axes[i]
    ax.plot(lfp.spikes[channel][:first_4_seconds])
    t = lfp.spikes[channel][:first_4_seconds]
    ax.set_ylabel("Voltage (uV)")
    ax.set_title(channel)
    ax.grid(True)

In [None]:
vhp1 = lfp.spikes['LFP1_vHp']
vhp2 = lfp.spikes['LFP2_vHp']
fig, axes = plt.subplots(1, 1, figsize=(50, 10), sharex=True)

# Overlay plots
axes.plot(vhp1[:first_4_seconds], label='LFP1_vHp')
axes.plot(vhp2[:first_4_seconds], label='LFP2_vHp', linestyle='--')
axes.set_title('Overlay of LFP1_vHp and LFP2_vHp')
axes.legend()

# Labels and layout
for ax in axes:
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Voltage (uV)")

plt.tight_layout()
plt.show()


# FOOF

TypeError: unhashable type: 'slice'