In [1]:
%matplotlib inline

# Preprocessing and Spike Sorting Tutorial

# Chapter 1: Importing Recording Data and Metadata

In [2]:
import os
import warnings
import glob
import pickle
import _pickle as cPickle
import imp
import git

  import imp


In [3]:
os.environ["SPECTRAL_CONNECTIVITY_ENABLE_GPU"] = "true"

In [4]:
from collections import defaultdict
import time
import json

In [5]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
import pandas as pd
import scipy.signal
from labellines import labelLine, labelLines


In [6]:
from spectral_connectivity import Multitaper, Connectivity

In [7]:
from probeinterface import get_probe
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import write_prb, read_prb

In [8]:
# Changing the figure size
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

The spikeinterface module by itself import only the spikeinterface.core submodule
which is not useful for end user



In [9]:
import spikeinterface

We need to import one by one different submodules separately (preferred).
There are 5 modules:

- :code:`extractors` : file IO
- :code:`toolkit` : processing toolkit for pre-, post-processing, validation, and automatic curation
- :code:`sorters` : Python wrappers of spike sorters
- :code:`comparison` : comparison of spike sorting output
- :code:`widgets` : visualization



In [10]:
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.preprocessing as sp

In [11]:
import spikeinterface.core

We can also import all submodules at once with this
  this internally import core+extractors+toolkit+sorters+comparison+widgets+exporters

This is useful for notebooks but this is a more heavy import because internally many more dependency
are imported (scipy/sklearn/networkx/matplotlib/h5py...)



In [12]:
import spikeinterface.full as si

In [13]:
# Increase size of plot in jupyter

plt.rcParams["figure.figsize"] = (10,6)

# Part 1: Importing Data

## Loading in the Preprocessed LFP 

- Getting the root directory of the Github Repo to base the files off of

In [14]:
git_repo = git.Repo(".", search_parent_directories=True)
git_root = git_repo.git.rev_parse("--show-toplevel")

In [15]:
git_root

'/nancy/projects/reward_competition_extention'

- Getting a list of all the electrophysiological recording files
    - **NOTE**: If your recording file does not end with `.rec` or is in a different directory than `./data` then you must change `glob.glob({./path/to/recording_file.rec})` below. Where you replace `{./path/to/recording_file.rec}` with the path to your recording file without the brackets.

In [16]:
time_range=(1000, 1005)
resampled_frequency = 1000

In [17]:
time_halfbandwidth_product=10

In [18]:
def get_lfp_extractor(recording_path: str, resample_rate: int = 1000):
    """
    Preprocesses a given electrophysiology recording for Local Field Potential.
    """
    trodes_recording = se.read_spikegadgets(recording_path, stream_id="trodes")
    preprocessed_recording = sp.bandpass_filter(trodes_recording, freq_min=0.5, freq_max=300)
    preprocessed_recording = sp.notch_filter(preprocessed_recording, freq=60)
    preprocessed_recording = sp.resample(preprocessed_recording, resample_rate=resample_rate)
    return preprocessed_recording

In [19]:
def compute_multitaper_spectrum(preprocessed_recording, start_frame=0, end_frame=60, time_halfbandwidth_product=3, resampled_frequency=1000):
    """
    Computes the multitaper spectral estimate for a given preprocessed recording.
    """
    traces = preprocessed_recording.get_traces(start_frame=start_frame*1000, end_frame=end_frame*1000)
    m = Multitaper(time_halfbandwidth_product=time_halfbandwidth_product, time_series=traces,
                   sampling_frequency=resampled_frequency)
    c = Connectivity.from_multitaper(m)
    return c

In [20]:
import seaborn as sns

