In [None]:
%matplotlib notebook

# register and preprocess
Run from here (as `!expipe register etc.`) or terminal (without exclamation)

```
!expipe register openephys /Users/ehagen/COBRA/COBRA_exp_data/1010_2018-10-29_18-02-16_1/ --user Espen --location expipehell --overwrite
!expipe register process --probe-path /Users/ehagen/COBRA/COBRA_exp_data/neuronexus-32-linear-list.prb --sorter ironclust 1010-291018-1
!expipe register psychopy 1010-291018-1
!expipe register mousexy  1010-291018-1
```

In [None]:
import expipe
import os
from expipe_plugin_cinpla.imports import project
import quantities as pq
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import h5py
import neo
import exdir
import exdir.plugins.git_lfs # stupid
#from exana.statistics import correlogram
import pandas
import scipy.signal as ss
import elephant.current_source_density as csd
from skimage.measure import block_reduce

In [None]:
#plt.rcParams.update(**plt.rcParamsDefault)
plt.rcParams.update({
    #'figure.dpi' : 150,
    'figure.figsize' : [6.4*1.5, 4.8*1.5],
})

In [None]:
expipe.config.settings

In [None]:
#action_id = '1006-121018-03'
#action_id = '1011-301018-1'
#action_id = '1012-090119-01' # flashes_quick
#action_id = '1012-090119-02' # gratings_quick
#action_id = '1012-090119-04' # flashes_quick
#action_id = '1012-090119-05' # gratings_quick
action_id = '1012-090119-06' # gratings
#action_id = '1012-090119-07' # images
#action_id = '1012-090119-08' # sparsenoise
#action_id = '1012-090119-09'  # static gratings
#action_id = '1012-090119-10' # videos
#action_id = '1012-090119-11' # flashes_quick
#action_id = '1012-090119-11' # gratings_quick
#project_id = PAR.PROJECT_ID
#project = PAR.PROJECT
action = project.actions[action_id]
from expipe_plugin_cinpla.scripts.utils import _get_data_path
exdir_path = _get_data_path(action)
#exdir_path = os.path.join(str(action._backend.path), 'data', 'main.exdir')
channel_group = 0

In [None]:
par = ANALYSIS_PARAMS = {
    'speed_filter': 5 * pq.m / pq.s,
    'pos_fs': 100 * pq.Hz,
    'f_cut': 6 * pq.Hz,
    'spat_binsize': 0.02 * pq.m,
    'spat_smoothing': 0.025,
    'grid_stepsize': 0.1 * pq.m,
    'box_xlen': 1 * pq.m,
    'box_ylen': 1 * pq.m,
    'ang_binsize': 4,
    'ang_n_avg_bin': 4,
    'imgformat': '.png',
    'corr_bin_width': 0.001 * pq.s,
    'corr_limit': 0.1 * pq.s,
    'isi_binsize': 1 * pq.ms,
    'isi_time_limit': 100 * pq.ms,
}

In [None]:
# Fourier params
NFFT = 512
noverlap = 384

# Stimulus epochs

In [None]:
f = exdir.File(exdir_path, 'r', plugins=[exdir.plugins.quantities, exdir.plugins.git_lfs])
stim = f['processing']['epochs']['visual_stimulus']
grp = stim[list(stim.keys())[0]]
if list(stim.keys()) == ['image']:
    stimulation_type = 'image'
    # nicer label formatting
    labels = np.array([pathlib.PureWindowsPath(lbl).parts[-1].rstrip('.png') for lbl in grp['image'].value])
    epo=neo.Epoch(times=grp['times'].value, 
                  durations=grp['duration'].value*pq.s, 
                  labels=labels)
elif list(stim.keys()) == ['grating']:
    stimulation_type = 'grating'
    annotations = dict(frequency=grp['frequency'].value*pq.Hz,
                       orientation=grp['orientation'].value*pq.deg,
                       phase=grp['phase'].value,
                       spatial_frequency=grp['spatial_frequency'].value
                      )
    df = pandas.DataFrame(annotations)
    # nicer label formatting
    labels = np.array([df.loc[i].to_csv().replace(',', ':').replace('\n', ',') for i in range(df.shape[0])])
    labels = [label.replace('spatial_frequency:', '\omega=') for label in labels]
    labels = [label.replace('frequency:', 'f=') for label in labels]
    labels = [label.replace('orientation:', r'\alpha=') for label in labels]
    labels = [label.replace('phase:', r'\theta=') for label in labels]
    labels = np.array([r'${}$'.format(label) for label in labels])
    epo=neo.Epoch(times=grp['times'].value, 
                  durations=grp['duration'].value*pq.s, 
                  labels=labels,
                  **annotations)
elif list(stim.keys()) == ['sparsenoise']:
    stimulation_type = 'sparsenoise'
    epo=neo.Epoch(times=grp['times'].value, 
                  durations=grp['duration'].value*pq.s, 
                  labels=grp['image'].value)
elif list(stim.keys()) == ['movie']:
    stimulation_type = 'movie'
    # nicer label formatting
    labels = np.array([pathlib.PureWindowsPath(lbl).parts[-1].rstrip('.mp4') for lbl in grp['movie'].value])
    epo=neo.Epoch(times=grp['times'].value, 
                  durations=np.array([30.]*grp['times'].size)*pq.s, #### HACK!!!! ###### 
                  labels=labels)
elif list(stim.keys()) == ['flash']:
    stimulation_type = 'flash'
    epo=neo.Epoch(times=grp['times'].value, 
                  durations=(grp['duration'].value + 0.25)*pq.s, 
                  labels=grp['color'].value)
f.close()

In [None]:
stimulation_type , epo.times, epo.durations #, seg.analogsignals

# Spikes
Load all spiketrains as list of `neo.SpikeTrain` objects

