## Where is the probe?

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path


import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.widgets as sw

import numpy as np
import  multitaper
from tqdm import tqdm
import pandas as pd

INPUT = '/ceph/sjones/projects/FlexiVexi/raw_data/FNT103/2024-08-29T15-17-32/Open-Ephys/FNT103_2024-09-03_15-17-05/Record Node 103'

## Have a look at local data

In [None]:
recording = se.read_openephys(INPUT, stream_id  = '1')
seg1 = recording.select_segments(0)

In [None]:
seg1

rec1 = si.highpass_filter(recording, freq_min=400.)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1)
rec2 = rec1.remove_channels(bad_channel_ids)
print('bad_channel_ids', bad_channel_ids)

rec3 = si.phase_shift(rec2)
rec4 = si.common_reference(rec3, operator="median", reference="global")
rec = rec4
rec

In [None]:
split_recording_dict = recording.split_by("group")


In [None]:
probe1 = split_recording_dict[1]

w_ts = sw.plot_traces(probe1, mode="map", time_range=(5, 15), show_channel_ids=True, order_channel_by_depth=True, seconds_per_row=60, clim  = (-50, 50))


Keep 10s of data

In [None]:
samp = probe1.sampling_frequency
traces =  (probe1.get_traces(start_frame=5*samp, end_frame=15*samp)).T

In [None]:
traces.shape

In [None]:
nChans, nSamps = traces.shape
print('Data has %d channels and %d samples',(nChans,nSamps))
plt.plot(np.arange(nSamps)/samp,traces[0,:]/1000)
plt.xlabel('Time (s)')
plt.ylabel('LFP (mV)')

They want to have a look at power in Dbs in the delta band (0-4 Hz). We use multitaper because the window is so short.  

In [None]:
psd = multitaper.MTSpec(x=traces[0,:]/10E6, dt=1.0/samp, nw=5) # run the multitaper spectrum
pxx, f = psd.spec, psd.freq # unpack power spectrum and frequency from output
plot_range = (f<=10) & (f>=0) # find the frequencies we want to plot
plt.semilogy(f[plot_range],pxx[plot_range])
plt.xlabel('frequency (Hz)')
plt.ylabel('power (V**2)')

In [None]:
# Define your trace and sampling rate
trace = traces[0, :] / 10E6  # Example trace, adjust scaling as needed
n = len(trace)
dt = 1.0 / samp  # Sampling interval

# Run FFT
fft_values = np.fft.rfft(trace)
frequencies = np.fft.rfftfreq(n, dt)

# Calculate the power spectral density (PSD)
psd = np.abs(fft_values) ** 2 / n

# Define the frequency range (0 to 10 Hz)
plot_range = (frequencies <= 10) & (frequencies >= 0)

# Plot the power spectral density
plt.semilogy(frequencies[plot_range], psd[plot_range])
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power (V^2)')
plt.show()

In [None]:
# Define the frequency range of interest (0-4 Hz)
band_range = (f >= 0) & (f <= 4)

# Calculate the total power in the 0-4 Hz band by summing the power values in that range
power_band = np.sum(pxx[band_range])

# Convert the power to dB
power_db = 10 * np.log10(power_band)

power_db

Let's do it for all 96

In [None]:
pxx_list = list(np.zeros(96))
f_list = list(np.zeros(96))

for i in tqdm(np.arange(len(pxx_list))):
    print (i)
    psd = multitaper.MTSpec(x=traces[i,:]/10E6, dt=1.0/samp, nw=5) # run the multitaper spectrum
    pxx, f = psd.spec, psd.freq # unpack power spectrum and frequency from output
    pxx_list[i] = pxx
    f_list[i] = f

In [None]:
freq_per_channel = {
    'channel': np.arange(96), 
    'pxx': pxx_list, 
    'f': f_list
}

freq =  pd.DataFrame(freq_per_channel)

In [None]:
freq.to_csv('freq.csv')

In [None]:
def get_delta_power(pxx, f):
    # Define the frequency range of interest (0-4 Hz)
    band_range = (f >= 0) & (f <= 4)

    # Calculate the total power in the 0-4 Hz band by summing the power values in that range
    power_band = np.sum(pxx[band_range])

    # Convert the power to dB
    power_db = 10 * np.log10(power_band)

    return power_db

In [None]:
freq['delta_power'] =  [get_delta_power(pxx, f) for pxx,f in zip(pxx_list, f_list)]

In [None]:
freq.to_csv('freq.csv')

In [None]:
probemap = probe1.get_probe().to_dataframe()

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
si.plot_probe_map(probe1, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 9000)

In [None]:
probemap

In [None]:
probemap['channel'] = probe1.channel_ids
probemap['dbs'] = freq['delta_power']

In [None]:
fig, ax = plt.subplots()

# Create a scatter plot
sc = ax.scatter(probemap['x'], probemap['y'], c=probemap['dbs'], cmap='viridis', s=50)

# Add color bar for the 'dfs' values (make sure to pass the scatter plot object `sc`)
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label('Delta power (power in Db from 0 to 4 Hz in a signal in V)')

# Set x-axis limits
ax.set_xlim((100, 450))

In [None]:

probemap.to_csv('probemap.csv')