In [91]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import warnings
import copy
import json

warnings.simplefilter("ignore")
pd.set_option('display.max_rows', 300)

In [92]:
# neuron information for its identification
neuron_id_info = ['date', 'slice id', 'state']

In [93]:
def get_early_recordings(grp_data):
    '''
        For each group, keep only the earliest recording
    '''
    ret_data = pd.DataFrame()
    for _, g in grp_data:
        g.sort_values('time stamp', inplace=True)
        ret_data = ret_data.append( g.iloc[0, :] )
    return ret_data

## 1. Data engineering

In [94]:
# load dataset with measures of firing properties
data = pd.read_csv("preprocessed_data_efeatures.csv")
data = data.merge(pd.read_csv('data_selection.csv')[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

## Current amplitude round off

In [95]:
def round_current(row):
    inc = {'fi':20.0, 'rmih':50.0, 'tburst':50.0}[row['protocol']]
    return round(row['step current']/inc)*inc


# round off for current amplitudes
#data['step current'] = data.apply(round_current, axis=1)
#data['bias current'] = data['bias current'].apply( lambda x : round(x/10.0)*10.0 )
#data['total current'] = data['total current'].apply( lambda x : round(x/20.0)*20.0 )

data['total current'] = data['step current'] + data['bias current']

# 1. F-I curve

In [96]:
# retain only traces selected
data_fi = data.merge(pd.read_csv('data_selection.csv')[['filename', 'file key']], on=['filename', 'file key'], how='inner')

# keep early recordings only
data_fi = get_early_recordings(data_fi.groupby(neuron_id_info + ['step current']))

In [97]:
def ranking(data, increments=np.arange(0, 300, 5), nstep=3, apthresh=4, hyper=False, perc_round=5):
    
    # if we deliver hyperpolarizing current check rebound spikes
    ap_count_attr = 'AP_count_after_stim' if hyper else 'AP_count'
    
    data = data[neuron_id_info + ['step current', ap_count_attr]]
    
    # remove traces with less than 4 spikes
    data.drop(data[data[ap_count_attr] < apthresh].index, inplace=True)
    
    # remake the dataframe
    _data = pd.DataFrame()    

    # get the fi curve for each neuron
    for k, g in data.groupby(neuron_id_info):
        if hyper:
            cur0 = g['step current'].max() # threshold current
        else:
            cur0 = g['step current'].min() # threshold current

        # get normalized step current
        g['step current norm'] = g['step current'].apply(lambda x: round(100 * x / cur0 / perc_round) * perc_round)

        if g['step current norm'].nunique() < g.shape[0]:
            print ("Warning! Colliding current steps.")

        _data = pd.concat([_data, g])

    data = _data
    del _data
    

    rank = pd.DataFrame()
    df_by_step = pd.DataFrame()

    # base increments
    if nstep > 1:
        base_increments = 100 + increments
    else:
        base_increments = [100]
    
    for cur0 in base_increments:
        for cur_inc in increments:
            
                # triplettes of traces
                _data = pd.DataFrame()
                
                if nstep > 1:
                    for i in range(nstep):
                        _data = pd.concat([_data, data[data['step current norm'] == (cur0 + i * cur_inc)]])
                else:
                    _data = pd.concat([_data, data[data['step current norm'] == (cur0 + cur_inc)]])
                    
                # remove neurons without all the triplettes
                for k, g in _data.groupby(neuron_id_info):
                    if g['step current norm'].nunique() < nstep:
                        _data.drop(_data[(_data[neuron_id_info] == k).all(axis=1)].index, inplace=True)
                        
                # add cur step and base
                _data['cur step'] = cur_inc
                _data['cur base'] = cur0
                    
                if _data.shape[0] > 0:
                    df_by_step = pd.concat([df_by_step, _data])                        

                    # analyze by type
                    for x in _data['state'].unique():
                        _rows_sel = _data[_data['state'] == x]
                        sz = _rows_sel.shape[0] 

                        row = {}
                        row['n'] = sz
                        row['state'] = x

                        for i in range(nstep):
                            row['cur' + str(i)] = cur0 + i * cur_inc

                        row['cur step'] = cur_inc
                        rank = rank.append(row, ignore_index=True)
                    

    return rank, df_by_step

In [98]:
rank, df_by_step = ranking(data_fi, nstep=3)    
# remove current steps with fewer trials than 1 or no APs
rank = rank[rank['n'] > 1]
for k, g in rank.groupby('state'):
    g = g.sort_values(by=['n', 'cur step'])
    print (k)
    print (g)
    print ()

6ohda
      n  state  cur0  cur1  cur2  cur step
2     3  6ohda   100   110   120        10
33    3  6ohda   110   120   130        10
46    3  6ohda   120   130   140        10
71    3  6ohda   130   140   150        10
82    3  6ohda   140   150   160        10
75    3  6ohda   135   150   165        15
141   3  6ohda   185   200   215        15
36    3  6ohda   110   130   150        20
42    3  6ohda   115   135   155        20
111   3  6ohda   160   180   200        20
137   3  6ohda   180   200   220        20
144   3  6ohda   200   220   240        20
163   3  6ohda   220   240   260        20
222   3  6ohda   350   375   400        25
232   3  6ohda   375   400   425        25
9     3  6ohda   100   130   160        30
43    3  6ohda   115   150   185        35
50    3  6ohda   120   155   190        35
194   3  6ohda   265   300   335        35
52    3  6ohda   120   160   200        40
87    3  6ohda   140   180   220        40
113   3  6ohda   160   200   240        40
139  

In [99]:
df_by_step = df_by_step[(df_by_step['cur base'] == 120) &  (df_by_step['cur step'] == 20) & (df_by_step['state'] == 'control') | (df_by_step['cur base'] == 150) & (df_by_step['cur step'] == 100) & (df_by_step['state'] == '6ohda')] # select steps

In [100]:
# filter the data
cols = neuron_id_info + ['step current'] # columns

# select fi data
data_fi_selection = data_fi[data_fi[cols].isin(df_by_step[cols]).all(axis=1)]

# get normalized step current
for k, r in data_fi_selection.iterrows():
    # get position in the df
    pos = (r[cols] == df_by_step[cols]).all(axis=1)
    
    # set current norm
    data_fi_selection.loc[k, ['step current norm']] = df_by_step.loc[pos, ['step current norm']].to_numpy()[0]

In [101]:
data_fi_selection.shape

(30, 37)

# 2. Sag amplitude

In [102]:
# keep only the trace selection
data_sag = data[data['protocol'] == 'rmih']

# get round step current
data_sag['step current round'] = data_sag.apply(round_current, axis=1)

# neuron filtering
data_sag = data_sag.merge(data_fi_selection[neuron_id_info].drop_duplicates(), on=neuron_id_info, how='inner')

# keep early recordings only
data_sag = get_early_recordings(data_sag.groupby(neuron_id_info + ['step current round', 'stim dur']))

# keep only 0 to negative
data_sag = data_sag[data_sag['step current'] <= 0]

In [103]:
rank, df_by_step = ranking(data_sag, nstep=1, apthresh=2, hyper=True, perc_round=5)    
# remove current steps with fewer trials than 1 or no APs
rank = rank[rank['n'] > 1]
for k, g in rank.groupby('state'):
    g = g.sort_values(by=['n', 'cur step'])
    print (k)
    print (g)
    print ()

6ohda
   n  state  cur0  cur step
9  2  6ohda   100       295
0  3  6ohda   100         0
4  3  6ohda   100       100
7  3  6ohda   100       200

control
    n    state  cur0  cur step
10  2  control   100       295
3   4  control   100        50
5   6  control   100       100
1   7  control   100         0



In [104]:
df_by_step = df_by_step[df_by_step['cur step'] == 0] # select steps

In [105]:
df_by_step.shape

(10, 8)

In [106]:
# filter the data
cols = neuron_id_info + ['step current'] # columns

# select fi data
data_sag_selection = data_sag[data_sag[cols].isin(df_by_step[cols]).all(axis=1)]

# get normalized step current
for k, r in data_sag_selection.iterrows():
    # get position in the df
    pos = (r[cols] == df_by_step[cols]).all(axis=1)
    
    # set current norm
    data_sag_selection.loc[k, ['step current norm']] = df_by_step.loc[pos, ['step current norm']].to_numpy()[0]

In [107]:
data_sag_selection

Unnamed: 0.1,Unnamed: 0,AP_amplitude,AHP_depth,AP_duration_half_width,AP_width,AP_count,time_to_first_spike,inv_first_ISI,inv_second_ISI,inv_last_ISI,...,sag_amplitude,voltage_deflection,input_resistance,AP1_amp,AP2_amp,state,slice id,date,step current round,step current norm
3,328,,,,,,2029.5,227.272727,181.818182,111.111111,...,10.643735,-13.403638,280.001297,,,6ohda,Slice1c2,2018-01-05,-100.0,100.0
43,1090,,,,,,2032.1,243.902439,0.0,243.902439,...,10.833482,-9.664112,246.237115,,,control,Slice1c3,2018-02-21,-100.0,100.0
23,420,,,,,,2053.3,243.902439,196.078431,86.956522,...,3.198013,-6.288712,192.19012,,,6ohda,Slice1c2,2018-03-30,-50.0,100.0
34,464,,,,,,2033.7,204.081633,149.253731,114.942529,...,7.048215,-9.472362,291.52769,,,6ohda,Slice1c5,2018-04-04,-50.0,100.0
59,1445,,,,,,2042.7,153.846154,0.0,153.846154,...,1.974889,-18.353629,124.006942,,,control,Slice1c5,2018-06-22,-100.0,100.0
65,1504,,,,,,2046.7,158.730159,111.111111,111.111111,...,4.324457,-8.95598,234.664776,,,control,Slice2c1,2018-06-22,-50.0,100.0
67,1577,,,,,,2111.0,133.333333,0.0,133.333333,...,3.2699,-29.944758,284.388807,,,control,Slice3c2,2018-06-22,-150.0,100.0
72,1646,,,,,,2028.4,181.818182,0.0,181.818182,...,6.40865,-9.122168,142.967741,,,control,Slice2c1,2018-07-16,-100.0,100.0
80,1824,,,,,,2198.7,172.413793,111.111111,111.111111,...,1.57433,-18.612964,247.969661,,,control,Slice2c1,2019-04-09,-100.0,100.0
85,2231,,,,,,2077.2,163.934426,107.526882,107.526882,...,3.054428,-13.654931,295.284562,,,control,Slice1c1,2019-05-09,-50.0,100.0


## JSON export

In [108]:
## old featurs

efeatures = {
      'fi':[
        'AP_amplitude',
        'AHP_depth',
        'AP_width',
        'AP_count',
        'time_to_first_spike',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'adaptation_index2',
        'voltage_base',
        'voltage_after_stim',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'max_amp_difference',
        'clustering_index',
        'fast_AHP'
      ],
      'rmih':[
        'voltage_base',
        'AP1_amp_rev',
        'AP2_amp_rev',
        'time_to_first_spike',
        'AP_count_before_stim',
        'AP_count_after_stim',
        'inv_first_ISI',
        'inv_second_ISI',
        'inv_last_ISI',
        'voltage_after_stim',
        'sag_amplitude',
        'voltage_deflection',
      ]
    }

In [109]:
json_file_feature = {
    'control':{},
    '6ohda':{}
}

json_file_protocol = {
    'control':{},
    '6ohda':{}
}

In [110]:
def row_to_json(row, thresh=None):
    ret = [ ]
    for col_name in row.columns:
        ret.append(
            {
                "feature":col_name, 
                "val":[row.loc['mean', col_name], row.loc['std', col_name]],
                "n":row.loc['count', col_name],
                "weight":1.0
            }
        )
        
        if thresh:
            ret[-1].update({'threshold':thresh})
    return { "soma.v":ret }

In [111]:
def get_json_output(data, efeature_names, cur_key, thresh=None, key_fmt='', stim_dur_flag=True, stim_amp_flag=True):
    _key_fmt = key_fmt + 'Step'
    if stim_amp_flag:
        _key_fmt += '%i'
    if stim_dur_flag:
        _key_fmt += '_%ims'
        
    json_file_output = {}
    for k, g in data.groupby('state'):
        json_file_output[k] = {}
        
        output = {}
        for kk, gg in g.groupby([cur_key, 'stim dur']):
            gg_descr = gg[efeature_names].describe().loc[['count', 'mean', 'std'], :] 

            _kk = ()
            
            if stim_amp_flag:
                _kk += (kk[0],)
                
            if stim_dur_flag:
                _kk += (kk[1],)

            output.update( {(_key_fmt % _kk):row_to_json(gg_descr, thresh=thresh)} )

        json_file_output[k].update(output)
    return json_file_output

### Features

In [112]:
output = get_json_output(data_fi_selection, efeatures['fi'], 'step current norm', key_fmt='', thresh=-20)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

In [113]:
output = get_json_output(data_sag_selection, efeatures['rmih'], 'step current norm', key_fmt='Sag', thresh=-35)

for k in json_file_feature:
    json_file_feature[k].update(output[k])

### Add the time constants

In [114]:
time_constants = {'6ohda':
                  {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[21.04,13.2],"weight":1}]}},
                   'control':
                   {"Pulse":{"soma.v":[{"feature":"decay_time_constant_after_stim2","val":[20.69,10.27],"weight":1}]}}
                 }

