In [82]:
import numpy as np
import pandas as pd


def raster(event_name,cellid,event_times,rasterfs,PreEndTime=200,PostBegTime=500,split=False):
    # spike_dict : spikes of a hole recording
    # unitidx : unit's spikes we want to look at
    # stim_dict : stimuli, used for the names of the events
    # eventidx : idx of event we want to look at
    # rasterfs : spike frequency
    
    binlen=1.0/rasterfs
    h=np.array([])
    ff = (event_times['name']==event_name)
    ## pull out each epoch from the spike times, generate a raster of spike rate
    halfNb = int(event_times.loc[ff].shape[0]/2)        
    
    m = np.empty([2,PostBegTime-PreEndTime])
    
    for idx,(i,d) in enumerate(event_times.loc[ff].iterrows()):
        edges=np.arange(d['start']+PreEndTime/rasterfs,d['start']+PostBegTime/rasterfs+binlen,binlen)
        th,e=np.histogram(spike_dict[cellid],edges)
        th=np.reshape(th,[1,-1])
        if h.size==0:
            # lazy hack: intialize the raster matrix without knowing how many bins it will require
            h=th
        else:
            # concatenate this repetition, making sure binned length matches
            if th.shape[1]<h.shape[1]:
                h=np.concatenate((h,np.zeros([1,h.shape[1]])),axis=0)
                h[-1,:]=np.nan
                h[-1,:th.shape[1]]=th
            else:
                h=np.concatenate((h,th[:,:h.shape[1]]),axis=0)
        if idx == halfNb-1 and split==True:
            m[0,:] = np.nanmean(h,axis=0)[0:m.shape[1]]
            h=np.array([])

    if split==True:
        m[1,:] = np.nanmean(h,axis=0)[0:m.shape[1]]
    else:
        m = np.nanmean(h,axis=0)[0:m.shape[1]]

    return h,m

def getTrainTestTimes(event_times,trainNb,testNb):
    # event_times : timings of events 
    # trainNb : number of stimuli presented for the trains
    # testNb : number of stimuli presented for the tests
    
    wavEvents = event_times[event_times['name'].str.contains('.wav')]
    occurences =  wavEvents['name'].value_counts(sort=True)

    Train_names = list(occurences[occurences==trainNb].index)
    Test_names = list(occurences[occurences==testNb].index)
    if Train_names == [] or Test_names == [] :
        raise ValueError('wrong trainNb or testNb')        
    
    Train_times = pd.DataFrame(columns={'name','start','end'})
    Train_times = Train_times[['name','start','end']] #Order the columns
    Test_times = Train_times.copy()

    #Get stimuli onset and offset times for trains
    trial_indexs = event_times['name'][event_times['name']=='TRIAL'].index
    idx1 = 0; idx2 = 0;
    for trial_idx in trial_indexs:
        name = event_times.iloc[trial_idx+1]['name']
        if name in Train_names :
            Train_times.at[idx1,'name'] = name
            Train_times.at[idx1,'start'] = event_times.iloc[trial_idx+3]['end']
            Train_times.at[idx1,'end'] = event_times.iloc[trial_idx+4]['start']
            idx1 +=1
        elif name in Test_names :
            Test_times.at[idx2,'name'] = name
            Test_times.at[idx2,'start'] = event_times.iloc[trial_idx+3]['end']
            Test_times.at[idx2,'end'] = event_times.iloc[trial_idx+4]['start']
            idx2 +=1
        else : 
            raise ValueError('Neither a Test nor a Train stimuli name')

    #Train_times = Train_times.sort_values('name')
    #Test_times = Test_times.sort_values('name')

    return Train_times,Test_times