In [None]:
load_waveforms = True
if load_waveforms:
    f = exdir.File(exdir_path, 'r', plugins=[exdir.plugins.quantities, exdir.plugins.git_lfs])
    io = neo.ExdirIO(str(exdir_path))
    blk = io.read_block()
    blk_cluster_ids = np.array([str(unit.annotations['cluster_id']) for unit in blk.channel_indexes[0].units])
    t_stop = f.attrs['session_duration']
    UnitTimes = f['processing']['electrophysiology']['channel_group_0']['UnitTimes']
    units = np.sort(list(UnitTimes.keys()))
    spiketrains = []
    for unit in units:
        if not UnitTimes[unit].attrs['cluster_group'] == 'noise':
            spiketrains += [neo.SpikeTrain(UnitTimes[unit]['times'].value, t_stop=t_stop, 
                                          name=unit, description=UnitTimes[unit].attrs['cluster_group'],
                                          waveforms=blk.segments[0].spiketrains[np.where(blk_cluster_ids==unit)[0][0]].waveforms)]
    f.close()
else:
    f = exdir.File(exdir_path, 'r', plugins=[exdir.plugins.quantities, exdir.plugins.git_lfs])
    t_stop = f.attrs['session_duration']
    UnitTimes = f['processing']['electrophysiology']['channel_group_0']['UnitTimes']
    units = np.sort(list(UnitTimes.keys()))
    spiketrains = [neo.SpikeTrain(UnitTimes[unit]['times'].value, t_stop=t_stop, 
                                  name=unit, description=UnitTimes[unit].attrs['cluster_group'],
                                  waveforms=None) 
                   for unit in units if not UnitTimes[unit].attrs['cluster_group'] == 'noise']
    f.close()

# Waveforms

In [None]:
# def some plotting functions
def remove_axis_junk(ax, lines=['right', 'top']):
    for loc, spine in ax.spines.items():
        if loc in lines:
            spine.set_color('none')            
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

def draw_lineplot(
        ax, data, dt=0.1,
        T=(0, 200),
        scaling_factor=1.,
        vlimround=None,
        label='local',
        scalebar=True,
        unit='mV',
        ylabels=True,
        color='k',
        ztransform=True,
        filter=False,
        filterargs=dict(N=2, Wn=0.02, btype='lowpass')
        ):
    ''' draw some nice lines'''
    
    tvec = np.arange(data.shape[1])*dt
    if T[0] < 0:
        tvec += T[0]
    try:
        tinds = (tvec >= T[0]) & (tvec <= T[1])
    except TypeError:
        print(data.shape, T)
        raise Exception
    
    # apply temporal filter
    if filter:
        b, a = ss.butter(**filterargs)
        data = ss.filtfilt(b, a, data, axis=-1)
    
    #subtract mean in each channel
    if ztransform:
        dataT = data.T - data.mean(axis=1)
        data = dataT.T

    zvec = np.arange(data.shape[0])
    vlim = abs(data[:, tinds]).max()
    if vlimround is None:
        vlimround = 2.**np.round(np.log2(vlim)) / scaling_factor
    else:
        pass

    yticklabels=[]
    yticks = []
    
    for i, z in enumerate(zvec):
        if i == 0:
            ax.plot(tvec[tinds], data[i][tinds] / vlimround + z, lw=.5,
                    rasterized=False, label=label, clip_on=False,
                    color=color)
        else: 
            ax.plot(tvec[tinds], data[i][tinds] / vlimround + z, lw=.5,
                    rasterized=False, clip_on=False,
                    color=color)
        yticklabels.append('ch. %i' % (i+1))
        yticks.append(z)
        
    if scalebar:
        ax.plot([tvec[tinds][-1], tvec[tinds][-1]],
                [zvec[-1], zvec[-2]], lw=2, color='k', clip_on=False)
        ax.text(tvec[tinds][-1]+np.diff(T)*0.0, np.mean([zvec[-1], zvec[-2]]),
                '$2^{' + '{}'.format(int(np.log2(vlimround))) + '}$ ' + '{0}'.format(unit),
                color='r', rotation='vertical',
                va='center', zorder=100)

    ax.axis(ax.axis('tight'))
    if ylabels:
        ax.yaxis.set_ticks(yticks)
        ax.yaxis.set_ticklabels(yticklabels)
        #ax.set_ylabel('channel', labelpad=0.1)
    else:
        ax.yaxis.set_ticklabels([])
    remove_axis_junk(ax, lines=['right', 'top'])
    ax.set_xlabel(r't (ms)', labelpad=0.1)
    
    return vlimround

In [None]:
if load_waveforms:
    vlimround=2**6
    fig, axes = plt.subplots(1, len(spiketrains), sharey=True)
    for i, (ax, spiketrain) in enumerate(zip(axes, spiketrains)):
        mean = spiketrain.waveforms.mean(axis=0)
        std = spiketrain.waveforms.std(axis=0)
        draw_lineplot(ax, mean, dt=1./30, T=(-0.5, 1.5), vlimround=vlimround, unit=mean.dimensionality)
        draw_lineplot(ax, mean+std*2, dt=1./30, T=(-.5, 1.5), vlimround=vlimround, scalebar=False, color='0.5')
        draw_lineplot(ax, mean-std*2, dt=1./30, T=(-.5, 1.5), vlimround=vlimround, scalebar=False, color='0.5')
        ax.set_title('#{}'.format(spiketrain.name))
        if i != 0:
            plt.setp(ax.get_yticklabels(), visible=False)

# LFP, MUA and CSD
LFPs, MUAs and corresponding reconstructed CSDs defined as `neo.AnalogSignal` objects

In [None]:
f = exdir.File(exdir_path, 'r', plugins=[exdir.plugins.quantities])
# LFP
_lfp = f['processing']['electrophysiology']['channel_group_0']['LFP']
keys = list(_lfp.keys())
electrode_value = [_lfp[key]['data'].value.flatten() for key in keys]
electrode_idx = [_lfp[key].attrs['electrode_idx'] for key in keys]
sampling_rate = _lfp[keys[0]].attrs['sample_rate']
units = _lfp[keys[0]]['data'].attrs['unit']
LFP = np.r_[[_lfp[key]['data'].value.flatten() for key in keys]].T
#LFP = (LFP.T - np.median(np.array(LFP), axis=-1)).T #CMR reference
#LFP = (LFP.T - LFP[:, 0]).T # use topmost channel as reference
LFP = LFP[:, np.argsort(electrode_idx)]

