### Algorithm to detect and characterize burst dynamics


In [None]:
# Path to the toolbox
import sys; sys.path.insert(1, '/home/vinicius/storage1/projects/GrayData-Analysis')
import os 

# GDa functions
import GDa.stats.bursting                as     bst
from   GDa.session                       import session
from   GDa.temporal_network              import temporal_network
from   GDa.util                          import smooth

import matplotlib.pyplot                 as     plt
import matplotlib
import GDa.graphics.plot                 as     plot

import numpy                             as     np
import xarray                            as     xr


from   tqdm                              import tqdm
from   sklearn.manifold                  import TSNE
from   scipy                             import stats

In [None]:
SMALL_SIZE, MEDIUM_SIZE, BIGGER_SIZE = plot.set_plot_config()

In [None]:
# Create directory to save figures
if not os.path.exists("img/n5.0.2"):
    os.makedirs("img/n5.0.2")

### Detecting bursts (example)

In [None]:
x         = np.array([0,1,1,0,0,1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,1,1,1])
mask      = {}
mask['1'] = np.array([1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0]).astype(bool)
mask['2'] = np.logical_not(mask['1'])

In [None]:
fig = plt.figure(figsize=(5, 4), dpi=600)

gs1 = fig.add_gridspec(nrows=1, ncols=1, left=0.05, right=0.95, bottom=0.45, top=0.95)
gs2 = fig.add_gridspec(nrows=1, ncols=1, left=0.05, right=0.95, bottom=0.08, top=0.43)

# Panel A
ax1 = plt.subplot(gs1[0])
png = plt.imread("img/n5.0.0/cartton_act.png")
plt.sca(ax1)
im = plt.imshow(png, interpolation='none')
plt.axis('off')    
pad = 10
plt.xlim(-pad, png.shape[1]+pad)
plt.ylim(png.shape[0]+pad, -pad) 

# Panel B
ax2 = plt.subplot(gs2[0])

plt.sca(ax2)
plt.vlines(np.arange(len(x))[x==1], 0, 1,  color='k', label='spike-train')
#plt.plot(mask['1'], lw=1, label='Stage 1', color='blue')
#plt.plot(mask['2'], lw=1, label='Stage 2', color='red')
plt.hlines(1,-0.1,11.5,color="blue")
plt.hlines(1,11.5,22.1,color="red")
plt.vlines(11.5,-0.01,0.99,ls="--", color="gray")
plt.xlim(-0.1,22.1)
plt.ylim(-0.01,1.01)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.spines["left"].set_visible(False)
plt.yticks([])
#plt.legend()

bg = plot.Background(visible=False)
plot.add_panel_letters(fig, axes=[ax1, ax2], fontsize=12,
                       xpos=[0,0], ypos=[0.9, 1.05])
bg.axes.text(0.3, 0.45, "Stage 1", ha='center', fontsize=MEDIUM_SIZE)
bg.axes.text(0.73, 0.45, "Stage 2", ha='center', fontsize=MEDIUM_SIZE)

plt.savefig("img/n5.0.2/example_burst_stats.png")

#### Finding the length of burst durations

In [None]:
print(f'burst lengths = {bst.find_activation_sequences(x, dt=None,)}')

#### Finding the length of burst durations for segments of the spike-train using a mask

In [None]:
for idx, key in enumerate(mask):
    print(f'Mask {idx}, burst lengths = {bst.masked_find_activation_sequences(x, mask[key], dt=None, drop_edges=False)}')

#### Finding the length of burst durations for segments of the spike-train using a mask dropping the bursts in the edge between masks

In [None]:
for idx, key in enumerate(mask):
    print(f'Mask {idx}, burst lengths = {bst.masked_find_activation_sequences(x, mask[key], dt=None, drop_edges=True)}')

In [None]:
import numpy  as np
import numba  as nb
import xarray as xr
from   frites.utils   import parallel_func

@nb.jit(nopython=True)
def _nan_pad(x, new_size):
    return np.concatenate( (x, np.nan*np.ones(new_size-len(x))) )

