<a href="https://colab.research.google.com/github/JanBellingrath/Hippocampal_Neuron_Data_Processing/blob/main/Data_Pipeline_Google_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
from os.path import join
import dask
import numpy as np  
import pandas as pd
import scipy
from scipy.io import loadmat
from collections import namedtuple
import mrestimator as mre

import loren_frank_data_processing
from loren_frank_data_processing.core import get_data_filename, logger, get_epochs, get_data_structure
from loren_frank_data_processing.tetrodes import  convert_tetrode_epoch_to_dataframe, get_LFP_dataframe, _get_tetrode_id
from loren_frank_data_processing.neurons import get_neuron_info_path, get_spikes_dataframe, make_neuron_dataframe, convert_neuron_epoch_to_dataframe, _add_to_dict, _get_neuron_id
import loren_frank_data_processing.core

In [None]:
#modified functions from loren_frank_dataprocessing, where changes largely relate to data-processing in GoogleColab; + some minor bug fixes.

def get_data_filename_modified(animal, day, file_type):
    '''Returns the Matlab file name assuming it is in the Raw Data
    directory.
    Parameters
    ----------
    animal : namedtuple
        First element is the directory where the animal's data is located.
        The second element is the animal shortened name.
    day : int
        Day of recording
    file_type : str
        Data structure name (e.g. linpos, dio)
    Returns
    -------
    filename : str
        Path to data file
    '''
    filename = '{animal.short_name}{file_type}{day:02d}.mat'.format(
        animal=animal,
        file_type=file_type,
        day=day)
    return filename 


def get_spikes_dataframe(neuron_key, animals):
    '''Spike times for a particular neuron.
    Parameters
    ----------
    neuron_key : tuple
        Unique key identifying that neuron. Elements of the tuple are
        (animal_short_name, day, epoch, tetrode_number, neuron_number).
        Key can be retrieved from `make_neuron_dataframe` function.
    animals : dict of named-tuples
        Dictionary containing information about the directory for each
        animal. The key is the animal_short_name.
    Returns
    -------
    spikes_dataframe : pandas.DataFrame
    '''
    animal, day, epoch, tetrode_number, neuron_number = neuron_key
    try:
        neuron_file = loadmat(
            get_data_filename_modified(animals[animal], day, 'spikes'))
        print(neuron_file)
    except (FileNotFoundError, TypeError):
        logger.warning('Failed to load file: {0}'.format(
            get_data_filename_modified(animals[animal], day, 'spikes')))
    try:
        spike_time = neuron_file['spikes'][0, -1][0, epoch - 1][
            0, tetrode_number - 1][0, neuron_number - 1][0]['data'][0][
            :, 0]
        spike_time = pd.TimedeltaIndex(spike_time, unit='s', name='time')
    except IndexError:
        spike_time = []
    return pd.Series(
        np.ones_like(spike_time, dtype=int), index=spike_time,
        name='{0}_{1:02d}_{2:02}_{3:03}_{4:03}'.format(*neuron_key))


def get_LFP_dataframe_modified(tetrode_key, animals):
    '''Gets the LFP data for a given epoch and tetrode.
    Parameters
    ----------
    tetrode_key : tuple
        Unique key identifying the tetrode. Elements are
        (animal_short_name, day, epoch, tetrode_number).
    animals : dict of named-tuples
        Dictionary containing information about the directory for each
        animal. The key is the animal_short_name.
    Returns
    -------
    LFP : pandas dataframe
        Contains the electric potential and time
    '''
    #print(tetrode_key)
    #tetrode_key = tuple(tetrode_key.split(","))
    #tetrode_key = tetrode_key[1:-1].split(', ')
    
    try:
        lfp_file = loadmat(get_LFP_filename_modified(tetrode_key, animals))
        lfp_data = lfp_file['eeg'][0, -1][0, -1][0, -1]
        lfp_time = reconstruct_time(
            lfp_data['starttime'][0, 0].item(),
            lfp_data['data'][0, 0].size,
            float(lfp_data['samprate'][0, 0].squeeze()))
        return pd.Series(
            data=lfp_data['data'][0, 0].squeeze().astype(float),
            index=lfp_time,
            name='{0}_{1:02d}_{2:02}_{3:03}'.format(*tetrode_key))
    except (FileNotFoundError, TypeError):
        logger.warning('Failed to load file: {0}'.format(
            get_LFP_filename_modified(tetrode_key, animals)))
        