LFP = neo.AnalogSignal(LFP, 
                       units=units, t_stop=t_stop, sampling_rate=sampling_rate)
LFP = LFP.rescale('mV')

# MUA
_mua = f['processing']['electrophysiology']['channel_group_0']['MUA']
keys = list(_mua.keys())
sampling_rate = _mua[keys[0]].attrs['sample_rate']
units = _mua[keys[0]]['data'].attrs['unit']
MUA = np.r_[[_mua[key]['data'].value.flatten() for key in keys]].T
MUA = MUA[:, np.argsort(electrode_idx)]
MUA = neo.AnalogSignal(MUA, 
                       units=units, t_stop=t_stop, sampling_rate=sampling_rate)
MUA = MUA.rescale('mV')
f.close()

In [None]:
# CSD
coords=np.arange(LFP.shape[1]).reshape((-1,1))*25*pq.um
h = np.ones(coords.size)*25*pq.um
diam = 100*pq.um
CSD = csd.estimate_csd(LFP, coords=coords, h=h, diam=diam, method='StepiCSD').rescale('uA/mm**3')

# MouseXY tracking

In [None]:
f = exdir.File(exdir_path, 'r', plugins=[exdir.plugins.quantities])
tracks = dict()
for key, value in f['processing']['tracking']['trackball']['position'].items():
    tracks[key] = value
f.close()

# Spiketrain analysis

In [None]:
def raster_plot(ax, spiketrains, T=[0., 10.], epo=None):
    '''
    Arguments
    ---------
    ax : matplotlib.axes._subplots.AxesSubplot
    spiketrains : list of neo.SpikeTrain objects
    T : length 2 list/tuple of floats
        time interval in seconds
    epo : None or neo.Epoch object
        show onset/offset times of stimuli
    '''
    yticklabels = []
    for i, spiketrain in enumerate(spiketrains):
        yticklabels.append('{} ({})'.format(spiketrain.name, spiketrain.description))
        ax.plot(spiketrain, np.zeros(spiketrain.size)+i, 'C0|')
    if epo is not None:
        axis = ax.axis('tight')
        ax.vlines(epo.times, axis[2], axis[3], 'g')
        ax.vlines((epo.times+epo.durations), axis[2], axis[3], 'r')        
    ax.set_yticks(range(len(spiketrains)))
    ax.set_yticklabels(yticklabels)
    ax.set_ylabel('unit id')
    ax.set_xlim(T)
    ax.set_xlabel('t (s)')
    ax.set_title('spike raster')

In [None]:
fig, ax = plt.subplots(1,1)
raster_plot(ax, spiketrains, T=[30., 40.], epo=epo)