for k in json_file_feature.keys():
    json_file_feature[k].update(time_constants[k])

### Input resistance

In [115]:
output = get_json_output(data_sag_selection, [ 'input_resistance' ], 'step current norm', stim_amp_flag=False)
for k in json_file_feature:
    json_file_feature[k].update(output[k])

### AP peaks/amplitudes

In [116]:
for output in json_file_feature.values():
    for entry_name, entry in output.items():
        for entry1 in entry['soma.v']:
            if (entry1['feature'].startswith('AP_amplitude') or \
                entry1['feature'].startswith('AP1_peak') or \
                entry1['feature'].startswith('AP2_peak') or \
                entry1['feature'].startswith('AP1_amp') or \
                entry1['feature'].startswith('AP2_amp')):
                entry1['weight'] = 0.75
            elif entry1['feature'] == 'AP_count_before_stim' or entry_name.startswith('Step') and entry_name.endswith('ms') and entry1['feature'] == 'AP_count_after_stim':
                entry1['val'][1] = 0.01


### AHP depth

In [117]:
for output in json_file_feature.values():
    for entry_name, entry in output.items():
        for entry1 in entry['soma.v']:
            if entry1['feature'].startswith('AHP_depth'):
                entry1['weight'] = 2
            elif entry1['feature'].startswith('fast_AHP'):
                entry1['weight'] = 2

