## Where is the probe?

In [6]:
import matplotlib.pyplot as plt
from pathlib import Path


import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.widgets as sw

import numpy as np
import  multitaper
from tqdm import tqdm
import pandas as pd
from scipy import signal

INPUT = '/ceph/sjones/projects/FlexiVexi/raw_data/FNT103/2024-09-05T14-53-54/Open-Ephys/2024-09-05_15-53-47/Record Node 103'
#INPUT = '/ceph/sjones/projects/sequences/NPX_DATA/SP156_all_shanks'

## Have a look at local data

In [7]:
recording = se.read_openephys(INPUT, stream_id  = '1', block_index = 0)


rec1 = si.highpass_filter(recording, freq_min=400)
rec = si.common_reference(rec1, operator="median", reference="global")


In [None]:
split_recording_dict = recording.split_by("group")


In [None]:
probe1 = split_recording_dict[3]

w_ts = sw.plot_traces(probe1, mode="map", time_range=(5, 15), show_channel_ids=True, order_channel_by_depth=True, seconds_per_row=60, clim  = (-50, 50))


Keep 10s of data

In [None]:
samp = probe1.sampling_frequency
traces =  (probe1.get_traces(start_frame=5*samp, end_frame=15*samp)).T

In [None]:
traces.shape

### Look at traces

In [None]:
nChans, nSamps = traces.shape
print('Data has %d channels and %d samples',(nChans,nSamps))
plt.plot(np.arange(nSamps)/samp,traces[0,:]/1000)
plt.xlabel('Time (s)')
plt.ylabel('LFP (mV)')

Now, let's have a look at the LFP by depth. We remove slow oscillations and 

In [None]:
rec1 = si.bandpass_filter(recording=recording, freq_min=300, freq_max=6000)


rec = si.common_reference(rec1, operator="median", reference="global")

# Plot with spikeinterface or sw.plot_traces
w_ts = sw.plot_traces(rec, mode="map", time_range=(5, 15), show_channel_ids=True, order_channel_by_depth=True)

# If w_ts is an Axes object, this will get the parent figure
fig = w_ts.figure
ax = w_ts.ax

# Set the figure size (width, height)
fig.set_size_inches(10, 15)

# Get the current y-ticks
yticks = ax.get_yticks()

# Set the y-ticks to show only every 10th channel
new_yticks = yticks[::10]
ax.set_yticks(new_yticks)

# Show the modified plot
plt.show()

In [None]:
traces =  (rec.get_traces(start_frame=5*samp, end_frame=15*samp)).T


In [None]:

nChans, nSamps = traces.shape
print('Data has %d channels and %d samples',(nChans,nSamps))
plt.figure(figsize=(30, 6))
plt.plot(np.arange(nSamps)/samp,traces[139,:]/1000)
plt.xlabel('Time (s)')
plt.ylabel('LFP (mV)')

### Welch power spectrum

Break into four 4.5 slices of 2.5s, average them

In [None]:
samp

In [None]:
rec_welch = si.bandpass_filter(recording=recording, freq_min=1, freq_max=samp/2-1)
#Less than nyquist, more than 1

traces =  (rec_welch.get_traces(start_frame=5*samp, end_frame=45*samp)).T
nChans, nSamps = traces.shape


In [None]:
n_windows =  4
window_samples = nSamps//n_windows
windows = np.zeros((n_windows, 2))
index = 0
for window in np.arange(n_windows):
    windows[index, 0] = index*window_samples
    windows[index, 1] = index*window_samples + window_samples
    index +=1



In [None]:
def fscale(ns, si=1, one_sided=False):
    """
    numpy.fft.fftfreq returns Nyquist as a negative frequency so we propose this instead

    :param ns: number of samples
    :param si: sampling interval in seconds
    :param one_sided: if True, returns only positive frequencies
    :return: fscale: numpy vector containing frequencies in Hertz
    """
    fsc = np.arange(0, np.floor(ns / 2) + 1) / ns / si  # sample the frequency scale
    if one_sided:
        return fsc
    else:
        return np.concatenate((fsc, -fsc[slice(-2 + (ns % 2), 0, -1)]), axis=0)


freq = fscale(window_samples, 1/samp, one_sided = True)

spectra = np.zeros((nChans, len(freq)))

