In [None]:
# import KWIKphy
import sys, pickle
import numpy as np
import pandas as pd
import os, itertools
# from KWIKphy.session import Session
import xmltodict

## Loading raw spike data

#### Paths: same as behavior

#### Acquisition parameters (integers or floats):
  - **amplification**: voltage amplification
  - **nbits**: floats in .dat file are stored on nbits
  - **nChannels**: total number of channels, needed to read .dat file
  - **offset**: always 0
  - **spikeSamplingRate**: number of acquisition per seconds for the spikes   
    other names: SampleRate, sample_rate   
  - **voltageRange**
  
#### Other
  - **channelGroupList**: dictionary of channel groups (shank), ex: {1: [0,1,2,3,4,...], 2: [8,9,..], ..}
  - **clusterGroup**: dictionary {channel_group: cluster_group}  
       with cluster_group a dictionnary {"Good":[list of clu],"Noise":[List of clu],"MUA":...,"Unsorted":...}   
  
#### Nested dictionary { channelGroup: { clu : 1D numpy array} }  
    For each channelGroup, a dictionary  
    In this dictionary, for each clu, the list of the spike times or spike sample
    
   - **spikeSample**: sample of the spikes
   - **spikeTime**: sample/spikeSamplingRate 
   - **spikeIndex**: index in the list, usefull to get waveform In .kwx

In [None]:
class Base_RawSpikeData:
    
    def __init__(self,rootFolder,experiment,parameters={},saveAsPickle=True):             
        self.root=os.sep+rootFolder.strip(os.sep)
        self.experiment=experiment.strip(os.sep)
        self.session=self.experiment
        self.animal=experiment[:6]
        
        #paths
        self.sessionPath=os.path.join(self.root,self.animal,"Experiments",self.experiment)
        self.fullPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,self.experiment)
        
        #check if every file is available, add them to the self
        self._add_files()
        
        #parameters dict
        self.parameters=parameters #dictionnary with 
        acquisitionParameters=self.read_acquisitionSystem_parameters()
        for key in acquisitionParameters:
            self.__dict__[key]=acquisitionParameters[key]
        acquisitionParameters.update(self.parameters)
                
        self.channelGroup = self.read_probe()
        
        if saveAsPickle:
            self.save_as_pickle()
        
    def _add_files(self,files=['prm','prb','kwik','evt.cam','evt.tre']):
        
        for file in files:
            assert isinstance(file,str), 'file types must be a string'
            key=file.replace('.','')+'File'
            self.__dict__.update({key:self.fullPath+'.'+file})
            if not os.path.isfile(self.__dict__[key]):
                self.__dict__[key]=False

    def read_acquisitionSystem_parameters(self):
        assert self.prmFile, "No prm file"
        prm=self.prm_reader(self.prmFile)
        defaultParam={
            "gain":prm['voltage_gain'],
            "nBits":prm['nbits'],
            "samplingRate":prm['sample_rate'],
            "nChannels":prm['nchannels']
        }
        return defaultParam               
                
    def get_dict(self):
        return self.__dict__
    
    def save_as_pickle(self,folder="Analysis",name="rawspikedata.p"):

        folderPath=os.path.join(self.sessionPath,folder)
        if not os.path.exists(folderPath):
            os.mkdir(folderPath)
        filePath=os.path.join(folderPath,name)
        with open(filePath, "wb" ) as file:
            pickle.dump(self, file)
        
    def read_probe(self):
        """
        read channels number for each group (shank)
        """
        assert self.prbFile, "No prb file"
        prb=self.prm_reader(self.prbFile)
        return prb['channel_groups']
    
    @staticmethod
    def prm_reader(prmFile):
        CWD=os.getcwd()
        try:
            os.chdir(os.path.dirname(prmFile))
            prmName=os.path.basename(prmFile)
            %run $prmName
        finally:
            os.chdir(CWD)
        return globals()


## Klusta+Phy spike data (**.kwik*)

In [None]:
class Unit:
    """
    This class is to implement the basic data structures for detected clusters
    klusta: an instance of the KlustaRawSpikeData class
    """
    
    #locking new attribute creation other than those stated below
    __slots__=('spikeSamples','type','shank','channels','NSpikes','id')
    
    def __init__(self,spikeTimes, cluType, cluChGroup, cluCh=-1, cluId=-1):
        if not isinstance(spikeTimes,np.ndarray):
            self.spikeSamples=np.asarray(spikeTimes)
            assert self.spikeSamples.ndim <=2, "spikeTime must be a 1D vector"
            if self.spikeSamples.ndim==2:
                assert 1 in self.spikeSamples.shape, "spikeTime has bad structure"
        else:
            self.spikeSamples=spikeTimes
        self.type=cluType
        self.shank=cluChGroup
        self.channels=cluCh
        self.NSpikes=len(self.spikeSamples)
        self.id=cluId
        
    def __repr__(self):
        return f'<{self.type} unit containing {self.NSpikes} spikes>'


