In [33]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import warnings
import copy
import json

warnings.simplefilter("ignore")
pd.set_option('display.max_rows', 300)

In [34]:
# neuron information for its identification
neuron_id_info = ['date', 'slice id', 'state']

In [35]:
def get_early_recordings(grp_data):
    '''
        For each group, keep only the earliest recording
    '''
    ret_data = pd.DataFrame()
    for _, g in grp_data:
        g.sort_values('time stamp', inplace=True)
        ret_data = ret_data.append( g.iloc[0, :] )
    return ret_data

## 1. Data engineering

In [36]:
# load dataset with measures of firing properties
data = pd.read_csv("preprocessed_data_efeatures.csv")
data = data.merge(pd.read_csv('data_selection.csv')[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

## Current amplitude round off

In [37]:
def round_current(row):
    inc = {'fi':20.0, 'rmih':50.0, 'tburst':50.0}[row['protocol']]
    return round(row['step current']/inc)*inc


# round off for current amplitudes
data['step current'] = data.apply(round_current, axis=1)
data['bias current'] = data['bias current'].apply( lambda x : round(x/10.0)*10.0 )
data['total current'] = data['total current'].apply( lambda x : round(x/20.0)*20.0 )

data['total current'] = data['step current'] + data['bias current']

# 1. F-I curve

In [38]:
# retain only traces selected
data_fi = data.merge(pd.read_csv('data_selection.csv')[['filename', 'file key']], on=['filename', 'file key'], how='inner')

# keep early recordings only
data_fi = get_early_recordings(data_fi.groupby(neuron_id_info + ['step current']))

In [39]:
def ranking(data, increments=np.arange(40, 120, 20), nstep=3, apthresh=4, hyper=False):
    
    # if we deliver hyperpolarizing current check rebound spikes
    ap_count_attr = 'AP_count_after_stim' if hyper else 'AP_count'
    
    data = data[neuron_id_info + ['step current', ap_count_attr]]
    
    # remove traces with less than 4 spikes
    data.drop(data[data[ap_count_attr] < apthresh].index, inplace=True)

    rank = pd.DataFrame()
    df_by_step = pd.DataFrame()
    for cur_inc in increments:
        _data = pd.DataFrame()
        
        # get the fi curve for each neuron
        for k, g in data.groupby(neuron_id_info):
            if hyper:
                cur0 = g['step current'].max() # threshold current
            else:
                cur0 = g['step current'].min() # threshold current

            # get the 3 steps
            __data = pd.DataFrame()
            if nstep > 1:
                for i in range(nstep):
                    __data = pd.concat([__data, g[g['step current'] == (cur0 + i * cur_inc)]])     
            else:
                __data = pd.concat([__data, g[g['step current'] == (cur0 + cur_inc)]])      
            
            # get normalized step current
            __data['step current norm'] = __data['step current'] - cur0
            
            # if it does not have all the steps do not add
            if __data.shape[0] < nstep:
                continue
            elif __data.shape[0] > nstep:
                print('Warning: there are too many traces')
                
            __data['cur step'] = cur_inc
                
            # add the new data
            _data = pd.concat([_data, __data])
        
        df_by_step = pd.concat([df_by_step, _data])
        
        # analyze by type
        for x in _data['state'].unique():
            _rows_sel = _data[_data['state'] == x]
            sz = _rows_sel.shape[0] 

            row = {}
            row['n'] = sz
            row['state'] = x

            row['cur step'] = cur_inc
            rank = rank.append(row, ignore_index=True)                    

    return rank, df_by_step

In [40]:
rank, df_by_step = ranking(data_fi, nstep=3)    
# remove current steps with fewer trials than 1 or no APs
rank = rank[rank['n'] > 1]
for k, g in rank.groupby('state'):
    g = g.sort_values(by=['n', 'cur step'])
    print (k)
    print (g)
    print ()

6ohda
    n  state  cur step
6   9  6ohda       100
4  21  6ohda        80
2  24  6ohda        60
0  33  6ohda        40

control
    n    state  cur step
5  18  control        80
7  18  control       100
3  36  control        60
1  45  control        40



In [41]:
df_by_step = df_by_step[df_by_step['cur step'] == 40] # select steps

In [42]:
# filter the data
cols = neuron_id_info + ['step current'] # columns

# select fi data
data_fi_selection = data_fi[data_fi[cols].isin(df_by_step[cols]).all(axis=1)]

# get normalized step current
for k, r in data_fi_selection.iterrows():
    # get position in the df
    pos = (r[cols] == df_by_step[cols]).all(axis=1)
    
    # set current norm
    data_fi_selection.loc[k, ['step current norm']] = df_by_step.loc[pos, ['step current norm']].to_numpy()[0]

In [43]:
data_fi_selection.shape

(78, 37)

# 2. Sag amplitude

In [44]:
# keep only the trace selection
data_sag = data[data['protocol'] == 'rmih']

# neuron filtering
data_sag = data_sag.merge(data_fi_selection[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

# keep early recordings only
data_sag = get_early_recordings(data_sag.groupby(neuron_id_info + ['step current', 'stim dur']))

# keep only 0 to negative
data_sag = data_sag[data_sag['step current'] <= 0]

In [45]:
rank, df_by_step = ranking(data_sag, increments=[-50, -100, -150, -200], nstep=1, apthresh=2, hyper=True)    
# remove current steps with fewer trials than 1 or no APs
rank = rank[rank['n'] > 1]
for k, g in rank.groupby('state'):
    g = g.sort_values(by=['n', 'cur step'])
    print (k)
    print (g)
    print ()

6ohda
   n  state  cur step
5  5  6ohda      -150
2  7  6ohda      -100
1  8  6ohda       -50

control
    n    state  cur step
4   6  control      -150
3  10  control      -100
0  12  control       -50



In [46]:
df_by_step = df_by_step[df_by_step['cur step'] == -50] # select steps

In [47]:
# filter the data
cols = neuron_id_info + ['step current'] # columns

# select fi data
data_sag_selection = data_sag[data_sag[cols].isin(df_by_step[cols]).all(axis=1)]

# get normalized step current
for k, r in data_sag_selection.iterrows():
    # get position in the df
    pos = (r[cols] == df_by_step[cols]).all(axis=1)
    
    # set current norm
    data_sag_selection.loc[k, ['step current norm']] = df_by_step.loc[pos, ['step current norm']].to_numpy()[0]

In [48]:
data_sag_selection.shape

(20, 37)

In [49]:
df_by_step

Unnamed: 0,date,slice id,state,step current,AP_count_after_stim,step current norm,cur step
80,2018-02-21,Slice1c2,control,-100.0,4,-50.0,-50
94,2018-02-21,Slice1c3,control,-150.0,2,-50.0,-50
101,2018-03-21,Slice1c2,control,-100.0,6,-50.0,-50
32,2018-03-30,Slice1c2,6ohda,-100.0,6,-50.0,-50
43,2018-04-04,Slice1c5,6ohda,-100.0,5,-50.0,-50
144,2018-06-01,Slice1c1,control,-200.0,2,-50.0,-50
175,2018-06-22,Slice1c5,control,-150.0,2,-50.0,-50
181,2018-06-22,Slice2c1,control,-100.0,2,-50.0,-50
183,2018-06-22,Slice3c2,control,-200.0,2,-50.0,-50
188,2018-07-16,Slice2c1,control,-150.0,2,-50.0,-50


## JSON export

In [50]:
## old featurs

efeatures = {
      'fi':[
        'AP_amplitude',
        'AHP_depth',
        'AP_width',
        'AP_count',
        'time_to_first_spike',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'adaptation_index2',
        'voltage_base',
        'voltage_after_stim',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'max_amp_difference',
        'clustering_index',
        'fast_AHP'
      ],
      'rmih':[
        'voltage_base',
        'AP1_amp_rev',
        'AP2_amp_rev',
        'time_to_first_spike',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'voltage_after_stim',
        'sag_amplitude',
        'voltage_deflection',
      ]
    }

In [51]:
json_file_feature = {
    'control':{},
    '6ohda':{}
}

json_file_protocol = {
    'control':{},
    '6ohda':{}
}

In [52]:
def row_to_json(row, thresh=None):
    ret = [ ]
    for col_name in row.columns:
        ret.append(
            {
                "feature":col_name, 
                "val":[row.loc['mean', col_name], row.loc['std', col_name]],
                "n":row.loc['count', col_name],
                "weight":1.0
            }
        )
        
        if thresh:
            ret[-1].update({'threshold':thresh})
    return { "soma.v":ret }

In [53]:
def get_json_output(data, efeature_names, cur_key, thresh=None, key_fmt='', stim_dur_flag=True, stim_amp_flag=True):
    _key_fmt = key_fmt + 'Step'
    if stim_amp_flag:
        _key_fmt += '%i'
    if stim_dur_flag:
        _key_fmt += '_%ims'
        
    json_file_output = {}
    for k, g in data.groupby('state'):
        json_file_output[k] = {}
        
        output = {}
        for kk, gg in g.groupby([cur_key, 'stim dur']):
            gg_descr = gg[efeature_names].describe().loc[['count', 'mean', 'std'], :] 

            _kk = ()
            
            if stim_amp_flag:
                _kk += (kk[0],)
                
            if stim_dur_flag:
                _kk += (kk[1],)

            output.update( {(_key_fmt % _kk):row_to_json(gg_descr, thresh=thresh)} )

        json_file_output[k].update(output)
    return json_file_output

### Features

In [54]:
output = get_json_output(data_fi_selection, efeatures['fi'], 'step current norm', key_fmt='', thresh=-20)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

In [55]:
output = get_json_output(data_sag_selection, efeatures['rmih'], 'step current norm', key_fmt='Sag', thresh=-35)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

### Add the time constants

In [56]:
time_constants = {'6ohda':
                  {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[21.04,13.2],"weight":1}]}},
                   'control':
                   {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[20.69,10.27],"weight":1}]}}
                 }

