In [None]:
import datetime
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import tensorflow as tf
from tensorflow import keras


import keras.backend as K
import os
import sys
import h5py
import math
import csv
import pickle
import glob
import statistics
!pip install detecta
from detecta import detect_peaks



from scipy.signal import butter, filtfilt, lfilter, ellip


# MCS PyData tools
!pip install McsPyDataTools
import McsPy
import McsPy.McsData
from McsPy import ureg, Q_


from datetime import datetime
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras.layers import Dense, Input, LSTM, RepeatVector, TimeDistributed, Dropout
from keras.models import Model
from keras import callbacks
from keras.initializers import VarianceScaling
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering, AffinityPropagation
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from keras.constraints import UnitNorm, Constraint
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score

nmi = normalized_mutual_info_score
ari = adjusted_rand_score

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
def creating_y_labels_from_parameters_file(parameters, classification_task):
    """
    Fucntion to extract labels from paramters file depending on the classfication task

    Arguments:
    parameters -- parameters 2d array c
    classification_task -- Desired classification task, values: ['PyramidalvsInter', 'ExcvsInh']

    Returns:l
    y -- array containing the labels for every training example, shape = (#training examples, 1)
    """

    if classification_task == 'PyramidalvsInter':
        y = np.empty(shape=(parameters.shape[0],1))
        for i in range(0, parameters.shape[0]):
              if parameters[i][6] == 'p': #pyramidal neuron
                y[i] = 0
              elif parameters[i][6] == 'i': #Interneuron
                y[i] = 1
              else:
                   y[i] = 2 #Neither pyramidal nor interneuron

    if classification_task == 'ExcvsInh':
        y = np.empty(shape=(parameters.shape[0],1))
        for i in range(0, parameters.shape[0]):
            if parameters[i][8] != 0 and parameters[i][9] == 0: #excitatory postsynaptic connection
                   y[i] = 0
            elif parameters[i][8] == 0 and parameters[i][9] != 0: #inhibitory postsynaptic connection
                   y[i] = 1
            elif parameters[i][8] != 0 and parameters[i][9] != 0: #both connection types
                   y[i] = 3
            else:
                y[i] = 2 #Neither excitatory or inhibitory postsynaptic connection

    return y

def isolate_maximum_electrode(X, length=32):
  X_maxelec = np.empty(shape=(X.shape[0], length))
  for i in range(X.shape[0]):
    maxind = np.argmax(X[i,:])
    minind = np.argmin(X[i,:])
    if X[i,maxind] >= np.absolute(X[i,minind]):
      idx = maxind
    else:
      idx = minind
    elec = int(idx / 32)
    X_maxelec[i,:] = X[i,elec*32:(elec+1)*32]
  return X_maxelec

def normalization(Data_train, Data_dev, Data_test, epsilon = 0.001, alignment_training_examples = 'stacked_rows'):
  """
  Normalizes every example in the training, dev and test set in a feature-wise manner

  Arguments:
  Data_train -- A 2d array containing all training examples (stack of 1d training traces)
  Data_dev -- A 2d array containing all dev data examples (stack of 1d dev traces)
  Data_test -- A 2d array containing all test data examples (stack of 1d test traces)
  alignment_training_examples -- Describes in which dimensions the training examples are stacked
                                 (default = 'stacked_rows' means each training example is a row vector
                                 and they are stacked along the second axis,
                                 i.e. data.shape = (#training examples, #features per training example))
  """
  if alignment_training_examples == 'stacked_rows':
      #substract mean
      means = np.mean(Data_train, axis=0)
      means = means.reshape((1, Data_train.shape[1]))
      Normalized_data_train = Data_train - means

      #Normalize over variance
      element_wise_squares = np.multiply(Normalized_data_train,Normalized_data_train)
      variances = np.mean(element_wise_squares, axis=0)
      variances = variances.reshape((1, Data_train.shape[1]))
      Normalized_data_train /= (variances + epsilon)

      #Normalize data_test with same values
      Normalized_Data_dev = Data_dev - means
      Normalized_Data_dev /= (variances + epsilon)
      Normalized_Data_test = Data_test - means
      Normalized_Data_test /= (variances + epsilon)

  elif alignment_training_examples == 'stacked_columns':
      #substract mean
      means = np.mean(Data_train, axis=1)
      means = means.reshape((Data_train.shape[0], 1))
      Normalized_data_train = Data_train - means

      #Normalize over variance
      element_wise_squares = np.multiply(Normalized_data_train,Normalized_data_train)
      variances = np.mean(element_wise_squares, axis=1)
      variances = variances.reshape((Data_train.shape[0], 1))
      Normalized_data_train /= (variances + epsilon)

      #Normalize data_dev and data_test with same values
      Normalized_Data_dev = Data_dev - means
      Normalized_Data_dev /= (variances + epsilon)
      Normalized_Data_test = Data_test - means
      Normalized_Data_test /= (variances + epsilon)

  else:
    raise ValueError(f'Please specify the stacking of your training examples in the argument alignment_training_examples to one of the valid options: stacked_rows, stacked_columns')

  return Normalized_data_train, Normalized_Data_dev, Normalized_Data_test

