In [None]:
%matplotlib inline
import spikeinterface.full as si
import numpy as np
from os.path import join, split
from glob import glob
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
spikeglx_folder = Path(r'/path/to/bin/file/')

if len(glob(join(spikeglx_folder, '*.cbin'))) > 0:
    rec = si.read_cbin_ibl(spikeglx_folder)
else:
    rec = si.read_spikeglx(spikeglx_folder, stream_id=f'imec{split(spikeglx_folder)[-1][-1]}.ap')
rec

In [None]:
# Apply high-pass filter
print('\nApplying high-pass filter.. ')
rec_filtered = si.highpass_filter(rec, ftype='bessel', dtype='float32')

# Correct for inter-sample phase shift
print('Correcting for phase shift.. ')
rec_shifted = si.phase_shift(rec_filtered)

# Do common average referencing
print('Performing common average referencing.. ')
#rec_comref = si.common_reference(rec_filtered, reference='local')
rec_comref = si.common_reference(rec_filtered)

In [None]:
print('Detecting and interpolating over bad channels.. ')

# Detect dead channels
bad_channel_ids, all_channels = si.detect_bad_channels(rec_filtered, seed=42)
prec_dead_ch = np.sum(all_channels == 'dead') / all_channels.shape[0]
print(f'{np.sum(all_channels == "dead")} ({prec_dead_ch*100:.0f}%) dead channels')
dead_channel_ids = rec_filtered.get_channel_ids()[all_channels == 'dead']
out_channel_ids = rec_filtered.get_channel_ids()[all_channels == 'out']

# Detect noisy channels
bad_channel_ids, all_channels = si.detect_bad_channels(rec_comref, method='mad', seed=42)
prec_noise_ch = np.sum(all_channels == 'noise') / all_channels.shape[0]
print(f'{np.sum(all_channels == "noise")} ({prec_noise_ch*100:.0f}%) noise channels')
noisy_channel_ids = rec_comref.get_channel_ids()[all_channels == 'noise']

# Remove channels that are outside of the brain
rec_no_out = rec_shifted.remove_channels(remove_channel_ids=out_channel_ids)
        
# Interpolate over bad channels          
rec_interpolated = si.interpolate_bad_channels(rec_no_out, np.concatenate((
    dead_channel_ids, noisy_channel_ids)))

In [None]:
print('Destriping.. ')
if np.unique(rec_interpolated.get_property('group')).shape[0] > 1:
    print('Multi-shank probe detected, doing destriping per shank')
    rec_split = rec_interpolated.split_by(property='group')
    rec_destripe = []
    for sh in range(len(rec_split)):
        rec_destripe.append(si.highpass_spatial_filter(rec_split[sh]))
    rec_destriped = si.aggregate_channels(rec_destripe, renamed_channel_ids=rec_interpolated.get_channel_ids())
else:
    rec_destriped = si.highpass_spatial_filter(rec_interpolated)

In [None]:
%matplotlib widget
si.plot_traces({'raw':rec, 'filtered': rec_filtered, 'car': rec_comref, 'destriped': rec_destriped},
               time_range=[1000, 1000.04], color='k', backend='ipywidgets')

In [None]:
# Plot spectral density
data_chunk = si.get_random_data_chunks(
    rec_destriped,
    num_chunks_per_segment=1,
    chunk_size=30000,
    seed=0
)

fig, ax = plt.subplots(figsize=(10, 7))
for tr in data_chunk.T:
    p, f = ax.psd(tr, Fs=rec_destriped.sampling_frequency, color="b")

In [None]:
# Apply notch filter to peaks
freqs = [11300, 12640]  # frequency to filter out
qs = [8, 20]  # width of the filter (lower values = wider filter)

for i, (freq, q) in enumerate(zip(freqs, qs)):
    if i == 0:
        rec_destriped_notch = si.notch_filter(rec_destriped, freq=freq, q=q)
        rec_comref_notch = si.notch_filter(rec_comref, freq=freq, q=q)
    else:
        rec_destriped_notch = si.notch_filter(rec_destriped_notch, freq=freq, q=q)
        rec_comref_notch = si.notch_filter(rec_comref_notch, freq=freq, q=q)
        

In [None]:
# Plot result of notch filter 
data_chunk = si.get_random_data_chunks(
    rec_destriped_notch,
    num_chunks_per_segment=1,
    chunk_size=30000,
    seed=0
)

fig, ax = plt.subplots(figsize=(10, 7))
for tr in data_chunk.T:
    p, f = ax.psd(tr, Fs=rec_destriped_notch.sampling_frequency, color="b")

In [None]:
%matplotlib widget
si.plot_traces({'destriped': rec_destriped, 'destriped notch': rec_destriped_notch,
                'car': rec_comref, 'car notch': rec_comref_notch},
               time_range=[1000, 1000.04], color='k', backend='ipywidgets')