# MEA analysis script for a single recording
This is the analysis script to extract the trigger channel and the spikes from a single MEA recording. The input data is a .brw file which contains the stimulus channel of the recording and a .hdf5 file which is the output of the HS2 spikesorting algorithm. 
At the end of this script, dataframes containing the spikes sorted by stimuli and the trigger times for each stimulus can be exported either to a zipped pickle file or a .mat matlab file for forther analysis in Matlab.
First we import all needed libraries:

In [None]:
from MEA_analysis import backbone, stimulus_trace, spike_extractor, spike_plotly, stimulus_and_spikes, single_stimulus
from importlib import reload  
import qgrid
%matplotlib widget
reload(spike_plotly)
reload(spike_extractor)
from ipywidgets import interact, interact_manual, interactive
import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.graph_objects as go

### Import the stimulus file
Run the cell and click on the File select button. Choose the stimulus file that you would like to use from the pop up window. 

In [None]:
trigger_file = backbone.SelectFilesButton("Stimulus")
widgets.VBox([trigger_file, trigger_file.out])

#### Choice:
You can jump over the next steps and load a saved stimulus dataframe a few cells further down!

#### Plot the trigger channel and define stimulus borders
This is an interactive plot, you can define stimuli by selecting a trigger signal (which will be the stimulus begin) and than select a second trigger signal (which will be the stimulus end). The stimulus begin will be indicated by a yellow dot and the end by a red dot which will be plotted into the 
graph. Depending on how long the recording is, the selection process can take a few seconds, check on the bottom left, if Python indicates that it is "busy". 
Also please note: This overview plot is downsampled, to make the plotting faster. You can change the factor by changing the input parameter for this function `Test_stimulus.plot_trigger_channel_new("200ms")` . If you choose higher values, the downsampling will be higher, but this will also create more artefacts. The artefacts are not important until the point at which whole trigger signals disappear due to the downsampling. 

In [None]:
Test_stimulus = stimulus_trace.Stimulus_Extractor(trigger_file.files[0])
Test_stimulus.plot_trigger_channel_new("200ms")


Once you have selected all stimuli, run this cell to correctly identify them. A line will be plotted on top of the trigger signal for each stimulus.

In [None]:
Test_stimulus.get_stim_range_new()


Here you see all the information that has been gathered for the respective stimuli. Next, you should name the stimuli in the field "Stimulus_name" just double click a row and enter the name you want. Also enter the values for the Stimulus repeat logic and the repeat sublogic at the respective rows. 

In [None]:
Stimulus_table = qgrid.show_grid(Test_stimulus.stimuli)
Stimulus_table

Once you are done, run the next cell. The correct names of the stimuli will appear in the plot above.

In [None]:
Test_stimulus.get_changed_names(Stimulus_table)


In [None]:
Test_stimulus = stimulus_trace.Stimulus_Extractor(trigger_file.files[0])
Test_stimulus.load_from_saved(Test_stimulus.recording_folder+"stimulus_overview")

## Loading the spikes
Here we load the spikes. Run the cell and click the button to select the output file of the HS2 spikesorting. 

In [None]:
spikes_file = backbone.SelectFilesButton("Spikes")
widgets.VBox([spikes_file, spikes_file.out])

The next cell will plot an overview over all the cells found by the spikesorting and the number of spikes for each cell. Under the plot you find a number of interfaces that allow for setting thresholds for how many cells you want to load, and at which points cells with too many spikes or too less spikes shall be kicked out. 

In [None]:
reload(spike_extractor)
Spikes = spike_extractor.Recording_spikes(spikes_file.files[0])
thresholds = spike_extractor.Thresholds(Spikes.spikes, Test_stimulus)
overview = spike_plotly.Recording_overview(Spikes.spikes)
overview.plot_basic_recording_information(thresholds)

Once you have made the selection, you can load the respective spikes by running the next cell. It will return an overview over all loaded cells, and the spikes per stimulus, as well as the location of the cell on the MEA grid.

In [None]:
#Load spikes

reload(stimulus_and_spikes)
Spikes.define_subset(thresholds.threshold_left_widget.value, thresholds.threshold_right_widget.value)
Spikes.define_thr(thresholds.threshold_up_widget.value, thresholds.threshold_low_widget.value)
spikes_df = Spikes.get_spikes(True, True)[2]