def normalization_train(Data_train, epsilon = 0.001, alignment_training_examples = 'stacked_rows'):
  """
  Normalizes every training example of only the training set in a feature-wise manner

  Arguments:
  Data_train -- A 2d array containing all training examples (stack of 1d training traces)
  Data_dev -- A 2d array containing all dev data examples (stack of 1d dev traces)
  Data_test -- A 2d array containing all test data examples (stack of 1d test traces)
  alignment_training_examples -- Describes in which dimensions the training examples are stacked
                                 (default = 'stacked_rows' means each training example is a row vector
                                 and they are stacked along the second axis,
                                 i.e. data.shape = (#training examples, #features per training example))
  """
  if alignment_training_examples == 'stacked_rows':
      #substract mean
      means = np.mean(Data_train, axis=0)
      means = means.reshape((1, Data_train.shape[1]))
      Normalized_data_train = Data_train - means

      #Normalize over variance
      element_wise_squares = np.multiply(Normalized_data_train,Normalized_data_train)
      variances = np.mean(element_wise_squares, axis=0)
      variances = variances.reshape((1, Data_train.shape[1]))
      Normalized_data_train /= (variances + epsilon)

  elif alignment_training_examples == 'stacked_columns':
      #substract mean
      means = np.mean(Data_train, axis=1)
      means = means.reshape((Data_train.shape[0], 1))
      Normalized_data_train = Data_train - means

      #Normalize over variance
      element_wise_squares = np.multiply(Normalized_data_train,Normalized_data_train)
      variances = np.mean(element_wise_squares, axis=1)
      variances = variances.reshape((Data_train.shape[0], 1))
      Normalized_data_train /= (variances + epsilon)

  else:
    raise ValueError(f'Please specify the stacking of your training examples in the argument alignment_training_examples to one of the valid options: stacked_rows, stacked_columns')

  return Normalized_data_train

def Divide_train_dev_test(X, Y, fraction_list=[0.9, 0.05, 0.05], shuffle=False):
  """
  Divides a dataset into training, development and test subset

  Arguments:
  X -- A 2d array containing dataset
  Y -- A 2d array (shape = (#examples, 1) containing labels for each example)
  fraction_list -- list containing size percentages each subset should contain, e.g. [0.9, 0.05, 0.05] for 90% train, 5% dev & 5% test
  shuffle -- Boolean, if true do unison shuffling of X and Y

  Returns:
  X_train, Y_train, X_dev, Y_dev, X_test, Y_test -- Respective 2d array sub datasets
  """
  if shuffle == True:
    assert X.shape[0] == Y.shape[0]
    p = np.random.permutation(X.shape[0])
    X = X[p,:]
    Y = Y[p,:]

  X_train = X[0:round(fraction_list[0]*X.shape[0]),:]
  Y_train = Y[0:round(fraction_list[0]*X.shape[0]),:]
  X_dev = X[round(fraction_list[0]*X.shape[0]):round((fraction_list[0] + fraction_list[1])*X.shape[0]),:]
  Y_dev = Y[round(fraction_list[0]*X.shape[0]):round((fraction_list[0] + fraction_list[1])*X.shape[0]),:]
  X_test = X[round((fraction_list[0] + fraction_list[1])*X.shape[0]):,:]
  Y_test = Y[round((fraction_list[0] + fraction_list[1])*X.shape[0]):,:]

  return X_train, Y_train, X_dev, Y_dev, X_test, Y_test

def acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    y_pred = y_pred.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    from scipy.optimize import linear_sum_assignment
    ind = linear_sum_assignment(w.max() - w)
    ind = np.asarray(ind)
    ind = np.transpose(ind)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