def make_neuron_dataframe_modified(animals):
    '''Information about all recorded neurons such as brain area.
    The index of the dataframe corresponds to the unique key for that neuron
    and can be used to load spiking information.
    Parameters
    ----------
    animals : dict of named-tuples
        Dictionary containing information about the directory for each
        animal. The key is the animal_short_name.
    Returns
    -------
    neuron_information : pandas.DataFrame
    '''
    #neuron_file_names = [(get_neuron_info_path(animals[animal]), animal)
     #                    for animal in animals]
    neuron_file_names = ['concellinfo.mat', 'Corcellinfo.mat']
    neuron_data = [(loadmat(file_name), animal_name) for animal_name, file_name in zip(animals.keys(), neuron_file_names)]

    return pd.concat([
        convert_neuron_epoch_to_dataframe(
            epoch, animal, day_ind + 1, epoch_ind + 1)
        for cellfile, animal in neuron_data
        for day_ind, day in enumerate(cellfile['cellinfo'].T)
        for epoch_ind, epoch in enumerate(day[0].T)
    ]).sort_index()


def get_LFP_filename_modified(tetrode_key, animals):
    '''Returns a file name for the tetrode file LFP for an epoch.
    Parameters
    ----------
    tetrode_key : tuple
        Unique key identifying the tetrode. Elements are
        (animal_short_name, day, epoch, tetrode_number).
    animals : dict of named-tuples
        Dictionary containing information about the directory for each
        animal. The key is the animal_short_name.
    Returns
    -------
    filename : str
        File path to tetrode file LFP
    '''
    
    animal, day, epoch, tetrode_number = tetrode_key
    filename = ('{animal.short_name}eeg{day:02d}-{epoch}-'
                '{tetrode_number:02d}.mat').format(
                    animal=animals[animal], day=day, epoch=epoch,
                    tetrode_number=tetrode_number)
    return filename


def convert_neuron_epoch_to_dataframe_modified(tetrodes_in_epoch, animal, day,
                                      epoch):
    '''
    Given an neuron data structure, return a cleaned up DataFrame
    '''
    DROP_COLUMNS = ['ripmodtag', 'thetamodtag', 'runripmodtag',
                    'postsleepripmodtag', 'presleepripmodtag',
                    'runthetamodtag', 'ripmodtag2', 'runripmodtag2',
                    'postsleepripmodtag2', 'presleepripmodtag2',
                    'ripmodtype', 'runripmodtype', 'postsleepripmodtype',
                    'presleepripmodtype', 'FStag', 'ripmodtag3',
                    'runripmodtag3', 'ripmodtype3', 'runripmodtype3',
                    'tag', 'typetag', 'runripmodtype2',
                    'tag2', 'ripmodtype2', 'descrip']

    NEURON_INDEX = ['animal', 'day', 'epoch',
                    'tetrode_number', 'neuron_number']

    neuron_dict_list = [_add_to_dict(
        _convert_to_dict_modified(neuron), tetrode_ind, neuron_ind)
        for tetrode_ind, tetrode in enumerate(
        tetrodes_in_epoch[0][0])
        for neuron_ind, neuron in enumerate(tetrode[0])
        #if neuron.size > 0
    ]
    try:
        return (pd.DataFrame(neuron_dict_list)
                  .drop(DROP_COLUMNS, axis=1, errors='ignore')
                  .assign(animal=animal)
                  .assign(day=day)
                  .assign(epoch=epoch)
                  .assign(neuron_id=_get_neuron_id)
                # set index to identify rows
                  .set_index(NEURON_INDEX)
                  .sort_index())
    except AttributeError:
        logger.debug(f'Neuron info {animal}, {day}, {epoch} not processed')

def _convert_to_dict_modified(struct_array):
    try:
        return {name: struct_array[name].item().item()
                for name in struct_array.dtype.names
                if struct_array[name].item().size == 1}
    except TypeError:
        return {}
    #added in the AttributeError
    except AttributeError:
        return {}

def get_spike_indicator_dataframe_modified(neuron_key, animals, time_function=default_time_functionNEW):
    time = time_function()
    spikes_df = get_spikes_dataframe(neuron_key, animals)
    print(spikes_df.index_total_seconds)
    print(time.index.total_seconds)
    time_index = None
    try:
    
        time_index = np.digitize(spikes_df.index.total_seconds(),
                             time.index.total_seconds())
 
        time_index[time_index >= len(time)] = len(time) -1
    #the exception is for empty data
    except AttributeError: 
        print('No spikes here; data is emtpy')
    
    #the following accounts for an empty time_index
    if time_index is not None:
      return (spikes_df.groupby(time.iloc[time_index]).sum().reindex(index=time, fill_value=0))
    else:
        return None

def _convert_to_dict_modified(struct_array):
    try:
        return {name: struct_array[name].item().item()
                for name in struct_array.dtype.names
                if struct_array[name].item().size == 1}
    except TypeError:
        return {}
    #added in the AttributeError
    except AttributeError:
        return {}

def _get_neuron_id(dataframe):
    '''Unique identifier string for a neuron'''
    return (dataframe.animal + '_' +
            dataframe.day.map('{:02d}'.format) + '_' +
            dataframe.epoch.map('{:02}'.format) + '_' +
            dataframe.tetrode_number.map('{:03}'.format) + '_' +
            dataframe.neuron_number.map('{:03}'.format))