@nb.jit(nopython=True)
def find_start_end(array, find_zeros=False):
    """
    Given a binary array find the indexes where the sequences of ones start and begin if find_zeros is False. 
    Otherwise it will find the indexes where the sequences of zeros start and begin.
    For instance, for the array [0,1,1,1,0,0], would return 1 and 3 respectively for find_zeros=False, 
    and 1 and 2 for find_zeros=True.

    Parameters
    ----------
    array: array_like 
        Binary array.
    find_zeros: bool | False
        Wheter to find a sequence of zeros or ones

    Returns
    -------
    The matrix containing the start anb ending index 
    for each sequence of consecutive ones or zeros with shapes [n_seqs,2]
    where n_seqs is the number of sequences found.
    """
    if find_zeros: 
        _bounds = np.array([1])
    else:
        _bounds = np.array([0])

    bounded     = np.hstack((_bounds, array, _bounds))
    difs        = np.diff(bounded)
    # get 1 at run starts and -1 at run ends if find_zeros is False
    if not find_zeros:
        run_starts, = np.where(difs > 0)
        run_ends,   = np.where(difs < 0)
    # get -1 at run starts and 1 at run ends if find_zeros is True
    else:
        run_starts, = np.where(difs < 0)
        run_ends,   = np.where(difs > 0)
    return np.vstack((run_starts,run_ends)).T

@nb.jit(nopython=True)
def find_activation_sequences(spike_train, dt=None):
    """
    Given a spike-train, it finds the length of all activations in it.
    For example, for the following spike-train: x = {0111000011000011111},
    the array with the corresponding sequences of activations (ones) will be 
    returned: [3, 2, 5] (times dt if this parameter is provided).

    Parameters
    ----------
    spike_train: array_like
        The binary spike train.
    dt: int | None
        If provided the returned array with the length of activations will be given in seconds.

    Returns
    -------
    act_lengths: array_like
        Array containing the length of activations with shape [n_seqs]
        where n_seqs is the number of sequences found.
    """

    # If no dt is specified it is set to 1
    if dt is None:
        dt = 1
    out         = find_start_end(spike_train)
    act_lengths = (out[:,1]-out[:,0])*dt

    return act_lengths

#@nb.jit(nopython=True)
def masked_find_activation_sequences(spike_train, mask, dt=None, drop_edges=False, pad=False):
    """
    Similar to "find_activation_sequences" but a mask is applied to the spike_train while computing
    the size of the activation sequences.'

    Parameters
    ----------
    spike_train: array_like
        The binary spike train.
    mask: array_like
        Binary mask applied to the spike-train.
    dt: int | None
        If provided the returned array with the length of activations will be given in seconds.
    drop_edges: bool | False
        If True will remove the size of the last burst size in case the spike trains ends at one.

    Returns
    -------
    act_lengths: array_like
        Array containing the length of activations with shape [n_seqs]
        where n_seqs is the number of sequences found.
    """

    # Assure that mask is type bool
    mask = mask.astype(np.bool_)

    # Find the size of the activations lengths for the masked spike_train
    act_lengths = find_activation_sequences(spike_train[mask], dt=dt)
    # If drop_edges is true it will check if activation at the 
    # left and right edges crosses the mask limits.
    if len(act_lengths)>0 and drop_edges:
        idx, = np.where(mask==True)
        i,j  = idx[0], idx[-1]
        # If the mask starts at the beggining of the array
        # there is no possibility to cross from the left side
        if i>=1:
            if spike_train[i-1]==1 and spike_train[i]==1:
                act_lengths = np.delete(act_lengths,0)
        # If the mask ends at the ending of the array
        # there is no possibility to cross from the right side
        if j<len(mask)-1:
            if spike_train[j]==1 and spike_train[j+1]==1:
                act_lengths = np.delete(act_lengths,-1)
                 
    if pad:
        _new_size   = len(spike_train)//2+1
        act_lengths = _nan_pad(act_lengths, _new_size)
                
    return act_lengths