#Correlate spikes and stimuli
begin_idx, end_idx = stimulus_and_spikes.spikes_and_stimulus(spikes_df, Test_stimulus.stimuli)   
stimulus_spikes = stimulus_and_spikes.extract_stimulus_spikes(spikes_df, Test_stimulus.stimuli['Stimulus_name'], begin_idx, end_idx, np.array(Test_stimulus.stimuli['Begin_Fr'][:], dtype=int))
complete_dataframe = pd.DataFrame(columns=('Cell index', 'Centres x', 'Centres y', 'Nr of Spikes', 'Area', 'Stimulus ID', 'Stimulus name', 'Spikes'))
row = 0

for cell in range(len(spikes_df)):
       
    for stimulus in range(len(stimulus_spikes)):
        nr_spikes_new = np.count_nonzero(~stimulus_spikes['Spikes'][stimulus][:, cell].mask)
        area_new =  spikes_df['Area'].loc[cell]* (nr_spikes_new/spikes_df['Nr of spikes'].loc[cell])
        complete_dataframe.loc[row] = [spikes_df['Cell index'].loc[cell], spikes_df['Centres x'].loc[cell],  spikes_df['Centres y'].loc[cell],
                                       nr_spikes_new, area_new, stimulus, 
                                       Test_stimulus.stimuli['Stimulus_name'][stimulus], stimulus_spikes['Spikes'][stimulus][:, cell]]
        row = row+1

        
multi_complete_dataframe = complete_dataframe.set_index(['Cell index', 'Stimulus ID', 'Centres x', 'Centres y', 'Nr of Spikes', 'Area', 'Stimulus name'])
complete_dataset = qgrid.show_grid(multi_complete_dataframe)
complete_dataset

You can save this dataframe and the stimulus dataframe here to be able to load them together with other recordings later.

In [None]:
multi_complete_dataframe.to_pickle(Test_stimulus.recording_folder+"spikes_for_overview", compression="zip")
Test_stimulus.stimuli.to_pickle(Test_stimulus.recording_folder+"stimulus_overview", compression="zip")

In [None]:
waveform_window = widgets.Output(layout={'border': '1px solid black'})
waveform_window

This cell plots an overview over the how many spikes were detected (circle size) at which location on the MEA grid.

In [None]:
reload(spike_plotly)
plottingarray = spike_plotly.ArrayFigure(spikes_df, Spikes)
plottingarray.window = waveform_window
test, table= plottingarray.plot_locations(invert=True)


In [None]:
test

## Look at single stimuli
Next we can load single stimuli and look at the spiketrains. Run the next cell and select which stimulus you want to look at.

In [None]:
table

In [None]:
indices = plottingarray.indices.to_numpy()

In [None]:
indices

In [None]:
multi_complete_dataframe["Array_Selection"] = "False"
multi_complete_dataframe

In [None]:
for i in indices:
    multi_complete_dataframe.loc[multi_complete_dataframe.index.get_level_values(0) == i, "Array_Selection"] = "True"

In [None]:
reload(single_stimulus)
stimulus_extr = single_stimulus.Single_stimulus_spikes(multi_complete_dataframe, Test_stimulus)
select_stimulus = backbone.select_stimulus(len(Test_stimulus.stimuli)-1)

interact(stimulus_extr.load_spikes_for_stimulus, stimulus_id = select_stimulus)

### Export spikes to matlab
If you want, you can export the spikes for the stimulus you have selected. The file will be saved in the same folder in which the .hdf5 file is located in. Its names based on the stimulus name and stimulus ID

In [None]:
array_select_spikes = stimulus_extr.dataframe_view.get_changed_df()

In [None]:
array_select_spikes

In [None]:
# Export spiketimes to matlab
from scipy.io import savemat

Spikes = list(stimulus_extr.spikes_stimulus["Spikes"]/stimulus_extr.sampling_freq)
max_len = np.shape(Spikes[0])[0]
spikes_array = np.zeros((max_len, len(Spikes)), dtype=float)

for cell in range(len(Spikes)):
    spikes_array[:, cell] = Spikes[cell]
    
    

channel = np.array(Test_stimulus.channel[stimulus_extr.stimulus_info["Begin_Fr"]-10:stimulus_extr.stimulus_info["End_Fr"]])
channel = channel[:, 0].astype(int)