def get_trial_time_modified_3(epoch_key, animals):
    """Time in the recording session in terms of the LFP.
    This will return the LFP time of the first tetrode found (according to the
    tetrode info). This is useful when there are slightly different timings
    for the recordings and you need a common time.
    Parameters
    ----------
    epoch_key : tuple
        Unique key identifying a recording epoch with elements
        (animal, day, epoch)
    animals : dict of named-tuples
        Dictionary containing information about the directory for each
        animal. The key is the animal_short_name.
    Returns
    -------
    time : pandas.Index
    """
    tetrode_info = make_tetrode_dataframe_modified(animals, epoch_key=epoch_key)
    for tetrode_key in tetrode_info.index:
        tetrode_key = str(tetrode_key)
        print(tetrode_key)
        tetrode_key = ','.join(tetrode_key.split('_'))
        print(tetrode_key)
        lfp_df = get_LFP_dataframe_modified(tetrode_key, animals)
        if lfp_df is not None:
            break

    return lfp_df.index.rename("time")

def get_tetrode_info_path_modified(animal):
    '''Returns the Matlab tetrode info file name assuming it is in the
    Raw Data directory.
    Parameters
    ----------
    animal : namedtuple
        First element is the directory where the animal's data is located.
        The second element is the animal shortened name.
    Returns
    -------
    filename : str
        The path to the information about the tetrodes for a given animal.
    '''
    filename = '{animal}tetinfo'.format(animal=animal)
    return filename


In [None]:
#new functions

#the following function sums the number of spikes over the different time_series (one per neuron, initially), to create one time-series array for the collective activity
def sum_time_series_dict(time_series_dict):
    
    num_bins = max([len(x) for x in time_series_dict.values() if x is not None])
    res = np.zeros(num_bins)
    for arr in time_series_dict.values():
        if arr is not None:
            res[:len(arr)] += arr
    return res

#this function is used to subdivide the list of all the neuron-ids into more manageable chunks.
def divide_list_of_strings(strings, n):
    # calculate the target number of strings per sublist
    target = len(strings) // n
    # initialize the list of sublists
    sublists = [[] for _ in range(n)]
    # initialize the index of the current sublist
    index = 0
    # loop through the input strings and distribute them evenly among the sublists
    for string in strings:
        # add the current string to the current sublist
        sublists[index].append(string)
        # move to the next sublist if the current sublist is full
        if len(sublists[index]) == target:
            index += 1
    # return the list of sublists
    return sublists


#generate dict of keys of neurons with time series
def generate_spike_indicator_dict(neuron_key_list, animals):
    spike_indicator_dict = {}
    for neuron_key_str in neuron_key_list:
        
        animal_short_name, day_number, epoch_number, tetrode_number, neuron_number = neuron_key_str.split("_")
        neuron_key = (animal_short_name, int(day_number), int(epoch_number), int(tetrode_number), int(neuron_number))
        try:
            spike_indicator_array = get_spike_indicator_dataframe_modified(neuron_key, animals).values
        except AttributeError:
            spike_indicator_array = None
            print(f"No spike indicator data for neuron: {neuron_key}")
        spike_indicator_dict[neuron_key] = spike_indicator_array
    return spike_indicator_dict


def split_neuron_dataframe_informationally(df, split_cols):
    '''Splits a DataFrame into multiple DataFrames based on specified
    column values.
    Parameters
    ----------
    df : pandas.DataFrame
        The DataFrame to be split.
    split_cols : list of str
        The names of the columns to use for splitting the DataFrame.
    Returns
    -------
    dfs : dict
        A dictionary containing the split DataFrames. The keys are tuples
        containing the unique combinations of the specified column values.
    '''
    dfs = {}
    for idx, group in df.groupby(split_cols):
        dfs[idx] = group.copy()
    return dfs

  
def apply_input_handler_to_split_dataframes(dfs):
    '''
    Applies mre.input_handler() to each DataFrame in the input dictionary.
    Parameters
    ----------
    dfs : dict
        A dictionary containing split DataFrames.
    Returns
    -------
    output : dict
        A dictionary containing the split DataFrames, with mre.input_handler()
        applied to each DataFrame.
    '''
    output = {}
    for key, df in dfs.items():
        output[key] = mre.input_handler(df)
    return output

def get_neuron_ids_only(dfs):
    """
    Takes a dictionary of pandas dataframes and returns a list of lists,
    where each sublist contains the neuron ids for a unique area/subarea
    combination.
    
    Parameters
    ----------
    dfs : dict
        Dictionary of pandas dataframes.
        
    Returns
    -------
    neuron_ids : list of lists
        List of lists, where each sublist contains the neuron ids for a unique
        area/subarea combination.
    """
    neuron_ids = []
    for area_subarea in dfs.keys():
        df = dfs[area_subarea]
        area, subarea = area_subarea
        neuron_ids.append([area, subarea, df['neuron_id'].tolist()])
    return neuron_ids