In [None]:
# interspike interval (ISI) distributions (log-linear bins)
nrows = int(np.ceil(np.sqrt(len(spiketrains))))
ncols = int(np.ceil(len(spiketrains) / nrows))
fig = plt.figure()
gs = GridSpec(nrows, ncols)
fig.subplots_adjust(wspace=0.4, hspace=0.4)
bins = 10**np.linspace(-3, 1, 51) # 50 log-lin bins between 1 ms and 10 s
labels = [spiketrain.name for spiketrain in spiketrains]
for i, (sptr, name) in enumerate(zip(spiketrains, labels)):
    ax = fig.add_subplot(gs[i // ncols, i % ncols])
    ax.hist(np.diff(sptr), bins=bins, color='g' if 'good' in name else None)
    ax.semilogx()
    ax.set_xlim(bins.min(), bins.max())
    ax.set_title('unit {} ({})'.format(sptr.name, sptr.description))
    if i % ncols == 0:
        ax.set_ylabel('#')
    if i  < len(spiketrains) - ncols:
        ax.set_xticklabels([])
    else:
        ax.set_xlabel('ISI (s)')

In [None]:
# interspike interval (ISI) distributions (linear bins)
nrows = int(np.ceil(np.sqrt(len(spiketrains))))
ncols = int(np.ceil(len(spiketrains) / nrows))
fig = plt.figure()
gs = GridSpec(nrows, ncols)
fig.subplots_adjust(wspace=0.4, hspace=0.4)
bins = np.linspace(1, 50, 51) # 50 log-lin bins between 1 ms and 50 ms
for i, (sptr, name) in enumerate(zip(spiketrains, labels)):
    ax = fig.add_subplot(gs[i // ncols, i % ncols])
    ax.hist(np.diff(sptr)*1E3, bins=bins, color='g' if 'good' in name else None)
    ax.set_title('unit {} ({})'.format(sptr.name, sptr.description))
    if i % ncols == 0:
        ax.set_ylabel('#')
    if i  < len(spiketrains) - ncols:
        ax.set_xticklabels([])
    else:
        ax.set_xlabel('ISI (ms)')

In [None]:
def correlogram(t1, t2=None, binsize=.001, limit=.02, auto=False,
                density=False):
    """Return crosscorrelogram of two spike trains.
    Essentially, this algorithm subtracts each spike time in `t1`
    from all of `t2` and bins the results with np.histogram, though
    several tweaks were made for efficiency.
    Originally authored by Chris Rodger, copied from OpenElectrophy, licenced
    with CeCill-B. Examples and testing written by exana team.

    Parameters
    ---------
    t1 : np.array, or neo.SpikeTrain
        First spiketrain, raw spike times in seconds.
    t2 : np.array, or neo.SpikeTrain
        Second spiketrain, raw spike times in seconds.
    binsize : float, or quantities.Quantity
        Width of each bar in histogram in seconds.
    limit : float, or quantities.Quantity
        Positive and negative extent of histogram, in seconds.
    auto : bool
        If True, then returns autocorrelogram of `t1` and in
        this case `t2` can be None. Default is False.
    density : bool
        If True, then returns the probability density function.
    See also
    --------
    :func:`numpy.histogram` : The histogram function in use.

    Returns
    -------
    (count, bins) : tuple
        A tuple containing the bin right edges and the
        count/density of spikes in each bin.
    Note
    ----
    `bins` are relative to `t1`. That is, if `t1` leads `t2`, then
    `count` will peak in a positive time bin.

    Examples
    --------
    >>> t1 = np.arange(0, .5, .1)
    >>> t2 = np.arange(0.1, .6, .1)
    >>> limit = 1
    >>> binsize = .1
    >>> counts, bins = correlogram(t1=t1, t2=t2, binsize=binsize,
    ...                            limit=limit, auto=False)
    >>> counts
    array([0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 0, 0, 0])

    The interpretation of this result is that there are 5 occurences where
    in the bin 0 to 0.1, i.e.

    # TODO fix
    # >>> idx = np.argmax(counts)
    # >>> '%.1f, %.1f' % (abs(bins[idx - 1]), bins[idx])
    # '0.0, 0.1'

    The correlogram algorithm is identical to, but computationally faster than
    the histogram of differences of each timepoint, i.e.

    # TODO Fix the doctest
    # >>> diff = [t2 - t for t in t1]
    # >>> counts2, bins = np.histogram(diff, bins=bins)
    # >>> np.array_equal(counts2, counts)
    # True
    """
    if auto: t2 = t1
    # For auto-CCGs, make sure we use the same exact values
    # Otherwise numerical issues may arise when we compensate for zeros later

    if not int(limit * 1e10) % int(binsize * 1e10) == 0:
        raise ValueError(
            'Time limit {} must be a '.format(limit) +
            'multiple of binsize {}'.format(binsize) +
            ' remainder = {}'.format(limit % binsize))
    # For efficiency, `t1` should be no longer than `t2`
    swap_args = False
    if len(t1) > len(t2):
        swap_args = True
        t1, t2 = t2, t1

    # Sort both arguments (this takes negligible time)
    t1 = np.sort(t1)
    t2 = np.sort(t2)

    # Determine the bin edges for the histogram
    # Later we will rely on the symmetry of `bins` for undoing `swap_args`
    limit = float(limit)

    # The numpy.arange method overshoots slightly the edges i.e. binsize + epsilon
    # which leads to inclusion of spikes falling on edges.
    bins = np.arange(-limit, limit + binsize, binsize)

    # Determine the indexes into `t2` that are relevant for each spike in `t1`
    ii2 = np.searchsorted(t2, t1 - limit)
    jj2 = np.searchsorted(t2, t1 + limit)

    # Concatenate the recentered spike times into a big array
    # We have excluded spikes outside of the histogram range to limit
    # memory use here.
    big = np.concatenate([t2[i:j] - t for t, i, j in zip(t1, ii2, jj2)])

    # Actually do the histogram. Note that calls to np.histogram are
    # expensive because it does not assume sorted data.
    count, bins = np.histogram(big, bins=bins, density=density)

    if auto:
        # Compensate for the peak at time zero that results in autocorrelations
        # by subtracting the total number of spikes from that bin. Note
        # possible numerical issue here because 0.0 may fall at a bin edge.
        c_temp, bins_temp = np.histogram([0.], bins=bins)
        bin_containing_zero = np.nonzero(c_temp)[0][0]
        count[bin_containing_zero] = 0#-= len(t1)

    # Finally compensate for the swapping of t1 and t2
    if swap_args:
        # Here we rely on being able to simply reverse `counts`. This is only
        # possible because of the way `bins` was defined (bins = -bins[::-1])
        count = count[::-1]

    return count, bins[1:]

In [None]:
# spike-train correlelograms
fig = plt.figure()
gs = GridSpec(len(spiketrains), len(spiketrains))
for i, t1 in enumerate(spiketrains):
    for j, t2 in enumerate(spiketrains):
        if j < i:
            continue
        else:
            bin_width = par['corr_bin_width'].rescale('s').magnitude
            limit = par['corr_limit'].rescale('s').magnitude
            count, bins = correlogram(t1=t1.as_array(), t2=t2.as_array(),
                                      binsize=bin_width, limit=limit,
                                      auto=True if i == j else False)
            ax = fig.add_subplot(gs[i, j])
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.bar(bins, count, width=bin_width)
            ax.set_ylim(0, count.max())
            ax.set_xlim([-limit, limit])
            if i == 0:
                ax.set_title('#{}'.format(t2.name), fontsize='medium')
            if i == j:
                ax.set_ylabel('#{}'.format(t1.name))       

# timeseries analysis

In [None]:
# def some plotting functions
def remove_axis_junk(ax, lines=['right', 'top']):
    for loc, spine in ax.spines.items():
        if loc in lines:
            spine.set_color('none')            
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

def draw_lineplot(
        ax, data, dt=0.1,
        T=(0, 200),
        scaling_factor=1.,
        vlimround=None,
        label='local',
        scalebar=True,
        unit='mV',
        ylabels=True,
        color='k',
        ztransform=True,
        filter=False,
        filterargs=dict(N=2, Wn=0.02, btype='lowpass')
        ):
    ''' draw some nice lines'''
    
    tvec = np.arange(data.shape[1])*dt
    if T[0] < 0:
        tvec += T[0]
    try:
        tinds = (tvec >= T[0]) & (tvec <= T[1])
    except TypeError:
        print(data.shape, T)
        raise Exception
    
    # apply temporal filter
    if filter:
        b, a = ss.butter(**filterargs)
        data = ss.filtfilt(b, a, data, axis=-1)
    
    #subtract mean in each channel
    if ztransform:
        dataT = data.T - data.mean(axis=1)
        data = dataT.T

    zvec = np.arange(data.shape[0])
    vlim = abs(data[:, tinds]).max()
    if vlimround is None:
        vlimround = 2.**np.round(np.log2(vlim)) / scaling_factor
    else:
        pass

    yticklabels=[]
    yticks = []
    
    for i, z in enumerate(zvec):
        if i == 0:
            ax.plot(tvec[tinds], data[i][tinds] / vlimround + z, lw=.5,
                    rasterized=False, label=label, clip_on=False,
                    color=color)
        else: 
            ax.plot(tvec[tinds], data[i][tinds] / vlimround + z, lw=.5,
                    rasterized=False, clip_on=False,
                    color=color)
        yticklabels.append('ch. %i' % (i+1))
        yticks.append(z)
        
    if scalebar:
        ax.plot([tvec[tinds][-1], tvec[tinds][-1]],
                [zvec[-1], zvec[-2]], lw=2, color='k', clip_on=False)
        ax.text(tvec[tinds][-1]+np.diff(T)*0.0, np.mean([zvec[-1], zvec[-2]]),
                '$2^{' + '{}'.format(int(np.log2(vlimround))) + '}$ ' + '{0}'.format(unit),
                color='r', rotation='vertical',
                va='center', zorder=100)

    ax.axis(ax.axis('tight'))
    ax.yaxis.set_ticks(yticks)
    if ylabels:
        ax.yaxis.set_ticklabels(yticklabels)
        #ax.set_ylabel('channel', labelpad=0.1)
    else:
        ax.yaxis.set_ticklabels([])
    remove_axis_junk(ax, lines=['right', 'top'])
    ax.set_xlabel(r'time (ms)', labelpad=0.1)
    
    return vlimround

def draw_imageplot(
        ax, data, dt=0.1,
        T=(0, 200),
        scaling_factor=1.,
        label='local',
        unit='mV',
        ylabels=True,
        vmax=None,
        color='k',
        colorbar=True,
        ztransform=False,
        filter=False,
        filterargs=dict(N=2, Wn=0.02, btype='lowpass')):
    ''' draw some nice images'''
    
    tvec = np.arange(data.shape[1])*dt
    if T[0] < 0:
        tvec += T[0]
    try:
        tinds = (tvec >= T[0]) & (tvec <= T[1])
    except TypeError:
        print(data.shape, T)
        raise Exception
    
    # apply temporal filter
    if filter:
        b, a = ss.butter(**filterargs)
        data = ss.filtfilt(b, a, data, axis=-1)
    
    #subtract mean in each channel
    if ztransform:
        dataT = data.T - data.mean(axis=1)
        data = dataT.T

    zvec = np.arange(data.shape[0])

    yticklabels=[]
    yticks = []    
    for i, z in enumerate(zvec):
        yticklabels.append('ch. %i' % (i+1))
        yticks.append(z)

    if vmax is None:
        vmin = -data.std()*2.5
        vmax = data.std()*2.5
    else:
        vmin = -vmax
        
    # plot data as image
    im = ax.imshow(data[:, tinds], 
                   extent=(tvec[tinds][0], tvec[tinds][-1], zvec[0]-0.5, zvec[-1]+0.5), 
                   vmin=vmin, vmax=vmax, 
                   zorder=-1, cmap='bwr_r', interpolation='bilinear',
                   origin='lower')

    if colorbar:
        rect = np.array(ax.get_position().bounds)
        rect[0] += rect[2] + 0.01
        rect[1] = 0.3
        rect[2] = 0.01
        rect[3] = 0.4
        cax = fig.add_axes(rect)
        cbar = plt.colorbar(im, cax=cax)
        cbar.set_label(unit)


    ax.axis(ax.axis('tight'))
    ax.yaxis.set_ticks(yticks)
    if ylabels:
        ax.yaxis.set_ticklabels(yticklabels)
        #ax.set_ylabel('channel', labelpad=0.1)
    else:
        ax.yaxis.set_ticklabels([])
    remove_axis_junk(ax, lines=['right', 'top'])
    ax.set_xlabel(r'time (ms)', labelpad=0.1)
    
    return


In [None]:
# Pairwise LFP channel Pearson correlation coefficient
fig, ax = plt.subplots(1, 1)
im = ax.imshow(np.corrcoef(LFP.T), vmin=-1, vmax=1)
ax.axis(ax.axis('equal'))
fig.colorbar(im)

In [None]:
# LFP time series
fig, ax = plt.subplots(1, 1)
draw_lineplot(ax, np.array(LFP).T, dt=1., T=(30000, 40000), unit=str(LFP.units))
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')
ax.set_title('LFP shank {}'.format(channel_group))

In [None]:
# LFP power spectra (per channel)
fig, ax = plt.subplots(1, 1)
for x in LFP.T:
    ax.psd(np.array(x), NFFT=NFFT, Fs=float(LFP.sampling_rate.simplified), noverlap=noverlap)

In [None]:
# LFP/CSD time series
fig, ax = plt.subplots(1, 1)
draw_lineplot(ax, np.array(LFP).T, dt=1., T=(30000, 40000), unit=str(LFP.units))
draw_imageplot(ax, np.array(CSD).T, dt=1., T=(30000, 40000), unit=str(CSD.units))
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')
ax.set_title('LFP/CSD shank {}'.format(channel_group))

In [None]:
# CSD time series
fig, ax = plt.subplots(1, 1)
draw_lineplot(ax, np.array(CSD).T, dt=1., T=(30000, 40000), unit=str(CSD.units))
draw_imageplot(ax, np.array(CSD).T, dt=1., T=(30000, 40000), unit=str(CSD.units))
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')
ax.set_title('CSD shank {}'.format(channel_group))

In [None]:
# MUA time series
fig, ax = plt.subplots(1, 1)
draw_lineplot(ax, np.array(MUA).T, dt=1., T=(30000, 40000), unit=str(MUA.units))
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')
ax.set_title('MUA shank {}'.format(channel_group))

In [None]:
# plot MUA RMS across depth (MUA is already rectified). 
# Not identical, but inspired by Senzai et al., 
# Layer-Specific Physiological Features and Interlaminar Interactions in the 
# Primary Visual Cortex of the Mouse, Neuron (2018), https://doi.org/10.1016/j.neuron.2018.12.009
fig, ax = plt.subplots(1, 1)
ax.plot(MUA.mean(axis=0), np.arange(MUA.shape[1]), label=r'$\overline{\mathrm{MUA}}$')
ax.plot(MUA.std(axis=0), np.arange(MUA.shape[1]), label=r'$\sigma_\mathrm{MUA}$')
ax.set_yticks(range(MUA.shape[1]))
ax.legend(loc='best')
ax.set_yticklabels(['ch. {}'.format(i+1) for i in range(MUA.shape[1])])
ax.set_xlabel(r'$\overline{\mathrm{MUA}} & \sigma_\mathrm{MUA}$')
ax.set_title('mean MUA across channels')

## Whitened LFPs

In [None]:
# Compare LFP with whitened (spatially decorrelated) LFP in some frequency band
# Filter coefficients
b, a = ss.butter(N=1, Wn=np.array([1., 50.]) / LFP.sampling_rate * 2, btype='bandpass')
# filtered LFP
fLFP = neo.AnalogSignal(ss.filtfilt(b, a, LFP, axis=0), units=LFP.units, sampling_rate=LFP.sampling_rate)
# and corresponding whitening matrix
M = np.cov(fLFP.T)
w, E = np.linalg.eig(M)
W = np.array(np.matrix(E)*np.matrix(np.diag(w**-0.5))*np.matrix(E).T)*LFP.shape[1]
# compute the whitened filtered LFP
wfLFP = np.dot(W, fLFP.T).T

#fig, axes = plt.subplots(2, 1)
#ax = axes[0]
fig, ax = plt.subplots(1,1)
draw_lineplot(ax, np.array(fLFP.T), dt=1., T=(30000, 40000), unit=LFP.units)
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')

#ax = axes[1]
fig, ax = plt.subplots(1,1)
draw_lineplot(ax, np.array(wfLFP.T), dt=1., T=(30000, 40000), unit=LFP.units)
axis = ax.axis()
ax.vlines(epo.times*1000, axis[2], axis[3], 'g')
ax.vlines((epo.times+epo.durations)*1000, axis[2], axis[3], 'r')

In [None]:
# plot whitening matrix
fig, ax = plt.subplots(1,1)
im = ax.matshow(W)
ax.axis(ax.axis('equal'))
fig.colorbar(im)

# CSD/LFP/MUA/wLFP of different trials

In [None]:
data = 'LFP' # or 'LFP' or 'MUA' or 'fLFP' or 'wfLFP' or 'wLFP'
if data == 'CSD':
    signal = CSD
    vlimround=2**6
elif data == 'LFP':
    signal = LFP
    vlimround = 2**-2
elif data == 'fLFP':
    signal = fLFP
    vlimround = 2**-2
elif data == 'MUA':
    signal = MUA
    vlimround = 2**-4
elif data == 'wfLFP':
    signal = wfLFP
    vlimround = 2**6
elif data == 'wLFP':
    signal = wLFP
    vlimround = 2**6

prestim_duration = 200*pq.ms
    
if stimulation_type in ['image', 'grating', 'flash']:
    labels_ = np.unique(epo.labels)
    if labels_.size < 10:
        pass
    else:
        labels_ = labels_[:10]
    for label in labels_:
        inds = epo.labels == label
        container = np.zeros((int(((prestim_duration + epo.durations)*signal.sampling_rate).simplified[0]), 
                              signal.shape[1], 
                              inds.sum()))
        for j, (time, duration) in enumerate(zip(epo.times[inds], epo.durations[inds])):
            time_slice = signal.time_slice(time-prestim_duration, time+duration)
            if time_slice.shape != container[:, :, j].shape:
                container[:, :, j] = time_slice[:container.shape[0], :]
            else:
                container[:, :, j] = time_slice
        # center using prestim duration
        container = container - container[:int(prestim_duration)].mean(axis=0)       
        
        # colorbar range
        vmax = container.mean(axis=-1).std()*2.5
        
        # find a suitable number of trials to show:
        ncols = epo.labels.size // np.unique(epo.labels).size
        if ncols > 8:
            ncols = 10
        else:
            ncols += 2
        
        fig, axes = plt.subplots(ncols=ncols, nrows=1, sharey=True)
        fig.suptitle(label)
        for i, ax in enumerate(axes):
            if i < (ncols - 2):
                draw_lineplot(ax, container[:,:,i].T, 
                              dt=1., 
                              T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                              vlimround=vlimround,
                              scalebar=False,
                              unit=str(signal.units),
                             )
                draw_imageplot(ax, container[:,:,i].T, 
                               dt=1., 
                               T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                               unit=str(signal.units),
                               colorbar=False,
                               #ztransform=True,
                               vmax=vmax
                              )
                ax.set_title('# {}'.format(i+1))
            elif i == (ncols - 2): # MEAN
                draw_lineplot(ax, container.mean(axis=-1).T, 
                              dt=1., 
                              T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                              vlimround=vlimround,
                              scalebar=False,
                              unit=str(signal.units),
                              )
                draw_imageplot(ax, container.mean(axis=-1).T, 
                               dt=1., 
                               T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                               unit=str(signal.units),
                               colorbar=False,
                               vmax=vmax
                              )
                ax.set_title('mean')
            elif i == (ncols - 1): # STD
                draw_lineplot(ax, container.std(axis=-1).T, 
                              dt=1., 
                              T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                              vlimround=vlimround,
                              unit=str(signal.units),
                             )
                draw_imageplot(ax, container.std(axis=-1).T, 
                               dt=1., 
                               T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                               unit=str(signal.units),
                               colorbar=True,
                               #ztransform=True,
                               vmax=vmax
                              )
                ax.set_title('std')
            if i != 0:
                plt.setp(ax.get_yticklabels(), visible=False)
            #    ax.set_yticklabels([])
            #    ax.set_ylabel('')
            #    ax.set_xlabel('')

In [None]:
# compare mean response to different stimuli
if stimulation_type in ['image', 'grating', 'flash']:
    labels_ = np.unique(epo.labels)
    if labels_.size < 10:
        pass
    else:
        labels_ = labels_[:10]

    #mean_data = []
    ncols = labels_.size
    fig, axes = plt.subplots(ncols=ncols, nrows=1, sharey=True)
    for i, (ax, label) in enumerate(zip(axes, labels_)):
        inds = epo.labels == label
        container = np.zeros((int(((prestim_duration+epo.durations)*signal.sampling_rate).simplified[0]), 
                              signal.shape[1], 
                              inds.sum()))
        for j, (time, duration) in enumerate(zip(epo.times[inds], epo.durations[inds])):
            time_slice = signal.time_slice(time-prestim_duration, time+duration)
            if time_slice.shape != container[:, :, j].shape:
                container[:, :, j] = time_slice[:container.shape[0], :]
            else:
                container[:, :, j] = time_slice
        # compute mean across trials
        container = container.mean(axis=-1).T
        # center to prestim duration:
        container = (container.T - container[:, :int(prestim_duration)].mean(axis=-1)).T
        draw_lineplot(ax, container, 
                      dt=1., 
                      T=(-float(prestim_duration), float(duration.rescale('ms'))),
                      vlimround=vlimround,
                      scalebar=False if i < (len(labels_)-1) else True,
                      unit=str(signal.units),
                     )
        draw_imageplot(ax, container, 
                       dt=1., 
                       T=(-float(prestim_duration), float(duration.rescale('ms'))), 
                       unit=str(signal.units),
                       colorbar=False if i < (len(labels_)-1) else True,
                       vmax=container.std()*2.5
                      )
        if stimulation_type == 'grating':
            ax.set_title(label.replace(',\\', '$\n$\\').strip(','))
        else:
            ax.set_title(label.split('\\')[-1])

        if i != 0:
            ax.set_yticklabels([])

In [None]:
# for Alex: Store mean response per stimulus
# compare mean response to different stimuli
if stimulation_type in ['image', 'grating', 'flash']:
    labels_ = np.unique(epo.labels)
    mean_trial_data = []
    for i, label in enumerate(labels_):
        inds = epo.labels == label
        container = np.zeros((int((epo.durations*signal.sampling_rate)[0]), signal.shape[1], inds.sum()))
        for j, (time, duration) in enumerate(zip(epo.times[inds], epo.durations[inds])):
            container[:, :, j] = signal.time_slice(time, time+duration)
        mean_trial_data.append(container.mean(axis=-1))
    mean_trial_data = np.array(mean_trial_data) # n_labels * n_timesamples * n_channels
    f = h5py.File('mean_trial_data.h5', 'w')
    f['data'] = mean_trial_data
    f.close()

In [None]:
f.close()

In [None]:
# time, duration, signal.duration, signal, epo.times, epo.labels

# CSD PCA projection

# Single-unit responses to stimuli

In [None]:
for spiketrain in spiketrains:
    title = 'unit {} ({})'.format(spiketrain.name, spiketrain.description)

    labels = np.unique(epo.labels)
    ntrials = int(np.ceil(epo.labels.size / labels.size))
    container = np.zeros((ntrials, labels.size))

    # determine significance level assuming a stationary rate poisson process
    #mean_rate = elephant.statistics.mean_firing_rate(spiketrain)
        
    for j, label in enumerate(labels):
        inds = epo.labels == label
        for i, (time, duration) in enumerate(zip(epo.times[inds], epo.durations[inds])):
            container[i, j] += spiketrain.time_slice(time, time+duration).size
    fig, axes = plt.subplots(3, 1, sharex=True)
    fig.subplots_adjust(bottom=0.3)
    ax = axes[0]
    im = ax.pcolormesh(np.arange(labels.size+1)-0.5, np.arange(ntrials+1)-0.5, container, 
                       vmin=0, vmax=50, cmap='gray_r')
    ax.axis(ax.axis('tight'))
    ax.set_xticks(np.arange(labels.size))
    ax.set_xticklabels([])
    ax.set_ylabel('trial #')
    ax.set_title(title)
    # total spike count
    ax = axes[1]
    ax.bar(np.arange(labels.size), container.sum(axis=0))
    ax.set_ylabel('tot. spike count')
    ax.axis(ax.axis('tight'))
    # rate per stim
    ax = axes[2]
    ax.bar(np.arange(labels.size), (container.sum(axis=0) / epo.durations.sum()*labels.size).simplified)
    mean_rate = (container.sum() / epo.durations.sum()).simplified    
    ax.hlines(mean_rate, -1, len(labels), linestyle=':')
    ax.set_ylabel('rate (1/s)')
    ax.set_xticks(np.arange(labels.size))
    ax.set_xticklabels(labels, rotation=90, fontsize='x-small')
    ax.axis(ax.axis('tight'))

In [None]:
# tuning curves
if stimulation_type == 'grating':
    for spiketrain in spiketrains:
        title = 'unit {} ({})'.format(spiketrain.name, spiketrain.description)
        labels = np.unique(epo.annotations['orientation'])
        ntrials = int(np.ceil(epo.labels.size / labels.size))
        container = np.zeros((ntrials, labels.size))

        for j, label in enumerate(labels):
            inds = epo.annotations['orientation'] == label
            for i, (time, duration) in enumerate(zip(epo.times[inds], epo.durations[inds])):
                container[i, j] += spiketrain.time_slice(time, time+duration).size
        # determine significance level assuming a stationary rate poisson process
        # mean_rate = elephant.statistics.mean_firing_rate(spiketrain)

        fig, axes = plt.subplots(3, 1, sharex=True)
        fig.subplots_adjust(bottom=0.2)
        ax = axes[0]
        im = ax.pcolormesh(np.arange(labels.size+1)-0.5, np.arange(ntrials+1)-0.5, 
                           container, vmin=0, vmax=20, cmap='gray_r')
        ax.axis(ax.axis('tight'))
        ax.set_xticks(np.arange(labels.size))
        ax.set_xticklabels([])
        ax.set_ylabel('trial #')
        ax.set_title(title)
        ax = axes[1]
        ax.bar(np.arange(labels.size), container.sum(axis=0))
        #ax.hlines(mean_rate.magnitude, -1, len(labels), linestyle=':')
        ax.set_xticks(np.arange(labels.size))
        ax.set_xticklabels([str(label).split('\\')[-1] for label in labels], rotation=90, fontsize=12)
        ax.set_ylabel('tot. spike count')
        # rate per stim
        ax = axes[2]
        ax.bar(np.arange(labels.size), (container.sum(axis=0) / epo.durations.sum()*len(labels)).simplified)
        mean_rate = (container.sum() / epo.durations.sum()).simplified  
        ax.hlines(mean_rate.magnitude, -1, len(labels), linestyle=':')
        ax.set_ylabel('rate (1/s)')
        ax.set_xticks(np.arange(labels.size))
        ax.set_xticklabels(labels, rotation=90, fontsize='x-small')
        ax.axis(ax.axis('tight'))
        
        #break

In [None]:
# correlate image sequence with binned spike trains
if stimulation_type in ['sparsenoise']:
    # load h5 file with image data
    im_path = os.path.join(expipe_temp_storage, project_id, 
                          'COBRA_PsychoPy_files', 'files', 'datasets', 
                          'sparse_noise_images', 'imageData.h5')
    f = h5py.File(im_path, 'r')
    im_data = f['data'][:, :epo.size]
    im_dimensions = f['dimensions'].value
    f.close()
    im_shape = im_data.shape
    im_data = im_data.reshape((-1), im_data.shape[-1])
    
    for spiketrain in spiketrains:
        title = 'unit {}'.format(spiketrain.name)
        # count spike events per image
        sptr_binned = np.zeros(im_data.shape[-1], dtype=int)
        for i, (time, duration) in enumerate(zip(epo.times, epo.durations)):
            sptr_binned[i] += spiketrain.time_slice(time, time+duration).size
        corr = np.corrcoef(np.row_stack([sptr_binned, im_data]))[1:, 0].reshape(im_shape[:2])
        plt.figure()
        plt.imshow(corr, cmap='gray', vmin=-corr.std()*2, vmax=corr.std()*2, interpolation='nearest')
        plt.colorbar()
        plt.title(title)

In [None]:
# correlate image sequence with binned spike trains
if stimulation_type in ['image']:
    # load h5 file with image data
    im_path = os.path.join(os.environ['HOME'], 'COBRA', 
                          'COBRA_PsychoPy_files', 'files', 'datasets', 
                          'converted_images', 'imageData.h5')

    f = h5py.File(im_path, 'r')
    im_data = f['data'].value
    im_dimensions = f['dimensions'].value
    f.close()
    im_shape = im_data.shape
    
    # test that downsampling works
    plt.matshow(im_data[:, :, 0], cmap='gray')
    plt.title('original')
    
    
    # downsample image (too high-res for correlation calc)
    im_data = block_reduce(im_data, block_size=(10, 10, 1), func=np.mean)
    plt.matshow(im_data[:, :, 0], cmap='gray')
    plt.title('downsampled')
    im_shape = im_data.shape
    im_data = im_data.reshape((-1), im_data.shape[-1])
    
    # epoch bins
    bins = np.r_[epo.times, [epo.times[-1]+epo.durations[-1]]]
    for spiketrain in spiketrains:
        title = 'unit {} ({})'.format(spiketrain.name, spiketrain.description)
        # count spike events per image showing
        sptr_binned_tmp, _ = np.histogram(spiketrain.times, bins=bins)
        sptr_binned = np.zeros(im_shape[-1], dtype=int)
        # collapse bins with the same image
        for i, label in enumerate(np.unique(epo.labels)):
            inds = epo.labels == label
            sptr_binned[i] = sptr_binned_tmp[inds].sum()
        corr = np.corrcoef(np.row_stack([sptr_binned, im_data]))[1:, 0].reshape(im_shape[:2])
        plt.figure()
        plt.imshow(corr, cmap='gray', vmin=-corr.std()*2, vmax=corr.std()*2, interpolation='nearest')
        plt.colorbar()
        plt.title(title)

In [None]:
# correlate movie data with binned spike trains
if stimulation_type in ['movie']:
    raise NotImplementedError

# PSTH
(peristimulus time histograms)

In [None]:
def plot_psth(st, epoch, fig=None, axes=None, lags=(-0.1*pq.s, 1*pq.s), bin_size=0.01*pq.s, 
              marker='|', color='C0', n_trials=10,  histtype='bar'):
    '''
    Parameters:
    st : neo.SpikeTrain
    epoch : neo.Epoch
    lags : tuple of Quantity scalars
    bin_size : Quantity scalar
    color : mpl color 
    n_trials : int
        number of trials to include in PSTH
    '''
    labels = np.unique(epoch.labels, axis=-1)
    bins = np.linspace(lags[0], lags[1], int((lags[1]-lags[0])//bin_size)+1)
    
    if fig is None:
        fig, axes = plt.subplots(2, len(labels), sharex=True, sharey='row')
        fig.suptitle('unit {} ({})'.format(st.name, st.description))
    for i, label in enumerate(labels):
        axes[0, i].set_xlim(lags)
        axes[1, i].set_xlim(lags)

        sts = []
        for h, epo in enumerate(epoch[epoch.labels == label]):
            if h < n_trials:
                st_ = st.time_slice(t_start=(epo+lags[0]).simplified, 
                                    t_stop=(epo+lags[1]).simplified)
                sts.append((st_.times.simplified - epo.simplified).tolist())
                axes[0, i].plot(sts[h], np.zeros(len(sts[h])) + h, marker, color=color)

        axes[0, i].set_title('{}'.format(label), fontsize='x-small')
        axes[1, i].set_xlabel('lag (s)')
        axes[1, i].hist(flattenlist(sts), bins=bins, color=color, histtype=histtype)
        if i == 0:
            axes[0, i].set_ylabel('trial #')
            axes[1, i].set_ylabel('#')
    return fig, axes

In [None]:
if stimulation_type in ['image', 'flash']:
    for st in spiketrains:
        plot_psth(st, epo, lags=(-0.1*pq.s, epo.durations[0]), n_trials=np.inf)

# USB mouse recordings (trackball data)

In [None]:
fig, ax = plt.subplots(1,1)
for key, value in tracks.items():
    ax.plot(value['times'].value, value['data'].value.cumsum(), label=key)
ax.legend(loc='best')
ax.set_title('USB-mice paths')
ax.set_xlabel('time (s)')
ax.set_ylabel('displacement (a.u.)')