class Klusta_RawSpikeData(Base_RawSpikeData):
    '''
    Data obtained with Klusta
    And curated with phy
    4 default group: noise(0), MUA(1), Good(2) and Unsorted(3)
       User could have created new groups (>=4), but I didn't find the new names in the kwik file. 
       So new groups are merged with Good
    '''
    def __init__(self,*args,**kwargs):
        super().__init__(*args,**kwargs)
        assert self.kwikFile, 'No kwik file'
        
        #Reading spiking data off the kwik file
        self.read_spikes_sample_and_clu()
    
    #2 Following methods are defined because original class was not pickle-able
    #============================================
    def __getstate__(self):
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        del state['kwikSession']
        return state

    def __setstate__(self, state):
        # Restore instance attributes (
        self.__dict__.update(state)
        # Restore the previously removed entries
        self.kwikSession= Session(self.kwikFile)
    #=============================================
    
    def read_acquisitionSystem_parameters(self):
        
        self.kwikSession= Session(self.kwikFile)
        
        try:
            param={
                "gain":              self.kwikSession.model.metadata['voltage_gain'],
                "nBits":             self.kwikSession.model.metadata["nbits"],
                "samplingRate":      self.kwikSession.model.sample_rate,
                "voltageRange":      self.kwikSession.model.metadata["voltage_gain"],
                "nChannels":         self.kwikSession.model.metadata["n_channels"]
            }
        except Exception as e:
            print(self.session,"Failure to find info in the kwik file, loading from prmFile")
            print(repr(e))
            param=super().read_acquisitionSystem_parameters()
        
        return param 
           
    def read_spikes_sample_and_clu(self):
        #chgroup is channel group (shank)
        #cluGroup is the group of the cluster (noise, mua, good...)
        self.is_curated= False
        spikeSample,clusterGroup,unitDict={},{},{}
        spikeIndex= {}
        first= True
        for chgroup in self.kwikSession.model.channel_groups:
            if first:
                first= False
            else:
                self.kwikSession.change_channel_group(chgroup)
            unitDict[chgroup]=[]
            clusterGroup[chgroup]={}

            #default cluster group names (0: 'Noise', 1: 'MUA', 2: 'Good', 3: 'Unsorted')
            for cluGroupName in self.kwikSession.model.default_cluster_groups.values():
                clusterGroup[chgroup][cluGroupName]=[]
            
            #for every cluster in the channel_group
            for clu in self.kwikSession.model.cluster_ids:
                if clu == 0:
                    continue
                spikeSample=self.kwikSession.model.spike_samples[self.kwikSession.model.spike_clusters==clu]  
                cluGroupID=self.kwikSession.model.cluster_metadata.group(clu)
                if isinstance(cluGroupID, np.ndarray):
                    print("Warning- cluster group of cluster %s is an array (%s), taking first value"%(clu,cluGroupID))
                    cluGroupID=cluGroupID[0]
                if isinstance(cluGroupID, bytes):
                    print("Warning- cluster group of cluster %s is a bytes, putting it in 'unsorted'"%(clu))
                    cluGroupID = 3
                # if a group was created (ID>3), put it in "Good"(2)
                if cluGroupID>3:
                    cluGroupID=2
                cluGroupName=self.kwikSession.model.default_cluster_groups[cluGroupID]        
                clusterGroup[chgroup][cluGroupName].append(clu)
                
                unit=Unit(spikeTimes=spikeSample, cluType=cluGroupName, cluChGroup=chgroup, cluId=clu)
                unitDict[chgroup].append(unit)
                if cluGroupID != 3: #if the cluster is not unsorted
                    self.is_curated= True
       
        self.unitDict=unitDict
    
    def typed_units(self):
        """
        This function returns a dict of units
        ordered by their type (Good, Noise,...)
        """
        typedUnits={}
        for unit in itertools.chain.from_iterable(self.unitDict.values()):
            try:
                typedUnits[unit.type].append(unit)
            except KeyError:
                typedUnits[unit.type]=[unit]
        return typedUnits
        

In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    root="/data"
    experiment="Rat173_2018_02_25_12_03"

    data=Klusta_RawSpikeData(root,experiment,saveAsPickle=False)
    