In [1]:
import phy
import numpy as np
import pandas as pd
import os
from phy import session
import xmltodict

## Loading raw spike data

#### Paths: same as behavior

#### Acquisition parameters (integers or floats):
  - **amplification**: voltage amplification
  - **nbits**: floats in .dat file are stored on nbits
  - **nChannels**: total number of channels, needed to read .dat file
  - **offset**: always 0
  - **spikeSamplingRate**: number of acquisition per seconds for the spikes   
    other names: SampleRate, sample_rate   
  - **voltageRange**
  
#### Other
  - **channelGroupList**: dictionary of channel groups (shank), ex: {1: [0,1,2,3,4,...], 2: [8,9,..], ..}
  - **clusterGroup**: dictionary {channel_group: cluster_group}  
       with cluster_group a dictionnary {"Good":[list of clu],"Noise":[List of clu],"MUA":...,"Unsorted":...}   
  
#### Nested dictionary { channelGroup: { clu : 1D numpy array} }  
    For each channelGroup, a dictionary  
    In this dictionary, for each clu, the list of the spike times or spike sample
    
   - **spikeSample**: sample of the spikes
   - **spikeTime**: sample/spikeSamplingRate 
   - **spikeIndex**: index in the list, usefull to get waveform In .kwx

In [None]:
class BaseRawSpikeData:
    def __init__(self,rootFolder,rat,experiment,parameters={},saveAsPickle=True):             
        #clean name of folders (remove unnecessary slash or backslash)
        rootFolder=os.sep+rootFolder.strip(os.sep)
        rat=rat.strip(os.sep)
        experiment=experiment.strip(os.sep)
        
        #paths
        self.sessionPath=os.path.join(rootFolder,rat,"Experiments",experiment)
        self.fullPath=os.path.join(rootFolder,rat,"Experiments",experiment,experiment)
        
        #parameters dict
        self.parameters=parameters #dictionnary with 
        acquisitionParameters=self.read_acquisitionSystem_parameters()
        acquisitionParameters.update(self.parameters)
        
        #integers or floats
        self.amplification=None
        self.nBits=None
        self.nChannels=None
        self.offset=None
        self.spikeSamplingRate=None
        self.voltageRange=None
        try :
            self.amplification=acquisitionParameters["amplification"]
            self.nBits=int(acquisitionParameters["nBits"])
            self.nChannels=int(acquisitionParameters["nChannels"])
            self.offset=acquisitionParameters["offset"]
            self.spikeSamplingRate=acquisitionParameters["spikeSamplingRate"]
            self.voltageRange=acquisitionParameters["voltageRange"]
        except KeyError:
            pass
        
        #spikes dict{ key(channelGroup): {key(clu): spike time or sample} }
        self.spikeSample, self.clusterGroup=self.read_spikes_sample_and_clu()       
        self.channelGroupList = self.read_probe()

        sr=float(self.spikeSamplingRate)
        self.spikeTime={}
        for shank in self.spikeSample:
            self.spikeTime[shank]={clu:self.spikeSample[shank][clu]/sr for clu in self.spikeSample[shank]}
        
        if saveAsPickle:
            self.save_as_pickle()
        
    def get_dict(self):
        return self.__dict__
    
    def save_as_pickle(self,folder="Analysis",name="rawspikedata.p"):
        import pickle
        folderPath=os.path.join(self.sessionPath,folder)
        if not os.path.exists(folderPath):
            os.mkdir(folderPath)
        filePath=os.path.join(folderPath,name)
        pickle.dump(self.__dict__, open(filePath, "wb" ))

    def read_acquisitionSystem_parameters(self):
        defaultParam={
            "amplification":1000,
            "nBits":16,
            "offset":0,
            "spikeSamplingRate":20000,
            "voltageRange":10
        }
        return defaultParam
    
    def read_spikes_time_and_sample(self):        
        raise NotImplementedError("reading spikes is not implemented in base class")
        
    def read_probe(self):
        """
        read channels number for each group (shank)
        """
        ch = int(self.nChannels / len(self.spikeSample))
        return {key: list(range(index*ch, index*ch +ch)) for index, key in enumerate(sorted(self.spikeSample))}

### Pavel spike data (Kluster, clu,res)

