In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import warnings
import copy
import json

warnings.simplefilter("ignore")
pd.set_option('display.max_rows', 300)

In [2]:
# neuron information for its identification
neuron_id_info = ['date', 'slice id', 'state']

In [3]:
def get_early_recordings(grp_data):
    '''
        For each group, keep only the earliest recording
    '''
    ret_data = pd.DataFrame()
    for _, g in grp_data:
        g.sort_values('time stamp', inplace=True)
        ret_data = ret_data.append( g.iloc[0, :] )
    return ret_data

## 1. Data engineering

In [4]:
# load dataset with measures of firing properties
data = pd.read_csv("preprocessed_data_efeatures.csv")
data = data.merge(pd.read_csv('data_selection.csv')[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

## Current amplitude round off

In [5]:
def round_current(row):
    inc = {'fi':20.0, 'rmih':50.0, 'tburst':50.0}[row['protocol']]
    return round(row['step current']/inc)*inc


# round off for current amplitudes
data['step current'] = data.apply(round_current, axis=1)
data['bias current'] = data['bias current'].apply( lambda x : round(x/10.0)*10.0 )

data['total current'] = data['step current'] + data['bias current']

# 1. F-I curve

In [6]:
# retain only traces selected
data_fi = data.merge(pd.read_csv('data_selection.csv')[['filename', 'file key']], on=['filename', 'file key'], how='inner')

# keep early recordings only
data_fi = get_early_recordings(data_fi.groupby(neuron_id_info + ['step current']))

In [7]:
def ranking(data, increments=np.arange(40, 120, 20), steps=np.arange(20, 320, 20), nstep=3):
    data = data[neuron_id_info + ['step current', 'AP_count']]

    rank = pd.DataFrame()

    for cur0 in steps:
        for cur_inc in increments:
            
                # triplettes of traces
                _data = pd.DataFrame()
                
                for i in range(nstep):
                    _data = pd.concat([_data, data[data['step current'] == (cur0 + i * cur_inc)]])

                # remove traces with less than 4 spikes
                _data.drop(_data[_data['AP_count'] < 4].index, inplace=True)
                    
                # remove neurons without all the triplettes
                for k, g in _data.groupby(neuron_id_info):
                    if g['step current'].nunique() < nstep:
                        _data.drop(_data[(_data[neuron_id_info] == k).all(axis=1)].index, inplace=True)
                        
                        
                # analyze by type
                for x in _data['state'].unique():
                    _rows_sel = _data[_data['state'] == x]
                    sz = _rows_sel.shape[0] 
                    
                    row = {}
                    row['n'] = sz
                    row['state'] = x
                    
                    for i in range(nstep):
                        row['cur' + str(i)] = cur0 + i * cur_inc
                        
                    row['cur step'] = cur_inc
                    rank = rank.append(row, ignore_index=True)
                    

    return rank

In [8]:
rank = ranking(data_fi, nstep=3)    
# remove current steps with fewer trials than 1 or no APs
rank = rank[rank['n'] > 1]
for k, g in rank.groupby('state'):
    g = g.sort_values(by=['n', 'cur step'])
    print (k)
    print (g)
    print ()

6ohda
     n  state  cur0  cur1  cur2  cur step
2    3  6ohda    40    80   120        40
85   3  6ohda   260   300   340        40
89   3  6ohda   280   320   360        40
93   3  6ohda   300   340   380        40
4    3  6ohda    40   100   160        60
75   3  6ohda   220   280   340        60
81   3  6ohda   240   300   360        60
87   3  6ohda   260   320   380        60
91   3  6ohda   280   340   400        60
6    3  6ohda    40   120   200        80
53   3  6ohda   160   240   320        80
61   3  6ohda   180   260   340        80
69   3  6ohda   200   280   360        80
77   3  6ohda   220   300   380        80
83   3  6ohda   240   320   400        80
7    3  6ohda    40   140   240       100
14   3  6ohda    60   160   260       100
39   3  6ohda   120   220   320       100
47   3  6ohda   140   240   340       100
55   3  6ohda   160   260   360       100
63   3  6ohda   180   280   380       100
71   3  6ohda   200   300   400       100
8    6  6ohda    60   100   

In [9]:
data_fi_selection = pd.DataFrame()

# select
data_fi_selection = data_fi_selection.append(
    data_fi[((data_fi['step current'] == 120) | (data_fi['step current'] == 180) | (data_fi['step current'] == 240)) & (data_fi['state'] == '6ohda')]
)
# select
data_fi_selection = data_fi_selection.append(
    data_fi[((data_fi['step current'] == 200) | (data_fi['step current'] == 240) | (data_fi['step current'] == 280)) & (data_fi['state'] == 'control')]
)
    
# drop the traces without 4 spikes at least
data_fi_selection.drop(data_fi_selection[data_fi_selection['AP_count'] < 4].index, inplace=True)
        

# remove neurons without all the triplettes
for k, g in data_fi_selection.groupby(neuron_id_info):
    if g['step current'].nunique() < 3:
        data_fi_selection.drop(data_fi_selection[(data_fi_selection[neuron_id_info] == k).all(axis=1)].index, inplace=True)

In [10]:
data_fi_selection.shape

(48, 36)

In [11]:
# neuron filtering
data_fi = data_fi.merge(data_fi_selection[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

In [12]:
descr = data_fi_selection.groupby(neuron_id_info)['bias current'].mean().reset_index().groupby('state').describe()
descr[('bias current', 'se')] = descr[('bias current', 'std')] / np.sqrt(descr[('bias current', 'count')])
descr

Unnamed: 0_level_0,bias current,bias current,bias current,bias current,bias current,bias current,bias current,bias current,bias current
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,se
state,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
6ohda,7.0,67.142857,33.523268,20.0,45.0,70.0,85.0,120.0,12.670604
control,9.0,93.333333,56.124861,50.0,50.0,60.0,100.0,200.0,18.708287


# 2. Sag amplitude

In [13]:
# keep only the trace selection
data_sag = data[data['protocol'] == 'rmih']

# neuron filtering
data_sag = data_sag.merge(data_fi_selection[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

# keep early recordings only
data_sag = get_early_recordings(data_sag.groupby(neuron_id_info + ['step current']))

In [14]:
# for comparison with the paper, display over the same range of current injection:
data_sag_selection = data_sag[data_sag['step current'] == -200]

# 3. Burst 

In [15]:
# keep only the trace selection
data_rebound = data[data['protocol'] == 'tburst']

# neuron filtering
data_rebound = data_rebound.merge(data_fi_selection[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')
 
# keep early recordings only
data_rebound = get_early_recordings(data_rebound.groupby(neuron_id_info + ['step current','stim dur']))

In [16]:
# select data
data_rebound_selection = data_rebound[data_rebound['stim dur'] == 2000]

In [17]:
descr = data_rebound_selection.groupby(neuron_id_info)['bias current'].mean().reset_index().groupby('state').describe()
descr[('bias current', 'se')] = descr[('bias current', 'std')] / np.sqrt(descr[('bias current', 'count')])
descr

Unnamed: 0_level_0,bias current,bias current,bias current,bias current,bias current,bias current,bias current,bias current,bias current
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,se
state,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
6ohda,7.0,68.571429,36.253079,10.0,50.0,70.0,90.0,120.0,13.702376
control,7.0,67.857143,37.177182,40.0,50.0,60.0,62.5,150.0,14.051654


# JSON export

In [18]:
## old featurs

efeatures = {
      'fi':[
        'AP_amplitude',
        'AHP_depth',
        'AP_width',
        'AP_count',
        'time_to_first_spike',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'adaptation_index2',
        'voltage_base',
        'voltage_after_stim',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'max_amp_difference',
        'clustering_index',
        'fast_AHP'
      ],
      'rmih':[
        'voltage_base',
        'AP1_amp_rev',
        'AP2_amp_rev',
        'time_to_first_spike',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'voltage_after_stim',
        'sag_amplitude',
        'voltage_deflection',
      ]
    }

In [19]:
json_file_feature = {
    'control':{},
    '6ohda':{}
}

json_file_protocol = {
    'control':{},
    '6ohda':{}
}

In [20]:
def row_to_json(row, thresh=None):
    ret = [ ]
    for col_name in row.columns:
        ret.append(
            {
                "feature":col_name, 
                "val":[row.loc['mean', col_name], row.loc['std', col_name]],
                "n":row.loc['count', col_name],
                "weight":1.0,
                "t_trace_init":500
            }
        )
        
        if thresh:
            ret[-1].update({'threshold':thresh})
    return { "soma.v":ret }

In [21]:
def get_json_output(data, efeature_names, thresh=None, key_fmt='', stim_dur_flag=True, stim_amp_flag=True):
    _key_fmt = key_fmt + 'Step'
    if stim_amp_flag:
        _key_fmt += '%i'
    if stim_dur_flag:
        _key_fmt += '_%ims'
        
    json_file_output = {}
    for k, g in data.groupby('state'):
        json_file_output[k] = {}
        
        output = {}
        for kk, gg in g.groupby(['step current', 'stim dur']):
            gg_descr = gg[efeature_names].describe().loc[['count', 'mean', 'std'], :] 
            
            _kk = ()
            
            if stim_amp_flag:
                _kk += (kk[0],)
                
            if stim_dur_flag:
                _kk += (kk[1],)

            output.update( {(_key_fmt % _kk):row_to_json(gg_descr, thresh=thresh)} )

        json_file_output[k].update(output)
    return json_file_output

### Features

In [22]:
output = get_json_output(data_fi_selection, efeatures['fi'], key_fmt='', thresh=-20)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

In [23]:
output = get_json_output(data_sag_selection, efeatures['rmih'], key_fmt='Sag', thresh=-35)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

### Add the time constants

In [24]:
time_constants = {'6ohda':
                  {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[21.04,13.2],"weight":1, "t_trace_init":500 }]}},
                   'control':
                   {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[20.69,10.27],"weight":1, "t_trace_init":500}]}}
                 }

for k in json_file_feature.keys():
    json_file_feature[k].update(time_constants[k])

### Input resistance

In [25]:
output = get_json_output(data_sag_selection, [ 'input_resistance' ], stim_amp_flag=False)
for k in json_file_feature:
    json_file_feature[k].update(output[k])

### AP peaks/amplitudes

In [26]:
for output in json_file_feature.values():
    for entry_name, entry in output.items():
        for entry1 in entry['soma.v']:
            #if (entry1['feature'].startswith('AP_amplitude') or \
            #    entry1['feature'].startswith('AP1_peak') or \
            #    entry1['feature'].startswith('AP2_peak') or \
            #    entry1['feature'].startswith('AP1_amp') or \
            #    entry1['feature'].startswith('AP2_amp')):
            #    entry1['weight'] = 0.5
            #el
            if entry1['feature'] == 'AP_count_before_stim' or entry_name.startswith('Step') and entry_name.endswith('ms') and entry1['feature'] == 'AP_count_after_stim':
                entry1['val'][1] = 0.01


### AHP depth

In [27]:
#for output in json_file_feature.values():
#    for entry_name, entry in output.items():
#        for entry1 in entry['soma.v']:
#            if entry1['feature'].startswith('AHP_depth'):
#                entry1['weight'] = 2

### Protocols

In [28]:
protocol_template = {
    "Pulse": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 2000, "amp":-1.0, "duration": 0.5, "totduration": 4800}
        }
    },
    "Step_2000ms": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 2000, "amp":-0.01, "duration": 100, "totduration": 4800}
        }
    }
}