### Protocols

In [118]:
protocol_template = {
    "Pulse": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 5000, "amp":-1.0, "duration": 0.5, "totduration": 8000}
        }
    },
    "Step_2000ms": {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": 5000, "amp":-0.01, "duration": 100, "totduration": 8000}
        }
    }
}



In [119]:
def row_to_json_prot(cur_key, stim_amp, stim_dur, stim_hold, prefix='', delay=5000, tpad=1000.0):
    tdur = stim_dur + delay + tpad
    return {(prefix+"Step%i_%ims")%(cur_key, stim_dur): {
        "type": "StepProtocol", 
        "stimuli": {
            "step": {"delay": delay, "amp": round(stim_amp/1000.0, 3), "duration": stim_dur, "totduration": tdur},
            "holding": {"amp": round(stim_hold/1000.0,3), "delay": 0, "duration": tdur, "totduration": tdur}
        }}}

In [120]:
def get_json_protocol_output(data, efeature_names, cur_key, prefix='', stim_dur_flag=True, stim_amp_flag=True):        
    json_file_output = {}
    for k, g in data.groupby(['state']):
        json_file_output[k] = {}
        
        hold = g['bias current'].mean()
        
        output = {}
        for kk, gg in g.groupby([cur_key, 'stim dur']):
            output.update(row_to_json_prot(kk[0], gg['step current'].mean(), kk[1], hold, prefix=prefix))
        json_file_output[k].update(output)
    return json_file_output