In [268]:
def getInsOuts(cellidx,event_times,fs_spectro,stim_dict,boolFigure=False):
    print('Fetching ins & outs...\n')
    # fix random seed for reproducibility
    np.random.seed(7)

    # Segregate Train and Tests
    Train_times,Test_times = getTrainTestTimes(event_times,3,24)

    # Compute on a specific cell --> TODO : for all
    cellid = list(spike_dict.keys())[cellidx] 

    #_________________Training input (X)_______________________#
    stim_shape = np.shape(stim_dict[list(stim_dict.keys())[0]])
    nbTrains = len(set(Train_times['name']))

    # sound_time : from end of prestimsilence and beg of poststimsilence
    PreStimidx = list(event_times['name']).index('PreStimSilence')
    Endidx = event_times.columns.get_loc('end')
    PostBegTime = int(event_times.iloc[PreStimidx+1,Endidx-1]*fs_spectro)
    PreEndTime = int(event_times.iloc[PreStimidx,Endidx]*fs_spectro)
    sound_time = PostBegTime - PreEndTime

    X = np.zeros( (nbTrains,sound_time,stim_shape[0]) )
    for idx,event_name in enumerate(set(Train_times['name'])):
        X[idx,:,:] = np.transpose(stim_dict[event_name][:,PreEndTime:PostBegTime]) 
    X = X/X.max()

    #_________________Training output (Y)_______________________#
    Y = np.zeros( (nbTrains,sound_time,1) )
    for idx,event_name in enumerate(set(Train_times['name'])):
        h,m = raster(event_name,cellid,event_times,options['rasterfs'],PreEndTime,PostBegTime)
        Y[idx,:,0] = np.transpose(m[0:np.size(Y,1)])
    Y = Y/Y.max()

    #_________________TEST input (W)_______________________#
    nbTests = len(set(Test_times['name']))
    W = np.zeros( (nbTests,sound_time,stim_shape[0]) )
    for idx,event_name in enumerate(set(Test_times['name'])):
        W[idx,:,:] = np.transpose(stim_dict[event_name][:,PreEndTime:PostBegTime]) 
    W = W/W.max()

    #_________________Test output (Z)_______________________#
    Z = np.zeros( (nbTests,sound_time,2) )
    t_12 = np.zeros(3)
    for idx,event_name in enumerate(set(Test_times['name'])):
        h,m = raster(event_name,cellid,event_times,options['rasterfs'],PreEndTime,PostBegTime,split=True)
        Z[idx,:,:] = np.swapaxes(m,0,1)
        if boolFigure :
            plt.figure()
            plt.plot(Z[idx,:,0])
            plt.plot(Z[idx,:,1])
            plt.title('Split spike rates from sound{}'.format(idx))
        t_12[idx] = np.corrcoef(Z[idx,:,0],Z[idx,:,1])[0,1]
        print('Correlation between sound{}\'s split spike rates: {}'.format(idx,t_12[idx]))

    return X,Y,W,Z,t_12

In [301]:
def mk_and_fit_conv1Dmodel(X,Y,epochs,batch_size,time_window,kernel_init,lr,activation,loss,optim,validation_split,early_stop=False):
    #### KERAS 1D CONV MODEL
    #import os
    #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    #os.environ["CUDA_VISIBLE_DEVICES"] = ""

    #import tensorflow as tf
    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.01)
    #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    # Create your first MLP in Keras
    from keras.models import Sequential
    from keras.layers import Conv1D,Dense
    from keras.layers.advanced_activations import LeakyReLU
    from keras.constraints import non_neg
    import keras.initializers
    import keras.optimizers

    # create model
    model = Sequential()
    layer = Conv1D(input_shape=np.shape(X)[1:3],filters=1,kernel_size=time_window,strides=1,
                padding='causal',activation='relu',dilation_rate=1,use_bias=True,
                bias_initializer='random_uniform')

    layer.kernel_initializer = keras.initializers.RandomUniform(minval=kernel_init[0], maxval=kernel_init[1], seed=None)
    model.add(layer)

    # Compile model
    sgd = getattr(keras.optimizers,optim)(lr=lr)
    model.compile(loss = loss, optimizer=sgd)

    # Fit the model
    early_cbk = []
    if early_stop:
        early_cbk = [keras.callbacks.EarlyStopping(patience = 20,verbose=1)]

    start_time = time.time()
    history = model.fit(X,Y,validation_split = validation_split,epochs=epochs, batch_size=batch_size, verbose=0)#, callbacks = early_cbk)
    print('Elapsed fitting time : {}'.format(time.time() - start_time))

    return model,history