In [None]:
class Kluster_RawSpikeData(BaseRawSpikeData):
    '''
    Data obtained with Kluster (.clu, .res)
    Only one group: "good" clusters. Every other cluster (noise, mua) is merge into cluster 0.
    '''
    def read_acquisitionSystem_parameters(self):
        with open(self.fullPath+'.xml', "rb") as f:    # notice the "rb" mode
            d = xmltodict.parse(f, xml_attribs=True)
        acquisitionParam = d['parameters']['acquisitionSystem']
        acquisitionParam={key: float(acquisitionParam[key]) for key in acquisitionParam}
        acquisitionParam["spikeSamplingRate"]=acquisitionParam["samplingRate"]
        #number of shank ("channel group")
        
        g = d['parameters']['spikeDetection']['channelGroups']['group']
        if not isinstance(g, list):
            g = [g]
        self.nChGroup= len(g)
        self.channelGroupList = {index+1: [int(s) for s in shank['channels']['channel']] 
                                 for index, shank in enumerate(g)}
        
        return acquisitionParam
    
    def read_probe(self):
        return self.channelGroupList
    
    def read_spikes_sample_and_clu(self):
        '''  
        A function to load from the .clu and .res files of all the shank
        Removes the spikes from 0 and 1 clusters (noise). 
        '''
        spikeSample,clusterGroup={},{}

        for chGroup in range(1,self.nChGroup+1):
            #load clu and res
            if not os.path.exists(self.fullPath+'.clu.'+str(chGroup)):
                continue  
            clu = pd.read_csv(self.fullPath+'.clu.'+str(chGroup)).values #header=first line = number of clusters
            res = pd.read_csv(self.fullPath+'.res.'+str(chGroup),header=None).values

            #get all cluster id
            clusterIDList=list(np.unique(clu))
            #remove 0 and 1
            for value in [0,1]:
                if value in clusterIDList:
                    clusterIDList.remove(value)

            #load in dictionary
            spikeSample[chGroup]={}
            clusterGroup[chGroup]={"Good":clusterIDList}
            
            for clusterID in clusterIDList:
                spikeSample[chGroup][clusterID]=res[clu==clusterID]             
        return spikeSample,clusterGroup
            


In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    #root,animal,a,experiment="data/Rat034/Experiments/Rat034_2015_03_04_10_04".split("/")
    ROOT="/data"
    ANIMAL="MOU035"
    SESSION="MOU035_2014_12_11_11_08"

    data=Kluster_RawSpikeData(ROOT,ANIMAL,SESSION)
    print(data.channelGroupList)
    print(data.nChannels)

### Teresa spike data (klusta, .kwik)

In [None]:
class Klusta_RawSpikeData(BaseRawSpikeData):
    '''
    Data obtained with Klusta Suite
    Can be read with phy
    4 default group: noise(0), MUA(1), Good(2) and Unsorted(3)
       User could have created new groups (>=4), but I didn't find the new names in the kwik file. 
       So new groups are merged with Good
    '''
    def read_acquisitionSystem_parameters(self):
        import phy 
        from phy.session import Session
        self.session= Session(self.fullPath+".kwik")
        param={
            "amplification":1000,
            "nBits":self.session.model.metadata["nbits"],
            "offset":0,
            "spikeSamplingRate":self.session.model.sample_rate,
            "voltageRange":self.session.model.metadata["voltage_gain"],
            "nChannels":self.session.model.metadata["nchannels"]
        }
        return param
    
    def read_probe(self):        
        del self.session
        return self.channelGroupList
           
    def read_spikes_sample_and_clu(self):
        #chgroup is channel group (shank)
        #cluGroup is the group of the cluster (noise, mua, good...)
        spikeSample,clusterGroup={},{}
        spikeIndex={}
        self.channelGroupList = {}
        first = True
        for chgroup in self.session.model.channel_groups:
            if first:
                first = False
            else:
                self.session.change_channel_group(chgroup)
            spikeSample[chgroup]={}
            clusterGroup[chgroup]={}
            spikeIndex[chgroup]={}
            self.channelGroupList[chgroup] = list(self.session.model.channel_order)
            #default cluster group names (0: 'Noise', 1: 'MUA', 2: 'Good', 3: 'Unsorted')
            for cluGroupName in self.session.model.default_cluster_groups.values():
                clusterGroup[chgroup][cluGroupName]=[]
            
            #for every cluster in the channel_group
            for clu in self.session.model.cluster_ids:
                if clu == 0:
                    continue
                spikeSample[chgroup][clu]=self.session.model.spike_samples[self.session.model.spike_clusters==clu]  
                spikeIndex[chgroup][clu]=self.session.model.spike_ids[self.session.model.spike_clusters==clu]
                cluGroupID=self.session.model.cluster_metadata.group(clu)
                if isinstance(cluGroupID, np.ndarray):
                    print("Warning- cluster group of cluster %s is an array (%s), taking first value"%(clu,cluGroupID))
                    cluGroupID=cluGroupID[0]
                if isinstance(cluGroupID, bytes):
                    print("Warning- cluster group of cluster %s is a bytes, putting it in 'unsorted'"%(clu))
                    cluGroupID = 3
                # if a group was created (ID>3), put it in "Good"(2)
                if cluGroupID>3:
                    cluGroupID=2
                cluGroupName=self.session.model.default_cluster_groups[cluGroupID]        
                clusterGroup[chgroup][cluGroupName].append(clu)

        self.spikeIndex=spikeIndex
        return spikeSample,clusterGroup
   


In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    #root,animal,a,experiment="data/Rat034/Experiments/Rat034_2015_03_04_10_04".split("/")
    root="data"
    animal="MOU102"
    experiment="MOU102_2016_01_21_11_16/"

    data=Klusta_RawSpikeData(root,animal,experiment)
    print(data.channelGroupList)
    