test_dic = {}
test_dic["spiketimestamps"] = spikes_array
test_dic["Ch_new"] = {}
test_dic["Ch_new"]["trigger_ch"] = channel
test_dic["Ch_new"]["SamplingFrequency"] = stimulus_extr.sampling_freq
test_dic["Cell_idx"] = stimulus_extr.spikes_stimulus.index.get_level_values(0).to_numpy()
#test = test_dic["Ch_new"] = np.core.records.fromarrays([[1, 10], [2, 20]], names=['field1', 'field2'])

savemat(Test_stimulus.recording_folder+Test_stimulus.stimuli.loc[select_stimulus.value]["Stimulus_name"]+
        str(select_stimulus.value)+"_spikes.mat", test_dic)

In [None]:
reload(spike_plotly)
Colours = spike_plotly.Colour_template()
colour_selection = Colours.select_preset_colour()
interact(Colours.pickstimcolour, selected_stimulus=colour_selection)

In [None]:
Colours.changed_selection()
stimulus_extr.Colours = Colours

In [None]:
stimulus_extr.spikes_psth_all()

In [None]:
array_select_spikes = single_stimulus.spikes_psth_all(array_select_spikes, stimulus_extr.trigger_complete,
                                     int(stimulus_extr.stimulus_info["Stimulus_repeat_logic"]),
                                     Test_stimulus.sampling_frequency[0])

In [None]:
reload(single_stimulus)
reload(stimulus_and_spikes)
Quality_df = single_stimulus.calculate_quality_index(array_select_spikes, stimulus_extr.trigger_complete,
                                     int(stimulus_extr.stimulus_info["Stimulus_repeat_logic"]),
                                     Test_stimulus.sampling_frequency[0])

In [None]:
Quality_overview = qgrid.show_grid(Quality_df)
Quality_overview

In [None]:
Quality_df.index.get_level_values(0)

In [None]:
np.where([Quality_df.index.get_level_values(0) == 28][0])

In [None]:
Quality_df.iloc[11]

In [None]:
stimulus_extr.spikes_stimulus = Quality_overview.get_changed_df()

In [None]:
reload(spike_plotly)
spike_plotly.plot_heatmap_new(Quality_overview.get_changed_df(), stimulus_extr.stimulus_info, Colours)

In [None]:
Quality_overview.get_changed_df()

In [None]:
reload(spike_plotly)
spike_plotly.plot_heatmap_new(Quality_overview.get_changed_df(), stimulus_extr.stimulus_info, Colours)

In [None]:
output_w = widgets.Output(layout={'border': '1px solid black'})
output_w

In [None]:
stimulus_extr.define_output_window(output_w)
stimulus_extr.cells_df = Quality_overview.get_changed_df()
Quality_overview.on('selection_changed', stimulus_extr.plot_raster_whole_stimulus_from_grid_new)

In [None]:
stimulus_extr.plot_raster_whole_stimulus(cell_idx=28)

In [None]:
test = stimulus_extr.spikes_stimulus.reset_index()

In [None]:
reload(spike_plotly)
spike_plotly.plot_qc_locations(stimulus_extr.spikes_stimulus.reset_index())

In [None]:
output_w = widgets.Output(layout={'border': '1px solid black'})
output_w

In [None]:
stimulus_extr.define_output_window(output_w)
stimulus_extr.cells_df = Quality_overview.get_changed_df()
Quality_overview.on('selection_changed', stimulus_extr.plot_raster_whole_stimulus_from_grid_new)

In [None]:
with stimulus_extr.out_window:
    display(raster_plot)

In [None]:
stimulus_extr.plot_raster_whole_stimulus(233)

In [None]:
test = Quality_df["ISI_x"].loc[13].to_numpy()[0]

In [None]:
test

In [None]:
reload(spike_plotly)
spikes, spiketrains = stimulus_and_spikes.get_spikes_whole_stimulus(Quality_df, stimulus_extr.trigger_complete
                                                        , 16, int(stimulus_extr.stimulus_info["Stimulus_repeat_logic"]),
                                                        stimulus_extr.sampling_freq)
cell_df = stimulus_extr.spikes_stimulus.loc[16]
raster_plot = spike_plotly.plot_raster_whole_stimulus_new(cell_df, spiketrains, int(stimulus_extr.stimulus_info["Stimulus_repeat_logic"]),
                                                              int(stimulus_extr.stimulus_info["Stimulus_repeat_sublogic"]), stimulus_extr.Colours.axcolours,
                                                              stimulus_extr.Colours.LED_names)

In [None]:
raster_plot

In [None]:
cell_df["Gauss_average"].to_numpy()[0]