class ClusteringLayer(Layer):
    """
    Clustering layer converts input sample (feature) to soft label, i.e. a vector that represents the probability of the
    sample belonging to each cluster. The probability is calculated with student's t-distribution.

    # Example
    ```
        model.add(ClusteringLayer(n_clusters=10))
    ```
    # Arguments
        n_clusters: number of clusters.
        weights: list of Numpy array with shape `(n_clusters, n_features)` witch represents the initial cluster centers.
        alpha: degrees of freedom parameter in Student's t-distribution. Default to 1.0.
    # Input shape
        2D tensor with shape: `(n_samples, n_features)`.
    # Output shape
        2D tensor with shape: `(n_samples, n_clusters)`.
    """

    def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(ClusteringLayer, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.initial_weights = weights
        self.input_spec = layers.InputSpec(ndim=2)

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = input_shape[1]
        self.input_spec = layers.InputSpec(dtype=K.floatx(), shape=(None, input_dim))
        self.clusters = self.add_weight(shape=(self.n_clusters, input_dim), initializer='glorot_uniform', name='clusters')
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, inputs, **kwargs):
        """ student t-distribution, as same as used in t-SNE algorithm.
         Measure the similarity between embedded point z_i and centroid µ_j.
                 q_ij = 1/(1+dist(x_i, µ_j)^2), then normalize it.
                 q_ij can be interpreted as the probability of assigning sample i to cluster j.
                 (i.e., a soft assignment)
        Arguments:
            inputs: the variable containing data, shape=(n_samples, n_features)
        Return:
            q: student's t-distribution, or soft labels for each sample. shape=(n_samples, n_clusters)
        """
        q = 1.0 / (1.0 + (K.sum(K.square(K.expand_dims(inputs, axis=1) - self.clusters), axis=2) / self.alpha))
        q **= (self.alpha + 1.0) / 2.0
        q = K.transpose(K.transpose(q) / K.sum(q, axis=1)) # Make sure each sample's 10 values add up to 1.
        return q

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], self.n_clusters

    def get_config(self):
        config = {'n_clusters': self.n_clusters}
        base_config = super(ClusteringLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def target_distribution(q):
  """
  computing an auxiliary target distribution
  """
  weight = q ** 2 / q.sum(0)
  return (weight.T / weight.sum(1)).T





def gradient_transform(X, sr = 24000):
  number_sampling_points = X.shape[1]
  time_step = 1 / sr * 10**6 #in us
  X_grad = np.empty(shape=(X.shape[0], X.shape[1]-1))
  for i in range(X.shape[0]):
    for j in range(X.shape[1]-1):
      X_grad[i, j] = (X[i,j+1] - X[i,j])/time_step

  return X_grad

tableau20 = [(31, 119, 180), (255, 127, 14), (44, 160, 44), (214, 39, 40),(148, 103, 189),
  (140, 86, 75), (227, 119, 194), (127, 127, 127), (188, 189, 34), (23, 190, 207)]

  # Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.
for i in range(len(tableau20)):
  r, g, b = tableau20[i]
  tableau20[i] = (r / 255., g / 255., b / 255.)


def autoencoder(dims, act='relu', init='glorot_uniform'):
        """
        Fully connected auto-encoder model, symmetric.
        Arguments:
            dims: list of number of units in each layer of encoder. dims[0] is input dim, dims[-1] is units in hidden layer.
                The decoder is symmetric with encoder. So number of layers of the auto-encoder is 2*len(dims)-1
            act: activation, not applied to Input, Hidden and Output layers
        return:
            (ae_model, encoder_model), Model of autoencoder and model of encoder
        """
        n_stacks = len(dims) - 1
        # input
        input_img = keras.Input(shape=(dims[0],), name='input')
        x = input_img
        # internal layers in encoder
        for i in range(n_stacks-1):
            x = Dense(dims[i + 1], activation=act, kernel_initializer=init, name='encoder_%d' % i)(x)
            #x = layers.Dropout(0.3)(x)


        # hidden layer
        encoded = Dense(dims[-1], kernel_initializer=init, name='encoder_%d' % (n_stacks - 1))(x)  # hidden layer, features are extracted from here

        x = encoded
        # internal layers in decoder
        for i in range(n_stacks-1, 0, -1):
            x = Dense(dims[i], activation=act, kernel_initializer=init, name='decoder_%d' % i)(x)
            #x = layers.Dropout(0.3)(x)

        # output
        x = Dense(dims[0], kernel_initializer=init, name='decoder_0')(x)
        decoded = x
        return keras.Model(inputs=input_img, outputs=decoded, name='AE'), keras.Model(inputs=input_img, outputs=encoded, name='encoder')


#Functions
def get_channel_data(data_stream, channel_ids = []):
  """s
  Get the signal of the given channel over the course of time and in its measured range.
  :param channel_id: ID of the channel
  :return: Tuple (vector of the signal, unit of the values)
  """
  if channel_ids == []:
    signal = data_stream.channel_data[:, :]
  else:
    signal = data_stream.channel_data[data_stream.channel_infos[channel_ids].row_index, :]
  scale = data_stream.channel_infos[0].adc_step.magnitude
  #scale = self.channel_infos[channel_id].get_field('ConversionFactor') * (10**self.channel_infos[channel_id].get_field('Exponent'))
  signal_corrected = (signal - data_stream.channel_infos[0].get_field('ADZero'))  * scale
  return signal_corrected

def plot_analog_stream_channel(analog_stream, channel_idx, from_in_s=0, to_in_s=None, show=True):
    """
    Plots data from a single AnalogStream channel

    :param analog_stream: A AnalogStream object
    :param channel_idx: A scalar channel index (0 <= channel_idx < # channels in the AnalogStream)
    :param from_in_s: The start timestamp of the plot (0 <= from_in_s < to_in_s). Default: 0
    :param to_in_s: The end timestamp of the plot (from_in_s < to_in_s <= duration). Default: None (= recording duration)
    :param show: If True (default), the plot is directly created. For further plotting, use show=False
    """
    # extract basic information
    ids = [c.channel_id for c in analog_stream.channel_infos.values()]
    channel_id = ids[channel_idx]
    channel_info = analog_stream.channel_infos[channel_id]
    sampling_frequency = channel_info.sampling_frequency.magnitude

    # get start and end index
    from_idx = max(0, int(from_in_s * sampling_frequency))
    if to_in_s is None:
        to_idx = analog_stream.channel_data.shape[1]
    else:
        to_idx = min(analog_stream.channel_data.shape[1], int(to_in_s * sampling_frequency))

    # get the timestamps for each sample
    time = analog_stream.get_channel_sample_timestamps(channel_id, from_idx, to_idx)

    # scale time to seconds:
    scale_factor_for_second = Q_(1,time[1]).to(ureg.s).magnitude
    time_in_sec = time[0] * scale_factor_for_second

    # get the signal
    signal = analog_stream.get_channel_in_range(channel_id, from_idx, to_idx)

    # scale signal to µV:
    scale_factor_for_uV = Q_(1,signal[1]).to(ureg.uV).magnitude
    signal_in_uV = signal[0] * scale_factor_for_uV

    # construct the plot
    _ = plt.figure(figsize=(20,6))
    _ = plt.plot(time_in_sec, signal_in_uV)
    _ = plt.xlabel('Time (%s)' % ureg.s)
    _ = plt.ylabel('Voltage (%s)' % ureg.uV)
    _ = plt.title('Channel %s' % channel_info.info['Label'])
    if show:
        plt.show()

def butter_bandpass(lowcut, highcut, fs, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='bandpass')
    return b, a

def butter_bandpass_high(lowcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = butter(order, low, btype='highpass')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=2):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data,  axis=0) #filtfilt(b, a, data)
    return y