for k in json_file_feature.keys():
    json_file_feature[k].update(time_constants[k])

### Input resistance

In [57]:
output = get_json_output(data_sag_selection, [ 'input_resistance' ], 'step current', stim_amp_flag=False)
for k in json_file_feature:
    json_file_feature[k].update(output[k])

### AP peaks/amplitudes

In [26]:
for output in json_file_feature.values():
    for entry_name, entry in output.items():
        for entry1 in entry['soma.v']:
            if (entry1['feature'].startswith('AP_amplitude') or \
                entry1['feature'].startswith('AP1_peak') or \
                entry1['feature'].startswith('AP2_peak') or \
                entry1['feature'].startswith('AP1_amp') or \
                entry1['feature'].startswith('AP2_amp')):
                entry1['weight'] = 0.5
            elif entry1['feature'] == 'AP_count_before_stim' or entry_name.startswith('Step') and entry_name.endswith('ms') and entry1['feature'] == 'AP_count_after_stim':
                entry1['val'][1] = 0.01


### AHP depth

In [27]:
for output in json_file_feature.values():
    for entry_name, entry in output.items():
        for entry1 in entry['soma.v']:
            if entry1['feature'].startswith('AHP_depth'):
                entry1['weight'] = 0.5

### Protocols

In [58]:
protocol_template = {
    "Pulse": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 2000, "amp":-1.0, "duration": 0.5, "totduration": 5000}
        }
    },
    "Step_2000ms": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 2000, "amp":-0.01, "duration": 100, "totduration": 5000}
        }
    }
}



