In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import time
import datetime as dt
import os
import seaborn as sns

import random
import string

In [None]:
import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

from typing import Tuple, List
from probeinterface import Probe

def load_recording_from_raw(root: str, sample_base: str, well: Tuple[int, int], time_samplings_to_mask: List[Tuple[float, float]]):

    traces_list = []
    channel_ids = []

    df = pd.read_csv(f'{root}/{sample_base}/{sample_base}.info', index_col=0, names=['index', 'value'], sep='\t')
    sampling_frequency = df.loc['SamplingFrequency', 'value']
    voltage_scale = np.abs(df.loc['VoltageScale', 'value'])

    # We choose 10 here because in 64-electrode MEAs the range would be up to 9. 
    # Since the time required for the non-existing electrodes is small, we don't mine using a larger number.
    for Erow in range(1,10):  
        for Ecol in range(1,10):
            filename = f'{root}/{sample_base}/{well[0]}-{well[1]}-{Erow}-{Ecol}_voltageRaw'
            is_txt, is_gzip = os.path.exists(f'{filename}.txt'), os.path.exists(f'{filename}.txt.gz') 

            if is_txt or is_gzip:
                channel_ids.append(f'{Erow}-{Ecol}')
                
                if is_txt:
                    list_voltages = np.loadtxt(f'{filename}.txt')
                elif is_gzip:
                    list_voltages = np.loadtxt(f'{filename}.txt.gz')

                traces_list.append(list_voltages)
            

    trace_array = np.asarray(traces_list).transpose() / voltage_scale

    for time_sampling in time_samplings_to_mask:
        t0 = int(time_sampling[0] * sampling_frequency)
        tf = int(time_sampling[1] * sampling_frequency)
        trace_array[t0:tf, :] = 0

    sample_recording = si.NumpyRecording(
        traces_list=[trace_array],
        sampling_frequency=sampling_frequency,
        channel_ids=np.asarray(channel_ids)
    )

    sample_recording.set_property('group', [0] * len(channel_ids))
    sample_recording.is_dumpable = True  # This is necessary for some options later, like spike sorting

    return sample_recording


def load_probe_recording(recording: si.NumpyRecording, type_MEAS: int, ):
    dist_multiplier = 350 if type_MEAS == 16 else 300
    circle_radius = 50

    channel_ids = recording.get_channel_ids()

    positions = np.zeros((len(channel_ids), 2), dtype=float)
    contact_vector = []
    for channel_idx, channel in enumerate(channel_ids):
        x_coord, y_coord = (int(channel.split('-')[0]) - 1) * dist_multiplier, (int(channel.split('-')[1]) - 1) * dist_multiplier
        positions[channel_idx, 1], positions[channel_idx, 0] = x_coord, y_coord
        
        contact_vector.append((0, x_coord,   y_coord, 'circle', circle_radius, '', '', channel_idx, 'um', 1., 0., 0., 1.))

    # later if we are using peak detection, we may need it
    recording.set_channel_locations(locations=positions)

    probe = Probe(ndim=2, si_units='um')
    probe.set_contacts(positions=positions, shapes='circle', shape_params={'radius': circle_radius})
    probe.device_channel_indices = np.arange(len(channel_ids))
    probe.create_auto_shape('rect')

    recording.set_probe(probe)


    # Create contact_vector
    dtypes=[('probe_index', '<i8'), ('x', '<f8'), ('y', '<f8'), ('contact_shapes', '<U64'), 
            ('radius', '<f8'), ('shank_ids', '<U64'), ('contact_ids', '<U64'), ('device_channel_indices', '<i8'), 
            ('si_units', '<U64'), ('plane_axis_x_0', '<f8'), ('plane_axis_x_1', '<f8'), ('plane_axis_y_0', '<f8'), 
            ('plane_axis_y_1', '<f8')]

    recording.set_property('contact_vector', np.asarray(contact_vector, dtype=dtypes))

