In [1]:
timeseries='/tmp/mountainlab-tmp/output_d468ac4041de2645c7ae6f801daea945aabf5b12_timeseries_out.mda'
firings='/tmp/mountainlab-tmp/output_0454d0e46a65c78c5f4b2d9093a06f85530a2ab5_firings_out'
clip_size=32

In [2]:
import numpy as np

from mltools import mdaio
from timeserieschunkreader import TimeseriesChunkReader

processor_name='ephys.compute_templates_spk'
processor_version='0.01'

In [41]:
def compute_templates_spk(*,timeseries,firings,waveforms_out,clip_size=100):
    """
    Compute templates (average waveforms) for clusters defined by the labeled events in firings. One .spk.n file per n-trode.

    Parameters
    ----------
    timeseries : INPUT
        Path of timeseries mda file (MxN) from which to draw the event clips (snippets) for computing the templates. M is number of channels, N is number of timepoints.
    firings : INPUT
        Path of firings mda file (RxL) where R>=3 and L is the number of events. Second row are timestamps, third row are integer labels.
    params : INPUT
        params.json file. Needed to see number of channels per tetrode.
        
    waveforms_out : OUTPUT
        Base Path (MxTxK). T=clip_size, K=maximum cluster label. Note that empty clusters will correspond to a template of all zeros. 
        
    clip_size : int
        (Optional) clip size, aka snippet size, number of timepoints in a single template
    """    
    templates=compute_templates_helper(timeseries=timeseries,firings=firings,clip_size=clip_size)
    return mdaio.writemda32(templates,templates_out)
    
# Same as compute_templates, except return the templates as an array in memory
def compute_templates_helper(*,timeseries,firings,waveforms_out,clip_size=100):
    X=mdaio.DiskReadMda(timeseries)
    M,N = X.N1(),X.N2()
    F=mdaio.readmda(firings)
    L=F.shape[1]
    L=L
    T=clip_size
    Tmid = int(np.floor((T + 1) / 2) - 1);
    whch=F[0,:].ravel()[:]
    times=F[1,:].ravel()[:]
    labels=F[2,:].ravel().astype(int)[:]
    K=np.max(labels)

    tetmap = [x + np.arange(1,5) for x in np.arange(0,M-1,4)]
    which_tet = [np.where(w==tetmap)[0][0]+1 for w in whch]
    
    print("Starting:")
    for tro in np.arange(1,12):
        print("Tetrode: "+str(tro))
        inds_k=np.where(which_tet==tro)[0]
        print("Create Waveforms Array: "+str(M)+","+str(T)+","+str(len(inds_k)))
        waveforms = np.zeros((M,T,len(inds_k)),dtype='int16')
        for i,ind_k in enumerate(inds_k): # for each spike
            t0=int(times[ind_k])
            if (clip_size<=t0) and (t0<N-clip_size):
                clip0=X.readChunk(i1=0,N1=M,i2=t0-Tmid,N2=T)
                clip_int = clip0.astype(dtype='int16')
                waveforms[:,:,i] = clip_int
        print("Writing Waveforms to File.")
        waveforms.tofile(waveforms_out+str(tro), format='')
    
    return True

In [44]:
tetmap = [x + np.arange(1,5) for x in np.arange(0,48-1,4)]
which_tet = [np.where(w==tetmap)[0][0]+1 for w in whch]

In [37]:
np.where([True,False,True])[0]

array([0, 2])

In [35]:
?np.where

In [42]:
whch = compute_templates_helper(timeseries=timeseries,firings=firings,
                         waveforms_out=waveforms_out,clip_size=clip_size)