In [121]:
output = get_json_protocol_output(data_fi_selection, efeatures['fi'], 'step current norm')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
output = get_json_protocol_output(data_sag_selection, efeatures['rmih'], 'step current norm', prefix='Sag')

for k in json_file_protocol:
    json_file_protocol[k].update(output[k])
    
    
for k in json_file_protocol:
    json_file_protocol[k].update(copy.deepcopy(protocol_template))

In [122]:
for k in json_file_feature:
    s = json.dumps(json_file_feature[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_features.json"%k, "w") as fo:
        fo.write(s)
        
        
for k in json_file_protocol:
    print (json_file_protocol[k])
    s = json.dumps(json_file_protocol[k], indent=4)
    for i in range(20, 12, -4):
        s = s.replace('\n' + (' ' * i), ' ')
    for i in range(20, 8, -4):
        s = s.replace('\n' + (' ' * i) + '}', ' }')
    with open("%s_protocols.json"%k, "w") as fo:
        fo.write(s)


{'Step120_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 5000, 'amp': 0.191, 'duration': 2000.0, 'totduration': 8000.0}, 'holding': {'amp': 0.074, 'delay': 0, 'duration': 8000.0, 'totduration': 8000.0}}}, 'Step140_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 5000, 'amp': 0.223, 'duration': 2000.0, 'totduration': 8000.0}, 'holding': {'amp': 0.074, 'delay': 0, 'duration': 8000.0, 'totduration': 8000.0}}}, 'Step160_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 5000, 'amp': 0.255, 'duration': 2000.0, 'totduration': 8000.0}, 'holding': {'amp': 0.074, 'delay': 0, 'duration': 8000.0, 'totduration': 8000.0}}}, 'SagStep100_2000ms': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 5000, 'amp': -0.095, 'duration': 2000.0, 'totduration': 8000.0}, 'holding': {'amp': 0.005, 'delay': 0, 'duration': 8000.0, 'totduration': 8000.0}}}, 'Pulse': {'type': 'StepProtocol', 'stimuli': {'step': {'delay': 5000, 'amp': -1.0, 'duration': 0.5, 'totduratio