def tensor_find_activation_sequences(spike_train, mask, dt=None, drop_edges=False, pad=False, n_jobs=1):
    """
    A wrapper from "masked_find_activation_sequences" to run for tensor data 
    of shape [links, trials, time].

    Parameters
    ----------
    spike_train: array_like
        The binary spike train.
    mask: array_like
        Binary mask applied to the spike-train with size [trials, time]. For more than one mask
        a dicitionary should be provided where for each key an array with size [trials, time]
        is provided.
    dt: int | None
        If provided the returned array with the length of activations will be given in seconds.
    drop_edges: bool | False
        If True will remove the size of the last burst size in case the spike trains ends at one.
    n_jobs: int | 1
        Number of threads to use

    Returns
    -------
    act_lengths: array_like
        Array containing the length of activations for each link and trial
    """

    # Checking inputs
    assert isinstance(spike_train, (np.ndarray, xr.DataArray))
    assert isinstance(mask, (dict, np.ndarray, xr.DataArray))
    assert spike_train.ndim == 3

    # Number of edges
    n_edges = spike_train.shape[0]
    if pad:
        _new_size   = len(spike_train)//2+1

    # Find the activation sequences for each edge
    #@nb.jit(nopython=True)
    def _edgewise(x, m):
        act_lengths = np.empty((x.shape[0],_new_size))
        # For each trial
        for i in range(x.shape[0]):
            #act_lengths += [act_lengths, masked_find_activation_sequences(x[i,...], m[i,...], drop_edges=drop_edges, dt=dt)]
            #act_lengths = np.concatenate( act_lengths, masked_find_activation_sequences(x[i,...], m[i,...], drop_edges=drop_edges, dt=dt) )
            #act_lengths += [np.apply_along_axis(masked_find_activation_sequences, -1, 
            #                x[i,...], m[i,...], drop_edges=drop_edges, pad=pad,
            #                dt=dt)]
            act_lengths[i] = masked_find_activation_sequences(x[i,...], m[i,...], drop_edges=drop_edges, pad=pad, dt=dt)
            #print(f"{len(masked_find_activation_sequences(x[i,...], m[i,...], drop_edges=drop_edges, pad=pad, dt=dt))}")
        #act_lengths = np.concatenate( act_lengths, axis=0 )
        return act_lengths 

    # Computed in parallel for each edge
    parallel, p_fun = parallel_func(
    _edgewise, n_jobs=n_jobs, verbose=False,
    total=n_edges)

    if isinstance(mask, (np.ndarray, xr.DataArray)):
        assert len(mask.shape) == 2
        act_lengths = parallel(p_fun(spike_train[e,...], mask) for e in range(n_edges))
    elif isinstance(mask, dict):
        # Use the same keys as the mask
        act_lengths = dict.fromkeys(mask.keys())
        for key in mask.keys():
            assert len(mask[key].shape) == 2
            act_lengths[key] = parallel(p_fun(spike_train[e,...], mask[key]) for e in range(n_edges))
    return act_lengths

In [None]:
spike_train = np.random.rand(1176,540,20)>0.5
mask        = np.zeros((540,20))
mask[:,:10] = 1
mask        = mask.astype(bool)

In [None]:
find_start_end(x)

In [None]:
find_activation_sequences(x)

In [None]:
%timeit masked_find_activation_sequences(spike_train[0,0], mask[0], dt=None, drop_edges=False)

In [None]:
tensor_find_activation_sequences(spike_train, mask, dt=None, drop_edges=False, pad=True, n_jobs=1)

In [None]:
def masked_per_trial():
    np.apply_along_axis(masked_find_activation_sequences, -1, 
                        spike_train[0,0], mask[0], drop_edges=False, 
                        dt=None)

In [None]:
mask[0]

In [None]:
spike_train[0,0,:10]

In [None]:
a = np.array([1,2])

In [None]:
a

In [None]:
len(spike_train[0,0])//2+1

In [None]:
@nb.jit(nopython=True)
def nan_pad(x, new_size):
    return np.concatenate( (a, np.nan*np.ones(new_size-len(x))) )

In [None]:
masked_find_activation_sequences(spike_train[0,0], mask[0], dt=None, drop_edges=False, pad=True)

In [None]:
@nb.jit(nopython=True)
def _nan_pad(x, new_size):
    return np.concatenate( (x, np.nan*np.ones(new_size-len(x))) )


In [None]:
_nan_pad(spike_train, 25)

In [None]:
x

In [None]:
s=[]
for i in range(540):
    s.append( len(masked_find_activation_sequences(spike_train[0,i], mask[0], drop_edges=False, pad=True,
                        dt=None)))

In [None]:
@nb.jit(nopython=True)
def _edgewise(x, m):
    act_lengths = np.zeros((x.shape[0],11), dtype=np.int8)
    # For each trial
    for i in range(x.shape[0]):
        act_lengths[i] = masked_find_activation_sequences(x[i,...], m[i,...], drop_edges=True, pad=True, dt=1)
    return act_lengths 

In [None]:
_edgewise(spike_train[0],mask)