In [350]:
#Compare out and predicted
def prediction_score_and_plots(W,predicted,Y,Z,onTest=True,fig = True):
    import random as rand
    if fig:
        plt.plot(history.history['loss'])
        if 'val_loss' in history.history:
            plt.plot(history.history['val_loss'])
        #plt.plot(history_early.history['val_loss'],'--')
        plt.legend(('loss','loss with validation','validation_loss'))
        plt.title('loss over training')
        plt.figure

    score = []
    if onTest:
        for idx in range(3):
            example = idx
            if fig :
                plt.figure()
                plt.plot(Z[example,:,0])
                plt.plot(Z[example,:,1])
                plt.plot(predicted[example,:,0])
                plt.title("individual PSTH of {}th cell from sound {}".format(cellidx,example))
                plt.legend(('output1','output2','prediction','prediction_early_stop'))        
            c1 = np.corrcoef(Z[example,:,0],predicted[example,:,0])[0,1]
            c2 = np.corrcoef(Z[example,:,1],predicted[example,:,0])[0,1]
            score.append( (c1**2/2 + c2**2/2)/t_12[example] )
            print('Explained Score for sound {}: {}'.format(example,score[example]))
    else :
        for idx in range(5):
            example = rand.randint(0,Y.shape[0]-1)   
            if fig:
                plt.figure()
                plt.plot(Y[example,:,0])
                plt.plot(predicted[example,:,0])
                plt.title("individual PSTH of {}th cell from sound {}".format(cellidx,example))
                plt.legend(('output','prediction'))

    #Plot STRF
    dirac_spec = np.concatenate((np.ones((X.shape[2],1)),np.zeros((X.shape[2],time_window-1))),axis=1)
    weights = model.get_weights()[0].squeeze().transpose()

    STRF = np.zeros(weights.shape)
    for idx in range(weights.shape[0]):
        conv = np.convolve(dirac_spec[idx],weights[idx],mode='full')
        STRF[idx][:] = conv[0:weights.shape[1]]

    if fig:
        plt.figure()
        plt.imshow(STRF)
    
    return score

In [352]:
import os, io, re, scipy.io, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nems.utilities as nu
import nems.db as nd
import nems.utilities.baphy

parmfilepath='/auto/data/daq/Tartufo/TAR010/TAR010c16_p_NAT.m'

options={'rasterfs': 100, 'includeprestim': True, 'stimfmt': 'ozgf', 'chancount': 18, 'cellid': 'all', 'pupil': True}
event_times, spike_dict, stim_dict, state_dict = nu.baphy.baphy_load_recording(parmfilepath,options)

X,Y,W,Z,t_12 = getInsOuts(21,event_times,options['rasterfs'],stim_dict,False)

model,history = mk_and_fit_conv1Dmodel(X,Y,1000,40,10,[0,0.001],0.005,'relu','poisson','SGD',0.2,False)
predicted = model.predict(W)

score = prediction_score_and_plots(W,predicted,Y,Z,onTest=True,fig = True)

save = False
if save:
    # SAVE MODEL
    config = model.get_config()[0]
    print(config)
    save_name = config['class_name'] + '_'  + config['config']['activation'] + '_compiler-' + \
        loss + '-' + str(lr) + '_'+ '_ker-' + \
        str(config['config']['kernel_initializer']['config']['minval']) + '-' + \
        str(config['config']['kernel_initializer']['config']['maxval']) + '_' + \
        'epochs-' + str(epochs) + '_batch-' + str(batch_size) 

    model.save('STRF_computation/models_trained/' + save_name)

Cached stim: /auto/data/tmp/tstim/NaturalSounds-2-0.5-3-1-White______-100-0-3__8-65dB-ozgf-fs100-ch18-incps1.mat
Spike file: /auto/data/daq/Tartufo/TAR010/sorted/TAR010c16_p_NAT.spk.mat
rounding Trial offset spike times to even number of rasterfs bins
342 trials totaling 2076.12 sec
Creating trial events
Creating trial outcome events
Removing post-response stimuli
Keeping 2394/2394 events that precede responses


In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

#import tensorflow as tf
#sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))