def butter_bandpass_filter_high(data, lowcut, fs, order=2):
    b, a = butter_bandpass_high(lowcut, fs, order=order)
    y = lfilter(b, a, data,  axis=0)
    return y


def ellip_filter(data, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = ellip(order,0.1,40, [low, high], 'bandpass')
    #y = filtfilt(b, a, data)
    y = filtfilt(b, a, data,  axis=0, padtype = 'odd', padlen=3*(max(len(b),len(a))-1))
    return y

def ellip_filter_high(data, lowcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = ellip(order,0.1,40, low, 'highpass')
    #y = filtfilt(b, a, data)
    y = filtfilt(b, a, data,  axis=0, padtype = 'odd', padlen=3*(max(len(b),len(a))-1))
    return y



def spike_binning(X, time_window=0.1, fsample=20000, recording_len=None):
  """
  Inputs:
  X: Standard array of shape Number of spikes x (Recording points per spike + 2) (e.g. 64 + 2; +2 for spike time and Electrode at which the spike was recorded)
  time_window: Time window for binning, default: 0.1 s
  fsample: sampling frequency
  recording_len: Total lenght of recording (in s)

  Outputs:
  binned_spike_times: Array of shape Number of electrodes x (recording_len / time_window). For each electrode (rows), each column gives the number of spikes registered in the respective time interval
  """
  #If no recording_len is explicitely give, take last recorded spike time and round up to generate an artificial recording length
  if recording_len == None:
    print("Make sure the sampling frequency is correct. Using fsample = ", fsample)
    max_spike_time = np.max(X[:,1])
    recording_len = np.ceil(max_spike_time/fsample)

  #Get rid of spike shape info, not relevant here
  X = X[:,:2]

  #Define list of all recorded electrodes
  electrode_list = np.unique(X[:,0])

  #Calculate binned spike time representation
  binned_spike_times = np.zeros(shape=(len(electrode_list), int(recording_len / time_window)))

  for i in range(len(electrode_list)):
    rel_spikes = X[X[:,0] == electrode_list[i], :]
    for j in range(int(recording_len / time_window)):
      timely_spikes = rel_spikes[rel_spikes[:,1] <= (j+1)*fsample,:]
      timely_spikes = timely_spikes[timely_spikes[:,1] > j*fsample]
      binned_spike_times[i,j] = timely_spikes.shape[0]

  return binned_spike_times


def binary_mutual_information(X, Y, quantile=75):
  """
  Inputs:
  X,Y: binned spike time arrays of two electrodes of which we want to calculate the mutual information. I.e. two rows of spike_binning output
  quantile: To simplify caclulations, we classify each time bin either as burst containing (set to 1) or non-member (set to 0). To decide if the spikes counted in a certain time interval should be considered
  member or non-member, we apply the requirement that the number of spikes must the be in a certain quantile (default 75) of all bins of the respective electrode

  Outputs:
  MI: Mutual information score for the two electrodes
  """

  TH_X = np.percentile(X, quantile)
  TH_Y = np.percentile(Y, quantile)

  X_bin = X > TH_X
  Y_bin = Y > TH_Y

  # get counts of each possible combination of
  # classification at each time interval
  a = len(Y_bin[X_bin==False][Y_bin[X_bin==False] == False])
  b = len(Y_bin[X_bin==True][Y_bin[X_bin==True] == False])
  c = len(Y_bin[X_bin==False][Y_bin[X_bin==False] == True])
  d = len(Y_bin[X_bin==True][Y_bin[X_bin==True] == True])

  # build dataframe of counts, apply artificial replacement
  # of zeros with ones to prevent NaN being produced in computation
  vals = np.array([[a,c],
                  [b,d]])
  vals[vals == 0] = 1


  # compute mutual information by iterating over every possible
  # combination of time interval classifs, get counts of intervals
  n = np.sum(vals)
  MI = 0

  for i in range(vals.shape[0]):
    for j in range(vals.shape[1]):
      pi = np.sum(vals[i,:]) / n
      pj = np.sum(vals[:,j]) / n
      pij = vals[i,j] / n
      MI = MI + (pij * math.log2(pij / (pi * pj)))

  return MI

def get_valid_electrode_pairs(X, max_dist=300, electrode_layout="MCS standard"):
  """
  Inputs:
  X: Standard array of shape Number of spikes x (Recording points per spike + 2) (e.g. 64 + 2; +2 for spike time and Electrode at which the spike was recorded)
  max_dist: Maximum distance between electrodes to be considered a valid pair in um
  electrode_layout: Spatial layout of electrodes. Currently, we only support the MultiChannelSystem 60 electrode standard layout


  Output:
  electrode_pairs: List of electrode pairs (tupels) that are valid
  electrode_pairs_idx: List of corresponding electrode pair indices (tupels)
  """
  if electrode_layout == "MCS standard":
    inter_electrode_distance = 200
    MCS_layout = np.array([[0,21,31,41,51,61,71,0],
                      [12,22,32,42,52,62,72,82],
                      [13,23,33,43,53,63,73,83],
                      [14,24,34,44,54,64,74,84],
                      [15,25,35,45,55,65,75,85],
                      [16,26,36,46,56,66,76,86],
                      [17,27,37,47,57,67,77,87],
                      [0,28,38,48,58,68,78,0]])
  else:
    print("Error: Currently only MCS standard layout supported!")

  #Create list of all electrode pairs
  electrodes = np.unique(X[:,0])
  electrode_pairs = [(a, b) for idx, a in enumerate(electrodes) for b in electrodes[idx + 1:]]

  #Filter out all pairs that have a larger eucl. distance than dist_max
  delete_idx = list()
  for pair in electrode_pairs:
    pos1 = np.where(MCS_layout == pair[0])
    pos2 = np.where(MCS_layout == pair[1])
    x_dist = np.abs(pos1[0][0] - pos2[0][0])
    y_dist = np.abs(pos1[1][0] - pos2[1][0])
    euc_dist = np.sqrt((x_dist*inter_electrode_distance)**2 + (y_dist*inter_electrode_distance)**2)
    if euc_dist > max_dist:
      delete_idx.append(electrode_pairs.index(pair))

  electrode_pairs = [i for j, i in enumerate(electrode_pairs) if j not in delete_idx]

  #Convert to respective indices (which are 0-59 sorted)
  electrode_pairs_idx = list()
  for pair in electrode_pairs:
    electrode_pairs_idx.append((electrodes.tolist().index(pair[0]), electrodes.tolist().index(pair[1])))

  return electrode_pairs, electrode_pairs_idx


def calculate_MI_pairs(X, time_window=0.1, fsample=20000, recording_len=None, max_dist=300, electrode_layout="MCS standard", quantile=75):
  """
  Inputs:
  X: Standard array of shape Number of spikes x (Recording points per spike + 2) (e.g. 64 + 2; +2 for spike time and Electrode at which the spike was recorded)
  time_window: Time window for binning, default: 0.1 s
  fsample: sampling frequency
  recording_len: Total lenght of recording (in s)
  max_dist: Maximum distance between electrodes to be considered a valid pair in um
  electrode_layout: Spatial layout of electrodes. Currently, we only support the MultiChannelSystem 60 electrode standard layout
  quantile: To simplify caclulations, we classify each time bin either as burst containing (set to 1) or non-member (set to 0). To decide if the spikes counted in a certain time interval should be considered
  member or non-member, we apply the requirement that the number of spikes must the be in a certain quantile (default 75) of all bins of the respective electrode

  Outputs:
  MI_arr: Array of all Mutual information pairs of electrodes
  """
  #Create binned array of spike times
  binned_spike_times = spike_binning(X, time_window=time_window, fsample=fsample, recording_len=recording_len)

  #Calculate valid electrode pairs
  electrode_pairs, electrode_pairs_idx = get_valid_electrode_pairs(X, max_dist=max_dist, electrode_layout=electrode_layout)


  #Calculate Mutual information for all pairs
  MI_arr = np.zeros(len(electrode_pairs_idx))
  for i in range(len(electrode_pairs_idx)):
    MI_arr[i] = binary_mutual_information(binned_spike_times[electrode_pairs_idx[i][0],:], binned_spike_times[electrode_pairs_idx[i][1],:], quantile=quantile)

  return electrode_pairs, MI_arr


def do_MI_analysis_across_recordings(data_list = [], condition_names=[], analysis_name=None, time_window=0.1, fsample=20000, max_dist=300, electrode_layout="MCS standard", quantile=75, plot=True, plot_violin=True, normalize=False):
  """
  Inputs:
  data_list: List of dicts including standard arrays of shape Number of spikes x (Recording points per spike + 2) (e.g. 64 + 2; +2 for spike time and Electrode at which the spike was recorded) of each condition, recording_len and sampling frequency
  condition_names: List of condition names, e.g. condition_names=["Day 0", "Day 1",...]
  plot: Should results be plotted
  normalize: Should results be normalized to MI score of first condition (as usually baseline)
  time_window: Time window for binning, default: 0.1 s
  max_dist: Maximum distance between electrodes to be considered a valid pair in um
  electrode_layout: Spatial layout of electrodes. Currently, we only support the MultiChannelSystem 60 electrode standard layout
  quantile: To simplify caclulations, we classify each time bin either as burst containing (set to 1) or non-member (set to 0). To decide if the spikes counted in a certain time interval should be considered
  member or non-member, we apply the requirement that the number of spikes must the be in a certain quantile (default 75) of all bins of the respective electrode

  Outputs:
  MI_arr_list: List of arrays of all Mutual information pairs of valid electrodes for each condition
  optional: plotting
  """

  MI_arr_list = list()

  for dat in data_list:
    electrode_pairs, MI_arr = calculate_MI_pairs(dat['Raw_spikes'], time_window=time_window, fsample=dat['Sampling rate'], recording_len=dat['Recording len'], max_dist=max_dist, electrode_layout=electrode_layout, quantile=quantile)
    MI_arr_list.append(MI_arr)

  if plot_violin:
    #Violinplot!
    fig = plt.figure()
    ax = fig.add_axes([1,1,1,1])
    parts = ax.violinplot(MI_arr_list, positions=range(1,len(MI_arr_list)+1), showmeans=True)

    for pc in parts['bodies']:
        pc.set_edgecolor('black')
        pc.set_alpha(0.3)
    ax.set_xticks(range(1,len(MI_arr_list)+1))
    ax.set_xlim(range(1,len(MI_arr_list)+1)[0]-0.5,range(1,len(MI_arr_list)+1)[-1]+0.5)
    ax.set_xticklabels(condition_names)
    ax.set_xlabel('Conditions')
    ax.set_ylabel('Mutual information in [bits]')
    ax.set_title("Mutual Information Analysis: " + analysis_name)
    plt.show()



  MI_mean_arr = np.array([element.mean() for element in MI_arr_list])
  MI_SE = np.array([statistics.stdev(element) / np.sqrt(len(element)) for element in MI_arr_list])

  if plot:
    tableau20 = [(31, 119, 180), (255, 127, 14), (44, 160, 44), (214, 39, 40),(148, 103, 189),
                (140, 86, 75), (227, 119, 194), (127, 127, 127), (188, 189, 34), (23, 190, 207)]
    # Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.
    for i in range(len(tableau20)):
      r, g, b = tableau20[i]
      tableau20[i] = (r / 255., g / 255., b / 255.)
    fig = plt.figure()
    ax = fig.add_axes([1,1,1,1])
    ax.errorbar(range(1, len(data_list)+1), MI_mean_arr, yerr=MI_SE, capsize=8, mfc=tableau20[0], mec=tableau20[0], ecolor=tableau20[0])
    #ax.plot(range(1, len(X)+1), means_sttc)
    ax.set_xticks(range(1,len(data_list)+1))
    ax.set_xlim(range(1,len(data_list)+1)[0]-0.5,range(1,len(data_list)+1)[-1]+0.5)
    ax.set_xticklabels(condition_names)
    ax.set_xlabel('Conditions')
    ax.set_ylabel('Mutual information in [bits]')
    ax.set_title("Mutual Information Analysis: " + analysis_name)
    plt.grid(color=(0.9, 0.9, 0.9),linestyle="-", linewidth=1)
    plt.show()

    if normalize:
      fig = plt.figure()
      ax = fig.add_axes([1,1,1,1])
      ax.errorbar(range(1, len(data_list)+1), MI_mean_arr/MI_mean_arr[0]*100, yerr=MI_SE/MI_mean_arr[0]*100, capsize=8, mfc=tableau20[0], mec=tableau20[0], ecolor=tableau20[0])
      #ax.plot(range(1, len(X)+1), means_sttc)
      ax.set_xticks(range(1,len(data_list)+1))
      ax.set_xlim(range(1,len(data_list)+1)[0]-0.5,range(1,len(data_list)+1)[-1]+0.5)
      ax.set_xticklabels(condition_names)
      ax.set_xlabel('Conditions')
      ax.set_ylabel('Normed Mutual information in [%]')
      ax.set_title("Relative change in Mutual Information Analysis: " + analysis_name)
      plt.grid(color=(0.9, 0.9, 0.9),linestyle="-", linewidth=1)
      plt.show()

  return condition_names, MI_mean_arr, MI_SE, electrode_pairs, MI_arr_list


**1. Filtering**

In [None]:
#load data
file_name = 'name'
path = 'path' #add path
file_path = path + file_name + '.h5'
os.chdir(path)

big_file = False #In case file is too big for RAM, dividde in two parts. Set to true and load 30 channels each

file = McsPy.McsData.RawData(file_path)
electrode_stream = file.recordings[0].analog_streams[0]
if including_stimulation:
  event_stream = file.recordings[0].event_streams[0].event_entity[1]
  all_events = event_stream.get_events()
  all_event_timestamps = event_stream.get_event_timestamps()[0]

# extract basic information
ids = [c.channel_id for c in electrode_stream.channel_infos.values()]
sr = int(electrode_stream.channel_infos[0].sampling_frequency.magnitude)

#Get signal
if big_file:
  step_size = 10
  min_step = 0
  max_step = 60
  for i in range(min_step, max_step, step_size):
    #signal = get_channel_data(electrode_stream, channel_ids = [j for j in range(i,i+step_size)])
    scale_factor_for_uV = Q_(1,'volt').to(ureg.uV).magnitude
    if i == min_step:
      recording_data = (get_channel_data(electrode_stream, channel_ids = [j for j in range(i,i+step_size)]) * scale_factor_for_uV).T
    else:
      recording_data = np.concatenate((recording_data, (get_channel_data(electrode_stream, channel_ids = [j for j in range(i,i+step_size)]) * scale_factor_for_uV).T), axis=1)
    print("iteration", i+step_size, "completed")
    print("recording.shape:", recording_data.shape)

else:
  signal = get_channel_data(electrode_stream, channel_ids = [])
  scale_factor_for_uV = Q_(1,'volt').to(ureg.uV).magnitude
  recording_data = (get_channel_data(electrode_stream, channel_ids = []) * scale_factor_for_uV).T


#Filtering
# Sample rate and desired cutoff frequencies (in Hz).
lowcut = 300
highcut = 3000
fsample = sr

filtered = np.empty(shape=(recording_data.shape[0], recording_data.shape[1]))
for i in range(recording_data.shape[1]):
  filtered[:,i] = butter_bandpass_filter(recording_data[:,i], lowcut, highcut, fsample, order=2)

del recording_data




**2. Spike detection**

In [None]:
reject_ch = [None] #if there is no damaged channels

length_of_chunck = 360 #Divide recoridng in chunks (in s) to make the thresholding more dynamic
no_chunks = math.ceil(filtered.shape[0]/fsample/length_of_chunck) #finding the number of chunks. (e.g. is recoring is 60.5s, no_chunks = 61)#no_chunks = 20 #for specifying a time window manually
print("Number of chunks:", no_chunks)
spk = list()
spike_times = [[0] for x in range(filtered.shape[1])]
ids = [c.channel_id for c in electrode_stream.channel_infos.values()]


points_pre = 20
points_post = 44

overlap = 20 #10 data point overlap between chunks to check whetehr it is the local minima.
t = 0 #rejects the 1st second to avoid ripples in filter

while t < no_chunks: #rejects the incomplete chunk at the end to avoid filter ripple
  if (t+1)*fsample*length_of_chunck <= len(filtered):
    chunk = filtered[t*fsample*length_of_chunck:((t+1)*fsample*length_of_chunck + overlap - 1)]
      #print("complete chunk")
  else:
    chunk = filtered[t*fsample*length_of_chunck:]
  #stdev = np.std(chunk, axis = 0)
  med = np.median(np.absolute(chunk)/0.6745, axis=0)  #chunck

  for index in range(len(chunk)):
    if index > points_pre and index <fsample*length_of_chunck-points_post:
      threshold_cross = chunk[index, :] < -5*med  #choose the threshold value. Finds which channels exceed the threshold at this instance
      threshold_arti =  chunk[index, :] > -30*med

      threshold = threshold_cross*threshold_arti


      probable_spike = threshold #*stim_reject#finds out which electrode crosses -5*SD and ignores stimulation artefacts

      if np.sum(probable_spike > 0):
        for e in range(filtered.shape[1]):
          if big_file==True:
            channel_id = ids[e + min_step]
            channel_info = electrode_stream.channel_infos[channel_id]
            ch = int(channel_info.info['Label'][-2:])
          else:
            channel_id = ids[e]
            channel_info = electrode_stream.channel_infos[channel_id]
            ch = int(channel_info.info['Label'][-2:])
          if probable_spike[e] == 1 and not(ch in reject_ch): #whether threshold exceeded at an electrode and if it is rejected
            #bottom line does an additional check whether 3 points crosses -2.5*SD
            #if probable_spike[e] and np.sum((filtered[(index-1):(index+2), e]) < -2.5*stdev[e]*np.ones(3)) == 3:
            t_diff = (fsample*t*length_of_chunck + index) - spike_times[e][-1]
            if t_diff > 0.002*fsample and chunk[index, e] == np.min(chunk[(index-points_pre):(index+points_post), e]): #and chunk[index, e] < -9: #whether the spike is 2ms apart and whether it is the true minumum and not just any point below -5*SD
              spike_times[e].append(fsample*t*length_of_chunck + index)
              if (fsample*t*length_of_chunck + index + points_post) < filtered.shape[0]: #making sure that the whole spike waveform is within the limits of the filtered signal array
                spk_wave = list(filtered[(fsample*t*length_of_chunck + index - points_pre):(fsample*t*length_of_chunck + index + points_post), e])#selecting 1.6ms around the spike time from the whole fltered signal array
                spk_wave.insert(0, (fsample*t*length_of_chunck + index))
                spk_wave.insert(0, ch)
                spk.append(spk_wave)


  t = t+1
  if t%5 == 0:
    print("Completed chunk", t)

print("Total number of detected spikes:", len(spk))
f = open(path + '/Spikes_' + file_name + '.txt', 'a', newline='\n')
csv.writer(f, delimiter=" ", ).writerows(spk)
f.close()

In [None]:
data_path = path + '/Spikes_' + file_name + '.txt'
X = np.genfromtxt(data_path, delimiter=' ')

Results = {}
Results["Filename"] = file_name
Results["Sampling rate"] = fsample
Results["Recording len"] = filtered.shape[0] / fsample
Results["Raw_spikes"] = X[X[:, 1].argsort()]

#save
with open(path + '/Spike_File_' + file_name + '.pkl', 'wb+') as f:
    pickle.dump(Results, f, -1)

**3. Analyse Spike data**

In [None]:
file_path = path + '/Spike_File_' + file_name + '.pkl'
file_names = ['Day 1', 'Day 2', 'Day 3', 'Day 4']
dict_list = list()
recording_lens = []
pseudo_cond_names = list(["Day 1 / Baseline", "Day 2", "Day 3", "Day 4"])
analysis_name = "Control_graphene"
sr = []
#open files
for i in range(len(file_names)):
  with open(file_path + file_names[i] + '.pkl', 'rb') as f:
    dict_list.append(pickle.load(f))
    recording_lens.append(dict_list[i]["Recording len"])
    sr.append(int(dict_list[i]['Sampling rate']))

electrode_list = list()
for p in range(len(dict_list)):
  electrode_list = electrode_list + np.unique(dict_list[p]['Raw_spikes'][:,0]).tolist()
electrode_list = list(np.unique(electrode_list))

#Calculate firing overall rates
res = np.zeros(shape=(len(electrode_list), 4))
for el in electrode_list:
  for j in range(4):
    el_dict = dict_list[j]['Raw_spikes'][dict_list[j]['Raw_spikes'][:,0]==el,:]
    res[electrode_list.index(el), j] = el_dict.shape[0]
res = res / np.array(recording_lens)
rates = np.zeros(shape=(1,4))
for j in range(len(dict_list)):
  rates[0,j] = dict_list[j]['Raw_spikes'].shape[0] / recording_lens[j]
print("Firing rates:", rates)
rates = (rates / rates[0,0] - 1)*100
print("Relative change in firing rate vs. Day 1:", rates)


**Synchronicity analysis**

In [None]:
dt = 0.1
max_dist = 300

condition_names, MI_mean_arr, MI_SE, electrode_pairs, MI_arr_list = do_MI_analysis_across_recordings(data_list = dict_list, condition_names=pseudo_cond_names, analysis_name=analysis_name, plot=True, normalize=False, time_window=dt, fsample=20000, max_dist=max_dist, electrode_layout="MCS standard", quantile=75)

**Spike Sort Filtering**

In [None]:
#Data pre-processing
X = np.empty(shape=(1,67))
for p in range(len(dict_list)):
  X = np.concatenate((X, np.concatenate((p*np.ones(shape=(len(dict_list[p]['Raw_spikes']),1)),dict_list[p]['Raw_spikes'][:,:]),axis=1)), axis=0)
X = X[1:,:]

cond = X[:,0]
elecs = X[:,1]
spike_times = X[:,2]
X = X[:,3:]

#Pre-processing
X_grad = gradient_transform(X)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_grad)

del X_grad

#initialize training parameters
file_name = "test"
dims = [X_train.shape[-1], 500, 500, 2000, 10]
#init = VarianceScaling(scale=1. / 3., mode='fan_in', distribution='uniform')
pretrain_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
pretrain_epochs = 100
batch_size = 4096
save_dir = save_path
name_save_process = "test" + '.h5'

#initialize & train
auto, encoder = autoencoder(dims)
auto.compile(optimizer=pretrain_optimizer, loss='mse')
cb = tf.keras.callbacks.EarlyStopping(monitor="loss", min_delta=0.0001, patience=10)
history = auto.fit(X_train, X_train, batch_size=batch_size, epochs=pretrain_epochs, callbacks=cb)
auto.save_weights(save_dir + name_save_process)

(331339, 64)


In [None]:
#Cluster in arbitrary number of classes for artefact detection
n_clusters = 20
kmeans = KMeans(n_clusters=n_clusters, n_init=20)
y_pred = kmeans.fit_predict(encoder.predict(X_train))
print(np.unique(y_pred))
for n in range(len(np.unique(y_pred))):
  print("Cluster", n, ":", len(y_pred[y_pred==n]))

In [None]:
n_clusters = 20
kmeans = KMeans(n_clusters=n_clusters, n_init=20)
y_pred = kmeans.fit_predict(encoder.predict(X_train))
print(np.unique(y_pred))
for n in range(len(np.unique(y_pred))):
  print("Cluster", n, ":", len(y_pred[y_pred==n]))


In [None]:
#plot spike classes
t = np.linspace(0, 3.2, 64)
for j in range(len(np.unique(y_pred))):
  temp_spikes = X[y_pred[:]==j,:]

  plt.plot(t, np.mean(temp_spikes, axis=0), color=tableau20[j%10], label='Class %d' %(j))
  plt.fill_between(t, np.mean(temp_spikes, axis=0) - np.std(temp_spikes, axis=0), np.mean(temp_spikes, axis=0) + np.std(temp_spikes, axis=0), alpha=0.1, edgecolor=tableau20[j%10], facecolor=tableau20[j%10])
  plt.grid(color=(0.9, 0.9, 0.9),linestyle="-", linewidth=1)
  plt.xlim(0,3.2)
  plt.ylim(-50,50)
  plt.legend(loc="upper right")
  plt.ylabel("Voltage in [uV]")
  plt.xlabel("Time in [ms]")
  plt.show()