In [21]:
def plot_and_save_spectrum(c, recording_intermediate, ax=None, freq_min=0, freq_max=10, recording_basename=None, channels=None, ymax=None):
    """
    Plots and saves the power spectrum for each tetrode's waveform.
    """

    if ax is None:
        ax = plt.gca()    
    ax.set_xlabel("Frequency")
    ax.set_ylabel("Power")
    if channels is None:
        ax.plot(c.frequencies, c.power().squeeze()[:,:], label=recording_intermediate.get_channel_ids())
    else:
        channel_index = [recording_intermediate.get_channel_ids().tolist().index(item) for item in channels if item in recording_intermediate.get_channel_ids()]
        ax.plot(c.frequencies, c.power().squeeze()[:,channel_index], label=channel_index)

    ax.set_xlim(freq_min,freq_max)
    if ymax is not None:
        ax.set_ylim(0, ymax)
    ax.legend()
    labelLines(ax.get_lines(), zorder=2.5)
    
    return ax

In [22]:
recording_path = "/scratch/back_up/reward_competition_extention/data/omission/2023_06_18/20230618_100636_standard_comp_to_omission_D2_subj_1-4_and_1-1.rec/20230618_100636_standard_comp_to_omission_D2_subj_1_1_t1b2L_box2_merged.rec"

In [23]:
recording_path = "/scratch/back_up/reward_competition_extention/data/omission/2023_06_18/20230618_100636_standard_comp_to_omission_D2_subj_1-4_and_1-1.rec/20230618_100636_standard_comp_to_omission_D2_subj_1_4_t4b3L_box1_merged.rec"

In [24]:
recording_path = "/scratch/back_up/reward_competition_extention/data/omission/2023_06_20/20230620_114347_standard_comp_to_omission_D4_subj_1-2_and_1-1.rec/20230620_114347_standard_comp_to_omission_D4_subj_1-2_t3b3L_box_1_merged.rec"

In [29]:
recording_path = "/scratch/back_up/reward_competition_extention/data/pilot/20221215_145401_comp_amd_om_6_1_and_6_3.rec/20221215_145401_comp_amd_om_6_1_top_4_base_3.rec"

In [36]:
recording_path = "/scratch/back_up/reward_competition_extention/data/standard/2023_06_12/20230612_101430_standard_comp_to_training_D1_subj_1-4_and_1-3.rec/20230612_101430_standard_comp_to_training_D1_subj_1-3_t3b3L_box2_merged.rec"

In [37]:
trodes_recording = se.read_spikegadgets(recording_path, stream_id="trodes")

In [38]:
type(trodes_recording)

spikeinterface.extractors.neoextractors.spikegadgets.SpikeGadgetsRecordingExtractor

In [39]:
lfp_extractor = get_lfp_extractor(recording_path)

In [40]:
c = compute_multitaper_spectrum(lfp_extractor, time_halfbandwidth_product=10, start_frame=600, end_frame=660)

In [41]:
channels

['28', '29', '30', '31']

In [42]:
recording_basename = os.path.splitext(os.path.basename(recording_path))[0]
output_dir_path = "./proc/{}".format(recording_basename)
os.makedirs(output_dir_path, exist_ok=True)

freq_min=0
freq_max=6
for num in range(8):
    channels = [num for num in lfp_extractor.channel_ids[num*4: num*4+4]]
    plot_and_save_spectrum(c, recording_intermediate=lfp_extractor, channels=channels, freq_min=freq_min, freq_max=freq_max)
    plt.title(recording_basename)
    plt.savefig(os.path.join(output_dir_path, "lfp_power_freq_{}_{}_ch_{}.png".format(freq_min, freq_max, "-".join(channels))))
    plt.close()
    # plt.show()

In [43]:
recording_basename = os.path.splitext(os.path.basename(recording_path))[0]
output_dir_path = "./proc/{}".format(recording_basename)
os.makedirs(output_dir_path, exist_ok=True)

freq_min=6
freq_max=12
for num in range(8):
    channels = [num for num in lfp_extractor.channel_ids[num*4: num*4+4]]
    plot_and_save_spectrum(c, recording_intermediate=lfp_extractor, channels=channels, freq_min=freq_min, freq_max=freq_max)
    plt.title(recording_basename)
    plt.savefig(os.path.join(output_dir_path, "lfp_power_freq_{}_{}_ch_{}.png".format(freq_min, freq_max, "-".join(channels))))
    plt.close()
    # plt.show()