for window in tqdm(np.arange(n_windows)):
    start, end = int(windows[window, 0]), int(windows[window, 1])
    trace = traces[:, start:end]
    _, w = signal.welch(
    trace/10E6, fs=samp, window='hann', nperseg=window_samples,
    detrend='constant', return_onesided=True, scaling='density', axis=-1
    )

    spectra += w

spectrum = spectra/n_windows

In [None]:
plot_range = (freq<=10) & (freq>=1) # find the frequencies we want to plot
fig, ax = plt.subplots()

# Correct method call for semilogy
ax.semilogy(freq[plot_range], spectrum[0, plot_range])

In [None]:
freq_per_channel = {
    'channel': np.arange(len(probe1.channel_ids)), 
    'pxx': spectrum, 
    'f': freq
    }



### Obtain multitaper power spectrum

They want to have a look at power in Dbs in the delta band (0-4 Hz). We use multitaper because the window is so short.  

In [None]:
psd = multitaper.MTSpec(x=traces[0,:]/10E6, dt=1.0/samp, nw=5) # run the multitaper spectrum
pxx, f = psd.spec, psd.freq # unpack power spectrum and frequency from output
plot_range = (f<=10) & (f>=0) # find the frequencies we want to plot
plt.semilogy(f[plot_range],pxx[plot_range])
plt.xlabel('frequency (Hz)')
plt.ylabel('power (V**2)')

In [None]:
# Define your trace and sampling rate
trace = traces[0, :] / 10E6  # Example trace, adjust scaling as needed
n = len(trace)
dt = 1.0 / samp  # Sampling interval

# Run FFT
fft_values = np.fft.rfft(trace)
frequencies = np.fft.rfftfreq(n, dt)

# Calculate the power spectral density (PSD)
psd = np.abs(fft_values) ** 2 / n

# Define the frequency range (0 to 10 Hz)
plot_range = (frequencies <= 10) & (frequencies >= 0)

# Plot the power spectral density
plt.semilogy(frequencies[plot_range], psd[plot_range])
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power (V^2)')
plt.show()

In [None]:
# Define the frequency range of interest (0-4 Hz)
band_range = (f >= 0) & (f <= 4)

# Calculate the total power in the 0-4 Hz band by summing the power values in that range
power_band = np.sum(pxx[band_range])

# Convert the power to dB
power_db = 10 * np.log10(power_band)

power_db

Let's do it for all 96

In [None]:
pxx_list = list(np.zeros(96))
f_list = list(np.zeros(96))

for i in tqdm(np.arange(len(pxx_list))):
    print (i)
    psd = multitaper.MTSpec(x=traces[i,:]/10E6, dt=1.0/samp, nw=5) # run the multitaper spectrum
    pxx, f = psd.spec, psd.freq # unpack power spectrum and frequency from output
    pxx_list[i] = pxx
    f_list[i] = f

In [None]:
freq_per_channel = {
    'channel': np.arange(96), 
    'pxx': pxx_list, 
    'f': f_list
}

freq =  pd.DataFrame(freq_per_channel)

In [None]:
freq.to_csv('freq.csv')

In [None]:
def get_delta_power(pxx, f):
    # Define the frequency range of interest (0-4 Hz)
    band_range = (f >= 0) & (f <= 4)

    # Calculate the total power in the 0-4 Hz band by summing the power values in that range
    power_band = np.sum(pxx[band_range])

    # Convert the power to dB
    power_db = 10 * np.log10(power_band)

    return power_db

In [None]:
freq['delta_power'] =  [get_delta_power(pxx, f) for pxx,f in zip(pxx_list, f_list)]

In [None]:
freq.to_csv('freq.csv')

In [None]:
probemap = probe1.get_probe().to_dataframe()

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
si.plot_probe_map(probe1, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 9000)

In [None]:
probemap

In [None]:
probemap['channel'] = probe1.channel_ids
probemap['dbs'] = freq['delta_power']

In [None]:
fig, ax = plt.subplots()

# Create a scatter plot
sc = ax.scatter(probemap['x'], probemap['y'], c=probemap['dbs'], cmap='viridis', s=50)

# Add color bar for the 'dfs' values (make sure to pass the scatter plot object `sc`)
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label('Delta power (power in Db from 0 to 4 Hz in a signal in V)')

# Set x-axis limits
ax.set_xlim((100, 450))

In [None]:

probemap.to_csv('probemap.csv')