In [29]:
def row_to_json_prot(cur_key, stim_amp, stim_dur, stim_hold, prefix='', delay=2000, tpad=800.0):
    tdur = stim_dur + delay + tpad
    return {(prefix+"Step%i_%ims")%(cur_key, stim_dur): {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": delay, "amp": round(stim_amp/1000.0, 3), "duration": stim_dur, "totduration": tdur},
            "holding": {"amp": round(stim_hold/1000.0,3), "delay": 0, "duration": tdur, "totduration": tdur}
        }}}

In [30]:
def get_json_protocol_output(data, efeature_names, cur_key, prefix='', stim_dur_flag=True, stim_amp_flag=True):        
    json_file_output = {}
    for k, g in data.groupby(['state']):
        json_file_output[k] = {}
        
        hold = g['bias current'].mean()
        
        output = {}
        for kk, gg in g.groupby([cur_key, 'stim dur']):
            output.update(row_to_json_prot(kk[0], gg['step current'].mean(), kk[1], hold, prefix=prefix))
        json_file_output[k].update(output)
    return json_file_output

In [31]:
output = get_json_protocol_output(data_fi_selection, efeatures['fi'], 'step current')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
output = get_json_protocol_output(data_sag_selection, efeatures['rmih'], 'step current', prefix='Sag')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
for k in json_file_protocol:
    json_file_protocol[k].update(copy.deepcopy(protocol_template))