In [59]:
def row_to_json_prot(cur_key, stim_amp, stim_dur, stim_hold, prefix='', delay=2000, tpad=1000.0):
    tdur = stim_dur + delay + tpad
    return {(prefix+"Step%i_%ims")%(cur_key, stim_dur): {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": delay, "amp": round(stim_amp/1000.0, 3), "duration": stim_dur, "totduration": tdur},
            "holding": {"amp": round(stim_hold/1000.0,3), "delay": 0, "duration": tdur, "totduration": tdur}
        }}}

In [60]:
def get_json_protocol_output(data, efeature_names, cur_key, prefix='', stim_dur_flag=True, stim_amp_flag=True):        
    json_file_output = {}
    for k, g in data.groupby(['state']):
        json_file_output[k] = {}
        
        hold = g['bias current'].mean()
        
        output = {}
        for kk, gg in g.groupby([cur_key, 'stim dur']):
            output.update(row_to_json_prot(kk[0], gg['step current'].mean(), kk[1], hold, prefix=prefix))
        json_file_output[k].update(output)
    return json_file_output

In [61]:
output = get_json_protocol_output(data_fi_selection, efeatures['fi'], 'step current norm')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
output = get_json_protocol_output(data_sag_selection, efeatures['rmih'], 'step current norm', prefix='Sag')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
for k in json_file_protocol:
    json_file_protocol[k].update(copy.deepcopy(protocol_template))

In [62]:
for k in json_file_feature:
    s = json.dumps(json_file_feature[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_features.json"%k, "w") as fo:
        fo.write(s)
        
        
for k in json_file_protocol:
    print (json_file_protocol[k])
    s = json.dumps(json_file_protocol[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_protocols.json"%k, "w") as fo:
        fo.write(s)


{'Step0_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.152, 'duration': 2000.0, 'totduration': 5000.0}, 'holding': {'amp': 0.065, 'delay': 0, 'duration': 5000.0, 'totduration': 5000.0}}}, 'Step40_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.192, 'duration': 2000.0, 'totduration': 5000.0}, 'holding': {'amp': 0.065, 'delay': 0, 'duration': 5000.0, 'totduration': 5000.0}}}, 'Step80_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': 0.232, 'duration': 2000.0, 'totduration': 5000.0}, 'holding': {'amp': 0.065, 'delay': 0, 'duration': 5000.0, 'totduration': 5000.0}}}, 'SagStep-50_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': -0.133, 'duration': 2000.0, 'totduration': 5000.0}, 'holding': {'amp': 0.003, 'delay': 0, 'duration': 5000.0, 'totduration': 5000.0}}}, 'Pulse': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 2000, 'amp': -1.0, 'duration': 0.5, 'totduration': 