In [None]:
def retrieve_peaks(root, sample_base, well):
    session_token = dt.datetime.now().strftime("%y-%m-%d") + '_' + \
                ''.join(random.choice(string.ascii_letters) for i in range(8)) + str(well[0]) + '-' + str(well[1])
    
    recording = load_recording_from_raw(root=root, sample_base=sample_base, well=well, time_samplings_to_mask=[])
    load_probe_recording(recording=recording, type_MEAS=16)
    
    recording_bin = recording.save(n_jobs=16, chunk_duration="1s", folder=f'tmp/bin_{session_token}')

    recording_f = spre.bandpass_filter(recording_bin, freq_min=300, freq_max=5000)

    recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')

    noise_levels = si.get_noise_levels(recording_cmr, return_scaled=False)

    peaks = detect_peaks(recording_cmr,
                        method='locally_exclusive',
                        local_radius_um=450, 
                        detect_threshold=5,
                        noise_levels=noise_levels,
                        )
    
    list_peaks = []
    list_electrodes = []

    for i in range(16):
        list_peaks_i = [peak[0] / recording.sampling_frequency for peak in peaks if peak[1] == i]
        list_peaks += list_peaks_i

        el_x, el_y = i // 4, i % 4
        list_electrodes += [f'{el_x + 1}{el_y + 1}'] * len(list_peaks_i)

    return list_peaks, list_electrodes

In [None]:
DATE = '2024_05_07'
MV = 500

folder_base = f'/data/Proyectos/Nanoneuro_exps_ane/raw_files/{DATE}/{MV}/'
folder_df_save = f'/data/Proyectos/Nanoneuro_exps_ane/results/{DATE}/{MV}/'

os.makedirs(folder_df_save, exist_ok=True)

list_conditions = [#('Condition', 'Treatment', 'Wells', 'Well_num', 'Folder') 
                    ('CTRL', 'PRE',   ['A1', 'A2', 'A3'], ['11', '12', '13'], 'D24_POSTsiembra_P2_(000)'), 
                    ('CTRL', 'POST',  ['A1', 'A2', 'A3'], ['11', '12', '13'], 'D24_POSTsiembra_P2_POSTexperimento(000)'),
                    ('BP', 'PRE',     ['A4', 'A5', 'A6'], ['14', '15', '16'], 'D24_POSTsiembra_P2_(000)'), 
                    ('BP', 'POST',    ['A4', 'A5', 'A6'], ['14', '15', '16'], 'D24_POSTsiembra_P2_POSTexperimento(000)'), 
                    ('LINK1', 'PRE',  ['B1', 'B2', 'B3'], ['21', '22', '23'], 'D24_POSTsiembra_P2_(000)'), 
                    ('LINK1', 'POST', ['B1', 'B2', 'B3'], ['21', '22', '23'], 'D24_POSTsiembra_P2_POSTexperimento(000)'),
                    ('LINK2', 'PRE',  ['C1', 'C2', 'C3'], ['31', '32', '33'], 'D24_POSTsiembra_P2_(000)'), 
                    ('LINK2', 'POST', ['C1', 'C2', 'C3'], ['31', '32', '33'], 'D24_POSTsiembra_P2_POSTexperimento(000)'),
                    ('LINK3', 'PRE',  ['D1', 'D2', 'D3'], ['41', '42', '43'], 'D24_POSTsiembra_P2_(000)'), 
                    ('LINK3', 'POST', ['D1', 'D2', 'D3'], ['41', '42', '43'], 'D24_POSTsiembra_P2_POSTexperimento(000)'),
                    ('BP+LINK1', 'PRE',  ['B4', 'B5', 'B6'], ['24', '25', '26'], 'D25_POSTsiembra_P2(000)'), 
                    ('BP+LINK1', 'POST', ['B4', 'B5', 'B6'], ['24', '25', '26'], 'D25_POSTsiembra_P2_POSTexperiment(001)'),
                    ('BP+LINK2', 'PRE',  ['C4', 'C5', 'C6'], ['34', '35', '36'], 'D25_POSTsiembra_P2(000)'), 
                    ('BP+LINK2', 'POST', ['C4', 'C5', 'C6'], ['34', '35', '36'], 'D25_POSTsiembra_P2_POSTexperiment(001)'),
                    ('BP+LINK3', 'PRE',  ['D4', 'D5', 'D6'], ['44', '45', '46'], 'D25_POSTsiembra_P2(000)'), 
                    ('BP+LINK3', 'POST', ['D4', 'D5', 'D6'], ['44', '45', '46'], 'D25_POSTsiembra_P2_POSTexperiment(001)'),
                   ]