In [32]:
for k in json_file_feature:
    s = json.dumps(json_file_feature[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_features.json"%k, "w") as fo:
        fo.write(s)
        
        
for k in json_file_protocol:
    print (json_file_protocol[k])
    s = json.dumps(json_file_protocol[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_protocols.json"%k, "w") as fo:
        fo.write(s)


{'Step200_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.2, 'duration': 2000.0, 'totduration': 4800.0}, 'holding': {'amp': 0.093, 'delay': 0, 'duration': 4800.0, 'totduration': 4800.0}}}, 'Step240_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.24, 'duration': 2000.0, 'totduration': 4800.0}, 'holding': {'amp': 0.093, 'delay': 0, 'duration': 4800.0, 'totduration': 4800.0}}}, 'Step280_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.28, 'duration': 2000.0, 'totduration': 4800.0}, 'holding': {'amp': 0.093, 'delay': 0, 'duration': 4800.0, 'totduration': 4800.0}}}, 'SagStep-200_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': -0.2, 'duration': 2000.0, 'totduration': 4800.0}, 'holding': {'amp': 0.0, 'delay': 0, 'duration': 4800.0, 'totduration': 4800.0}}}, 'Pulse': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': -1.0, 'duration': 0.5, 'totduration': 480

# Selected dataset export

In [33]:
pd.concat([data_fi, data_sag, data_rebound]).to_csv('preprocessed_data_efeatures_selection.csv')