In [None]:
df_peaks = pd.DataFrame(columns=['condition', 'treatment', 'well', 'well_num', 'electrode', 'time'])

for condition, treatment, list_wells, list_wells_num, sample_base in list_conditions:
    for well, well_num in zip(list_wells, list_wells_num):
        print(condition, treatment, well, well_num, sample_base)
        list_peak_times, list_electrodes = retrieve_peaks(root=folder_base, sample_base=sample_base, well=(int(well_num[0]), int(well_num[1])))
        
        df_peaks_i = pd.DataFrame({'condition': [condition] * len(list_peak_times), 
                                    'treatment': [treatment] * len(list_peak_times), 
                                    'well': [well] * len(list_peak_times), 
                                    'well_num': [well_num] * len(list_peak_times), 
                                    'electrode': list_electrodes, 
                                    'time': list_peak_times
                                    })
        
        df_peaks = pd.concat([df_peaks, df_peaks_i]).reset_index(drop=True)
        print(len(df_peaks))

In [None]:
df_peaks.to_csv(f'{folder_df_save}/df_peaks_full.csv', index=False)

## Load df peaks and remove non-compliant electrodes

In [None]:
df_peaks = pd.read_csv(f'{folder_df_save}/df_peaks_full.csv')
df_peaks

In [None]:
# Plot each well
df_peaks_sub = df_peaks[(df_peaks['condition'] == 'CTRL') & 
                        (df_peaks['treatment'] == 'PRE') & 
                        (df_peaks['well'] == 'A1')]

df_peaks_sub

In [None]:
fig = plt.figure(figsize=(30, 6))

y = [(int(i[0]) - 1) * 4 + (int(i[1]) - 1) for i in df_peaks_sub['electrode'].astype(str).values]
x =  df_peaks_sub['time'].values 

plt.yticks(np.arange(16), [f'{i//4 + 1}{i%4 + 1}' for i in np.arange(16)])

plt.scatter(x, y, marker='|', alpha=0.15)

# plt.xlim([10, 30])

In [None]:
# MAKE A LIST OF ELECTRODES TO REFUSE
# Since each condition has a different set of wells, we only need to encode the well info

dict_electrode_refuse = {'A1': ['42', '32'],
                         'A2': [],
                         'A3': [],
                         'A4': [],
                         'A5': [],
                         'A6': [],
                         'B1': [],
                         'B2': [],
                         'B3': [],
                         'B4': [],
                         'B5': [],
                         'B6': [],
                         'C1': [],
                         'C2': [],
                         'C3': [],
                         'C4': [],
                         'C5': [],
                         'C6': [],
                         'D1': [],
                         'D2': [],
                         'D3': [],
                         'D4': [],
                         'D5': [],
                         'D6': [],}

list_remove_idx = []

for well, list_electrodes in dict_electrode_refuse.items():
    for electrode in list_electrodes:
        df_sub = df_peaks[(df_peaks['well'] == well) & (df_peaks['electrode'] == int(electrode))]
        list_remove_idx += df_sub.index.tolist()

list_remove_idx_bool = np.ones(len(df_peaks)).astype(bool)
list_remove_idx_bool[list_remove_idx] = False

df_peaks_sub_filter = df_peaks.loc[list_remove_idx_bool]

In [None]:
df_peaks

In [None]:
df_peaks_sub_filter